13_7:グラフを追加

最後に、メインのJavaScriptのforループの外側にさらにループを追加し、訓練回数だけ繰り返すようにします。また学習の進み具合の分かるグラフや、訓練のスピード値を追加します。

このリンクをクリックすると、最終のカートポールサンプルの動作を見ることができます。

HTML

<div class="container">
  <div>
    <canvas id="cart-pole-canvas" height="150px" width="500px"></canvas>
    <div>
      <!-- 追加 -->
      <span id="iteration-info"></span>
      <span>:</span>
      <span id="game-info"></span>
      <span>:</span>
      <span id="step-info"></span>
    </div>
    <div>
      <span class="status-span">訓練のスピード:</span>
      <span id="train-speed" class="status-span"></span>
    </div>
  </div>
  <div id="chart" class="graph-area"></div>
</div>

メインのJavaScript

document.addEventListener('DOMContentLoaded', async() => {

    const gameInfo = document.getElementById('game-info');
    const stepInfo = document.getElementById('step-info');

    // 追加
    const iterationInfo = document.getElementById('iteration-info');

    let currentActions_;

    const buildModel = () => {
        const model = tf.sequential();
        model.add(tf.layers.dense({
            units: 4,
            activation: 'elu',
            inputShape: [4]
        }));
        model.add(tf.layers.dense({
            units: 1
        }));
        model.summary();
        return model;
    }

    const getLogitsAndActions = (inputs) => {
        return tf.tidy(() => {
            const logits = model.predict(inputs);
            const leftProb = tf.sigmoid(logits);
            const rightProb = tf.sub(1, leftProb)
            const leftRightProbs = tf.concat([leftProb, rightProb], 1);
            const actions = tf.multinomial(leftRightProbs, 1, null, true);
            return [logits, actions];
        });
    }

    const getGradientsAndSaveActions = (inputTensor) => {
        const f = () => tf.tidy(() => {
            const [logits, actions] = getLogitsAndActions(inputTensor);
            currentActions_ = actions.dataSync();
            const labels = tf.sub(1, tf.tensor2d(currentActions_, actions.shape, 'float32'));

            return tf.losses.sigmoidCrossEntropy(labels, logits).asScalar();

        });
        return tf.variableGrads(f);
    }

    const pushGradients = (record, gradients) => {
        for (const key in gradients) {
            if (key in record) {
                record[key].push(gradients[key]);
            }
            else {
                record[key] = [gradients[key]];
            }
        }
    }

    const maxStepsPerGame = 500;
    const numGames = 20;

    const cartPoleSystem = new CartPole(true);
    const model = buildModel();

    const learningRate = 0.05;
    const optimizer = tf.train.adam(learningRate);
    const discountRate = 0.95;

    // 訓練を繰り返す回数
    const trainIterations = 20;
    // 開始時刻
    let t0 = new Date().getTime();
    // 総学習回数
    let xstep = 0;

    for (let l = 0; l < trainIterations; ++l) {
        //
        const allGradients = [];
        const allRewards = [];
        const gameSteps = [];
        //
        let maxStemNum = 0;

        iterationInfo.textContent = 1 + l + '回めの繰り返し';

        for (let i = 0; i < numGames; ++i) {

            gameInfo.textContent = i + 1 + '回めのゲーム';
            cartPoleSystem.setRandomState();
            const gameRewards = [];
            const gameGradients = [];

            for (let j = 0; j < maxStepsPerGame; ++j) {
                stepInfo.textContent = j + 1 + '回めのステップ';

                const gradients = tf.tidy(() => {
                    const inputTensor = cartPoleSystem.getStateTensor();
                    const {
                        value, grads
                    } = getGradientsAndSaveActions(inputTensor);
                    return grads;
                });

                pushGradients(gameGradients, gradients);
                const action = currentActions_[0];
                const isDone = cartPoleSystem.update(action);
                await maybeRenderDuringTraining(cartPoleSystem);

                if (isDone) {
                    gameRewards.push(0);
                    break;
                }
                else {
                    gameRewards.push(1);
                }

                //
                xstep++;
                maxStemNum = j;

            } // 内側のループの終わり

            //
            plot(xstep, maxStemNum);

            gameSteps.push(gameRewards.length);
            pushGradients(allGradients, gameGradients);
            allRewards.push(gameRewards);

            tf.tidy(() => {
                const normalizedRewards = discountAndNormalizeRewards(allRewards, discountRate);
                const gradients = scaleAndAverageGradients(allGradients, normalizedRewards);
                optimizer.applyGradients(gradients);
            });
        } // 外側のループの終わり

        tf.dispose(allGradients);
        // [訓練のスピード]の表示
        const t1 = new Date().getTime();
        const stepsPerSecond = sum(gameSteps) / ((t1 - t0) / 1e3);
        t0 = t1;
        const trainSpeed = document.getElementById('train-speed').textContent = `${stepsPerSecond.toFixed(1)} steps/s`
        await tf.nextFrame();
    } // 一番外のループの終わり
}, false);

ユーティリティー関数

/**
 * 数値の配列の平均を計算する
 *
 * @param {number[]} xs
 * @returns {number} `xs`の算術平均
 */
const mean = (xs) => {
    return sum(xs) / xs.length;
}

/**
 * 数値の配列の合計を計算する
 *
 * @param {number[]} xs
 * @returns {number} `xs`の合計
 * @throws `xs`が空ならエラー
 */
const sum = (xs) => {
    if (xs.length === 0) {
        throw new Error('xsは空でない配列であることが期待される。');
    }
    else {
        return xs.reduce((x, prev) => prev + x);
    }
}

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

CAPTCHA