以降では、「TensorFlow.js Example:Reinforcement Learning: Cart Pole」サンプルのコードを参考に、カートポールサンプルのシンプルバージョンを作成していきます。
これは、言い換えると、複雑に絡み合っているように見える「TensorFlow.js Example:Reinforcement Learning: Cart Pole」サンプルのコードを解きほぐし、シンプルな形にまとめていく作業です。
まずは、カートポールを動くようにします。これにより、モデルの訓練やゲーム、ステップなどの回数がどういうものか、目で確認できるようになります。
HTMLは次の簡単なものを使用します。カートポールのアニメーションは、このキャンバス上に描画します。
HTML
<canvas id="cart-pole-canvas" height="150px" width="500px"></canvas>
<div>
<span id="game-info"></span>
<span>:</span>
<span id="step-info"></span>
</div>
CSSでは、キャンバスに枠線を設定しておきます。
#cart-pole-canvas {
border: 1px solid black;
}
JavaScriptではまず、CartPoleクラスを作成します。下記コードは、「tfjs-examples/cart-pole/cart_pole.js」で公開されているものと、基本的に同じです。このJavaScriptはCartPole.jsという名前でHTMLファイルに読み込みます。
/**
* http://incompleteideas.net/book/code/pole.cにもとづく実装
*/
/**
* カートポールシステムシミュレーター
*
* 制御理論の観点から言うと、このシステムには4つの状態変数がある。
*
* - x: カートの1D位置
* - xDot: カートの速度
* - theta: ポールの角度(ラジアン単位)。値0は垂直位置に対応する。
* (垂直に立っている)。
* - thetaDot: ポールの角速度
*
* システムは単一の行動によって制御される。
*
* - 左方か右方への力
*/
class CartPole {
/**
* CartPoleのコンストラクタ
*/
constructor() {
// システムを特徴づける定数
this.gravity = 9.8;
this.massCart = 1.0;
this.massPole = 0.1;
this.totalMass = this.massCart + this.massPole;
this.cartWidth = 0.2;
this.cartHeight = 0.1;
this.length = 0.5;
this.poleMoment = this.massPole * this.length;
this.forceMag = 10.0;
this.tau = 0.02; // 状態を更新する間隔の秒数
// しきい値。これを超えると、シミュレーションは失敗したことになる。
this.xThreshold = 2.4;
this.thetaTheshold = 12 / 360 * 2 * Math.PI;
this.setRandomState();
}
/**
* カートポールシステムの状態をランダムに設定する。
*/
setRandomState() {
// カートポールシステムの制御理論状態変数
// カートの位置。メートル
this.x = Math.random() - 0.5;
// カートの速度
this.xDot = (Math.random() - 0.5) * 1;
// ポールの角度、ラジアン
this.theta = (Math.random() - 0.5) * 2 * (6 / 360 * 2 * Math.PI);
// ポールの角速度
this.thetaDot = (Math.random() - 0.5) * 0.5;
}
/**
* 現在の状態を、shapeが[1, 4]のtf.Tensorとして取得する。
*/
getStateTensor() {
return tf.tensor2d([
[this.x, this.xDot, this.theta, this.thetaDot]
]);
}
/**
* カートポールシステムを行動を使って更新する。
* @param {number} action `action`の符号がすべて
* 0以上の値は、一定の大きさを持つ右方への力になる。
* 0未満の値は、同じ一定の大きさを持つ左方への力になる。
*/
update(action) {
const force = action > 0 ? this.forceMag : -this.forceMag;
const cosTheta = Math.cos(this.theta);
const sinTheta = Math.sin(this.theta);
const temp = (force + this.poleMoment * this.thetaDot * this.thetaDot * sinTheta) / this.totalMass;
const thetaAcc = (this.gravity * sinTheta - cosTheta * temp) / (this.length * (4 / 3 - this.massPole * cosTheta * cosTheta / this.totalMass));
const xAcc = temp - this.poleMoment * thetaAcc * cosTheta / this.totalMass;
// 4つの状態変数を、オイラー法を使って更新する。
this.x += this.tau * this.xDot;
this.xDot += this.tau * xAcc;
this.theta += this.tau * this.thetaDot;
this.thetaDot += this.tau * thetaAcc;
return this.isDone();
}
/**
* このシミュレーションを終えるかどうかを決める。
*
* シミュレーションは、`x`(カートの位置)が範囲を超えるか、
* `theta`(ポールの角度)が範囲を超えるかすると終わり。
*
* @returns {bool} シミュレーションが終わりかどうか
*/
isDone() {
return this.x < -this.xThreshold || this.x > this.xThreshold ||
this.theta < -this.thetaTheshold || this.theta > this.thetaTheshold;
}
}
また、カートポールをキャンバスに描画するJavaScriptコードをrender.jsという名前で保存して、HTMLファイルに読み込みます。これは「tfjs-examples/cart-pole」サンプルのui.jsファイルに書かれているrenderCartPole()関数とほとんど同じです。
/**
* システムの現在の状態をHTMLのキャンバスにレンダリングする。
*
* @param {CartPole} cartPole レンダリングするCartPoleシステムのインスタンス
* @param {HTMLCanvasElement} canvas レンダリングを行うHTMLCanvasElementのインスタンス
*/
function renderCartPole(cartPole) {
const canvas = document.getElementById('cart-pole-canvas');
if (!canvas.style.display) {
canvas.style.display = 'block';
}
const X_MIN = -cartPole.xThreshold;
const X_MAX = cartPole.xThreshold;
const xRange = X_MAX - X_MIN;
const scale = canvas.width / xRange;
const context = canvas.getContext('2d');
context.clearRect(0, 0, canvas.width, canvas.height);
const halfW = canvas.width / 2;
// カートを描画
const railY = canvas.height * 0.8;
const cartW = cartPole.cartWidth * scale;
const cartH = cartPole.cartHeight * scale;
const cartX = cartPole.x * scale + halfW;
context.beginPath();
context.strokeStyle = '#000000';
context.lineWidth = 2;
context.rect(cartX - cartW / 2, railY - cartH / 2, cartW, cartH);
context.stroke();
// カートの下に車輪を描画
const wheelRadius = cartH / 4;
for (const offsetX of[-1, 1]) {
context.beginPath();
context.lineWidth = 2;
context.arc(cartX - cartW / 4 * offsetX, railY + cartH / 2 + wheelRadius,
wheelRadius, 0, 2 * Math.PI);
context.stroke();
}
// ポールを描画
const angle = cartPole.theta + Math.PI / 2;
const poleTopX =
halfW + scale * (cartPole.x + Math.cos(angle) * cartPole.length);
const poleTopY = railY -
scale * (cartPole.cartHeight / 2 + Math.sin(angle) * cartPole.length);
context.beginPath();
context.strokeStyle = '#ffa500';
context.lineWidth = 6;
context.moveTo(cartX, railY - cartH / 2);
context.lineTo(poleTopX, poleTopY);
context.stroke();
// 地面を描画
const groundY = railY + cartH / 2 + wheelRadius * 2;
context.beginPath();
context.strokeStyle = '#000000';
context.lineWidth = 1;
context.moveTo(0, groundY);
context.lineTo(canvas.width, groundY);
context.stroke();
const nDivisions = 40;
for (let i = 0; i < nDivisions; ++i) {
const x0 = canvas.width / nDivisions * i;
const x1 = x0 + canvas.width / nDivisions / 2;
const y0 = groundY + canvas.width / nDivisions / 2;
const y1 = groundY;
context.beginPath();
context.moveTo(x0, y0);
context.lineTo(x1, y1);
context.stroke();
}
// 左右の端を描画
const limitTopY = groundY - canvas.height / 2;
context.beginPath();
context.strokeStyle = '#ff0000';
context.lineWidth = 2;
context.moveTo(1, groundY);
context.lineTo(1, limitTopY);
context.stroke();
context.beginPath();
context.moveTo(canvas.width - 1, groundY);
context.lineTo(canvas.width - 1, limitTopY);
context.stroke();
}
またui.jsファイルに次のmaybeRenderDuringTraining()関数を記述し、HTMLファイルに読み込みます。これも、tfjs-examples/cart-pole」サンプルのui.jsファイルに書かれているものとほとんど同じです。
const maybeRenderDuringTraining = async(cartPole) => {
renderCartPole(cartPole);
await tf.nextFrame(); // UIスレッドをブロックしない。
}
そしてメインのJavaScriptコードには以下を記述します。
const gameInfo = document.getElementById('game-info');
const stepInfo = document.getElementById('step-info');
// 暫定的にカートポールを動かすための関数
// 左方と右方の力に相当する0と1をランダムに生成して返す
const getAction = () => {
let action;
const randomNum = Math.random();
if (randomNum > 0.5) {
action = 1;
}
else {
action = 0;
}
return action;
}
// 1ゲーム当たりのステップ回数
// タイムステップ、時間ステップ:1回当たりに進める時間間隔。
const maxStepsPerGame = 100;
// カートポールがmaxStepsPerGame回アクションを取ると、次のゲームにすすむ。
const numGames = 2;
// CartPoleクラスのインスタンス
const cartPoleSystem = new CartPole(true);
// ゲームをnumGames回、繰り返す。
for (let i = 0; i < numGames; ++i) {
gameInfo.textContent = i + 1 + '回めのゲーム';
// カートポールシステムの状態は、全ゲームの開始時、ランダムに初期化する
cartPoleSystem.setRandomState();
// maxStepsPerGame回、カートポールは自分の状態を調べ、それにもとづいたアクションを取る
// ただしここでは暫定的に、ランダムに決めている。
for (let j = 0; j < maxStepsPerGame; ++j) {
stepInfo.textContent = j + 1 + '回めのステップ';
tf.tidy(() => {
// カートの位置、速度、ポールの角度、角速度を表すtf.Tensor
const inputTensor = cartPoleSystem.getStateTensor();
// 各変数の値を出力する。
// Float32Array(4) [1.1674411296844482, 2.146995782852173, -3.294525146484375, -6.860813617706299]
console.log(inputTensor.dataSync());
});
// 暫定的に、actionは0か1
const action = getAction();
// カートポールのupdate()にactionを渡して4つの状態変数を更新する。
// カートポールのポールが倒れているか、カートが範囲を超えているかどうかを調べる。
const isDone = cartPoleSystem.update(action);
// カートポールを、カートポールのプロパティを使って、#cart-pole-canvas <canvas>に描画
await maybeRenderDuringTraining(cartPoleSystem);
}
}
HTMLファイルをブラウザで開くと、次の動画で示すような、カートポールの短いアニメーションが実行されます。
メインのJavaScriptの内側のforループでは、maxStepsPerGame(100)回、CartPoleクラスのインスタンスであるcartPoleSystemのgetStateTensor()メソッドが呼び出され、カートポールのカートの位置、速度、ポールの角度、角速度を表すtf.Tensorが返されます。これはFloat32Arrayで、たとえば[1.1674411296844482, 2.146995782852173, -3.294525146484375, -6.860813617706299]といった4つの数値を持っています。
本来はgetStateTensor()メソッドが返したinputTensorを処理して、アクションを求めるのですが、ここでは暫定の処置として、ただランダムに1か0を返す関数を使って、変数actionを決めています。そしてこのactionをcartPoleSystemのupdate()メソッドに渡して、カートポールを更新しています。
そして最後、その状態にあるカートポールをmaybeRenderDuringTraining()関数に渡して、HTMLのキャンバスに描画しています。内側のforループはここまでで、これがステップと言われる部分です。
ステップがmaxStepsPerGame回繰り返されると、1ゲーム終了です。今は変数numGamesに2を指定しているので、ゲームは外側のループによって2回繰り返され、終わります。
カートポールは、ランダムに返される1か0によって左右に動きますいたり、ポールが回ったりします。