次のコードは、ブラウザで開くと自動的にTensorFlow.jsのモデルが作成され学習を始めます。学習が終わり、ボタンをクリックすると、テキストフィールドに入力された数値(20)に対応する数値をデベロッパーツールのコンソールに出力します。
TensorFlow.jsのモデルには、[-1, 0, 1, 2, 3, 4]と[-3, -1, 1, 3, 5, 7]という配列がデータとして与えられます。前述したとおり、その関係性はy = 2x – 1です。テキストフィールドの数値は20なので、yは2 x 20 -1 = 39です。TensorFlow.jsは今の場合、38.4103889465332と推測しているので、概ね合っていることが分かります。
<!doctype html>
<html lang="ja">
<head>
<meta charset="utf-8">
<title>START</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<style>
</style>
</head>
<body>
<h3>START</h3>
<p>y = 2x - 1という分かりやすい関係性を例に、TensorFlow.jsでの作業の流れをつかむ</p>
<input type="button" id="predict-button" value="推論">
<input type="number" id="number-input" value="20">
<script>
document.addEventListener('DOMContentLoaded',async()=>{
// モデルを非同期で作成する
const buildModel = async()=>{
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
model.compile({ loss: 'meanSquaredError', optimizer: 'sgd' });
return model;
}
info('データ取得開始');
// getData()が値を返すまで次に進まない
// 訓練用データと教師用データ用の変数
const [xs, ys] = await getData();
info('データ取得完了');
info('モデル構築開始');
// buildModel()が値を返すまで次に進まない
const model = await buildModel();
info('モデル構築完了');
info('モデル訓練開始');
// モデルは与えられたxsとysから、その関係性を探る。
// 具体的には、y = 2x-1 の係数を探る
await model.fit(xs, ys, {epochs: 250});
info('モデル訓練完了');
// 訓練が終わると、訓練済みモデルを使った推論が行える
document.getElementById('predict-button').addEventListener('click', ()=>{
const inputNumber = document.getElementById('number-input').value;
// 数値入力フィールドに入力された数値から結果を推論する
const res = model.predict(tf.tensor2d([inputNumber], [1, 1]));
console.log(res.dataSync());
}, false);
},false);
// データを取得する。データは多くの場合外部にある。
const getData= async()=>{
// 訓練用データ
const trainData = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
// 教師用データ
const labelData = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
return [trainData, labelData];
}
const info =(msg)=>{
console.log(msg);
}
</script>
</body>
</html>
以降では、ここに記述したコードを順に見ていきます。