データが揃ったので、モデルを作成して訓練します。モデルは「boston-housing」サンプルのindex.jsファイルに記述されている3つのものを使用します。
model.jsファイルを作成して、次のコードを記述します。
const LEARNING_RATE = 0.01;
/**
* 線形回帰モデルを構築して返す。
*
* @returns {tf.Sequential} 線形回帰モデル
*/
const linearRegressionModel = () => {
const model = tf.sequential();
model.add(tf.layers.dense({
inputShape: [BOSTONDATA_NUMFEATURES],
units: 1
}));
model.summary();
model.compile({
optimizer: tf.train.sgd(LEARNING_RATE),
loss: 'meanSquaredError'
});
return model;
};
/**
* 多層パーセプトロン回帰モデルを構築して返す。
* 隠れ層を1つ、sigmoid関数によって活性化される10のユニットを持つ。
*
* @returns {tf.Sequential} 多層パーセプトロン回帰モデル
*/
const multiLayerPerceptronRegressionModel1Hidden = () => {
const model = tf.sequential();
model.add(tf.layers.dense({
inputShape: [BOSTONDATA_NUMFEATURES],
units: 50,
activation: 'sigmoid',
kernelInitializer: 'leCunNormal' // 切断正規分布
}));
model.add(tf.layers.dense({
units: 1
}));
model.summary();
model.compile({
optimizer: tf.train.sgd(LEARNING_RATE),
loss: 'meanSquaredError'
});
return model;
};
/**
* 多層パーセプトロン回帰モデルを構築して返す。
* 隠れ層を1つ、sigmoid関数によって活性化される50のユニットを持つ。
*
* @returns {tf.Sequential} 多層パーセプトロン回帰モデル
*/
const multiLayerPerceptronRegressionModel2Hidden = () => {
const model = tf.sequential();
// 1層め
model.add(tf.layers.dense({
inputShape: [BOSTONDATA_NUMFEATURES],
units: 50,
activation: 'sigmoid',
kernelInitializer: 'leCunNormal'
}));
// 2層め
model.add(tf.layers.dense({
units: 50,
activation: 'sigmoid',
kernelInitializer: 'leCunNormal'
}));
model.add(tf.layers.dense({
units: 1
}));
model.summary();
model.compile({
optimizer: tf.train.sgd(LEARNING_RATE),
loss: 'meanSquaredError'
});
return model;
};
BOSTONDATA_NUMFEATURESはdata.jsで宣言した、12を参照する変数です。12は説明変数の数(特徴量)です。
下図は3つのモデルからsummary()メソッドを呼び出した結果のスクリーンショットです。
グラフとUI関係のコードをそれぞれ、graph.js、ui.jsという名前で作成します。
下記のgraph.js で使用しているのは、この「TensofrFlow.jsことはじめ」で最初からずっと使っているplotly.jsという名前のグラフ描画用ライブラリです。
let xData1 = [];
let yData1 = [];
let xData2 = [];
let yData2 = [];
const plot = (a, b, c, d) => {
xData1.push(a);
yData1.push(b);
xData2.push(c);
yData2.push(d);
const trace1 = {
x: xData1,
y: yData1,
type: 'scatter',
name: 'trainLoss'
};
const trace2 = {
x: xData2,
y: yData2,
type: 'scatter',
name: 'valLoss'
};
const layout = {
xaxis: {
title: 'epoch'
},
yaxis: {
title: 'loss'
},
title: '多変数回帰'
};
Plotly.newPlot('chart', [trace1, trace2], layout, {
displayModeBar: false
});
}
ui.jsには次のコードを記述します。
const updateStatus = (message) => {
document.getElementById('status').value = message;
};
const updateBaselineStatus = (message) => {
document.getElementById('baselineStatus').value = message;
};