0

我正在测试一个最简单的示例来学习使用自定义 IEnumerable 数据视图而不是从文本文件中加载传统数据来进行情绪分析。我创建了一个 TestData 和 TrainingData 列表,其中包含一些示例回顾,以便通过遵循 github 和文档上提供的示例轻松学习。但是缺少一些东西,我创建的模型无法正常工作......它只是给出了错误的结果,因为一切都是积极的。

主要的

private static string ModelPath = @"C:\ML\SentimentModel.zip";

void Main()
{
    var mlContext = new MLContext(seed: 1);
    var trainingData = GetTrainingData();
    var testData = GetTestData();

    BuildTrainEvaluateAndSaveModel(mlContext, trainingData, testData);
    TestPrediction(mlContext);
}

测试和培训

private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext, List<SentimentData> trainingData, List<SentimentData> testData)
{
    // STEP 1: Common data loading configuration
    IDataView trainingDataView = mlContext.Data.ReadFromEnumerable(trainingData);
    IDataView testDataView = mlContext.Data.ReadFromEnumerable(trainingData);

    // STEP 2: Common data process configuration with pipeline data transformations          
    var dataProcessPipeline = mlContext.Transforms.Text.FeaturizeText(outputColumnName: DefaultColumnNames.Features, inputColumnName: nameof(SentimentData.Text));

    // STEP 3: Set the training algorithm, then create and config the modelBuilder                            
    var trainer = mlContext.BinaryClassification.Trainers.FastTree(labelColumn: DefaultColumnNames.Label, featureColumn: DefaultColumnNames.Features);
    var trainingPipeline = dataProcessPipeline.Append(trainer);

    // STEP 4: Train the model fitting to the DataSet
    Console.WriteLine("=============== Training the model ===============");
    ITransformer trainedModel = trainingPipeline.Fit(trainingDataView);

    // STEP 6: Save/persist the trained model to a .ZIP file

    using (var fs = new FileStream(ModelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
        mlContext.Model.Save(trainedModel, fs);

    Console.WriteLine("The model is saved to {0}", ModelPath);

    return trainedModel;
}

private void TestPrediction(MLContext mlContext)
{
    var testData = GetTestData();
    ITransformer trainedModel;
    using (var stream = new FileStream(ModelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
    {
        trainedModel = mlContext.Model.Load(stream);
    }
    var engine = trainedModel.CreatePredictionEngine<SentimentData, SentimentPrediction>(mlContext);
    foreach(var test in testData)
    {
        var result = engine.Predict(test);
        Console.WriteLine($"Prediction : {(Convert.ToBoolean(result.Prediction) ? "Negative" : "Postive")} | Actual: {test.Expected} | Text : {test.Text}");
    }
}

模型和训练/测试数据

public List<SentimentData> GetTrainingData()
{
    return new List<SentimentData>
            {
                new SentimentData
                {
                    Label = true,
                    Text = "Good service."
                },
                new SentimentData
                {
                    Label = true,
                    Text = "Very good service"
                },
                new SentimentData
                {
                    Label = true,
                    Text = "Amazing service"
                },
                new SentimentData
                {
                    Label = true,
                    Text = "Great staff, will visit again. thanks for the gift"
                },
                new SentimentData
                {
                    Label = false,
                    Text = "Bad staff, bad service. Will never visit this hotel"
                },
                new SentimentData
                {
                    Label = false,
                    Text = "The service was very bad"
                },
                new SentimentData
                {
                    Label = false,
                    Text = "Hotel location is worst"
                }
            };
}

public List<SentimentData> GetTestData()
{
    return new List<SentimentData>
            {
                new SentimentData
                {
                    Label = true,
                    Text = "Worst hotel in New York",
                    Expected = "Negative"
                },
                new SentimentData
                {
                    Label = true,
                    Text = "I ordered pizza and recieved Wine. Bad staff",
                    Expected = "Negative"
                },
                new SentimentData
                {
                    Label = true,
                    Text = "The hotel was so amazing, and they givena bag to me on gift",
                    Expected = "Positive"
                },
                new SentimentData
                {
                    Label = true,
                    Text = "The hotel staff was great, will visit again",
                    Expected = "Positive"
                }
            };
}

public class SentimentData
{
    public bool Label { get; set; }
    public string Text { get; set; }

    // Additional property for testing purpose
    public string Expected {get; set;}
}

public class SentimentPrediction
{
    [ColumnName("PredictedLabel")]
    public bool Prediction { get; set; }
    public float Probability { get; set; }
    public float Score { get; set; }
}

在此处输入图像描述

4

0 回答 0