12_6 3つのモデルの訓練と評価、推測

モデルを訓練して、評価、推測します。

HTMLファイルでは、関係するJavaScriptコードを次の順番で読み込みます。

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
<script src="https://fastcdn.org/Papa-Parse/4.1.2/papaparse.min.js"></script>
<script src="js/data.js"></script>
<script src="js/normalization.js"></script>
<script src="js/models.js"></script>
<script src="js/graph.js"></script>
<script src="js/ui.js"></script>

HTML要素は次のように記述します。

<div class="container">
  <div id="chart" class="graph-area"></div>
  <div class="no-flex">
    <textarea rows="3" cols="40" readonly="true" id="status">データの読み込み中...</textarea>
    <textarea rows="3" cols="40" readonly="true" id="baselineStatus">基準はまだ算出されていない...</textarea>
    <ul>
      <li><span>予測価格</span><span>-></span><span>元データの価格</span></li>
      <li id="res1"><span>〇〇〇〇</span><span>-></span><span></span></li>
      <li id="res2"><span>〇〇〇〇</span><span>-></span><span></span></li>
      <li id="res3"><span>〇〇〇〇</span><span>-></span><span></span></li>
    </ul>
  </div>
</div>

メインのJavaScriptコードには次のものを記述します。

// Googleの提供する4つのCSVファイルをダウンロードし、配列に入れる。
const [trainFeatures, trainTarget, testFeatures, testTarget] =
await Promise.all([
    loadCsv(TRAIN_FEATURES_FN), loadCsv(TRAIN_TARGET_FN),
    loadCsv(TEST_FEATURES_FN), loadCsv(TEST_TARGET_FN)
]);

// シャッフルする。
shuffle(trainFeatures, trainTarget);
shuffle(testFeatures, testTarget);

updateStatus('データの読み込み終了、tf.Tensorに変換');

// CSVからのデータをtf.Tensorオブジェクトに変換する
const trainFeaturesTF = tf.tensor2d(trainFeatures, [333, 12], 'float32');
const trainTargetTF = tf.tensor2d(trainTarget, [333, 1], 'float32');
const testFeaturesTF = tf.tensor2d(testFeatures, [173, 12], 'float32');
const testTargetTF = tf.tensor2d(testTarget, [173, 1], 'float32');

// trainFeaturesTFデータの各列の平均と標準偏差を計算する
let { dataMean, dataStd } = determineMeanAndStddev(trainFeaturesTF);
const trainFeaturesNormalized = normalizeTensor(trainFeaturesTF, dataMean, dataStd);
const testFeaturesNormalized = normalizeTensor(testFeaturesTF, dataMean, dataStd);

// 評価用とテスト用に分ける
const [testFeaturesNormalized_evaluate, testFeaturesNormalized_test] = testFeaturesNormalized.split([170, 3]);
const [testTargetTF_evaluate, testTargetTF_test] = testTargetTF.split([170, 3]);

updateStatus('変換終了、モデルを作成して訓練開始');

// モデルがクリアすべき損失のベースラインを計算する
const computeBaseline = () => {
    // 訓練用データに含まれる住宅の平均価格
    const avgPrice = tf.mean(trainTargetTF);
    console.log(`平均: ${avgPrice.dataSync()}`);
    // 平均2乗誤差で得られた値を誤差のベースライン(基準線)とする
    // これをクリアしないモデルはNG
    const baseline = tf.mean(tf.pow(tf.sub(testTargetTF, avgPrice), 2));
    const baselineMsg = `損失の基準線 (平均2乗誤差): ${ baseline.dataSync()[0].toFixed(2) }`;
    updateBaselineStatus(baselineMsg);
};

computeBaseline();
// 線形回帰モデル
const model = linearRegressionModel();
// 多層パーセプトロン回帰モデル(隠れ層1つ)
// const model = multiLayerPerceptronRegressionModel1Hidden();
// 多層パーセプトロン回帰モデル(隠れ層2つ)
//const model = multiLayerPerceptronRegressionModel2Hidden();

// モデルの訓練に使用するハイパーパラメータ
const NUM_EPOCHS = 200;
const BATCH_SIZE = 40;
const LEARNING_RATE = 0.01;

let trainLoss;
let valLoss;
updateStatus('訓練中...');
// モデルを訓練する
await model.fit(trainFeaturesNormalized, trainTargetTF, {
    batchSize: BATCH_SIZE,
    epochs: NUM_EPOCHS,
    validationSplit: 0.2, // 検証にまわすデータの割合
    callbacks: {
        // epochの終了時ごとに呼び出されるコールバック関数
        onEpochEnd: async(epoch, logs) => {
            trainLoss = logs.loss;
            valLoss = logs.val_loss;
            // グラフにプロット
            plot(epoch, logs.loss, epoch, logs.val_loss);
            await tf.nextFrame();
        }
    }
});

updateStatus('評価用データで評価...');
const result = model.evaluate(testFeaturesNormalized_evaluate, testTargetTF_evaluate, {
    batchSize: BATCH_SIZE
});
const testLoss = result.dataSync()[0];
updateStatus(
    `訓練用データの最終的な損失: ${trainLoss.toFixed(4)}\n` +
    `検証用データの最終的な損失: ${valLoss.toFixed(4)}\n` +
    `評価用データの損失: ${testLoss.toFixed(4)}`);

// テスト用データを取り出す
const x1 = testFeaturesNormalized_test.slice(0, 1);
const x2 = testFeaturesNormalized_test.slice(1, 1);
const x3 = testFeaturesNormalized_test.slice(2, 1);

// モデルに推測させる
const predict1 = model.predict(x1);
const predict2 = model.predict(x2);
const predict3 = model.predict(x3);

// 推測結果を表示する。合わせて比較のため、元データの価格も表示する
const span_predict1 = document.getElementById('res1').firstElementChild;
span_predict1.textContent = predict1.dataSync()[0].toFixed(4);
span_predict1.nextElementSibling.nextElementSibling.textContent = testTargetTF_test.slice(0, 1).dataSync()[0].toFixed(4);

const span_predict2 = document.getElementById('res2').firstElementChild;
span_predict2.textContent = predict2.dataSync()[0].toFixed(4);
span_predict2.nextElementSibling.nextElementSibling.textContent = testTargetTF_test.slice(1, 1).dataSync()[0].toFixed(4);

const span_predict3 = document.getElementById('res3').firstElementChild;
span_predict3.textContent = predict3.dataSync()[0].toFixed(4);
span_predict3.nextElementSibling.nextElementSibling.textContent = testTargetTF_test.slice(2, 1).dataSync()[0].toFixed(4);

作成するモデルを変更するには、100行め当たりからの次のコードのコメントを適宜操作します。

// 線形回帰モデル
const model = linearRegressionModel();
// 多層パーセプトロン回帰モデル(隠れ層1つ)
// const model = multiLayerPerceptronRegressionModel1Hidden();
// 多層パーセプトロン回帰モデル(隠れ層2つ)
//const model = multiLayerPerceptronRegressionModel2Hidden();

下図は上から、線形回帰モデル、多層パーセプトロン回帰モデル(隠れ層1つ)、多層パーセプトロン回帰モデル(隠れ層2つ)を作成し、ここまでのコードで実行した結果です。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

CAPTCHA