tf.Sequentialオブジェクトのfit()メソッドを呼び出すと、モデルを訓練することができます。fit()メソッドには、訓練がうまく進んでいるかを見ることのできる方法が提供されています。
訓練の進捗状況を確認するには、fit()メソッドの第2引数に指定するオブジェクトにcallbacksプロパティを指定し、その値となるオブジェクトで、onEpochEndプロパティを指定します。
onEpochEndは訓練ごとの終わりに呼び出されるコールバック関数で、引数としてepochとlogsを受け取ります。epochを調べるとその時点での回数が分かり、logs.lossからはその時点での損失が分かります。この損失が小さくなっていけばいくほど訓練はうまく進んでいることになります。
次の例では、onEpochEnd関数の実行時に、epochとlogs.lossをコンソールに出力しています。またその際には、ブラウザの画面がフリーズしたように見えないように、tf.nextFrame()を呼び出しています。
tf.nextFrame()関数は、内部でwindow.requestAnimationFrame()を用いてブラウザ画面のフレームを次に送り、Promiseを返します。
ここまでのコードで、モデルの訓練の進捗が具体的な数値で分かるようになります。
info('モデル訓練開始');
// モデルは与えられたxsとysから、その関係性を探る。
// 具体的には、y = 2x-1 の係数を探る
await model.fit(xs, ys, {
epochs: 250,
// 1回の訓練が終わるたびに呼び出されるコールバック関数
callbacks: {
onEpochEnd: async (epoch, logs) => {
// 繰り返しの回数と損失値を出力
console.log('繰り返し回数'+ epoch);
console.log('ロス' + logs.loss)
// 画面がフリーズしないように次のフレームに進む
await tf.nextFrame();
}
}})
// tf.Sequential.fit()はPromiseを返すので、訓練が終わったら、then()が呼び出される。
.then(()=>{
info('モデル訓練完了');
// tf.Modelが内部に持っている係数の値を調べる
const a = model.trainableWeights[0].read(); // ax + b のaの値
const b = model.trainableWeights[1].read(); // ax + b のbの値
console.log(a.dataSync());
console.log(b.dataSync());
// tf.tidy()を使うと、指定した関数の実行後、
// 関数内にあるtf.Tensorオブジェクトが占めるGPUメモリを解放できる。
tf.tidy(()=>{
const res = model.predict(tf.tensor2d([20], [1, 1]));
console.log(res.dataSync());
});
// tf.Tensorオブジェクトを破棄する
xs.dispose();
ys.dispose();
});
前の例では、ボタンのクリックで推論を行いましたが、上記コードでは tf.Sequential.fit()がPromiseを返すことを利用して、訓練が終わったら、自動的に推論に進むように変更しています。
Promiseのthen()メソッドに渡している無名関数の記述で覚えておきたいのが、モデルが確定させた推論の要点にたどり着く方法です。model.trainableWeights[0].read()では、ax + b のaに当たる推測値に、model.trainableWeights[1].read()では、ax + bのbに当たる推測値にたどり着くことができます。
今回の訓練では、aが1.96230…、bが-0.88312…と推測しているので、y = 2x -1の2と-1にほぼ一致しているのが分かります。
その後のtf.tidy()関数は、不要になったtf.TensorオブジェクトをGPUメモリから削除します。fit()は自動的にtf.Tensorオブジェクトを削除しますが、predict()はそうしないので、推測値を得たらすぐにtf.Tensorオブジェクトが占めていたメモリを解放するのがベストプラクティスとされています。
tf.memory().numTensorsプロパティを調べると、その時点でのtf.Tensorオブジェクトの個数が分かります。
またtf.Tensorオブジェクトのdispose()メソッドを呼び出すと、そのtf.TensorオブジェクトをGPUメモリから削除できます。