7-6 モデルの推測結果の表示

前回までで、モデルはテスト用データに対し97%ほどの精度で正解することが分かりました。以下では、テスト用データの画像と、モデルの出した答えを視覚的に比較します。

下図は実行結果の例です。データは100個表示しており、不正解の赤は2つなので、100-2から正解率は98%だと類推できます。

メインのコードでは、train()関数の呼び出し後、showPredictions()関数を呼び出します。

await train();
showPredictions();

showPredictions()は次の関数です。これもTensorFlow.jsのmnistサンプルを参考にしているので、英文コメントの日本語訳もつけておきます。

const showPredictions = () => {
    // 表示するデータ数
    const testExamples = 100;
    // testXSの頭からtestExamples個分のテスト用画像データを取得
    const testExamplesXS = testXS.slice([0, 0, 0, 0], [testExamples, 28, 28, 1]);
    // testYSの頭からtestExamples個分のテスト用ラベルデータを取得
    const testExamplesYS = testYS.slice([0, 0], [testExamples, 10]);

    const output = model.predict(testExamplesXS);
    // tf.argMax()は、指定された軸(axis)に沿った最大値のインデックスを返す。
    // この例のようなカテゴリ分類の作業では通常、クラスをone-hotベクトルで表す。
    // one-hotベクトルは、1出力クラスごとに1要素を持つ1Dベクトル。
    // ベクトルの値は、1要素だけが1でほかは全部0(例:[0, 0, 0, 1, 0])。
    // model.predict()からの出力は確率分布になるので、argMax()を使って、
    // 最も高い確率を持つベクトル要素のインデックスを得る。これがモデルの出した予想になる。
    // 0.75が最大値なので、そのインデックスは3。要素の値を合計すると1になる。
    // argmax([0.07, 0.1, 0.03, 0.75, 0.05]) == 3

    // dataSync()はtf.Tensor値をGPUから同期的にダウンロードするので、CPUで動作する通常のJavaScriptコードが
    // 使用できる。コードの動作をブロックしたくない場合には、非同期的に動作するdata()を使用する。
    const axis = 1;
    const labels = Array.from(testExamplesYS.argMax(axis).dataSync());
    const predictions = Array.from(output.argMax(axis).dataSync());
    // 答え合わせを表示する
    showTestResults(testExamplesXS, predictions, labels);
    // 後始末
    testXS.dispose();
    testYS.dispose();
    testExamplesYS.dispose();
}

最後に呼び出しているshowTestResults(testExamplesXS, predictions, labels)は、次の関数です。

// 答え合わせの結果を表示する
const showTestResults = (testExamplesXS, predictions, labels) => {
    // testExamplesXSのshapeは、[サンプル数, 28, 28, 1]
    // testExamplesは表示するサンプルの数
    const testExamples = testExamplesXS.shape[0];

    new Promise((resolve, reject) => {
        for (let i = 0; i < testExamples; i++) {
            // testExamplesXSのi個めを取り出す
            const sliced = testExamplesXS.slice([i, 0, 0, 0], [1, 28, 28, 1]);
            // 描画するキャンバス
            const canvas = document.createElement('canvas');
            // [1, 28, 28, 1]を[28, 28]にreshape
            const imgTf = sliced.reshape([28, 28]);
            // イメージtf.Tensorオブジェクトをcanvasに描画する
            tf.toPixels(imgTf, canvas);

            // 予測値とラベル値の答え合わせ
            const pred = document.createElement('div');
            pred.className = 'pred-div';
            // i個めの予測
            const prediction = predictions[i];
            // i個めの答え
            const label = labels[i];
            // 予測と答えが一致しているなら
            if (prediction === label) {
                // <div>の背景色を緑色にする
                pred.classList.add('pred-correct');
                // そうでないなら
            }
            else {
                // <div>の背景色を赤色にする
                pred.classList.add('pred-incorrect');
            }
            // <div>の文字を予測の数字にする
            pred.innerText = `${prediction}`;
            // 画面に表示
            document.querySelector('.image-div').appendChild(canvas);
            document.querySelector('.label-div').appendChild(pred);
        }
    }).then(() => {
        testExamplesXS.dispose();
    });
}

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

CAPTCHA