5_2_3 TensofFlow.jsのモデルの作成と訓練

今行おうとしているのは、[22,23,23,…,34,35,35]という最高気温と、[303,313,323,…,463,443,483]と並ぶビールの売上本数のデータから、では、統計が取られていない30度の日にはビールは一体何本売れそうかを推測する、ということです。

これは、統計学で回帰(Regression)と呼ばれる手法で、機械学習でも同様に回帰と呼ばれています。具体的に言うと、データをグラフに描いてその傾向を考え、そこから線を描いて具体的な線上の数値を割り出す、というようなことです。

では、この回帰問題に適したモデルはどのようなものがよいのでしょう? 機械学習を本当に学ぶなら根本から勉強する必要がありますが、そうするとTensorFlow.jsではなく、Python言語のTensorFlowやさらに進めたKerasを学ぶべきでしょう。

そこまでしないで何とかしたい場合には、Kerasのドキュメントが参考になります。

ドキュメントの「SequentialモデルでKerasを始めてみよう」ページの「コンパイル」項には、次の記述がされています。

# 平均二乗誤差を最小化する回帰問題の場合
model.compile(optimizer=’rmsprop’,
loss=’mse’)

これを参考にすると、ごく単純なTensorFlow.jsのモデルは次のように記述できます。前述のWebアプリではこのモデルを使っています。

const buildModel = async()=>{
    const model = tf.sequential();
    model.add(tf.layers.dense({units: 1, inputShape: [1]}));
    const learningRate = 0.01;
    const optimizer = tf.train.rmsprop(learningRate);
    model.compile({loss: 'meanSquaredError', optimizer: optimizer});
    return model;
}

JavaScriptの配列からtf.Tensorオブジェクトに変換する関数は次のように定義しています。

// CSVからのデータ配列を受け取り、tf.Tensorに変換して配列にまとめて返す
const getData= (dataArray)=>{
  // 訓練用データ
  const trainData = tf.tensor2d(dataArray[0], [25, 1]);
  // 教師用データ
  const labelData = tf.tensor2d(dataArray[1], [25, 1]);
  return [trainData, labelData];
}

変数dataArrayには、0番めに最高気温のデータが、1番めに売上本数のデータが入っています。これらをtf.tensor2d()でtf.Tensorオブジェクトに変換し、配列に入れて呼び出し元に返します。

最後に、plot()関数で呼び出したtrain()関数です。

// モデルを訓練し、訓練終了後に30度の日の売上予想を推測する。
const train = async(beer) => {
    // 訓練用と教師用データを得る
    const [xs, ys] = await getData(beer);
    xs.print();
    ys.print();
    // モデルを構築する
    const model = await buildModel();
    // 訓練を繰り返す回数
    const epochs = 2500;
    // tf.Model.fit()はPromiseを返す
    model.fit(xs, ys, {
        epochs: epochs,
        // 訓練中に呼び出されるコールバック関数
        callbacks: {
            onEpochEnd: async(epoch, logs) => {
                // 繰り返し回数と損失を出力
                console.log(epoch);
                console.log(logs.loss);
                // グラフに損失値の変化を描く
                plotLoss(epoch, logs.loss)
                    // 画面がフリーズしないように次のフレームに進む
                await tf.nextFrame();
            }
        }
        // 訓練が終わったら
    }).then(() => {
        // 実行後、GPUにあるtf.Tensorオブジェクトを消去し、メモリを解放する
        tf.tidy(() => {
            // モデルが推測した係数を使って、y = ax + bの直線を描いてみる
            // y = ax + b のa
            console.log(model.trainableWeights[0].read().dataSync());
            // y = ax + b のb
            console.log(model.trainableWeights[1].read().dataSync());
            // 係数aとbを取り出す
            const a = model.trainableWeights[0].read().dataSync()[0];
            const b = model.trainableWeights[1].read().dataSync()[0];

            let yData = [];
            let xData = 0;
            const len = temperatureData.length;

            for (let i = 0; i < len; i++) {
                // y = ax + bを計算し、結果を配列に追加する
                let y = temperatureData[i] * a + b;
                yData.push((y));
            }
            // 直線を描く
            addLine(temperatureData, yData);

            // 30度のときの売上数を予測
            const pred = model.predict(tf.tensor2d([30], [1, 1]));
            // モデルが推論した予測数
            pred.print();

            // 予測数を小数第2位の数値に整えて表示する
            const predNum = Math.floor((pred.dataSync() * 100)) / 100;
            document.getElementById('result').textContent = '最高気温が30度の日は、' + predNum + '個売れると予測した。';
            // グラフ上に点を追加
            addPoint([30], [predNum]);
        });
        xs.dispose();
        ys.dispose();
    });
}

左のグラフに描く直線について言うと、訓練を終えたモデルはmodel.trainableWeightsプロパティに、今の場合で言う係数の推測値を持っています。このプロパティはその名前から、機械学習の参考書で言うところの重み変数を保持する配列であろうと思われます。それを読み取り(read())、値を同期的にダウンロード(dataSync())します。数値は配列に入っているので[0]で値を取得します。

// 係数aとbを取り出す
const a = model.trainableWeights[0].read().dataSync()[0];
const b = model.trainableWeights[1].read().dataSync()[0];

y = ax + bのaとbが得られたので、Plotyに渡す横軸と縦軸の数値の配列を作成します。横軸用の配列はtemperatureDataがそのまま流用できます。

let yData = [];
let xData = 0;
const len = temperatureData.length;

for (let i = 0; i < len; i++) {
    // y = ax + bを計算し、結果を配列に追加する
    let y = temperatureData[i] * a + b;
    yData.push((y));
}
// 直線を描く
addLine(temperatureData, yData);

結果は良好のように見えます。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

CAPTCHA