13_3 カートポールを動かす

以降では、「TensorFlow.js Example:Reinforcement Learning: Cart Pole」サンプルのコードを参考に、カートポールサンプルのシンプルバージョンを作成していきます。

これは、言い換えると、複雑に絡み合っているように見える「TensorFlow.js Example:Reinforcement Learning: Cart Pole」サンプルのコードを解きほぐし、シンプルな形にまとめていく作業です。

まずは、カートポールを動くようにします。これにより、モデルの訓練やゲーム、ステップなどの回数がどういうものか、目で確認できるようになります。

HTMLは次の簡単なものを使用します。カートポールのアニメーションは、このキャンバス上に描画します。
HTML

<canvas id="cart-pole-canvas" height="150px" width="500px"></canvas>
<div>
  <span id="game-info"></span>
  <span>:</span>
  <span id="step-info"></span>
</div>

CSSでは、キャンバスに枠線を設定しておきます。

#cart-pole-canvas {
    border: 1px solid black;
}

JavaScriptではまず、CartPoleクラスを作成します。下記コードは、「tfjs-examples/cart-pole/cart_pole.js」で公開されているものと、基本的に同じです。このJavaScriptはCartPole.jsという名前でHTMLファイルに読み込みます。

/**
 * http://incompleteideas.net/book/code/pole.cにもとづく実装
 */

/**
 * カートポールシステムシミュレーター
 *
 * 制御理論の観点から言うと、このシステムには4つの状態変数がある。
 *
 *   - x: カートの1D位置
 *   - xDot: カートの速度
 *   - theta: ポールの角度(ラジアン単位)。値0は垂直位置に対応する。
 *     (垂直に立っている)。
 *   - thetaDot: ポールの角速度
 *
 * システムは単一の行動によって制御される。
 *
 *   - 左方か右方への力
 */
class CartPole {
    /**
     * CartPoleのコンストラクタ
     */
    constructor() {
        // システムを特徴づける定数
        this.gravity = 9.8;
        this.massCart = 1.0;
        this.massPole = 0.1;
        this.totalMass = this.massCart + this.massPole;
        this.cartWidth = 0.2;
        this.cartHeight = 0.1;
        this.length = 0.5;
        this.poleMoment = this.massPole * this.length;
        this.forceMag = 10.0;
        this.tau = 0.02; // 状態を更新する間隔の秒数

        // しきい値。これを超えると、シミュレーションは失敗したことになる。
        this.xThreshold = 2.4;
        this.thetaTheshold = 12 / 360 * 2 * Math.PI;

        this.setRandomState();
    }

    /**
     * カートポールシステムの状態をランダムに設定する。
     */
    setRandomState() {
        // カートポールシステムの制御理論状態変数
        // カートの位置。メートル
        this.x = Math.random() - 0.5;
        // カートの速度
        this.xDot = (Math.random() - 0.5) * 1;
        // ポールの角度、ラジアン
        this.theta = (Math.random() - 0.5) * 2 * (6 / 360 * 2 * Math.PI);
        // ポールの角速度
        this.thetaDot = (Math.random() - 0.5) * 0.5;
    }

    /**
     * 現在の状態を、shapeが[1, 4]のtf.Tensorとして取得する。
     */
    getStateTensor() {
        return tf.tensor2d([
            [this.x, this.xDot, this.theta, this.thetaDot]
        ]);
    }

    /**
     * カートポールシステムを行動を使って更新する。
     * @param {number} action `action`の符号がすべて
     *   0以上の値は、一定の大きさを持つ右方への力になる。
     *    0未満の値は、同じ一定の大きさを持つ左方への力になる。
     */
    update(action) {
        const force = action > 0 ? this.forceMag : -this.forceMag;

        const cosTheta = Math.cos(this.theta);
        const sinTheta = Math.sin(this.theta);

        const temp = (force + this.poleMoment * this.thetaDot * this.thetaDot * sinTheta) / this.totalMass;
        const thetaAcc = (this.gravity * sinTheta - cosTheta * temp) / (this.length * (4 / 3 - this.massPole * cosTheta * cosTheta / this.totalMass));
        const xAcc = temp - this.poleMoment * thetaAcc * cosTheta / this.totalMass;

        // 4つの状態変数を、オイラー法を使って更新する。
        this.x += this.tau * this.xDot;
        this.xDot += this.tau * xAcc;
        this.theta += this.tau * this.thetaDot;
        this.thetaDot += this.tau * thetaAcc;

        return this.isDone();
    }

    /**
     * このシミュレーションを終えるかどうかを決める。
     *
     * シミュレーションは、`x`(カートの位置)が範囲を超えるか、
     * `theta`(ポールの角度)が範囲を超えるかすると終わり。
     *
     * @returns {bool} シミュレーションが終わりかどうか
     */
    isDone() {
        return this.x < -this.xThreshold || this.x > this.xThreshold ||
            this.theta < -this.thetaTheshold || this.theta > this.thetaTheshold;
    }
}

また、カートポールをキャンバスに描画するJavaScriptコードをrender.jsという名前で保存して、HTMLファイルに読み込みます。これは「tfjs-examples/cart-pole」サンプルのui.jsファイルに書かれているrenderCartPole()関数とほとんど同じです。

/**
 * システムの現在の状態をHTMLのキャンバスにレンダリングする。
 *
 * @param {CartPole} cartPole レンダリングするCartPoleシステムのインスタンス
 * @param {HTMLCanvasElement} canvas レンダリングを行うHTMLCanvasElementのインスタンス
 */
function renderCartPole(cartPole) {

    const canvas = document.getElementById('cart-pole-canvas');
    if (!canvas.style.display) {
        canvas.style.display = 'block';
    }
    const X_MIN = -cartPole.xThreshold;
    const X_MAX = cartPole.xThreshold;
    const xRange = X_MAX - X_MIN;
    const scale = canvas.width / xRange;

    const context = canvas.getContext('2d');
    context.clearRect(0, 0, canvas.width, canvas.height);
    const halfW = canvas.width / 2;

    // カートを描画
    const railY = canvas.height * 0.8;
    const cartW = cartPole.cartWidth * scale;
    const cartH = cartPole.cartHeight * scale;

    const cartX = cartPole.x * scale + halfW;

    context.beginPath();
    context.strokeStyle = '#000000';
    context.lineWidth = 2;
    context.rect(cartX - cartW / 2, railY - cartH / 2, cartW, cartH);
    context.stroke();

    // カートの下に車輪を描画
    const wheelRadius = cartH / 4;
    for (const offsetX of[-1, 1]) {
        context.beginPath();
        context.lineWidth = 2;
        context.arc(cartX - cartW / 4 * offsetX, railY + cartH / 2 + wheelRadius,
            wheelRadius, 0, 2 * Math.PI);
        context.stroke();
    }

    // ポールを描画
    const angle = cartPole.theta + Math.PI / 2;
    const poleTopX =
        halfW + scale * (cartPole.x + Math.cos(angle) * cartPole.length);
    const poleTopY = railY -
        scale * (cartPole.cartHeight / 2 + Math.sin(angle) * cartPole.length);
    context.beginPath();
    context.strokeStyle = '#ffa500';
    context.lineWidth = 6;
    context.moveTo(cartX, railY - cartH / 2);
    context.lineTo(poleTopX, poleTopY);
    context.stroke();

    // 地面を描画
    const groundY = railY + cartH / 2 + wheelRadius * 2;
    context.beginPath();
    context.strokeStyle = '#000000';
    context.lineWidth = 1;
    context.moveTo(0, groundY);
    context.lineTo(canvas.width, groundY);
    context.stroke();

    const nDivisions = 40;
    for (let i = 0; i < nDivisions; ++i) {
        const x0 = canvas.width / nDivisions * i;
        const x1 = x0 + canvas.width / nDivisions / 2;
        const y0 = groundY + canvas.width / nDivisions / 2;
        const y1 = groundY;
        context.beginPath();
        context.moveTo(x0, y0);
        context.lineTo(x1, y1);
        context.stroke();
    }

    // 左右の端を描画
    const limitTopY = groundY - canvas.height / 2;
    context.beginPath();
    context.strokeStyle = '#ff0000';
    context.lineWidth = 2;
    context.moveTo(1, groundY);
    context.lineTo(1, limitTopY);
    context.stroke();
    context.beginPath();
    context.moveTo(canvas.width - 1, groundY);
    context.lineTo(canvas.width - 1, limitTopY);
    context.stroke();
}

またui.jsファイルに次のmaybeRenderDuringTraining()関数を記述し、HTMLファイルに読み込みます。これも、tfjs-examples/cart-pole」サンプルのui.jsファイルに書かれているものとほとんど同じです。

const maybeRenderDuringTraining = async(cartPole) => {
    renderCartPole(cartPole);
    await tf.nextFrame(); // UIスレッドをブロックしない。
}

そしてメインのJavaScriptコードには以下を記述します。

const gameInfo = document.getElementById('game-info');
const stepInfo = document.getElementById('step-info');

// 暫定的にカートポールを動かすための関数
// 左方と右方の力に相当する0と1をランダムに生成して返す
const getAction = () => {
        let action;
        const randomNum = Math.random();
        if (randomNum > 0.5) {
            action = 1;
        }
        else {
            action = 0;
        }
        return action;
    }
    // 1ゲーム当たりのステップ回数
    // タイムステップ、時間ステップ:1回当たりに進める時間間隔。
const maxStepsPerGame = 100;

// カートポールがmaxStepsPerGame回アクションを取ると、次のゲームにすすむ。
const numGames = 2;
// CartPoleクラスのインスタンス
const cartPoleSystem = new CartPole(true);

// ゲームをnumGames回、繰り返す。
for (let i = 0; i < numGames; ++i) {
    gameInfo.textContent = i + 1 + '回めのゲーム';
    // カートポールシステムの状態は、全ゲームの開始時、ランダムに初期化する
    cartPoleSystem.setRandomState();

    // maxStepsPerGame回、カートポールは自分の状態を調べ、それにもとづいたアクションを取る
    // ただしここでは暫定的に、ランダムに決めている。
    for (let j = 0; j < maxStepsPerGame; ++j) {
        stepInfo.textContent = j + 1 + '回めのステップ';
        tf.tidy(() => {
            // カートの位置、速度、ポールの角度、角速度を表すtf.Tensor
            const inputTensor = cartPoleSystem.getStateTensor();
            // 各変数の値を出力する。
            // Float32Array(4) [1.1674411296844482, 2.146995782852173, -3.294525146484375, -6.860813617706299]
            console.log(inputTensor.dataSync());
        });
        // 暫定的に、actionは0か1
        const action = getAction();
        // カートポールのupdate()にactionを渡して4つの状態変数を更新する。
        // カートポールのポールが倒れているか、カートが範囲を超えているかどうかを調べる。
        const isDone = cartPoleSystem.update(action);

        // カートポールを、カートポールのプロパティを使って、#cart-pole-canvas <canvas>に描画
        await maybeRenderDuringTraining(cartPoleSystem);
    }
}

HTMLファイルをブラウザで開くと、次の動画で示すような、カートポールの短いアニメーションが実行されます。

メインのJavaScriptの内側のforループでは、maxStepsPerGame(100)回、CartPoleクラスのインスタンスであるcartPoleSystemのgetStateTensor()メソッドが呼び出され、カートポールのカートの位置、速度、ポールの角度、角速度を表すtf.Tensorが返されます。これはFloat32Arrayで、たとえば[1.1674411296844482, 2.146995782852173, -3.294525146484375, -6.860813617706299]といった4つの数値を持っています。

本来はgetStateTensor()メソッドが返したinputTensorを処理して、アクションを求めるのですが、ここでは暫定の処置として、ただランダムに1か0を返す関数を使って、変数actionを決めています。そしてこのactionをcartPoleSystemのupdate()メソッドに渡して、カートポールを更新しています。

そして最後、その状態にあるカートポールをmaybeRenderDuringTraining()関数に渡して、HTMLのキャンバスに描画しています。内側のforループはここまでで、これがステップと言われる部分です。

ステップがmaxStepsPerGame回繰り返されると、1ゲーム終了です。今は変数numGamesに2を指定しているので、ゲームは外側のループによって2回繰り返され、終わります。

カートポールは、ランダムに返される1か0によって左右に動きますいたり、ポールが回ったりします。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

CAPTCHA