機械学習で扱うのは、目に見える色の変化やオブジェクトの変形でなく、小さな数値の微細な変化なので、扱うデータや、訓練の進捗状況を視覚化することが重要です。
機械学習の環境が進んでいるPythonには、多くの機械学習向けライブラリが存在し、データの取得1つをとっても実にスマートに簡単にできますが、TensorFlow.jsでは残念ながら、現時点ではそこまで環境整備が進んでおらず、視覚化は自分で行う必要があります。
グラフの描画には、たとえば<canvas>の使用が思いつきますが、もう少し便利なJavaScriptライブラリも多数存在します。本記事では、plotly.jsというライブラリを使用します。
plotly.jsは次のようにして読み込みます。
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
次の例では、左に訓練用と教師データのグラフを、右に訓練の進捗状況のグラフを表示しています。
左のグラフでは、横軸に訓練用(xs)、縦軸に教師用(ys)をとっています。これらの値は元々、y = 2x -1の関係性に一致するので、プロットした青丸はy = 2x -1の直線上に並びます。
右のグラフでは、横軸にepochs値を、縦軸に損失値をとっています。損失値は訓練の開始時には大きいですが、訓練が進むにつれ急激に小さくなっていることが分かります。
これは、学習がうまく進んでいることを示しています。
訓練が終わると、左のグラフ上で、xs値に対する推測値をオレンジの線で結んでいます。学習がうまく進んでいるので、オレンジの線はほぼ青い丸の上に重なります。
そして最後、左のグラフ上に、未知の20に対する推測値を大きな緑色の丸でプロットしています。
以下は全コードです。
<!doctype html>
<html lang="ja">
<head>
<meta charset="utf-8">
<title>START</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
<style>
.graph-area{
margin: 10px;
width: 500px;
height:500px;
border: 1px solid black;
}
.container{
display: flex;
}
</style>
</head>
<body>
<h3>START</h3>
<p>データと損失値の視覚化</p>
<div class="container">
<div id="chart1" class="graph-area"></div>
<div id="chart2" class="graph-area"></div>
</div>
<script>
document.addEventListener('DOMContentLoaded',async()=>{
// モデルを非同期で作成する
const buildModel = async()=>{
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
model.compile({ loss: 'meanSquaredError', optimizer: 'sgd' });
return model;
}
info('データ取得開始');
// getData()が値を返すまで次に進まない
// 訓練用データと教師用データ用の変数
const [xs, ys] = await getData();
info('データ取得完了');
// データをグラフに丸でプロット
addPoint(xs.dataSync(), ys.dataSync(), 8, true);
info('モデル構築開始');
// buildModel()が値を返すまで次に進まない
const model = await buildModel();
info('モデル構築完了');
info('モデル訓練開始');
// モデルは与えられたxsとysから、その関係性を探る。
// 具体的には、y = 2x-1 の係数を探る
await model.fit(xs, ys, {
epochs: 250,
// 1回の訓練が終わるたびに呼び出されるコールバック関数
callbacks: {
onEpochEnd: async (epoch, logs) => {
// 繰り返しの回数を横軸に、損失値縦軸にグラフを描く
addGraphData(epoch, logs.loss)
// 画面がフリーズしないように次のフレームに進む
await tf.nextFrame();
}
}})
// tf.Model.fit()はPromiseを返すので、訓練が終わったら、then()が呼び出される。
.then(()=>{
info('モデル訓練完了');
// tf.Modelが内部に持っている係数の値を調べる
const a = model.trainableWeights[0].read(); // ax + b のaの値
const b = model.trainableWeights[1].read(); // ax + b のbの値
console.log(a.dataSync());
console.log(b.dataSync());
// tf.tidy()を使うと、指定した関数の実行後、
// 関数内にあるtf.Tensorオブジェクトが占めるGPUメモリを解放できる。
tf.tidy(()=>{
// 訓練済みのモデルにxsを渡して、各値の推論値を持つtf.Tensorオブジェクト得る
const predictOut = model.predict(xs);
// それらの実際の値を得る
const predictData = predictOut.dataSync();
// 推論結果から線を描く
addLine(xs.dataSync(),predictData);
// xsが[20]の場合の推論結果をグラフに丸で描く
const res = model.predict(tf.tensor2d([20], [1, 1]));
addPoint([20], res.dataSync(), 12, false);
});
// tf.Tensorオブジェクトを破棄する
xs.dispose();
ys.dispose();
});
},false);
// データを取得する。データは多くの場合外部にある。
const getData= async()=>{
// 訓練用データ
const trainData = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
// 教師用データ
const labelData = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
return [trainData, labelData];
}
const info =(msg)=>{
console.log(msg);
}
// 訓練の回数と損失値の変化をグラフで描く
let xData1 = [];
let yData1 = [];
const trace = {
x: xData1,
y: yData1,
type: 'scatter'
};
const data = [trace];
const layout = {
xaxis: {
title: 'epochs'
},
yaxis: {
title: 'Loss'
},
title: '減少していく損失値の表示'
};
const addGraphData = (a, b) => {
xData1.push(a);
yData1.push(b);
Plotly.newPlot('chart2', data, layout,{ displayModeBar: false });
}
// グラフに線を追加する
const addLine = (xData, yData) => {
const trace = {
x: xData,
y: yData,
mode: 'lines',
type: 'scatter'
};
Plotly.addTraces('chart1', [trace]);
}
// グラフに丸を描く
const addPoint = (xData, yData, markerSize, isnewPlot) => {
const trace = {
x: xData,
y: yData,
mode: 'markers',
type: 'scatter',
marker: { size: markerSize }
};
const layout = {
xaxis: {
range: [-5, 25],
title: 'xsデータ'
},
yaxis: {
range: [-5, 50],
title: 'ysデータ'
},
title: 'データと推論結果の表示',
showlegend: false
}
if(isnewPlot){
// 初めてグラフを描く
Plotly.newPlot('chart1', [trace],layout,{ displayModeBar: false });
}else{
// 描いたグラフにデータを追加する
Plotly.addTraces('chart1', [trace]);
}
}
</script>
</body>
</html>