3-1 基本 その1 TensorFlow.jsの作業の流れ

次のコードは、ブラウザで開くと自動的にTensorFlow.jsのモデルが作成され学習を始めます。学習が終わり、ボタンをクリックすると、テキストフィールドに入力された数値(20)に対応する数値をデベロッパーツールのコンソールに出力します。

TensorFlow.jsのモデルには、[-1, 0, 1, 2, 3, 4]と[-3, -1, 1, 3, 5, 7]という配列がデータとして与えられます。前述したとおり、その関係性はy = 2x – 1です。テキストフィールドの数値は20なので、yは2 x 20 -1 = 39です。TensorFlow.jsは今の場合、38.4103889465332と推測しているので、概ね合っていることが分かります。

<!doctype html>
<html lang="ja">
<head>
    <meta charset="utf-8">
    <title>START</title>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
    <style>
    </style>
	</head>
  <body>
    <h3>START</h3>
    <p>y = 2x - 1という分かりやすい関係性を例に、TensorFlow.jsでの作業の流れをつかむ</p>
    <input type="button" id="predict-button" value="推論">
    <input type="number" id="number-input" value="20">
  <script>
    document.addEventListener('DOMContentLoaded',async()=>{

      // モデルを非同期で作成する
      const buildModel = async()=>{
        const model = tf.sequential();
        model.add(tf.layers.dense({units: 1, inputShape: [1]}));
        model.compile({ loss: 'meanSquaredError', optimizer: 'sgd' });
        return model;
      }

      info('データ取得開始');
      // getData()が値を返すまで次に進まない
     // 訓練用データと教師用データ用の変数
      const [xs, ys] = await getData();
      info('データ取得完了');

      info('モデル構築開始');
      // buildModel()が値を返すまで次に進まない
      const model = await buildModel();
      info('モデル構築完了');

      info('モデル訓練開始');
      // モデルは与えられたxsとysから、その関係性を探る。
      // 具体的には、y = 2x-1 の係数を探る
      await model.fit(xs, ys, {epochs: 250});
      info('モデル訓練完了');

      // 訓練が終わると、訓練済みモデルを使った推論が行える
      document.getElementById('predict-button').addEventListener('click', ()=>{
        const inputNumber = document.getElementById('number-input').value;
        // 数値入力フィールドに入力された数値から結果を推論する
        const res = model.predict(tf.tensor2d([inputNumber], [1, 1]));
        console.log(res.dataSync());
      }, false);
    },false);

    // データを取得する。データは多くの場合外部にある。
    const getData= async()=>{
      // 訓練用データ
      const trainData = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
      // 教師用データ
      const labelData = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
      return [trainData, labelData];
    }
    const info =(msg)=>{
      console.log(msg);
    }
  </script>
  </body>
</html>

以降では、ここに記述したコードを順に見ていきます。

“3-1 基本 その1 TensorFlow.jsの作業の流れ” への2件の返信

  1. 有益な情報をありがとうございます。

    当方の環境だけかもしれませんが、推論の
    const res = model.predict(tf.tensor2d([inputNumber], [1, 1]));
    部分で下記のようなエラーが出ました。(tensorflow.js v1.1.2)

    Uncaught Error: Argument ‘x’ passed to ‘slice2d’ must be numeric tensor, but got string tensor

    解決策としては
    [inputNumber*1]
    として数値にすることでエラーが消えました。

    1. コードを書いた時期にはエラーは出ていなかったので、tensorflow.jsのバージョンアップにともなうものかと想像します。
      とは言え、数値を取る関数に渡す引数のタイプが、確かに数値であることを確認するのは重要なことなので、型変換しておくとエラーは回避できます。

      // const inputNumber = document.getElementById(‘number-input’).value;
      let inputNumber = document.getElementById(‘number-input’).value;
      // inputNumber = Number(inputNumber);
      inputNumber = parseInt(inputNumber, 10);
      // この結果はnuberであればよい
      console.log(typeof inputNumber);

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

CAPTCHA