5_1 2特徴量2クラス

ここからは、いよいよ「機械学習」っぽくなっていきます。

2特徴量2クラスというのは、あやめデータの場合でいうと、たとえば、がく片の長さと幅の2つの特徴量から、satosaとversicolorに分けることです。出来の良い学習済みモデルが作成できると、モデルはsatosaかversicolorのどちらかを言い当てられるようになります。

ただし以降の例では、モデルの訓練時に新たに、検証用データというものを取り入れるので、それに注力できるよう、データ数を少なくしています。検証用データは、訓練用データとは別のデータで、訓練時にモデルに与えます。モデルが学習するのは訓練用データだけで、モデルはその時々の学習時点で、カンニングせずに、検証用データを取り込んで損失値を出力します。

上図の左のグラフは、がく片の長さをx軸、幅をy軸として、satosaかversicolorのデータ8個をプロットしたものです。右のグラフの青い曲線は訓練用データに対する損失値の減少を、オレンジの曲線は検証データに対する損失値の減少を示しています。

その右にはテスト用のボタンが縦に並んでいて、クリックによって、学習の済んだモデルの知らないデータを与え、推論させることができます。当たればオレンジ色に、はずれれば灰色に、ボタンの色が変わります。

以下は上図で占めたWebアプリの全コードです。TensFlow.JSのIrisサンプルを参考にしてはいまずが、TensFlow.JSのサンプルは全般に複雑で難解なので、できるだけ簡単に分かりやすくしています。とは言え、十分に長いです。

<!doctype html>
<html lang="ja">
<head>
    <meta charset="utf-8">
    <title>2特徴量2クラス</title>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>

    <style>
      .graph-area{
        margin: 10px;
        width: 500px;
        height:500px;
        border: 1px solid black;
      }
      .container{
        display: flex;
      }
      .controller{
        buttons: 10px;
      }
      ul{
        list-style-type: none;
        margin: 0;
        padding: 0;
      }

      #data-select li a{
        display: block;
        border: 1px solid #9F99A3;
        background-color: #EEEEEE;
        padding: 3px 10px;
        text-decoration: none;
        color: #333;
        width: 150px;
        margin: 2px 0px;
        text-align: left;
        font-size: 14px;
      }
    </style>
	</head>
  <body>
    <h3>2特徴量2クラス</h3>
    <p>がく片の長さと幅(2特徴量)でsatosaとversicolorを分類する</p>
    <div class="container">
        <div id="chart1" class="graph-area"></div>
        <div id="chart2" class="graph-area"></div>

    <div class="buttons">
      <ul id="data-select">
        <li><a href="javascript:void(0)" data-value1="5.4" data-value2="3.0" data-value3="1">[5.4, 3.0] versicolor</a></li>
        <li><a href="javascript:void(0)" data-value1="4.8" data-value2="3.0" data-value3="0">[4.8, 3.0] satosa</a></li>
        <li><a href="javascript:void(0)" data-value1="5.1" data-value2="2.5" data-value3="1">[5.1, 2.5] versicolor</a></li>
        <li><a href="javascript:void(0)" data-value1="5.5" data-value2="4.2" data-value3="0">[5.5, 4.2] satosa</a></li>
        <li><a href="javascript:void(0)" data-value1="6.2" data-value2="3.4" data-value3="2">[6.2, 3.4] virginica</a></li>
      </ul>
      <p id="info"></p>
    </div>
  </div>
  <script>
    document.addEventListener('DOMContentLoaded',async()=>{

      // satosaデータの1個めから8個めまでの、
      // がく片の長さと幅を選び出したもの(2特徴量)
      const satosaData =[
        [5.1, 3.5], [4.9, 3.0], [4.7, 3.2],
        [4.6, 3.1], [5.0, 3.6], [5.4, 3.9],
        [4.6, 3.4], [5.0, 3.4]
      ];

      // versicolorデータの1個めから8個めまでの、
        // がく片の長さと幅を選び出したもの(2特徴量)
      const versicolorData =[
        [7.0, 3.2], [6.4, 3.2], [6.9, 3.1],
        [5.5, 2.3], [6.5, 2.8], [5.7, 2.8],
        [6.3, 3.3], [4.9, 2.4]
      ];

      // クラスは2つ
      const IRIS_CLASSES = ['setosa', 'versicolor'];

      // 与えられた配列とインデックス番号から、グラフの描画に適した
      // xとy値の配列を返す
      const getXYData = async(dataArray, indexNumber) =>{
        return dataArray.map(arr=> arr[indexNumber]);
      }

      // satosaのxとyデータ
      let satosaXdata = await getXYData(satosaData, 0);
      let satosaYdata = await getXYData(satosaData, 1);
      //console.log(satosaXdata)
      //console.log(satosaYdata)

      // versicolorのxとyデータ
      let versicolorXdata = await getXYData(versicolorData, 0);
      let versicolorYdata = await getXYData(versicolorData, 1);
      //console.log(versicolorXdata)
      //console.log(versicolorYdata)

      // Plotlyを使ってsatosaとversicolorの分布図を一気に描く
      const plot = async(xData1, yData1, xData2, yData2) => {
        const trace1 = {
          x: xData1,
          y: yData1,
          mode: 'markers',
          type: 'scatter',
          name: 'satosa'
        };

        const trace2 = {
          x: xData2,
          y: yData2,
          mode: 'markers',
          type: 'scatter',
          name: 'versicolor'
        };

        const data = [trace1, trace2];
        const layout = {
          xaxis: {
            range: [0, 8],
            title: 'がく片の長さ'
        },
          yaxis: {
            range: [0, 8],
            title: 'がく片の幅'
        },
          title: 'satosaとversicolor'
        };
        Plotly.newPlot('chart1', data, layout, { displayModeBar: false });
      }
      // 画面左に分布図を描く
      plot( satosaXdata, satosaYdata, versicolorXdata, versicolorYdata);

      // 右のグラフの準備
      let xData1 = [];
      let yData1 = [];
      let xData2 = [];
      let yData2 = [];

      // 訓練が進むたびに訓練と検証の進み具合を描画する
      const addGraphData = (a, b, c, d) => {

        xData1.push(a);
        yData1.push(b);
        xData2.push(c);
        yData2.push(d);

        const trace1 = {
          x: xData1,
          y: yData1,
          type: 'scatter',
          name: '訓練'
        };
        const trace2 = {
          x: xData2,
          y: yData2,
          type: 'scatter',
          name: '検証'
        };
        const layout = {
          xaxis: {
            range: [0, param.epochs],
            title: 'epoch'
        },
          yaxis: {
            range: [0, 1],
            title: 'loss'
        },
          title: '損失値'
        };

        Plotly.newPlot('chart2', [trace1, trace2], layout, { displayModeBar: false });
      }

      // 訓練用データ
      // がく片の長さと幅の2特徴量
      const trainXData = satosaData.concat(versicolorData);
      // 教師用データ
      const trainYData = [
        // satosa
        [1,0], [1,0], [1,0], [1,0], [1,0], [1,0], [1,0], [1,0],
        // versicolor
        [0,1], [0,1], [0,1], [0,1], [0,1], [0,1], [0,1], [0,1]
      ];

      // 検証用データ
      const validationXData =[
        // satosaデータの9個めから10個めまでを選び出したもの
        [4.4, 2.9], [4.9, 3.1],
          // versicolorデータの9個めから10個めまでを選び出したもの
        [6.6, 2.9], [5.2, 2.7]
      ];

      const validationYData = [
        // satosaデータ
        [1,0], [1,0],
        // versicolor
        [0,1], [0,1]
      ];

      // 訓練と検証用のデータをtf.Tensorオブジェクトに変換し配列に入れて返す
      const getData = async()=>{
        const xsTrain = tf.tensor2d(trainXData, [16, 2]);
        const ysTrain = tf.tensor2d(trainYData, [16, 2]);
        const xsValidation = tf.tensor2d(validationXData, [4, 2]);
        const ysValidation = tf.tensor2d(validationYData, [4, 2]);
        return [xsTrain, ysTrain, xsValidation, ysValidation]
      }

      // モデルの構築
      const buildModel = async()=>{
        // TensFlow.JSのIrisサンプルで使用されているモデル
        // https://github.com/tensorflow/tfjs-examples/tree/master/iris
        const model = tf.sequential();
        model.add(tf.layers.dense(
            {units: 10, activation: 'sigmoid', inputShape: [2]}));
        // 出力は2クラス
        model.add(tf.layers.dense({units: 2, activation: 'softmax'}));

        const learningRate = 0.01;
        const optimizer = tf.train.adam(learningRate);
        model.compile({
            optimizer: optimizer,
            // 二値交差エントロピー誤差関数
            loss: 'binaryCrossentropy',
            metrics: ['accuracy'],
        });
        return model;
      }

      const param = {
        epochs:200
      }

      // モデルを訓練する
      const train =(model, xsTrain, ysTrain, xsValidation, ysValidation)=>{
        const lossValues = [];
        const epochs = param.epochs;
        model.fit(xsTrain, ysTrain, {
          // batchSize値は変更できる。デフォルトは32
          batchSize: 32,
          epochs: epochs,
          validationData: [xsValidation, ysValidation],
          callbacks: {
            onEpochEnd: async (epoch, logs) => {
              const lossData = [epoch,logs.loss];
              // 検証用データの損失値もグラフで描く
              const val_lossData = [epoch,logs.val_loss];
              addGraphData(lossData[0], lossData[1],val_lossData[0],val_lossData[1]);
              await tf.nextFrame();
            }
          }
        }).then(()=>{
          console.log('モデルの訓練完了');

          // <ul>要素がクリックされたら、
          document.getElementById('data-select').addEventListener('click', (e)=>{
            tf.tidy(() => {
              // e.targetはクリックされた<a>要素
              // <a href="javascript:void(0)" data-value1="5.1" data-value2="2.5" data-value3="1">[5.1, 2.5] versicolor</a
              // console.log(e.target);
              // クリックされた<a>要素のカスタムデータ属性を調べ、値を得る
              const lengthVal =  Number.parseInt((e.target.dataset["value1"]),10);  // がく片の長さ
              const widthVal =  Number(e.target.dataset["value2"]);                 // がく片の幅
              // tf.TensorFlowオブジェクトに変換する
              const data = tf.tensor2d([lengthVal, widthVal], [1, 2]);
              // 訓練済みモデルで推論し、結果を得る
              const predictOut = model.predict(data);
              // 結果は確率の配列
              predictOut.print(); // 例:[[0.1198563, 0.8801436],] -> versicolorの確率が高いと予測している

              let axis = 0;
              predictOut.argMax(axis).print();  // [0, 0] -> axis=0は、縦向きに読む
              axis =1;
              predictOut.argMax(axis).print();  // [1] -> axis=1は、横向きに読む

              // 予測結果の配列を横向きに読んだ場合での、最大値のインデックス位置を得る
              const predictOutMaxIndex = predictOut.argMax(axis);
              predictOutMaxIndex.print();       // [1]
              // その値
              const predictOutMaxIndexData = predictOutMaxIndex.dataSync()[0];
              console.log(predictOutMaxIndexData) // 1
              const whichClass = IRIS_CLASSES[predictOutMaxIndexData];
              console.log(whichClass)             // versicolor

              // e.target.dataset["value3"]が参照するのは、クリックされた<a>要素のdata-value3属性の値
              // それは元々、IRIS_CLASSES配列のインデックス値に対応させてある。
              // <a href="javascript:void(0)" data-value1="5.1" data-value2="2.5" data-value3="1">[5.1, 2.5] versicolor</a

              const answer = IRIS_CLASSES[Number(e.target.dataset["value3"])];  // satosaかversicolor

              let result = '';
              // 答え合わせをして、正解不正解に応じたUI操作を行う
              if(isCorrect(answer, whichClass)){
                result='正解';
                e.target.style.backgroundColor='coral'
              }else{
                result='不正解';
                e.target.style.backgroundColor='gray'
              }
              const message = "このあやめの種類は" + answer + '。機械学習の予測は' + whichClass + '。よって'+ result;
              document.getElementById('info').textContent = message;
            });
          }, false);
          // 後始末
          xsTrain.dispose();
          ysTrain.dispose();
          xsValidation.dispose();
          ysValidation.dispose();
        });

        // 与えられたaとbが一致するかどうか
        const isCorrect =(a,b)=>{
          if(a===b){
            return 1;
          }else{
            return 0;
          }
        }
      }

      // ページが読み込まれたら、モデルの訓練まで進む
      const start =async()=>{
        const [xsTrain, ysTrain, xsValidation, ysValidation] = await getData();
        const model = await buildModel();
        train(model, xsTrain, ysTrain, xsValidation, ysValidation);
      }

      start();
    }, false);
  </script>
  </body>
</html>

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

CAPTCHA