最後に、手書き数字認識アプリの全コードを紹介しておきます。ただしUIを考慮していないバージョンです。
まず画面左上の[ファイル選択ボタン]でtrain-images.idx3-ubyteとtrain-labels.idx1-ubyteを読み込み、モデルの訓練を開始します。
訓練が終わると、画面左の大きなキャンバスに文字が描けるようになります。マウスで数字を描くと、マウスアップのタイミングでモデルが数字の画像を読み取って、それがいくつなのかを推測します。結果は右の棒グラフで確率として示します。
HTMLファイル
<!doctype html>
<html lang="ja">
<head>
<meta charset="utf-8">
<title>手書き数字認識</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
<script src="js/MNISTData.js"></script>
<script src="js/model.js"></script>
<script src="js/graph.js"></script>
<script src="js/draw.js"></script>
<style>
.container {
display: flex;
}
canvas {
background-color: #FFF;
border: 1px solid black;
}
.drawing {
width: 400px;
}
#graph-area {
width: 600px;
height: 400px;
border: 1px solid black;
}
.drawing {
width: 400px;
}
input {
width: 100px;
margin: 10px;
margin-left: 150px;
}
#pred-div {
width: 27px;
height: 27px;
border: 1px solid black;
text-align: center;
margin: 1.5px;
}
</style>
</head>
<body>
<h3>手書き数字認識アプリ</h3>
<h5>左のキャンバスに描いた0から9までの数字1文字が何かをリアルタイムに認識する</h5>
<input type="file" id="input-file" multiple>
<div class="container">
<div class="drawing">
<canvas id="canvas" width="256" height="256"></canvas>
<canvas id="small-canvas" width="28" height="28"></canvas>
<div class="controller">
<input type="button" id="erase-button" value="Erase">
</div>
</div>
<div id="pred-div"></div>
<div id="graph-area"></div>
</div>
<script>
document.addEventListener('DOMContentLoaded', async() => {
document.getElementById('input-file').addEventListener('change', async(e) => {
let testImages, testLabels, trainImages, trainLabels;
let trainXS, trainYS;
let model;
// 読み込むバイナリファイルの名前。順番を固定する
// 訓練用データのみ
const names = ['train-images.idx3-ubyte', 'train-labels.idx1-ubyte'];
const files = e.target.files;
console.log("データの読み込み開始");
for (let i = 0; i < files.length; i++) {
const fileName = files[i].name;
if (fileName === names[0]) {
trainImages = new MNISTData(fileName);
await trainImages.loadData(e, i);
}
else if (fileName === names[1]) {
trainLabels = new MNISTData(fileName);
await trainLabels.loadData(e, i);
}
}
console.log("データの読み込み終了");
console.log('tf.Tensorオブジェクトに変換開始');
// 訓練用の画像tf.Tensorとラベルtf.Tensorオブジェクトを、適切なMNISTDataオブジェクトを使って作成する。
trainXS = tf.tensor4d(trainImages.getData(), [60000, 28, 28, 1]);
trainYS = tf.oneHot(tf.tensor1d(trainLabels.getData(), 'int32'), 10);
console.log('tf.Tensorオブジェクトに変換終了');
model = buildModel();
const train = async() => {
console.log('モデルの訓練開始');
// バッチサイズも重要なハイパーパラメータの1つで、訓練中、モデルの重みを更新するときに与えるひとまとめのサンプル数、
// つまりバッチを定義する。値が低すぎると少ないサンプルで重みを更新するので、うまく一般化できない。
// バッチサイズが大きいとそれだけ大きなメモリリソースが必要になるので、良好なパフォーマンスが約束されない。
//
const batchSize = 64;
// 訓練時の過学習の監視のため、訓練用データの最後15%を検証用に残す。
const validationSplit = 0.15;
// 訓練全体の繰り返し回数
const trainEpochs = 1;
// 訓練全体を通してのバッチ処理回数を保持する変数
let trainBatchCount = 0;
// 合計のバッチ数 = 60000 * (0.85 / 64) * 3 = 2390
const totalNumBatches = Math.ceil(trainXS.shape[0] * (1 - validationSplit) / batchSize) * trainEpochs;
// 長い時間のかかる、モデルを訓練するfit()の呼び出し中にコールバックを設定しているので、
// 訓練の進捗に合わせて、ページに損失値と精度値をプロットできる。
await model.fit(trainXS, trainYS, {
batchSize,
validationSplit,
epochs: trainEpochs,
callbacks: {
// 1バッチ処理後に呼び出される
// batchSize=64の場合には、64サンプルを処理した後呼び出される。
// 今の場合なら、796回(51,000/64)呼び出される
onBatchEnd: async(batch, logs) => {
trainBatchCount++;
console.log(`訓練中... (` + `${(trainBatchCount / totalNumBatches * 100).toFixed(1)}%` + ` 完了)`);
await tf.nextFrame();
}
}
});
trainXS.dispose();
trainYS.dispose();
}
await train();
console.log('訓練終了');
// 棒グラフを描画する
barChart([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
console.log('グラフ描画');
// リスナーを設定する
setListeners(model);
console.log('イベントリスナー設定');
// キャンバス描画に必要な設定を行う
setDraw()
}, false);
}, false);
</script>
<script src="js/listeners.js"></script>
</body>
</html>
draw.js
// キャンバスにマウスで文字を描く
const setDraw = () => {
// マウス操作で線を描画するための基本的なコード
const canvas = document.getElementById("canvas");
const context = canvas.getContext("2d");
// 描く線を太くする
context.lineWidth = 15;
// 線の色は白
context.strokeStyle = 'white';
// キャンバスの背景は黒
context.fillRect(0, 0, canvas.width, canvas.height);
// 描画開始位置をメモする変数
let startX = 0,
startY = 0;
// 描画終了位置をメモする変数
let endX = 0,
endY = 0;
// ブラウザの座標をキャンバスの座標に変換
const winXYtoCanvasXY = (x, y) => {
const box = canvas.getBoundingClientRect();
const tempX = (x - box.left);
const tempY = (y - box.top);
return {
x: tempX,
y: tempY
};
}
// mousedown時に、描画の開始位置を決める
const setStartXY = (e) => {
const loc = winXYtoCanvasXY(e.clientX, e.clientY);
startX = loc.x;
startY = loc.y;
context.beginPath();
context.moveTo(startX, startY);
// mousemove発生の監視を開始
canvas.addEventListener("mousemove", setEndXY);
// mouseup発生の監視を開始
canvas.addEventListener("mouseup", stopDrawFreeLine);
// mousedown発生の監視を解除
canvas.removeEventListener("mousedown", setStartXY);
}
// mousemove時に、描画の終了位置を決める
const setEndXY = (event) => {
const loc = winXYtoCanvasXY(event.clientX, event.clientY);
endX = loc.x;
endY = loc.y;
context.lineTo(endX, endY);
context.stroke();
}
const stopDrawFreeLine = () => {
//console.log("mouseup");
// 後始末
canvas.removeEventListener("mousemove", setEndXY);
canvas.removeEventListener("mouseup", stopDrawFreeLine);
// 再スタート
canvas.addEventListener("mousedown", setStartXY);
}
// mousedown発生の監視を開始
canvas.addEventListener("mousedown", setStartXY);
console.log('setDraw');
}
graph.js
// 手書き数字認識アプリで呼び出す。
// 0 -> 9の棒グラフを描く
const barChart = (barArray) => {
const data = [
{
x: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
y: [barArray[0], barArray[1], barArray[2], barArray[3], barArray[4], barArray[5], barArray[6], barArray[7], barArray[8], barArray[9]],
type: 'bar'
}
];
const layout = {
width: 600,
height: 400,
autorange: false,
nticks: 10,
xaxis: {
range: [-2, 12],
title: '0から9までの数字'
},
yaxis: {
range: [0, 1],
title: '%'
},
title: '手書き数字がいくつかの推測確率'
};
Plotly.newPlot('graph-area', data, layout);
}
listeners.js
const setListeners = (model) => {
console.log('イベントリスナー設定');
// 大きなキャンバスのマウスアップで呼び出される
const imageProcess = () => {
// 数字を描く大きなキャンバス 256 x 256
const canvas = document.getElementById("canvas");
// CNNに流すため、大きなキャンバスに描いた画像の縮小に使用する小さなキャンバス
// 28 x 28
const smallCanvas = document.getElementById("small-canvas");
const smallContext = smallCanvas.getContext("2d");
const sw = smallCanvas.width;
const sh = smallCanvas.height;
// 大きなキャンバスに描いた画像のイメージ化に使用する
const image = new Image();
// キャンバスに画像が表示されたら
image.onload = () => {
// 大きなキャンバスの画像を、小さなキャンバスに縮小して描く
smallContext.clearRect(0, 0, sw, sh);
smallContext.drawImage(image, 0, 0, sw, sh);
// そこからイメージデータを得る
const imageData = smallContext.getImageData(0, 0, sw, sh);
// イメージデータをtf.Tensorオブジェクトに変換する
const tfImage = tf.fromPixels(imageData, 1);
// バイリニア補間法で画像サイズを28x28に変更し、データ型をfloat32に変える。
const smallTfImage = tf.image.resizeBilinear(tfImage, [28, 28]).cast('float32');
//console.log(smallTfImage.shape); // [28, 28, 1]
// predict()に渡すため、shapeを[1, 28, 28, 1]に変更し、255で割る
const smallTfImageReshaped = smallTfImage.expandDims(0).div(tf.scalar(255));
//console.log(smallTfImage.expandDims(0).shape); // [1, 28, 28, 1]
// モデルの予測確率
const output = model.predict(smallTfImageReshaped);
// 予測確率を元に棒グラフを描画する。
barChart(Array.from(output.dataSync()));
// 予測確率から最大値を得て、それを確定結果とする
const axis = 1;
const predictions = Array.from(output.argMax(axis).dataSync());
const prediction = predictions[0];
// モデルが予測した数字をグラフの左に表示する。
document.getElementById('pred-div').innerText = `${prediction}`;
tfImage.dispose();
smallTfImage.dispose();
smallTfImageReshaped.dispose();
}
// 大きなキャンバスに描いた画像をImageオブジェクトに割り当てる、
image.src = canvas.toDataURL();
}
// キャンバス上でのマウスアップイベントで、imageProcessを呼び出す。
document.getElementById("canvas").addEventListener("mouseup", imageProcess);
// [Erase]ボタンのクリックでキャンバスをクリアする
document.getElementById('erase-button').addEventListener('click', () => {
const canvas = document.getElementById("canvas");
const context = canvas.getContext("2d");
context.clearRect(0, 0, canvas.width, canvas.height);
context.lineWidth = 10;
context.strokeStyle = 'white';
context.fillRect(0, 0, canvas.width, canvas.height)
}, false);
}
MNISTData.jsとmodel.jsは前のもと変わりません。