ml5.jsのWord2vecサンプルは「ml5-examples/p5js/Word2Vec」にあります。以下はサンプルで使用されているデータ(wordvecs10000.json)を使用した、ml5.word2vecの詳しい検証です。
目次
ml5.word2vec()メソッド
ml5.jsのWord2vecオブジェクトは、ml5.word2vec()メソッドに単語ベクトルを記述したJSONファイルへのパスを指定して作成します。
word2Vec = ml5.word2vec('data/wordvecs10000.json');
JSONファイルが読み込まれ、モデルの準備が整うと、Word2vecオブジェクトのreadyプロパティ(これはPromise)のthen()メソッドが呼び出されます。then()メソッドの引数である関数にはWord2vecオブジェクトが渡されます。
word2Vec.ready.then((w2v) => {
console.log(w2v.model.in.dataSync());
console.log(w2v.model.in.dataSync().length) // 300個
});
Word2vecオブジェクトのmodelプロパティは使用しているモデルです。これにJSONファイルに書かれた単語名(in、for、thatなど)をドットでつなげ、さらにそこからTensorFlow.jsのdataSync()メソッドを呼び出すと、その単語のベクトルが参照できます。下図は”in”の出力結果です。これは配列で300個の要素を持っていることが分かります。
ml5.word2vec()にはまた、第2引数としてコールバック関数を渡すことができます。これはモデルの準備が整ったタイミングで呼び出されます。
word2Vec = ml5.word2vec('data/wordvecs10000.json', modelLoaded);
modelLoaded()関数では、Word2vecオブジェクトの準備が整っているので、Word2vecの機能を使用することができます(nearest()やadd()、subtract()など)。
const modelLoaded = () => {
nearest();
}
const nearest = () => {
word2Vec.nearest(input, ? max, ? callback);
}
nearest(input, ?max, ?callback)
Word2vecの機能として最も分かりやすいのがnearest()メソッドです。これは第1引数として渡されたinputに最も似ている単語を、max分(デフォルトは10)だけ探し求め、その結果の配列をコールバック関数の引数に渡します。
以下はword2Vec.nearest()に単語’evening’を渡して、これに最も近い単語を10個探し出すnearest()関数の例です。
const nearest = () => {
const t1 = 'evening';
text('[' + t1 + ']', 10, 20)
// .nearest(input, ?max, ?callback)
word2Vec.nearest(t1, 10, (error, resArray) => {
if (resArray) {
console.log(resArray.length); // maxパラメータの値
resArray.forEach((elm, index) => {
console.log(elm.word);
console.log(elm.distance);
text(index + ': ' + elm.word, elm.distance * 50, index * 20 + 70);
});
}
else {
console.log('単語ベクトルは見つからなかった。');
}
});
}
ここでは、p5.jsのtext()関数を使っているので、p5.jsのキャンバスに下図の左が描かれ、コンソールに下図の右が出力されます。
nearest()メソッドのコールバック関数には、max個の要素を持つ配列が渡されるので、これを調べます。要素のオブジェクトはwordとdistanceというプロパティを持っています。wordはnearest()メソッドが探し出した単語で、distanceはその距離、つまり近さ、隔たりを意味する数値です。ここではtext()関数の第2引数(描画するx位置)にelm.distance * 50を指定しているので、nearest()が近い意味だと判断した単語ほど左に描画されます。
nearest()メソッドはまず、お題であるeveningを、モデルの作成時に読み込んだJSONファイルのデータから探してそのベクトル(要素数300)を読み取り、「9_3:似ているか? 似ていないか?」で述べたような方法で計算して、最も似ているベクトルを求め、その単語を特定します。
add(inputs, ?max, ?callback)
add()メソッドは、配列の形で渡された単語を加算します。計算が終わると、max個分の結果の配列を、callbackに渡して呼び出します。結果の配列には、加算した合計に最も近いベクトルを持つ単語がmax個だけ含まれます。
要領はnearest()と同じです。inputsには加算したい単語を配列に入れて渡します。次のコードでは’bicycle’と’engine’を足しています。
const add = () => {
const t1 = 'bicycle';
const t2 = 'engine';
text('[' + t1 + ' + ' + t2 + ']', 10, 20)
word2Vec.add([t1, t2], 10, (error, resArray) => {
if (resArray) {
console.log(resArray.length); // maxパラメータの値
resArray.forEach((elm, index) => {
console.log(elm.word);
console.log(elm.distance);
text(index + ': ' + elm.word, elm.distance * index * 10 + 50, index * 20 + 70)
});
}
else {
console.log('単語ベクトルは見つからなかった。');
}
});
}
なおここで、bicycleの代わりに複数形のbicyclesを指定すると、次のようなエラーがコンソールに表示されます。これは、入力に指定された’bicycles’のベクトルが見つからないというエラーです。これは次に示す方法で回避できます。
subtract(inputs, ?max, ?callback)
subtract([t1, t2])の場合、t1からt2を引きます。後はadd()と同じです。
const subtract = () => { const t1 = 'bikes'; const t2 = 'engine'; text('[' + t1 + ' - ' + t2 + ']', 10, 20); word2Vec.subtract([t1, t2]) .then((resArray) => { resArray.forEach((elm, index) => { console.log(elm.word); console.log(elm.distance); text(index + ': ' + elm.word, elm.distance * index * 10 + 50, index * 20 + 70) }); }).catch(error => console.log(error.message)); }
bikesからengineを引くとbicycleとはお見事です。ただし単数形のbikeでは違う結果が出ます。前出のコンソールのエラー表示はpromise.then().catch()で回避することができます(「Promiseを使う」)。
average(inputs, ?max, ?callback)
使い方はadd()などと同じで、渡された単語ベクトルの平均を計算します。2つの単語を渡す場合、これは”その中間”と考えることができます。
const average = () => { const t1 = 'human'; const t2 = 'monkey'; text('[' + t1 + ' と ' + t2 + ' の中間]', 10, 20); word2Vec.average([t1, t2]) .then((resArray) => { resArray.forEach((elm, index) => { console.log(elm.word); console.log(elm.distance); text(index + ': ' + elm.word, elm.distance * index * 10 + 50, index * 20 + 70); }); }).catch(error => console.log(error.message)); }
humanとmonkeyの中間はape(類人猿)とかgorilla(ゴリラ)と言っています。
getRandomWord(?callback)
1単語がランダムに得られます。
const getRandom = () => { word2Vec.getRandomWord(). then((result) => { console.log(result); text(result, 10, 20); }).catch(error => console.log(error.message)); }