最後に、メインのJavaScriptのforループの外側にさらにループを追加し、訓練回数だけ繰り返すようにします。また学習の進み具合の分かるグラフや、訓練のスピード値を追加します。
このリンクをクリックすると、最終のカートポールサンプルの動作を見ることができます。
HTML
<div class="container">
<div>
<canvas id="cart-pole-canvas" height="150px" width="500px"></canvas>
<div>
<!-- 追加 -->
<span id="iteration-info"></span>
<span>:</span>
<span id="game-info"></span>
<span>:</span>
<span id="step-info"></span>
</div>
<div>
<span class="status-span">訓練のスピード:</span>
<span id="train-speed" class="status-span"></span>
</div>
</div>
<div id="chart" class="graph-area"></div>
</div>
メインのJavaScript
document.addEventListener('DOMContentLoaded', async() => {
const gameInfo = document.getElementById('game-info');
const stepInfo = document.getElementById('step-info');
// 追加
const iterationInfo = document.getElementById('iteration-info');
let currentActions_;
const buildModel = () => {
const model = tf.sequential();
model.add(tf.layers.dense({
units: 4,
activation: 'elu',
inputShape: [4]
}));
model.add(tf.layers.dense({
units: 1
}));
model.summary();
return model;
}
const getLogitsAndActions = (inputs) => {
return tf.tidy(() => {
const logits = model.predict(inputs);
const leftProb = tf.sigmoid(logits);
const rightProb = tf.sub(1, leftProb)
const leftRightProbs = tf.concat([leftProb, rightProb], 1);
const actions = tf.multinomial(leftRightProbs, 1, null, true);
return [logits, actions];
});
}
const getGradientsAndSaveActions = (inputTensor) => {
const f = () => tf.tidy(() => {
const [logits, actions] = getLogitsAndActions(inputTensor);
currentActions_ = actions.dataSync();
const labels = tf.sub(1, tf.tensor2d(currentActions_, actions.shape, 'float32'));
return tf.losses.sigmoidCrossEntropy(labels, logits).asScalar();
});
return tf.variableGrads(f);
}
const pushGradients = (record, gradients) => {
for (const key in gradients) {
if (key in record) {
record[key].push(gradients[key]);
}
else {
record[key] = [gradients[key]];
}
}
}
const maxStepsPerGame = 500;
const numGames = 20;
const cartPoleSystem = new CartPole(true);
const model = buildModel();
const learningRate = 0.05;
const optimizer = tf.train.adam(learningRate);
const discountRate = 0.95;
// 訓練を繰り返す回数
const trainIterations = 20;
// 開始時刻
let t0 = new Date().getTime();
// 総学習回数
let xstep = 0;
for (let l = 0; l < trainIterations; ++l) {
//
const allGradients = [];
const allRewards = [];
const gameSteps = [];
//
let maxStemNum = 0;
iterationInfo.textContent = 1 + l + '回めの繰り返し';
for (let i = 0; i < numGames; ++i) {
gameInfo.textContent = i + 1 + '回めのゲーム';
cartPoleSystem.setRandomState();
const gameRewards = [];
const gameGradients = [];
for (let j = 0; j < maxStepsPerGame; ++j) {
stepInfo.textContent = j + 1 + '回めのステップ';
const gradients = tf.tidy(() => {
const inputTensor = cartPoleSystem.getStateTensor();
const {
value, grads
} = getGradientsAndSaveActions(inputTensor);
return grads;
});
pushGradients(gameGradients, gradients);
const action = currentActions_[0];
const isDone = cartPoleSystem.update(action);
await maybeRenderDuringTraining(cartPoleSystem);
if (isDone) {
gameRewards.push(0);
break;
}
else {
gameRewards.push(1);
}
//
xstep++;
maxStemNum = j;
} // 内側のループの終わり
//
plot(xstep, maxStemNum);
gameSteps.push(gameRewards.length);
pushGradients(allGradients, gameGradients);
allRewards.push(gameRewards);
tf.tidy(() => {
const normalizedRewards = discountAndNormalizeRewards(allRewards, discountRate);
const gradients = scaleAndAverageGradients(allGradients, normalizedRewards);
optimizer.applyGradients(gradients);
});
} // 外側のループの終わり
tf.dispose(allGradients);
// [訓練のスピード]の表示
const t1 = new Date().getTime();
const stepsPerSecond = sum(gameSteps) / ((t1 - t0) / 1e3);
t0 = t1;
const trainSpeed = document.getElementById('train-speed').textContent = `${stepsPerSecond.toFixed(1)} steps/s`
await tf.nextFrame();
} // 一番外のループの終わり
}, false);
ユーティリティー関数
/**
* 数値の配列の平均を計算する
*
* @param {number[]} xs
* @returns {number} `xs`の算術平均
*/
const mean = (xs) => {
return sum(xs) / xs.length;
}
/**
* 数値の配列の合計を計算する
*
* @param {number[]} xs
* @returns {number} `xs`の合計
* @throws `xs`が空ならエラー
*/
const sum = (xs) => {
if (xs.length === 0) {
throw new Error('xsは空でない配列であることが期待される。');
}
else {
return xs.reduce((x, prev) => prev + x);
}
}