BT

InfoQ ホームページ ニュース TensorFlow.jsマシンラーニングライブラリを使って、Chrome Dinosaur Gameのプレー方法をコンピュータに教える

TensorFlow.jsマシンラーニングライブラリを使って、Chrome Dinosaur Gameのプレー方法をコンピュータに教える

ブックマーク

原文(投稿日:2019/04/30)へのリンク

Fritzの刊行するHeartBeatは先日、GoogleのマシンラーニングライブラリであるTensorFlow.jsをブラウザで使って、Chrome Dinosaur Gameのプレーをコンピュータに教える方法について解説した、Aayush Arora氏による記事を公開した。

Chrome Dinosaur Game(T-Rex Gameとも呼ばれる)は5年前、あるユーザがインターネットから切断された状態でWebサイトにアクセスしようとしたときに、Chromeブラウザに現れたものだ。Chrome Dinosaur Gameは単純な無限ランナーで、プレーヤはサボテンを飛び越えたり、障害物を潜り抜けたりする。コントロールは基本的で、スペースバーを押すとジャンプし、下向きの矢印を押すとしゃがみ込む。目標はできる限り長く生き残ることであり、プレーヤが障害物を乗り越えた時間がタイマで測定される。

4年後、GoogleがついにChrome恐竜ゲームの起源を説明する

ゲームの性質を考慮して選択された機能セットは、ゲームのスピード 、現れる障害物の とティラノサウルスからの距離だ。コンピュータはこれら3つの変数をマップして、ジャンプするかしないか、2つの判断のどちらを選択するかを学習する((ゲームのオリジナルバージョンでは、恐竜がしゃがみ込むこともできるが、今回の決定リストではモデル化されない)。コンピュータは試行錯誤によって学習し、ゲームに失敗するたびにトレーニングデータを収集し、蓄積された経験を用いてゲームを再開する。

Tensorflow.jsは、マシンラーニングライブラリとして使用されている。TensorFlowのチュートリアルでは、マシンラーニングの実装で従うべき手順を明確にしている。

  1. 入力データをロードし、フォーマットし、視覚化する
  2. モデルのアーキテクチャを定義する
  3. トレーニング用のデータを用意する
  4. モデルをトレーニングする
  5. 予測する

今回の例では、トレーニングデータを使用しないで開始するため、最初のステップは実質的に空である。2番目のステップでArora氏は、逐次モデルをベースとして、いずれもシグモイド起動関数(sigmoid activation function)を備えた、入力レイヤと出力レイヤを持つニューラルネットワークを使用した。最初のレイヤは、ゲーム速度、現れる障害物の幅、ティラノサウルスからの距離という、前述の3つの予測変数を持ち、2番目と最後のレイヤの入力として機能する6つのユニットを計算する。最後のレイヤには2つの出力があり、それぞれの値がジャンプする確率、あるいはジャンプしない確率に対応する。

import  *  as  tf  from  '@tensorflow/tfjs';

dino.model  =  tf.sequential();
dino.model.add(tf.layers.dense({
  inputShape:[3],
  activation:'sigmoid',
  units:6
}))

dino.model.add(tf.layers.dense({
  inputShape:[6],
  activation:'sigmoid',
  units:2
}))

3番目のステップでは、入力データを、TensorFlow.jsが処理できるテンソル(tensor)に変換する。

dino.model.fit(
  tf.tensor2d(dino.training.inputs), 
  tf.tensor2d(dino.training.labels)
);

3番目のステップにはシャッフリング(shuffling)が実装されていないので、最初は空であるトレーニングセットに対して、コンピュータがゲームに失敗するたびに、段階的に入力が追加されていく。ここでの正規化は、トレーニングセット内の出力値を0から1の間に設定することで実現する。実際には、ティラノサウルスが障害物の回避に失敗した場合、対応する3入力(ゲーム速度、現れる障害物の幅、ティラノサウルスからの距離)は[1, 0][0, 1]のいずれかにマッピングされて 、第2レイヤの出力を符号化する。ティラノサウルスがジャンプして障害物の回避に失敗した場合、適切な決定はジャンプしないことである: [1, 0]。逆に、ティラノサウルスがジャンプせずに障害物にぶつかった場合には、ジャンプすることが適切な決定となる: [0, 1]

4番目のステップとして、トレーニングデータが利用可能になると、モデルはmeanSquaredError損失関数とAdamオプティマイザを使って、学習率0.1でトレーニングされる(Adamオプティマイザは実際には非常に効果的で、設定を必要としない)。

dino.model.compile({
  loss:'meanSquaredError',
  optimizer: tf.train.adam(0.1)
})

5番目のステップは、ゲームの繰り返し中に発生する。ゲームが進行し、3入力の新たな値が計算されると、予測が実行されて、実行すべきタイミングであれば(例えば、ティラノサウルスがジャンプ中でなければ)、"ジャンプする/しない"の判断が行われる。

if (!dino.jumping) {
  // whenever the dino is not jumping decide whether it needs to jump or not
  let action = 0;// variable for action 1 for jump 0 for not
  
  // call model.predict on the state vecotr after converting it to tensor2d object
  const prediction = dino.model.predict(tf.tensor2d([convertStateToVector(state)]));

  // the predict function returns a tensor we get the data in a promise as result
  // and based don result decide the action
  const predictionPromise = prediction.data();
  
  predictionPromise.then((result) => {
  // converting prediction to action
  if (result[1] > result[0]) {
  // we want to jump
  action = 1;
  // set last jumping state to current state
  dino.lastJumpingState = state;
  } else {
  // set running state to current state
  dino.lastRunningState = state;
  }
  
resolve(action);
});

Fritzは、iOSおよびAndroid開発者向けのマシンラーニングプラットフォームである。TensorFlow.jsは、Apache 2.0ライセンスの下で利用可能なオープンソースソフトウェアである。コントリビューションとフィードバックは、TensorFlowのGitHubプロジェクトを通して受け入れられている。いずれもTensorFlowのコントリビューションガイドラインに従うことが必要だ。

この記事に星をつける

おすすめ度
スタイル

こんにちは

コメントするには InfoQアカウントの登録 または が必要です。InfoQ に登録するとさまざまなことができます。

アカウント登録をしてInfoQをお楽しみください。

HTML: a,b,br,blockquote,i,li,pre,u,ul,p

コミュニティコメント

HTML: a,b,br,blockquote,i,li,pre,u,ul,p

HTML: a,b,br,blockquote,i,li,pre,u,ul,p

BT

あなたのプロファイルは最新ですか?プロフィールを確認してアップデートしてください。

Eメールを変更すると確認のメールが配信されます。

会社名:
役職:
組織規模:
国:
都道府県:
新しいメールアドレスに確認用のメールを送信します。このポップアップ画面は自動的に閉じられます。