3-1 基本 その1 TensorFlow.jsの作業の流れ

次のコードは、ブラウザで開くと自動的にTensorFlow.jsのモデルが作成され学習を始めます。学習が終わり、ボタンをクリックすると、テキストフィールドに入力された数値(20)に対応する数値をデベロッパーツールのコンソールに出力します。

TensorFlow.jsのモデルには、[-1, 0, 1, 2, 3, 4]と[-3, -1, 1, 3, 5, 7]という配列がデータとして与えられます。前述したとおり、その関係性はy = 2x – 1です。テキストフィールドの数値は20なので、yは2 x 20 -1 = 39です。TensorFlow.jsは今の場合、38.4103889465332と推測しているので、概ね合っていることが分かります。

<!doctype html>
<html lang="ja">
<head>
    <meta charset="utf-8">
    <title>START</title>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
    <style>
    </style>
	</head>
  <body>
    <h3>START</h3>
    <p>y = 2x - 1という分かりやすい関係性を例に、TensorFlow.jsでの作業の流れをつかむ</p>
    <input type="button" id="predict-button" value="推論">
    <input type="number" id="number-input" value="20">
  <script>
    document.addEventListener('DOMContentLoaded',async()=>{

      // モデルを非同期で作成する
      const buildModel = async()=>{
        const model = tf.sequential();
        model.add(tf.layers.dense({units: 1, inputShape: [1]}));
        model.compile({ loss: 'meanSquaredError', optimizer: 'sgd' });
        return model;
      }

      info('データ取得開始');
      // getData()が値を返すまで次に進まない
     // 訓練用データと教師用データ用の変数
      const [xs, ys] = await getData();
      info('データ取得完了');

      info('モデル構築開始');
      // buildModel()が値を返すまで次に進まない
      const model = await buildModel();
      info('モデル構築完了');

      info('モデル訓練開始');
      // モデルは与えられたxsとysから、その関係性を探る。
      // 具体的には、y = 2x-1 の係数を探る
      await model.fit(xs, ys, {epochs: 250});
      info('モデル訓練完了');

      // 訓練が終わると、訓練済みモデルを使った推論が行える
      document.getElementById('predict-button').addEventListener('click', ()=>{
        const inputNumber = document.getElementById('number-input').value;
        // 数値入力フィールドに入力された数値から結果を推論する
        const res = model.predict(tf.tensor2d([inputNumber], [1, 1]));
        console.log(res.dataSync());
      }, false);
    },false);

    // データを取得する。データは多くの場合外部にある。
    const getData= async()=>{
      // 訓練用データ
      const trainData = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
      // 教師用データ
      const labelData = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
      return [trainData, labelData];
    }
    const info =(msg)=>{
      console.log(msg);
    }
  </script>
  </body>
</html>

以降では、ここに記述したコードを順に見ていきます。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

CAPTCHA