3-3 基本 その3 視覚化

機械学習で扱うのは、目に見える色の変化やオブジェクトの変形でなく、小さな数値の微細な変化なので、扱うデータや、訓練の進捗状況を視覚化することが重要です。

機械学習の環境が進んでいるPythonには、多くの機械学習向けライブラリが存在し、データの取得1つをとっても実にスマートに簡単にできますが、TensorFlow.jsでは残念ながら、現時点ではそこまで環境整備が進んでおらず、視覚化は自分で行う必要があります。

グラフの描画には、たとえば<canvas>の使用が思いつきますが、もう少し便利なJavaScriptライブラリも多数存在します。本記事では、plotly.jsというライブラリを使用します。

plotly.jsは次のようにして読み込みます。

 <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>

次の例では、左に訓練用と教師データのグラフを、右に訓練の進捗状況のグラフを表示しています。

左のグラフでは、横軸に訓練用(xs)、縦軸に教師用(ys)をとっています。これらの値は元々、y = 2x -1の関係性に一致するので、プロットした青丸はy = 2x -1の直線上に並びます。

右のグラフでは、横軸にepochs値を、縦軸に損失値をとっています。損失値は訓練の開始時には大きいですが、訓練が進むにつれ急激に小さくなっていることが分かります。

これは、学習がうまく進んでいることを示しています

訓練が終わると、左のグラフ上で、xs値に対する推測値をオレンジの線で結んでいます。学習がうまく進んでいるので、オレンジの線はほぼ青い丸の上に重なります。

そして最後、左のグラフ上に、未知の20に対する推測値を大きな緑色の丸でプロットしています。

以下は全コードです。

<!doctype html>
<html lang="ja">
<head>
    <meta charset="utf-8">
    <title>START</title>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
    <style>
      .graph-area{
        margin: 10px;
        width: 500px;
        height:500px;
        border: 1px solid black;
      }
      .container{
        display: flex;
      }
    </style>
	</head>
  <body>
    <h3>START</h3>
    <p>データと損失値の視覚化</p>
    <div class="container">
        <div id="chart1" class="graph-area"></div>
        <div id="chart2" class="graph-area"></div>
    </div>
  <script>
    document.addEventListener('DOMContentLoaded',async()=>{
      // モデルを非同期で作成する
      const buildModel = async()=>{
        const model = tf.sequential();
        model.add(tf.layers.dense({units: 1, inputShape: [1]}));
        model.compile({ loss: 'meanSquaredError', optimizer: 'sgd' });
        return model;
      }

      info('データ取得開始');
      // getData()が値を返すまで次に進まない
      // 訓練用データと教師用データ用の変数
      const [xs, ys] = await getData();
      info('データ取得完了');

      // データをグラフに丸でプロット
      addPoint(xs.dataSync(), ys.dataSync(), 8, true);

      info('モデル構築開始');
      // buildModel()が値を返すまで次に進まない
      const model = await buildModel();
      info('モデル構築完了');
      info('モデル訓練開始');
      // モデルは与えられたxsとysから、その関係性を探る。
      // 具体的には、y = 2x-1 の係数を探る
      await model.fit(xs, ys, {
        epochs: 250,
        // 1回の訓練が終わるたびに呼び出されるコールバック関数
        callbacks: {
          onEpochEnd: async (epoch, logs) => {
            // 繰り返しの回数を横軸に、損失値縦軸にグラフを描く
            addGraphData(epoch, logs.loss)
            // 画面がフリーズしないように次のフレームに進む
            await tf.nextFrame();
        }
      }})
      // tf.Model.fit()はPromiseを返すので、訓練が終わったら、then()が呼び出される。
      .then(()=>{
        info('モデル訓練完了');
        // tf.Modelが内部に持っている係数の値を調べる
        const a = model.trainableWeights[0].read(); // ax + b のaの値
        const b = model.trainableWeights[1].read(); // ax + b のbの値
        console.log(a.dataSync());
        console.log(b.dataSync());
        // tf.tidy()を使うと、指定した関数の実行後、
        // 関数内にあるtf.Tensorオブジェクトが占めるGPUメモリを解放できる。
        tf.tidy(()=>{
          // 訓練済みのモデルにxsを渡して、各値の推論値を持つtf.Tensorオブジェクト得る
          const predictOut = model.predict(xs);
          // それらの実際の値を得る
          const predictData = predictOut.dataSync();
          // 推論結果から線を描く
          addLine(xs.dataSync(),predictData);

          // xsが[20]の場合の推論結果をグラフに丸で描く
          const res = model.predict(tf.tensor2d([20], [1, 1]));
          addPoint([20], res.dataSync(), 12, false);
        });
        // tf.Tensorオブジェクトを破棄する
        xs.dispose();
        ys.dispose();
      });

    },false);

    // データを取得する。データは多くの場合外部にある。
    const getData= async()=>{
      // 訓練用データ
      const trainData = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
      // 教師用データ
      const labelData = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
      return [trainData, labelData];
    }

    const info =(msg)=>{
      console.log(msg);
    }

    // 訓練の回数と損失値の変化をグラフで描く
    let xData1 = [];
    let yData1 = [];
    const trace = {
      x: xData1,
      y: yData1,
      type: 'scatter'
    };
    const data = [trace];
    const layout = {
      xaxis: {
        title: 'epochs'
      },
      yaxis: {
        title: 'Loss'
      },
      title: '減少していく損失値の表示'
    };

    const addGraphData = (a, b) => {
      xData1.push(a);
      yData1.push(b);
      Plotly.newPlot('chart2', data, layout,{ displayModeBar: false });
    }

    // グラフに線を追加する
    const addLine = (xData, yData) => {
      const trace = {
        x: xData,
        y: yData,
        mode: 'lines',
        type: 'scatter'
      };
      Plotly.addTraces('chart1', [trace]);
    }

    // グラフに丸を描く
    const addPoint = (xData, yData, markerSize, isnewPlot) => {
      const trace = {
        x: xData,
        y: yData,
        mode: 'markers',
        type: 'scatter',
        marker: { size: markerSize }
      };
      const layout = {
        xaxis: {
          range: [-5, 25],
          title: 'xsデータ'
        },
        yaxis: {
          range: [-5, 50],
          title: 'ysデータ'
        },
        title: 'データと推論結果の表示',
        showlegend: false
      }
      if(isnewPlot){
        // 初めてグラフを描く
        Plotly.newPlot('chart1', [trace],layout,{ displayModeBar: false });
      }else{
        // 描いたグラフにデータを追加する
        Plotly.addTraces('chart1', [trace]);
      }
    }
  </script>
  </body>
</html>

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

CAPTCHA