ここからは、いよいよ「機械学習」っぽくなっていきます。
2特徴量2クラスというのは、あやめデータの場合でいうと、たとえば、がく片の長さと幅の2つの特徴量から、satosaとversicolorに分けることです。出来の良い学習済みモデルが作成できると、モデルはsatosaかversicolorのどちらかを言い当てられるようになります。
ただし以降の例では、モデルの訓練時に新たに、検証用データというものを取り入れるので、それに注力できるよう、データ数を少なくしています。検証用データは、訓練用データとは別のデータで、訓練時にモデルに与えます。モデルが学習するのは訓練用データだけで、モデルはその時々の学習時点で、カンニングせずに、検証用データを取り込んで損失値を出力します。
上図の左のグラフは、がく片の長さをx軸、幅をy軸として、satosaかversicolorのデータ8個をプロットしたものです。右のグラフの青い曲線は訓練用データに対する損失値の減少を、オレンジの曲線は検証データに対する損失値の減少を示しています。
その右にはテスト用のボタンが縦に並んでいて、クリックによって、学習の済んだモデルの知らないデータを与え、推論させることができます。当たればオレンジ色に、はずれれば灰色に、ボタンの色が変わります。
以下は上図で占めたWebアプリの全コードです。TensFlow.JSのIrisサンプルを参考にしてはいまずが、TensFlow.JSのサンプルは全般に複雑で難解なので、できるだけ簡単に分かりやすくしています。とは言え、十分に長いです。
<!doctype html>
<html lang="ja">
<head>
<meta charset="utf-8">
<title>2特徴量2クラス</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
<style>
.graph-area{
margin: 10px;
width: 500px;
height:500px;
border: 1px solid black;
}
.container{
display: flex;
}
.controller{
buttons: 10px;
}
ul{
list-style-type: none;
margin: 0;
padding: 0;
}
#data-select li a{
display: block;
border: 1px solid #9F99A3;
background-color: #EEEEEE;
padding: 3px 10px;
text-decoration: none;
color: #333;
width: 150px;
margin: 2px 0px;
text-align: left;
font-size: 14px;
}
</style>
</head>
<body>
<h3>2特徴量2クラス</h3>
<p>がく片の長さと幅(2特徴量)でsatosaとversicolorを分類する</p>
<div class="container">
<div id="chart1" class="graph-area"></div>
<div id="chart2" class="graph-area"></div>
<div class="buttons">
<ul id="data-select">
<li><a href="javascript:void(0)" data-value1="5.4" data-value2="3.0" data-value3="1">[5.4, 3.0] versicolor</a></li>
<li><a href="javascript:void(0)" data-value1="4.8" data-value2="3.0" data-value3="0">[4.8, 3.0] satosa</a></li>
<li><a href="javascript:void(0)" data-value1="5.1" data-value2="2.5" data-value3="1">[5.1, 2.5] versicolor</a></li>
<li><a href="javascript:void(0)" data-value1="5.5" data-value2="4.2" data-value3="0">[5.5, 4.2] satosa</a></li>
<li><a href="javascript:void(0)" data-value1="6.2" data-value2="3.4" data-value3="2">[6.2, 3.4] virginica</a></li>
</ul>
<p id="info"></p>
</div>
</div>
<script>
document.addEventListener('DOMContentLoaded',async()=>{
// satosaデータの1個めから8個めまでの、
// がく片の長さと幅を選び出したもの(2特徴量)
const satosaData =[
[5.1, 3.5], [4.9, 3.0], [4.7, 3.2],
[4.6, 3.1], [5.0, 3.6], [5.4, 3.9],
[4.6, 3.4], [5.0, 3.4]
];
// versicolorデータの1個めから8個めまでの、
// がく片の長さと幅を選び出したもの(2特徴量)
const versicolorData =[
[7.0, 3.2], [6.4, 3.2], [6.9, 3.1],
[5.5, 2.3], [6.5, 2.8], [5.7, 2.8],
[6.3, 3.3], [4.9, 2.4]
];
// クラスは2つ
const IRIS_CLASSES = ['setosa', 'versicolor'];
// 与えられた配列とインデックス番号から、グラフの描画に適した
// xとy値の配列を返す
const getXYData = async(dataArray, indexNumber) =>{
return dataArray.map(arr=> arr[indexNumber]);
}
// satosaのxとyデータ
let satosaXdata = await getXYData(satosaData, 0);
let satosaYdata = await getXYData(satosaData, 1);
//console.log(satosaXdata)
//console.log(satosaYdata)
// versicolorのxとyデータ
let versicolorXdata = await getXYData(versicolorData, 0);
let versicolorYdata = await getXYData(versicolorData, 1);
//console.log(versicolorXdata)
//console.log(versicolorYdata)
// Plotlyを使ってsatosaとversicolorの分布図を一気に描く
const plot = async(xData1, yData1, xData2, yData2) => {
const trace1 = {
x: xData1,
y: yData1,
mode: 'markers',
type: 'scatter',
name: 'satosa'
};
const trace2 = {
x: xData2,
y: yData2,
mode: 'markers',
type: 'scatter',
name: 'versicolor'
};
const data = [trace1, trace2];
const layout = {
xaxis: {
range: [0, 8],
title: 'がく片の長さ'
},
yaxis: {
range: [0, 8],
title: 'がく片の幅'
},
title: 'satosaとversicolor'
};
Plotly.newPlot('chart1', data, layout, { displayModeBar: false });
}
// 画面左に分布図を描く
plot( satosaXdata, satosaYdata, versicolorXdata, versicolorYdata);
// 右のグラフの準備
let xData1 = [];
let yData1 = [];
let xData2 = [];
let yData2 = [];
// 訓練が進むたびに訓練と検証の進み具合を描画する
const addGraphData = (a, b, c, d) => {
xData1.push(a);
yData1.push(b);
xData2.push(c);
yData2.push(d);
const trace1 = {
x: xData1,
y: yData1,
type: 'scatter',
name: '訓練'
};
const trace2 = {
x: xData2,
y: yData2,
type: 'scatter',
name: '検証'
};
const layout = {
xaxis: {
range: [0, param.epochs],
title: 'epoch'
},
yaxis: {
range: [0, 1],
title: 'loss'
},
title: '損失値'
};
Plotly.newPlot('chart2', [trace1, trace2], layout, { displayModeBar: false });
}
// 訓練用データ
// がく片の長さと幅の2特徴量
const trainXData = satosaData.concat(versicolorData);
// 教師用データ
const trainYData = [
// satosa
[1,0], [1,0], [1,0], [1,0], [1,0], [1,0], [1,0], [1,0],
// versicolor
[0,1], [0,1], [0,1], [0,1], [0,1], [0,1], [0,1], [0,1]
];
// 検証用データ
const validationXData =[
// satosaデータの9個めから10個めまでを選び出したもの
[4.4, 2.9], [4.9, 3.1],
// versicolorデータの9個めから10個めまでを選び出したもの
[6.6, 2.9], [5.2, 2.7]
];
const validationYData = [
// satosaデータ
[1,0], [1,0],
// versicolor
[0,1], [0,1]
];
// 訓練と検証用のデータをtf.Tensorオブジェクトに変換し配列に入れて返す
const getData = async()=>{
const xsTrain = tf.tensor2d(trainXData, [16, 2]);
const ysTrain = tf.tensor2d(trainYData, [16, 2]);
const xsValidation = tf.tensor2d(validationXData, [4, 2]);
const ysValidation = tf.tensor2d(validationYData, [4, 2]);
return [xsTrain, ysTrain, xsValidation, ysValidation]
}
// モデルの構築
const buildModel = async()=>{
// TensFlow.JSのIrisサンプルで使用されているモデル
// https://github.com/tensorflow/tfjs-examples/tree/master/iris
const model = tf.sequential();
model.add(tf.layers.dense(
{units: 10, activation: 'sigmoid', inputShape: [2]}));
// 出力は2クラス
model.add(tf.layers.dense({units: 2, activation: 'softmax'}));
const learningRate = 0.01;
const optimizer = tf.train.adam(learningRate);
model.compile({
optimizer: optimizer,
// 二値交差エントロピー誤差関数
loss: 'binaryCrossentropy',
metrics: ['accuracy'],
});
return model;
}
const param = {
epochs:200
}
// モデルを訓練する
const train =(model, xsTrain, ysTrain, xsValidation, ysValidation)=>{
const lossValues = [];
const epochs = param.epochs;
model.fit(xsTrain, ysTrain, {
// batchSize値は変更できる。デフォルトは32
batchSize: 32,
epochs: epochs,
validationData: [xsValidation, ysValidation],
callbacks: {
onEpochEnd: async (epoch, logs) => {
const lossData = [epoch,logs.loss];
// 検証用データの損失値もグラフで描く
const val_lossData = [epoch,logs.val_loss];
addGraphData(lossData[0], lossData[1],val_lossData[0],val_lossData[1]);
await tf.nextFrame();
}
}
}).then(()=>{
console.log('モデルの訓練完了');
// <ul>要素がクリックされたら、
document.getElementById('data-select').addEventListener('click', (e)=>{
tf.tidy(() => {
// e.targetはクリックされた<a>要素
// <a href="javascript:void(0)" data-value1="5.1" data-value2="2.5" data-value3="1">[5.1, 2.5] versicolor</a
// console.log(e.target);
// クリックされた<a>要素のカスタムデータ属性を調べ、値を得る
const lengthVal = Number.parseInt((e.target.dataset["value1"]),10); // がく片の長さ
const widthVal = Number(e.target.dataset["value2"]); // がく片の幅
// tf.TensorFlowオブジェクトに変換する
const data = tf.tensor2d([lengthVal, widthVal], [1, 2]);
// 訓練済みモデルで推論し、結果を得る
const predictOut = model.predict(data);
// 結果は確率の配列
predictOut.print(); // 例:[[0.1198563, 0.8801436],] -> versicolorの確率が高いと予測している
let axis = 0;
predictOut.argMax(axis).print(); // [0, 0] -> axis=0は、縦向きに読む
axis =1;
predictOut.argMax(axis).print(); // [1] -> axis=1は、横向きに読む
// 予測結果の配列を横向きに読んだ場合での、最大値のインデックス位置を得る
const predictOutMaxIndex = predictOut.argMax(axis);
predictOutMaxIndex.print(); // [1]
// その値
const predictOutMaxIndexData = predictOutMaxIndex.dataSync()[0];
console.log(predictOutMaxIndexData) // 1
const whichClass = IRIS_CLASSES[predictOutMaxIndexData];
console.log(whichClass) // versicolor
// e.target.dataset["value3"]が参照するのは、クリックされた<a>要素のdata-value3属性の値
// それは元々、IRIS_CLASSES配列のインデックス値に対応させてある。
// <a href="javascript:void(0)" data-value1="5.1" data-value2="2.5" data-value3="1">[5.1, 2.5] versicolor</a
const answer = IRIS_CLASSES[Number(e.target.dataset["value3"])]; // satosaかversicolor
let result = '';
// 答え合わせをして、正解不正解に応じたUI操作を行う
if(isCorrect(answer, whichClass)){
result='正解';
e.target.style.backgroundColor='coral'
}else{
result='不正解';
e.target.style.backgroundColor='gray'
}
const message = "このあやめの種類は" + answer + '。機械学習の予測は' + whichClass + '。よって'+ result;
document.getElementById('info').textContent = message;
});
}, false);
// 後始末
xsTrain.dispose();
ysTrain.dispose();
xsValidation.dispose();
ysValidation.dispose();
});
// 与えられたaとbが一致するかどうか
const isCorrect =(a,b)=>{
if(a===b){
return 1;
}else{
return 0;
}
}
}
// ページが読み込まれたら、モデルの訓練まで進む
const start =async()=>{
const [xsTrain, ysTrain, xsValidation, ysValidation] = await getData();
const model = await buildModel();
train(model, xsTrain, ysTrain, xsValidation, ysValidation);
}
start();
}, false);
</script>
</body>
</html>