こんにちは。 コンサルティング & テクノロジー部の吾郷 です。 前回の記事では、//Build 2017 でリリースされました、Microsoft Azure の 機械学習系サービス の Cognitive Service に新しく増えた Custom Vision Service を紹介しました。 これを利用した Bot を作ってみようと思います。 (前回の記事 を先に見て頂けると、読みやすいと思います。) 完成のイメージとしては、みんなで作成する高山植物図鑑 Bot みたいになるはずです。
開発環境
開発環境は、以下のとおりです。 ・Visual Studio 2017 Version 15.2 ・Bot Framework Bot Framework を利用した開発環境の作成は、以下のドキュメントを参考にしました。https://docs.microsoft.com/en-us/bot-framework/dotnet/bot-builder-dotnet-quickstart また、松永さんの記事「Bot × Recommendations APIを試してみたい。~Bot Framework編~ 」も参考にしました。
完成イメージ
完成イメージは以下のようなシナリオです。 1.高山植物の画像を要求 2.すでに学習済モデルで識別し、一番スコアが高い植物名(タグ名)を返事する。 3.あっているか確認し、あっていれば終わり。違っていれば、学習用データに使っていいか確認。 4.学習用データとして使ってもいい場合は、この植物名を聞く。 5.教えてもらった植物名を画像にタグ付けし、Trainを実行。
ちなみに、Custom Vision Service の Train の実行条件に、タグ毎に5枚以上の学習データが必要ですが、 今回は、この考慮はしてないです。 ちゃんと実装する場合は、5枚以下のタグがある場合は、一度どこかのデータストアに一時保存しておく必要があります。
実装の説明
今回は、Custom Vision Service を利用するところを中心に作り方を紹介し、Bot Framework 周りのことは割愛します。 なお、プログラムは「さいごに」に掲載しています。
プロジェクトの作成
Visual Studio を起動し、新しいプロジェクト → Bot Application を開きます。 起動したら、Nuget パッケージマネージャで、Custom Vision Service のライブラリをインストールします。
ユーザが見せてくれた画像の識別
MessageReceivedAsyncで実装しています。 処理の概要は、こんな感じです。
画像識別だけなら、Key は Prediction Key を使います。 Training Key を使って SDK から Prediction Key を取得することもできるみたいですが、 今回は何のキーを使うのか明示したかったので、この構成としています。
画像の識別は以下のコード1行です。 var cvResult = await predictionEp.PredictImageAsync(projctId, photoStream); この後に、一番スコアが高いものを、Bot に回答させてます。 また、スコアが0.8以上なら自信満々に、0.8未満なら自信なさげにと、ちょっと人格っぽいのをつけてみました。 この辺の表現がうまくできると、人間味のある Bot になるんでしょうね。
外れていた場合に学習するために、識別用に登録した画像のIDを覚えておきます。 context.ConversationData.SetValue<Guid>("PredictImageId", cvResult.Id);
あってるかユーザに確認する。 PromptDialog.Confirm(context, AfterMessageAsync, "あってますか?", "???", promptStyle: PromptStyle.Auto); 今回は Yes or No の2択で確認します。2択以外が応答されると、「???」となっちゃいます。
画像の学習
LearningMessageReceivedAsyncで実装しています。 処理の概要は、こんな感じです。
画像の学習に関連する処理を行う場合は、Training Key を使います。
まずはタグの有無をチェックしてタグがなければ登録します。
識別したときに登録された画像IDを取得します。
画像IDとタグの紐づけをします。 紐づけする情報は、ImageIdCreateBatch型に入れてあげる必要があります。
Train を実行します。 本当はこれを実行する前に、すべてのタグが5枚以上になっているかチェックが必要です。 Train が成功したら、次の画像識別をされた場合に学習後のモデルを利用するために IsDefault を True にしておきます。
さいごに
今回利用した SDK ですが、ドキュメントがないんですよね。。。。 でも、メソッド名からだいたい処理内容は予測つけられますので、是非、みなさんも Custom Vision Service を使った Bot を作ってみてください。 最後に、RootDialog.cs の全体のソースを掲載します。 冗長なところはあしからず。
using System;
using System.Threading.Tasks;
using Microsoft.Bot.Builder.Dialogs;
using Microsoft.Bot.Connector;
using Microsoft.Cognitive.CustomVision.Models;
using Microsoft.Cognitive.CustomVision;
using System.Net.Http;
using System.Threading;
namespace botcvs.Dialogs
{
[Serializable]
public class RootDialog : IDialog
{
public Task StartAsync(IDialogContext context)
{
//context.Wait(MessageReceivedAsync);
context.Wait(MessageHelloAsync);
return Task.CompletedTask;
}
private async Task MessageHelloAsync(IDialogContext context, IAwaitable result)
{
var activity = await result as Activity;
if (activity.Type == "message")
{
if (activity.Text.Length > 0)
{
await context.PostAsync($"高山植物の画像を見せてください。");
context.Wait(MessageReceivedAsync);
}
}
}
private async Task MessageReceivedAsync(IDialogContext context, IAwaitable result)
{
var activity = await result as Activity;
if (activity.Type == "message")
{
// 画像をつけてもらえたか確認
if (activity.Attachments?.Count != 0)
{
// Custom Vision API を使う準備
var predictionCred = new PredictionEndpointCredentials("[PREDICTON KEY]");
var predictionEp = new PredictionEndpoint(predictionCred);
var projctId = new Guid("7c0feea9-f34f-40cd-a99f-4f282c945011");
// 送られてきた画像を Stream として取得
var photoUrl = activity.Attachments[0].ContentUrl;
var client = new HttpClient();
var photoStream = await client.GetStreamAsync(photoUrl);
// 画像を判定
var cvResult = await predictionEp.PredictImageAsync(projctId, photoStream);
// 一番高いスコアを取得
var tag = "";
var score = 0.00;
foreach (var item in cvResult.Predictions)
{
if (score < item.Probability)
{
score = item.Probability;
tag = item.Tag;
}
}
var msg = "";
if (score >= 0.8)
{
// scoreが0.8 を超えてたら自信満々に回答
msg = tag + "です!";
}
else
{
// scoreが0.8 未満は半疑問形
msg = tag + "かな?";
}
await context.PostAsync(msg);
// 後で学習するときに向けてIDを抑えておく
context.ConversationData.SetValue("PredictImageId", cvResult.Id);
// 確認のメッセージ
PromptDialog.Confirm(
context,
AfterMessageAsync,
"あってますか?",
"???",
promptStyle: PromptStyle.Auto);
}
else
{
// 画像を再度要求
await context.PostAsync($"高山植物の画像を見せてください。");
context.Wait(MessageReceivedAsync);
}
}
}
public async Task AfterMessageAsync(IDialogContext context, IAwaitable argument)
{
var confirm = await argument;
if (confirm)
{
await context.PostAsync("よかったです。また画像を見せてくださいね。");
context.Wait(MessageReceivedAsync);
}
else
{
await context.PostAsync("もう少し勉強します。");
PromptDialog.Confirm(
context,
AfterLearningAsync,
"この植物の名前を教えて頂けますか?",
"この植物の名前を教えて頂けますか?",
promptStyle: PromptStyle.Auto);
}
}
public async Task AfterLearningAsync(IDialogContext context, IAwaitable argument)
{
var confirm = await argument;
if (confirm)
{
await context.PostAsync("ありがとうございます。");
await context.PostAsync("この植物の名前は何というのですか?");
// 学習用のダイアログに遷移
context.Wait(LearningMessageReceivedAsync);
}
else
{
await context.PostAsync("そうですか。。。残念です。");
await context.PostAsync("また、素敵な植物の画像を見せてくださいね。");
context.Wait(MessageReceivedAsync);
}
}
private async Task LearningMessageReceivedAsync(IDialogContext context, IAwaitable result)
{
var activity = await result as Activity;
if (activity.Text.Length > 0)
{
// Custom Vision API を使う準備
var trainingCredentials = new TrainingApiCredentials("[TRAIN KEY]");
var trainingApi = new TrainingApi(trainingCredentials);
var projctId = new Guid("7c0feea9-f34f-40cd-a99f-4f282c945011");
Guid newTagGuid = Guid.Empty;
try
{
// すでにあるタグかチェック
var taglist = trainingApi.GetTags(projctId);
foreach (var tagname in taglist.Tags)
{
if (tagname.Name == activity.Text)
{
newTagGuid = tagname.Id;
break;
}
}
// タグIDが設定されていない場合は、タグを作成する。
if (newTagGuid == Guid.Empty)
{
var newTag = trainingApi.CreateTag(projctId, activity.Text);
newTagGuid = newTag.Id;
}
// 学習済の画像のIDを取得する。
var predictionId = context.ConversationData.GetValue("PredictImageId");
// 画像とタグを組み合わせて登録する
var imageIdTags = new ImageIdCreateBatch();
imageIdTags.TagIds = new[] { newTagGuid };
imageIdTags.Ids = new[] { predictionId };
var retCreatetagu = trainingApi.CreateImagesFromPredictions(projctId, imageIdTags);
if (retCreatetagu.IsBatchSuccessful == true)
{
// Trainを実行する。(タグの画像が5枚以下の場合はエラーとなるけど、ひとまず無視)
var iteration = trainingApi.TrainProject(projctId);
while (iteration.Status == "Training")
{
Thread.Sleep(1000);
// Re-query the iteration to get it's updated status
iteration = trainingApi.GetIteration(projctId, iteration.Id);
}
iteration.IsDefault = true;
trainingApi.UpdateIteration(projctId, iteration.Id, iteration);
await context.PostAsync("学習完了!!");
context.Wait(MessageReceivedAsync);
}
else
{
await context.PostAsync("** タグの最低数5枚に達していないのでTrainできませんでした **");
context.Wait(MessageReceivedAsync);
}
}
catch (Exception ex)
{
//await context.PostAsync(ex.Message);
context.Wait(LearningMessageReceivedAsync);
}
}
else
{
await context.PostAsync("もう一度教えてください。");
context.Wait(LearningMessageReceivedAsync);
}
}
}
}