0

我试图在我部署在 Google Cloud AI Platform 上的模型上的云函数中调用 predict 方法,我收到这个错误,告诉我HistGradientBoostingClassifier没有属性n_features_

我看到,当HistGradientBoostingClassifier在 gradient_boosting.py 的第 1100 行调用predict(self, X)它然后调用的方法self.predict_proba(X)(第 1114 行),该方法调用_raw_predict(第 1130 行)并在此函数内self.n_features_被访问(第 646 行)。_n_features在训练数据集 X 的拟合方法中的第 143 行BaseHistGradientBoosting分配HistGradientBoostingClassifier.

我通过 python API 调用 predict 方法:service.projects().predict(name=name, body={'instances': instances}).execute()

HistGradientBoostingClassifier 不会有这个属性,或者我的训练集有问题,以至于当我第一次创建模型和版本时它没有被传递到 fit 方法中: ml.projects().models().create(parent=project_id_model, body=model_request_dict).execute()ml.projects().models().versions().create(parent=project_id_version, body=version_request_dict).execute()?

这是云函数日志中的完整错误:

Traceback (most recent call last): 
File "/env/local/lib/python3.7/site-packages/google/cloud/functions/worker_v2.py", line 449, in run_background_function _function_handler.invoke_user_function(event_object) 
File "/env/local/lib/python3.7/site-packages/google/cloud/functions/worker_v2.py", line 268, in invoke_user_function return call_user_function(request_or_event) 
File "/env/local/lib/python3.7/site-packages/google/cloud/functions/worker_v2.py", line 265, in call_user_function event_context.Context(**request_or_event.context)) 
File "/user_code/main.py", line 41, in request_prediction raise RuntimeError(response['error']) 
RuntimeError: Prediction failed: Exception during sklearn prediction: 'HistGradientBoostingClassifier' object has no attribute 'n_features_' 

发送请求的完整云功能:

from google.cloud import storage
from google.cloud import firestore
import os
import pandas
import numpy as np
import math
import json

def request_prediction(event, context):
    # print("its working")
    PROJECT_ID = 'algae-model'
    MODEL_NAME = 'AlgaePredictor'
    VERSION_NAME = 'v1'

    ## setup file to store prediction data in cloud storage
    storage_client = storage.Client()
    bucket = storage_client.get_bucket('algae-mod-bucket1')
    resultBlob = bucket.blob('prediction-data')

    resultBlob.upload_from_string(str(event) + "\n\n\n" + str(context))


    ## hardcoded data and data from excel sheet
    # instances = [[22.6, 277.7, 8.49, 4.04, 18.06, 19.06, 5.77, 0.02, 5.45, 0.83], [math.nan, math.nan, math.nan, 3.07, 4.25, 10.72, 3.59, 0.27, 6.23, 0.84], [23.2, 284.6, 8.05, 3.12, 10.41, 16.07, 47.27, 0.52, 4.1, 1.51], [25.4, 306.9, 7.3, 3.47, 0.81, 8.49, 8.47, 0.0, 7.27, 0.68], [17.9, 440.5, 8.61, 1.93, 0.018, 2.39, 12.21, 1.21, 2.5, 3.79], [18.57, 397.99, 8.75, 1.36, 1.25, 6.2, 80.36, 1.206, 5.0, 3.11], [25.2, 413.6, 7.86, 4.5, 3.2, 6.5, 42.79, 1.21, 2.14, math.nan], [24.5, 340.3, 8.97, 9.02, 5.64, 23.4, 29.13, 0.18, 6.95, 1.0], [13.2, 344.4, 8.47, 9.77, 15.69, 18.76, 0.4, 0.002, 11.87, 1.0], [22.9, 362.8, 8.62, 4.07, 17.03, 15.8, 0.4, 0.76, 6.67, 1.39], [21.4, 439.1, 7.57, 3.69, 5.35, 15.08, 3.39, 1.05, 6.28, math.nan], [18.2, 333.8, 8.24, 14.1, 11.52, 25.37, 6.46, 0.05, 19.7, 1.16], [23.8, 328.0, 6.88, 0.83, 0.57, 4.5, 12.44, 0.57, 1.46, 0.85], [18.7, 228.0, 5.39, 9.85, 11.31, 23.03, 4.82, 0.01, 11.85, 1.11], [18.6, 245.9, 9.05, 3.34, 1.24, 7.67, 20.94, 0.01, 8.1, 0.53], [17.8, 249.0, 8.45, 9.39, 3.26, 13.13, 3.65, 0.02, 15.44, math.nan], [24.7, 301.5, 6.23, 6.03, 10.38, 17.16, 24.22, 0.11, 9.17, 0.92], [18.6, 343.8, 9.02, 1.04, 0.83, 3.86, 2.83, 0.53, 2.28, 0.91], [19.9, 323.3, 8.26, 4.72, 0.95, 8.22, 10.05, 0.01, 7.05, 1.13], [21.9, 123.4, 6.39, 6.9, 19.67, 20.69, 10.08, 0.03, 9.4, 0.6], [22.6, math.nan, math.nan, 4.69, 11.82, 11.2, 1.99, 0.0, 6.25, 0.69], [21.1, math.nan, math.nan, math.nan, 31.35, 18.41, 1.88, 0.0, 12.63, 1.26], [21.3, 352.1, 9.45, 12.2, 88.26, 3.04, 3.08, 0.45, 17.27, math.nan], [22.8, 251.3, 8.53, 2.66, 22.7, 9.02, 0.4, 0.002, 3.23, 0.47], [26.1, 373.5, 7.88, 7.4, 4.39, 13.76, 4.1, 1.09, 4.63, math.nan], [21.6, math.nan, math.nan, math.nan, 3.4, 6.25, 1.56, 0.0, 6.63, 0.54], [19.7, 290.3, 5.35, 7.87, 0.11, 1.7, 84.95, 2.08, 6.3, 8.38], [22.7, 380.9, 6.93, 7.24, 34.6, 34.52, 4.05, 0.74, 11.1, 1.59], [17.7, 327.7, 8.03, 2.83, 0.018, 2.72, 52.68, 0.54, 3.24, 1.52], [25.5, 345.6, 7.51, 5.42, 2.7, 17.44, 10.42, 0.01, 7.3, 0.97], [18.63, 323.14, 8.41, 0.6, 0.18, 1.41, 35.35, 0.453, 1.96, 1.04], [18.7, 280.9, 9.57, 1.02, 0.09, 0.97, 2.36, 0.26, 1.46, 0.46], [21.0, math.nan, math.nan, math.nan, 21.5, 21.79, 2.26, 0.01, 12.17, 1.08], [24.2, 322.9, 8.86, 6.04, 2.33, 17.75, 221.26, 0.29, 7.8, 0.82], [23.38, 358.44, 7.22, 2.28, 15.35, 16.51, 3.22, 0.477, 6.52, 1.61], [25.5, 308.3, 7.97, 3.35, 0.77, 8.87, 3.64, 0.01, 6.83, 0.8], [20.3, 430.2, 6.76, 2.48, 4.33, 6.22, 28.41, 1.23, 2.74, math.nan], [15.6, 331.0, 7.63, 0.85, 0.018, 1.11, 21.22, 0.65, 0.86, 1.4], [21.7, 254.4, 5.95, 5.34, 8.84, 15.74, 124.61, 0.01, 6.55, 0.63], [21.0, 304.4, 8.37, 6.26, 3.69, 12.34, 2.31, 0.01, 7.8, 1.18], [22.1, 285.6, 7.02, 7.6, 31.63, 25.19, 3.43, 0.002, 10.18, 0.61], [20.9, 451.0, 8.87, 1.42, 1.63, 9.93, 4.85, 0.67, 3.87, 1.87], [21.4, 351.3, 7.58, 3.41, 12.22, 12.16, 7.57, 0.49, 5.56, math.nan], [24.53, 433.68, 7.51, 2.48, 16.43, 20.15, 4.1, 2.173, 6.53, 3.47], [math.nan, math.nan, math.nan, 5.83, 12.79, 16.98, 5.03, 0.02, 6.8, 0.5], [25.4, 339.9, 5.71, 1.69, 2.51, 4.6, 5.89, 0.78, 1.72, math.nan], [21.8, 195.5, 7.59, 1.0, 0.71, 4.33, 10.01, 0.38, 1.67, 0.78], [21.7, 261.9, 6.87, 6.23, 2.87, 12.9, 19.62, 0.0, 8.8, 0.57], [20.9, 118.1, 7.93, 5.04, 0.13, 9.99, 6.32, 0.71, 7.38, 1.65], [24.4, 287.2, 7.92, math.nan, 2.63, 11.28, 0.93, 0.39, 11.93, math.nan], [9.7, 276.9, 9.48, 1.96, 0.18, 1.56, 2.0, 0.14, 2.82, 0.45], [19.5, 301.4, 8.13, 2.73, 1.48, 6.18, 15.99, 0.52, 3.58, 0.82], [24.4, 354.7, 6.12, 9.21, 2.19, 17.15, math.nan, math.nan, 12.67, math.nan], [22.3, 277.9, 8.37, 9.6, 39.59, 26.79, 0.4, 0.01, 11.97, 1.03], [23.8, 305.4, 7.03, 6.7, 7.26, 13.56, 3.45, 0.01, 10.0, 0.86], [15.14, 236.7, 7.86, math.nan, 11.97, 14.24, math.nan, math.nan, math.nan, math.nan], [26.4, 340.9, 7.99, 6.87, 11.66, 18.01, 7.56, 0.02, 10.93, 1.03], [20.6, 332.6, 8.67, 12.7, 79.92, 44.14, 2.65, 0.28, 20.6, math.nan], [19.8, 409.8, 9.84, 0.62, 0.92, 11.7, 7.08, 0.52, 1.9, 0.8], [math.nan, 382.3, 8.47, 0.62, 1.29, 3.58, 0.4, 0.16, 1.52, 0.48], [25.1, 351.3, 7.07, 7.49, 26.42, math.nan, 3.86, 0.21, 13.47, 1.6]]
    # instances = pandas.read_excel(excel_file).drop('Date_Time', axis=1).values.tolist()

    ## parse data from firestore
    vals = [list(val.values()) for val in event['value']['fields'].values()]
    instances = list([list(np.array(np.concatenate(vals, axis=0)))])


    ## call predict method on model
    service = googleapiclient.discovery.build('ml', 'v1')
    name = 'projects/{}/models/{}/versions/{}'.format(PROJECT_ID, MODEL_NAME, VERSION_NAME)
    response = service.projects().predict(name=name, body={'instances': instances}).execute()


    ## handle the response
    if 'error' in response:
        raise RuntimeError(response['error'])
    else:
        features = ['microcystin', 'feature2', 'anotherFeature']
        values = [100, 2.3, 312]

        predictions = {}
        if (len(features) == len(values)):
            for feature in features:
                for value in values:
                    predictions[feature] = value
                    values.remove(value)
                    break
            
        db = firestore.Client()
        doc_ref = db.collection(u'users').document(u'user-id/alert_prediction/create_time')
        doc_ref.delete()
        doc_ref.set({
            u'predictions': predictions
        })

创建/部署模型/版本的云功能:

import requests
from googleapiclient.discovery import build
from google.cloud import storage
import pandas as pd
import saginawBayFileSetup
import saginawBayModeling
from sklearn.ensemble import HistGradientBoostingClassifier


def init_model(event, context):
    # Setup access to Google Cloud services
    service_account_file = 'service-account.json'
    SCOPES = ['https://www.googleapis.com/auth/cloud-platform']
    BUCKET_NAME = 'algae-mod-bucket1'
    credentials = service_account.Credentials.from_service_account_file(service_account_file, scopes=SCOPES)
    ml = build('ml', 'v1', credentials=credentials)

    # get training set
    dataCleaningValues = ['',' ','nd','.','n/a']   # List of erroneous values in columns, used in import of file for na_values = dataCleaningValues.
    df = pd.read_excel('gs://algae-mod-bucket1/source-data/source-data-saginaw-bay-habs-2012-2019MergedFinal.xlsx', parse_dates=[['Date', 'Time']], sheet_name='saginaw_bay_habs_2012_2019', na_values=dataCleaningValues)

    print(df)
    # ignore
    # storage_client = storage.Client()
    # blob = storage.bucket(BUCKET_NAME).get_blob(data_file_location)
    # data = blob.download_as_string()
    # f = io.StringIO(str(data))

    # run setup with custom alert levels
    fields = event['value']['fields']
    keys = list(fields.keys())
    alertLevels = {}
    for key in keys:
        if (str(key).startswith("alert_level")):
            type = list(fields[key])[0]
            value = fields[key][type]
            alertLevels[key] = float(value)

    # print(alertLevels)

    cleaned_data = saginawBayFileSetup.run_setup(df, alertLevels["alert_level_1"], alertLevels["alert_level_2"], alertLevels["alert_level_3"])
    saginawBayModeling.gen_model(cleaned_data)

    # model = HistGradientBoostingClassifier()

    




    PROJECT_NAME = "algae-model"
    MODEL_NAME = "AlgaePredictor"
    VERSION_NAME = "v1"

    # Delete version and model
    project_id_model = 'projects/{}'.format(PROJECT_NAME)
    project_id_version = 'projects/{}/models/{}'.format(PROJECT_NAME, MODEL_NAME)
    # delete_version_body = {
    #     "name": project_id_version
    # }
    # delete_model_body = {
    #     "name": project_id_model
    # }
    # ml.projects.models().versions().delete(parent=project_id_version, body=delete_version_body).execute()
    # ml.projects().models().delete(parent=project_id_model, body=delete_model_body).execute()


    ## Create version and model
    model_request_dict = {
        "name": MODEL_NAME,
        "description": "Algae prediction model",
        "regions": [
            "us-central1"
        ],
        "onlinePredictionLogging": True,
        "onlinePredictionConsoleLogging": True
    }
    version_request_dict = {
        "name": VERSION_NAME,
        "deploymentUri": "gs://algae-mod-bucket1/model",
        "runtimeVersion": "2.2",
        "pythonVersion": "3.7"
    }
    response = ml.projects().models().create(parent=project_id_model, body=model_request_dict).execute()
    print(response)
    response = ml.projects().models().versions().create(parent=project_id_version, body=version_request_dict).execute()
    print(response)
4

0 回答 0