7-4 MNISTデータをtf.Tensorオブジェクトに変換する

MNISTの画像とラベルデータの読み取りと取得ができたので、次はMNISTの4つのファイルを開いて必要なデータを取得し、それらをtf.Tensorオブジェクトに変換するまでを見ていきます。

4つのファイルの読み取りは1つずつ行う方法もありますが、input要素にmultiple属性を付けると、読み取りをいっしょに開始することができます。

 <input type="file" id="input-file" multiple>

またファイルの読み取りは、画像とラベルの違いはあれ、似たような処理を行うので、その機能を持ったクラスを定義し、そのインスタンスそれぞれに画像かラベルのデータを持たせる方法が考えられます。

以下はそのためのMNISTDataクラスのコードです。

class MNISTData {
  constructor(name) {
    this.fileName = name;
    this.data = null;
  }

  getFileName() {
    return this.fileName;
  }

  setFileName(name) {
    this.fileName = name;
  }

  getData() {
    return this.data;
  }

  setData(data) {
    this.data = data;
  }

  async loadData(e, index) {
    const files = e.target.files;
    await new Promise((resolve, reject) => {
      const reader = new FileReader();
      // ファイルの読み取りに成功したら解決する
      reader.onload = (e) => {
        resolve(e.target.result);
      }
      //  MNISTファイルをArrayBufferとして読み取る
      reader.readAsArrayBuffer(files[index]);
    }).then((buffer) => {
      // ファイル名によってラベルデータか画像データかを判別し、処理を分ける
      if (this.getFileName().indexOf('labels') != -1) {
        console.log('これはラベルデータ');
        const dataView = new DataView(buffer);
        const itemNum = dataView.getInt32(4);
        // ラベルデータを入れる配列
        const labelArray = [];
        for (let i = 0; i < itemNum; i++) {
          labelArray.push(dataView.getUint8(8 + i));
        }
        // 自分のデータに設定
        this.setData(labelArray);
      } else {
        console.log('これは画像データ')
        const dataView = new DataView(buffer);
        // 含まれる画像の数
        const itemNum = dataView.getInt32(4);
        //画像の高さ
        const rows = dataView.getInt32(8);
        // 画像の幅
        const columns = dataView.getInt32(12);
        // 画像データを入れる配列
        const dataArray = [];
        let offset = 0;
        // 含まれる画像数分だけ繰り返す
        for (let i = 0; i < itemNum; i++) {
          dataArray.push([]);
          for (let j = 0; j < rows; j++) {
            for (let k = 0; k < columns; k++) {
              // dataArray配列にdataViewからの画像データを追加
              // 画像データは0016から始まるので、その分を足しておく
              dataArray[i].push(dataView.getUint8(16 + offset) / 255);
              offset++;
            }
          }
        }
        // 自分のデータに設定。tf.Tensorオブジェクトに変換しやすいように、flat()で平滑化しておく
        this.setData(dataArray.flat());
      }
    });
  }
}

MNISTDataクラスのインスタンスは、ユーザーがファイルを開こうとするときに作成します。インスタンスの主な機能はloadData(e, index)メソッドです。インスタンスは、ただファイルを読み取れ、と命令されても、4つあるどれを読み取ればよいのか分からないので、loadData()の引数のeとindexを使って、読み取るべきファイルを特定します。

eはユーザーがファイルを開こうとしたときのイベントオブジェクトです。インデックスは、あらかじめメインページで作成しておいたファイル名の配列のインデックスです。

以下はメインページのJavaScriptコードです。ここでは、読み込む4つのファイル名を持つ配列namesを作成しています。MNISTDataインスタンスのloadData()には、この配列のインデックスを渡します。

// 画像とラベルデータを保持するMNISTDataオブジェクト用変数
let testImages, testLabels, trainImages, trainLabels;
// 訓練用とテスト用の画像とテスト用tf.Tensorオブジェクトを保持する変数
let trainXS, trainYS, testXS, testYS;

const inputFile = document.getElementById('input-file');
inputFile.addEventListener('change', async(e) => {
    // 読み込むバイナリファイルの名前。順番を固定する
    const names = ['t10k-images.idx3-ubyte', 't10k-labels.idx1-ubyte', 'train-images.idx3-ubyte', 'train-labels.idx1-ubyte'];
    const files = e.target.files;
    //console.log(files.length);
    for (let i = 0; i < files.length; i++) {
        const fileName = files[i].name;
        // 't10k-images.idx3-ubyte'の場合
        if (fileName === names[0]) {
            testImages = new MNISTData(fileName);
            console.log(testImages.getFileName());
            await testImages.loadData(e, i);
            console.log('テスト用画像データを読み込んだ');
            // 't10k-labels.idx1-ubyte'の場合
        }
        else if (fileName === names[1]) {
            testLabels = new MNISTData(fileName);
            console.log(testLabels.getFileName());
            await testLabels.loadData(e, i);
            console.log('テスト用ラベルデータを読み込んだ');
            // 'train-images.idx3-ubyte'の場合
        }
        else if (fileName === names[2]) {
            trainImages = new MNISTData(fileName);
            console.log(trainImages.getFileName());
            await trainImages.loadData(e, i);
            console.log('訓練用画像データを読み込んだ');
        } // 'train-labels.idx1-ubyte'の場合
        else if (fileName === names[3]) {
            trainLabels = new MNISTData(fileName);
            console.log(trainLabels.getFileName());
            await trainLabels.loadData(e, i);
            console.log('訓練用ラベルデータを読み込んだ');
        }
    }
    console.log('tf.Tensorオブジェクトに変換開始');
    // 訓練用の画像tf.Tensorとラベルtf.Tensorオブジェクトを、適切なMNISTDataオブジェクトを使って作成する。
    trainXS = tf.tensor4d(trainImages.getData(), [60000, 28, 28, 1]);
    trainYS = tf.oneHot(tf.tensor1d(trainLabels.getData(), 'int32'), 10);
    // テスト用
    testXS = tf.tensor4d(testImages.getData(), [10000, 28, 28, 1]);
    testYS = tf.oneHot(tf.tensor1d(testLabels.getData(), 'int32'), 10);
    console.log('tf.Tensorオブジェクトに変換終了');

    // 画像tf.Tensorのチェック用関数
    const xsDataCheck = (xs, begin) => {
        const sliced = xs.slice([begin, 0, 0, 0], [1, 28, 28, 1]);
        const canvas = document.createElement('canvas');
        canvas.width = 28;
        canvas.height = 28;
        tf.toPixels(tf.reshape(sliced, [28, 28, 1]), canvas);
        return canvas;
    }

    // ラベルtf.Tensorのチェック用関数
    const ysDataCheck = (ys, begin) => {
        const sliced = ys.slice([begin, 0], [1, 10]);
        return sliced.argMax(1).dataSync()[0] + ' ';
    }

    console.log('データチェック開始');
    const div = document.getElementById('image-container');
    div.appendChild(xsDataCheck(trainXS, 0));
    div.appendChild(xsDataCheck(trainXS, 1));

    const p = document.getElementById('info');
    p.textContent += ysDataCheck(trainYS, 0);
    p.textContent += ysDataCheck(trainYS, 1);
    console.log('データチェック終了');

    trainXS.dispose();
    trainYS.dispose();
    testXS.dispose();
    testYS.dispose();

}, false);

4つのインスタンスがそれぞれに指定されたデータを読み取り、自身のデータとして保持できたら、tf.Tensorオブジェクトの変換に移ります。画像データは、後で見ていく畳み込みニューラルネットワーク(CNN)のレイヤーに渡しやすくするように、tf.tensor4d()で作成します。shapeの[60000,28,28,1]は、サイズが28×28でチャンネルが1(つまりモノクロ)のデータが60,000個ある、という意味です。

trainXS = tf.tensor4d(trainImages.getData(),[60000,28,28,1]);

ラベルデータは、tf.oneHot()関数で作成します。これはone-hotと呼ばれる形式の配列を作成します。one-hotは、1つの要素だけが1で、後は全部0である配列で、1である要素のインデックスが意味を持ちます。以下はone-hotの例です。

const labelArray = [5, 0, 1];
const oh = tf.oneHot(tf.tensor1d(labelArray, 'int32'), 10);
oh.print();

以下が出力されます。

[[0, 0, 0, 0, 0, 1, 0, 0, 0, 0], -> 要素10個、1のインデックスは5
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], -> 要素は10個、1のインデックスは0
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]] -> 要素は10個、1のインデックスは1

ざっくり言うと、手書き数字認識アプリはその訓練で、ある[1, 28, 28, 1]の画像データが[0, 0, 0, 0, 0, 1, 0, 0, 0, 0]で、次の[1, 28, 28, 1]の画像データが[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]である、ということを60,000サンプルにわたって教えられ、そこにある関係性を探ります。ラベルデータを10個の要素を持つone-hotの配列にするのは、そのための準備です。

その後では確認として、変換したtf.Tensorオブジェクトの値を、画像データは画像として、ラベルデータは数値として出力しています。下図は実行結果です。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

CAPTCHA