start()関数では、buildModel()の後、train(model, xsTrain, ysTrain, xsValidation, ysValidation);を実行しているので、次のtrain()関数が呼び出されます。xsTrainとysTrainは訓練用の、xsValidationとysValidationは検証用のデータです。
// モデルを訓練する
const train = (model, xsTrain, ysTrain, xsValidation, ysValidation) => {
const lossValues = [];
const epochs = param.epochs;
model.fit(xsTrain, ysTrain, {
// batchSize値は変更できる。デフォルトは32
batchSize: 32,
epochs: epochs,
validationData: [xsValidation, ysValidation],
callbacks: {
onEpochEnd: async(epoch, logs) => {
const lossData = [epoch, logs.loss];
// 検証用データの損失値もグラフで描く
const val_lossData = [epoch, logs.val_loss];
addGraphData(lossData[0], lossData[1], val_lossData[0], val_lossData[1]);
await tf.nextFrame();
}
}
}).then(() => {
console.log('モデルの訓練完了');
検証用のデータは、モデル(tf.Sequentialクラスのインスタンス)のfit()メソッドに渡すオブジェクトのvalidationDataプロパティの値として、配列で指定します。
すると、onEpochEnd()コールバック関数で、コールバック関数に渡される引数のlogsのval_lossプロパティで、検証用データに関する損失値にアクセスできるようになります。
上記コードでは、グラフの曲線を描くaddGraphData()関数に、訓練用データの損失値とともに、検証用データの損失値も渡しています。下図はその2つの曲線の例です。