前回までで、モデルはテスト用データに対し97%ほどの精度で正解することが分かりました。以下では、テスト用データの画像と、モデルの出した答えを視覚的に比較します。
下図は実行結果の例です。データは100個表示しており、不正解の赤は2つなので、100-2から正解率は98%だと類推できます。
メインのコードでは、train()関数の呼び出し後、showPredictions()関数を呼び出します。
await train();
showPredictions();
showPredictions()は次の関数です。これもTensorFlow.jsのmnistサンプルを参考にしているので、英文コメントの日本語訳もつけておきます。
const showPredictions = () => {
// 表示するデータ数
const testExamples = 100;
// testXSの頭からtestExamples個分のテスト用画像データを取得
const testExamplesXS = testXS.slice([0, 0, 0, 0], [testExamples, 28, 28, 1]);
// testYSの頭からtestExamples個分のテスト用ラベルデータを取得
const testExamplesYS = testYS.slice([0, 0], [testExamples, 10]);
const output = model.predict(testExamplesXS);
// tf.argMax()は、指定された軸(axis)に沿った最大値のインデックスを返す。
// この例のようなカテゴリ分類の作業では通常、クラスをone-hotベクトルで表す。
// one-hotベクトルは、1出力クラスごとに1要素を持つ1Dベクトル。
// ベクトルの値は、1要素だけが1でほかは全部0(例:[0, 0, 0, 1, 0])。
// model.predict()からの出力は確率分布になるので、argMax()を使って、
// 最も高い確率を持つベクトル要素のインデックスを得る。これがモデルの出した予想になる。
// 0.75が最大値なので、そのインデックスは3。要素の値を合計すると1になる。
// argmax([0.07, 0.1, 0.03, 0.75, 0.05]) == 3
// dataSync()はtf.Tensor値をGPUから同期的にダウンロードするので、CPUで動作する通常のJavaScriptコードが
// 使用できる。コードの動作をブロックしたくない場合には、非同期的に動作するdata()を使用する。
const axis = 1;
const labels = Array.from(testExamplesYS.argMax(axis).dataSync());
const predictions = Array.from(output.argMax(axis).dataSync());
// 答え合わせを表示する
showTestResults(testExamplesXS, predictions, labels);
// 後始末
testXS.dispose();
testYS.dispose();
testExamplesYS.dispose();
}
最後に呼び出しているshowTestResults(testExamplesXS, predictions, labels)は、次の関数です。
// 答え合わせの結果を表示する
const showTestResults = (testExamplesXS, predictions, labels) => {
// testExamplesXSのshapeは、[サンプル数, 28, 28, 1]
// testExamplesは表示するサンプルの数
const testExamples = testExamplesXS.shape[0];
new Promise((resolve, reject) => {
for (let i = 0; i < testExamples; i++) {
// testExamplesXSのi個めを取り出す
const sliced = testExamplesXS.slice([i, 0, 0, 0], [1, 28, 28, 1]);
// 描画するキャンバス
const canvas = document.createElement('canvas');
// [1, 28, 28, 1]を[28, 28]にreshape
const imgTf = sliced.reshape([28, 28]);
// イメージtf.Tensorオブジェクトをcanvasに描画する
tf.toPixels(imgTf, canvas);
// 予測値とラベル値の答え合わせ
const pred = document.createElement('div');
pred.className = 'pred-div';
// i個めの予測
const prediction = predictions[i];
// i個めの答え
const label = labels[i];
// 予測と答えが一致しているなら
if (prediction === label) {
// <div>の背景色を緑色にする
pred.classList.add('pred-correct');
// そうでないなら
}
else {
// <div>の背景色を赤色にする
pred.classList.add('pred-incorrect');
}
// <div>の文字を予測の数字にする
pred.innerText = `${prediction}`;
// 画面に表示
document.querySelector('.image-div').appendChild(canvas);
document.querySelector('.label-div').appendChild(pred);
}
}).then(() => {
testExamplesXS.dispose();
});
}