12_7 重みパラメータを探る

6_1 モデルの中身 その1」で見た重みの数値は、tf.layers.LayerクラスのgetWeights()メソッドで調べることができます。

// 最初のレイヤー
const firstLayer = model.layers[0];
// レイヤーの重み配列
const kernelArray = firstLayer.getWeights();
console.log(kernelArray.length);
console.log(kernelArray[0].dataSync()); // -> レイヤーの重みパラメータと同じ
console.log(kernelArray[1].dataSync());

// data()を使って非同期で求める
const promise = model.layers[0].getWeights()[0].data();
promise.then((kernel) => {
    console.log(kernel);
});

getWeights()はtf.Tensorオブジェクトが入った配列を返すので、tf.TensorオブジェクトからdataSync()かdata()メソッドを呼び出すことで、重みの具体的な数値を得ることができます。dataSync()は同期的に、data()は非同期的に動作します。data()を使用する場合には、返されるPromiseからthen()メソッドを使用します。

ここまで見てきたボストンデータセットの場合、重みの数値は、モデルが住宅価格の学習を進めるに当たって、説明変数の各データをどれだけ重要視しているかを表します。

具体的に言うと、下図のような数値が得られた場合には、モデルは「幹線道路へのアクセス」を「平均部屋数」よりも重視し、また「学校中退率」はほかのどの属性よりも重視せずに、住宅価格を予想した、ということになります。ただし以降のコードが利用できるのは、linearRegressionModel()で作成した単純なモデルだけです。

これは、次のコードで実現できます。

describeKerenelElements.jsファイル

// 12個の説明変数が持つ意味の日本語訳
const featureDescriptions = [
    '犯罪発生率', '広い宅地の割合', '非小売業の土地面積の割合', '川に近いかどうか',
    '窒素酸化物の濃度', '平均部屋数', '古い家の割合',
    '通勤距離', '幹線道路へのアクセス', '所得税率', '1教室当たりの生徒数',
    '学校中退率'
];

/**
 * 現在の線形の重みを人間が読めるように説明する
 *
 * @param {Array} kernel 長さが12の、浮動小数点数の配列。1つの値が1つの特徴に対応している。
 * @returns {List} オブジェクトのリスト。文字列の特徴名とその重みの値を持つ。
 */

// kernelは数値の要素を12個持つ配列
const describeKerenelElements = (kernel) => {
    // kernelの要素数が12個でない場合は、エラーを発生させる。
    tf.util.assert(kernel.length === 12, `kernel must be a array of length 12, got ${kernel.length}`);
    const outList = [];
    for (let idx = 0; idx < kernel.length; idx++) {
        // featureDescriptions配列とkernel配列の同じインデックスにある値を配列outListに追加する
        outList.push({
            description: featureDescriptions[idx],
            value: kernel[idx]
        });
    }
    return outList;
}

ui.jsに次のコードを追加します。

const NUM_TOP_WEIGHTS_TO_DISPLAY = 12;
/**
 * 単純な線形モデルが学習した、重みに関する情報を表示する出力領域を更新する。
 *
 * @param {List} weightsList 'value':数値と'description':文字列を持つオブジェクトのリスト
 */
const updateWeightDescription = (weightsList) => {
    // 重みのオブジェクトを、絶対値の降順で並び替える。
    weightsList.sort((a, b) => Math.abs(b.value) - Math.abs(a.value));
    const table = document.getElementById('myTable');
    // テーブルの中身を一度クリア
    table.innerHTML = '';
    // テーブルに新しい行を加える
    weightsList.forEach((weight, i) => {
        if (i < NUM_TOP_WEIGHTS_TO_DISPLAY) {
            let row = table.insertRow(-1);
            let cell1 = row.insertCell(0); // 文字
            let cell2 = row.insertCell(1); // 数値
            // 重みの正負によって色変え
            if (weight.value < 0) {
                cell2.setAttribute('class', 'negativeWeight');
            }
            else {
                cell2.setAttribute('class', 'positiveWeight');
            }
            // 文字と数値を表示する
            cell1.innerHTML = weight.description;
            cell2.innerHTML = weight.value.toFixed(4);
        }
    });
};

// https://www.w3schools.com/howto/tryit.asp?filename=tryhow_js_sort_table_number
const sortTable = () => {

    const table = document.getElementById("myTable");
    let switching = true;
    let shouldSwitch = false;
    let i2 = 0;
    /*Make a loop that will continue until
    no switching has been done:*/
    while (switching) {
        //start by saying: no switching is done:
        switching = false;
        const rows = table.rows;
        /*Loop through all table rows (except the
        first, which contains table headers):*/
        for (let i = 0; i < (rows.length - 1); i++) {
            i2 = i;
            //start by saying there should be no switching:
            shouldSwitch = false;
            /*Get the two elements you want to compare,
            one from current row and one from the next:*/
            const x = rows[i].getElementsByTagName("TD")[1];
            const y = rows[i + 1].getElementsByTagName("TD")[1];
            //check if the two rows should switch place:
            if (Number(x.innerHTML) < Number(y.innerHTML)) {
                //if so, mark as a switch and break the loop:
                shouldSwitch = true;
                break;
            }
        }
        if (shouldSwitch) {
            /*If a switch has been marked, make the switch
            and mark that a switch has been done:*/
            rows[i2].parentNode.insertBefore(rows[i2 + 1], rows[i2]);
            switching = true;
        }
    }
}

メインのJavaScriptコードに記述した、model.fit()内のonEpochEndコールバック関数に、次のコードを追加します(await tf.nextFrame();の前)。

// 12個の重みを取り出し、数値を表示する
model.layers[0].getWeights()[0].data().then(kernelAsArr => {
    const weightsList = describeKerenelElements(kernelAsArr);
    updateWeightDescription(weightsList);
});

そしてmodel.fit()の実行後のupdateStatus(‘訓練終了…’);の後に、次のコードを追加します。

setTimeout(sortTable, 1000);

この後の例では、次のようなCSSを使用しています。

.negativeWeight {
    color: #cc0000;
}
.positiveWeight {
    color: #00aa00;
}
table {
    margin-left: 10px;
}
td {
    border: thin gray solid;
}

ブラウザでHTMLファイルを開くと、モデルの訓練が始まり、モデルがどの要因をどれくらい重視して住宅の予想価格を決めようとしているかを、重みの数値を通して知ることができます。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

CAPTCHA