0
import pandas as pd
import matplotlib.pyplot as plt
 
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, expr
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml import Pipeline
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
 
spark = SparkSession.builder.getOrCreate()

我使用以下代码设置管道

pre_pipe = Pipeline(stages=[feature_indexer, assembler]).fit(stroke_df)
train = pre_pipe.transform(stroke_df)
train.persist()

train.select('features', 'stroke').show(10, truncate = False)

我现在有一个新数据集

new_data = spark.createDataFrame(
    data = [['Female', 42.0, 1, 0,'No','Private', 'Urban', 182.1, 26.8, 'smokes'],
            ['Female', 64.0, 1, 1, 'Yes', 'Self-employed', 'Rural', 171.5, 32.5, 'formerly smoked'], 
            ['Male', 37.0, 0, 0, 'Yes', 'Private', 'Rural', 79.2, 18.4, 'Unknown'],
            ['Male', 72.0, 0, 1, 'No', 'Govt_job', 'Urban', 125.7, 19.4, 'never smoked']],
    schema = (
    'gender STRING, age DOUBLE, hypertension INTEGER, heart_disease INTEGER, ever_married STRING, work_type STRING, '  
      'residence_type STRING, avg_glucose_level DOUBLE, bmi DOUBLE, smoking_status STRING'
)            
)
new_data.show()

new_pred = pre_pipe.transform(new_data)
new_pred.select("probability", "prediction")

但是当我试图通过将我之前创建的管道应用到新数据集来获得预测和概率时,我收到以下错误

AnalysisException:无法解析' probability'给定输入列:[年龄,avg_glucose_level,bmi,ever_married,ever_married_ix,特征,性别,gender_ix,heart_disease,heart_disease_ix,高血压,高血压_ix,residence_type,residence_type_ix,吸烟状态,吸烟状态_ix,工作类型,工作类型_ix];;

我无法弄清楚我做错了什么,请帮忙

4

0 回答 0