感谢 cthrash 提供的扩展帮助并在聊天中与我交谈。使用他的帖子以及一些故障排除,我已经找到了适合我的方法。代码非常笨重,但它只是为了测试并确保我能够做到这一点。要回答这个问题:
Nuget 包和类
使用 cthrash 的帖子,我能够安装训练和预测 nuget 包,它们是此特定应用程序的正确包。我需要以下课程:
Microsoft.Azure.CognitiveServices.Vision.CustomVision.Prediction
Microsoft.Azure.CognitiveServices.Vision.CustomVision.Prediction.Models
Microsoft.Azure.CognitiveServices.Vision.CustomVision.Training
Microsoft.Azure.CognitiveServices.Vision.CustomVision.Training.Models
端点根
按照此处的一些步骤,我确定端点 URL 只需作为根,而不是自定义视觉门户中提供的完整 URL。例如,
https://southcentralus.api.cognitive.microsoft.com/customvision/v2.0/Prediction/
改为
https://southcentralus.api.cognitive.microsoft.com
我使用了自定义视觉门户中的密钥和端点,并进行了更改,我能够同时使用训练和预测客户端来拉动项目和迭代。
获取项目 ID
CustomVisionPredictionClient.PredictImageAsync
如果在门户中未设置默认迭代,您需要一个Guid
项目 id 和一个迭代 id,以便使用。
我测试了两种获取项目ID的方法,
使用门户中的项目 ID 字符串
- 从项目设置下的门户中获取项目 ID 字符串。
- 第一个参数
PredictImageAsync
通过
Guid.Parse(projectId)
使用培训客户端
- 创建一个新的
CustomVisionTrainingClient
获取使用<Project>
清单
TrainingClient.GetProjects().ToList()
就我而言,我只有一个项目,所以我只需要第一个元素。
Guid projectId = projects[0].Id
获取迭代 ID
要获取项目的迭代 id,您需要CustomVisionTrainingClient
.
- 创建客户端
- 获取使用
<Iteration>
清单
client.GetIterations(projectId).ToList()
- 就我而言,我只有一次迭代,所以我只需要第一个元素。
Guid iterationId = iterations[0].Id
我现在可以使用我的模型对图像进行分类。在下面的代码中,fileStream 是传递给模型的图像流。
public async Task<string> Predict(Stream fileStream)
{
string projectId = "";
//string trainingEndpoint = "https://southcentralus.api.cognitive.microsoft.com/customvision/v2.2/Training/";
string trainingEndpoint = "https://southcentralus.api.cognitive.microsoft.com/";
string trainingKey = "";
//string predictionEndpoint = "https://southcentralus.api.cognitive.microsoft.com/customvision/v2.0/Prediction/";
string predictionEndpoint = "https://southcentralus.api.cognitive.microsoft.com";
string predictionKey = "";
CustomVisionTrainingClient trainingClient = new CustomVisionTrainingClient
{
ApiKey = trainingKey,
Endpoint = trainingEndpoint
};
List<Project> projects = new List<Project>();
try
{
projects = trainingClient.GetProjects().ToList();
}
catch(Exception ex)
{
Debug.WriteLine("Unable to get projects:\n\n" + ex.Message);
return "Unable to obtain projects.";
}
Guid ProjectId = Guid.Empty;
if(projects.Count > 0)
{
ProjectId = projects[0].Id;
}
if (ProjectId == Guid.Empty)
{
Debug.WriteLine("Unable to obtain project ID");
return "Unable to obtain project id.";
}
List<Iteration> iterations = new List<Iteration>();
try
{
iterations = trainingClient.GetIterations(ProjectId).ToList();
}
catch(Exception ex)
{
Debug.WriteLine("Unable to obtain iterations.");
return "Unable to obtain iterations.";
}
foreach(Iteration itr in iterations)
{
Debug.WriteLine(itr.Name + "\t" + itr.Id + "\n");
}
Guid iteration = Guid.Empty;
if(iterations.Count > 0)
{
iteration = iterations[0].Id;
}
if(iteration == Guid.Empty)
{
Debug.WriteLine("Unable to obtain project iteration.");
return "Unable to obtain project iteration";
}
CustomVisionPredictionClient predictionClient = new CustomVisionPredictionClient
{
ApiKey = predictionKey,
Endpoint = predictionEndpoint
};
var result = await predictionClient.PredictImageAsync(Guid.Parse(projectId), fileStream, iteration);
string resultStr = string.Empty;
foreach(PredictionModel pred in result.Predictions)
{
if(pred.Probability >= 0.85)
resultStr += pred.TagName + " ";
}
return resultStr;
}