MNISTの画像とラベルデータの読み取りと取得ができたので、次はMNISTの4つのファイルを開いて必要なデータを取得し、それらをtf.Tensorオブジェクトに変換するまでを見ていきます。
4つのファイルの読み取りは1つずつ行う方法もありますが、input要素にmultiple属性を付けると、読み取りをいっしょに開始することができます。
<input type="file" id="input-file" multiple>
またファイルの読み取りは、画像とラベルの違いはあれ、似たような処理を行うので、その機能を持ったクラスを定義し、そのインスタンスそれぞれに画像かラベルのデータを持たせる方法が考えられます。
以下はそのためのMNISTDataクラスのコードです。
class MNISTData {
constructor(name) {
this.fileName = name;
this.data = null;
}
getFileName() {
return this.fileName;
}
setFileName(name) {
this.fileName = name;
}
getData() {
return this.data;
}
setData(data) {
this.data = data;
}
async loadData(e, index) {
const files = e.target.files;
await new Promise((resolve, reject) => {
const reader = new FileReader();
// ファイルの読み取りに成功したら解決する
reader.onload = (e) => {
resolve(e.target.result);
}
// MNISTファイルをArrayBufferとして読み取る
reader.readAsArrayBuffer(files[index]);
}).then((buffer) => {
// ファイル名によってラベルデータか画像データかを判別し、処理を分ける
if (this.getFileName().indexOf('labels') != -1) {
console.log('これはラベルデータ');
const dataView = new DataView(buffer);
const itemNum = dataView.getInt32(4);
// ラベルデータを入れる配列
const labelArray = [];
for (let i = 0; i < itemNum; i++) {
labelArray.push(dataView.getUint8(8 + i));
}
// 自分のデータに設定
this.setData(labelArray);
} else {
console.log('これは画像データ')
const dataView = new DataView(buffer);
// 含まれる画像の数
const itemNum = dataView.getInt32(4);
//画像の高さ
const rows = dataView.getInt32(8);
// 画像の幅
const columns = dataView.getInt32(12);
// 画像データを入れる配列
const dataArray = [];
let offset = 0;
// 含まれる画像数分だけ繰り返す
for (let i = 0; i < itemNum; i++) {
dataArray.push([]);
for (let j = 0; j < rows; j++) {
for (let k = 0; k < columns; k++) {
// dataArray配列にdataViewからの画像データを追加
// 画像データは0016から始まるので、その分を足しておく
dataArray[i].push(dataView.getUint8(16 + offset) / 255);
offset++;
}
}
}
// 自分のデータに設定。tf.Tensorオブジェクトに変換しやすいように、flat()で平滑化しておく
this.setData(dataArray.flat());
}
});
}
}
MNISTDataクラスのインスタンスは、ユーザーがファイルを開こうとするときに作成します。インスタンスの主な機能はloadData(e, index)メソッドです。インスタンスは、ただファイルを読み取れ、と命令されても、4つあるどれを読み取ればよいのか分からないので、loadData()の引数のeとindexを使って、読み取るべきファイルを特定します。
eはユーザーがファイルを開こうとしたときのイベントオブジェクトです。インデックスは、あらかじめメインページで作成しておいたファイル名の配列のインデックスです。
以下はメインページのJavaScriptコードです。ここでは、読み込む4つのファイル名を持つ配列namesを作成しています。MNISTDataインスタンスのloadData()には、この配列のインデックスを渡します。
// 画像とラベルデータを保持するMNISTDataオブジェクト用変数
let testImages, testLabels, trainImages, trainLabels;
// 訓練用とテスト用の画像とテスト用tf.Tensorオブジェクトを保持する変数
let trainXS, trainYS, testXS, testYS;
const inputFile = document.getElementById('input-file');
inputFile.addEventListener('change', async(e) => {
// 読み込むバイナリファイルの名前。順番を固定する
const names = ['t10k-images.idx3-ubyte', 't10k-labels.idx1-ubyte', 'train-images.idx3-ubyte', 'train-labels.idx1-ubyte'];
const files = e.target.files;
//console.log(files.length);
for (let i = 0; i < files.length; i++) {
const fileName = files[i].name;
// 't10k-images.idx3-ubyte'の場合
if (fileName === names[0]) {
testImages = new MNISTData(fileName);
console.log(testImages.getFileName());
await testImages.loadData(e, i);
console.log('テスト用画像データを読み込んだ');
// 't10k-labels.idx1-ubyte'の場合
}
else if (fileName === names[1]) {
testLabels = new MNISTData(fileName);
console.log(testLabels.getFileName());
await testLabels.loadData(e, i);
console.log('テスト用ラベルデータを読み込んだ');
// 'train-images.idx3-ubyte'の場合
}
else if (fileName === names[2]) {
trainImages = new MNISTData(fileName);
console.log(trainImages.getFileName());
await trainImages.loadData(e, i);
console.log('訓練用画像データを読み込んだ');
} // 'train-labels.idx1-ubyte'の場合
else if (fileName === names[3]) {
trainLabels = new MNISTData(fileName);
console.log(trainLabels.getFileName());
await trainLabels.loadData(e, i);
console.log('訓練用ラベルデータを読み込んだ');
}
}
console.log('tf.Tensorオブジェクトに変換開始');
// 訓練用の画像tf.Tensorとラベルtf.Tensorオブジェクトを、適切なMNISTDataオブジェクトを使って作成する。
trainXS = tf.tensor4d(trainImages.getData(), [60000, 28, 28, 1]);
trainYS = tf.oneHot(tf.tensor1d(trainLabels.getData(), 'int32'), 10);
// テスト用
testXS = tf.tensor4d(testImages.getData(), [10000, 28, 28, 1]);
testYS = tf.oneHot(tf.tensor1d(testLabels.getData(), 'int32'), 10);
console.log('tf.Tensorオブジェクトに変換終了');
// 画像tf.Tensorのチェック用関数
const xsDataCheck = (xs, begin) => {
const sliced = xs.slice([begin, 0, 0, 0], [1, 28, 28, 1]);
const canvas = document.createElement('canvas');
canvas.width = 28;
canvas.height = 28;
tf.toPixels(tf.reshape(sliced, [28, 28, 1]), canvas);
return canvas;
}
// ラベルtf.Tensorのチェック用関数
const ysDataCheck = (ys, begin) => {
const sliced = ys.slice([begin, 0], [1, 10]);
return sliced.argMax(1).dataSync()[0] + ' ';
}
console.log('データチェック開始');
const div = document.getElementById('image-container');
div.appendChild(xsDataCheck(trainXS, 0));
div.appendChild(xsDataCheck(trainXS, 1));
const p = document.getElementById('info');
p.textContent += ysDataCheck(trainYS, 0);
p.textContent += ysDataCheck(trainYS, 1);
console.log('データチェック終了');
trainXS.dispose();
trainYS.dispose();
testXS.dispose();
testYS.dispose();
}, false);
4つのインスタンスがそれぞれに指定されたデータを読み取り、自身のデータとして保持できたら、tf.Tensorオブジェクトの変換に移ります。画像データは、後で見ていく畳み込みニューラルネットワーク(CNN)のレイヤーに渡しやすくするように、tf.tensor4d()で作成します。shapeの[60000,28,28,1]は、サイズが28×28でチャンネルが1(つまりモノクロ)のデータが60,000個ある、という意味です。
trainXS = tf.tensor4d(trainImages.getData(),[60000,28,28,1]);
ラベルデータは、tf.oneHot()関数で作成します。これはone-hotと呼ばれる形式の配列を作成します。one-hotは、1つの要素だけが1で、後は全部0である配列で、1である要素のインデックスが意味を持ちます。以下はone-hotの例です。
const labelArray = [5, 0, 1];
const oh = tf.oneHot(tf.tensor1d(labelArray, 'int32'), 10);
oh.print();
以下が出力されます。
[[0, 0, 0, 0, 0, 1, 0, 0, 0, 0], -> 要素10個、1のインデックスは5
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], -> 要素は10個、1のインデックスは0
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]] -> 要素は10個、1のインデックスは1
ざっくり言うと、手書き数字認識アプリはその訓練で、ある[1, 28, 28, 1]の画像データが[0, 0, 0, 0, 0, 1, 0, 0, 0, 0]で、次の[1, 28, 28, 1]の画像データが[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]である、ということを60,000サンプルにわたって教えられ、そこにある関係性を探ります。ラベルデータを10個の要素を持つone-hotの配列にするのは、そのための準備です。
その後では確認として、変換したtf.Tensorオブジェクトの値を、画像データは画像として、ラベルデータは数値として出力しています。下図は実行結果です。