5_1_3 モデルの訓練

start()関数では、buildModel()の後、train(model, xsTrain, ysTrain, xsValidation, ysValidation);を実行しているので、次のtrain()関数が呼び出されます。xsTrainとysTrainは訓練用の、xsValidationとysValidationは検証用のデータです。

// モデルを訓練する
const train = (model, xsTrain, ysTrain, xsValidation, ysValidation) => {
        const lossValues = [];
        const epochs = param.epochs;
        model.fit(xsTrain, ysTrain, {
                // batchSize値は変更できる。デフォルトは32
                batchSize: 32,
                epochs: epochs,
                validationData: [xsValidation, ysValidation],
                callbacks: {
                    onEpochEnd: async(epoch, logs) => {
                        const lossData = [epoch, logs.loss];
                        // 検証用データの損失値もグラフで描く
                        const val_lossData = [epoch, logs.val_loss];
                        addGraphData(lossData[0], lossData[1], val_lossData[0], val_lossData[1]);
                        await tf.nextFrame();
                    }
                }
            }).then(() => {
                    console.log('モデルの訓練完了');

検証用のデータは、モデル(tf.Sequentialクラスのインスタンス)のfit()メソッドに渡すオブジェクトのvalidationDataプロパティの値として、配列で指定します。

すると、onEpochEnd()コールバック関数で、コールバック関数に渡される引数のlogsのval_lossプロパティで、検証用データに関する損失値にアクセスできるようになります。

上記コードでは、グラフの曲線を描くaddGraphData()関数に、訓練用データの損失値とともに、検証用データの損失値も渡しています。下図はその2つの曲線の例です。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

CAPTCHA