AlphaZero 深層学習・強化学習・探索 人工知能プログラミング実践入門
布留川英一著 2019年 ボーンデジタル
第6章では、3目並べを題材に、実際にAlphaZeroを実装する。AlphaZeroは「二人零和有限確定完全情報ゲーム」においては、どんなに新しいゲームだったり、ルールが変更されても、人間を越え無敵になってしまう、言ってみれば魔王のようなアルゴリズムであるということが理解出来た。
サンプルプログラムは8つのソースファイルに分かれており、メインのtrain_cycle.pyをJupyterNotebookのterminalでPythonコマンドから実行すると、GoogleColab上からではなくても動いた。ただし、非常に時間がかかる。1回の学習に際し、500回の自己対戦を行うところを200回に数を減らして、1時間程度を要した。CPUが、Corei5-2410M、メモリ4G、GPU無しという非力な環境であることも影響している。また、学習サイクルが4周目ぐらいに達したところでpythonがエラーを起こし止まってしまう。それまでの学習成果は、学習の結果より強いモデルになればモデルが上書きされるので、無駄にはならない。
プログラムの構造は以下の通り。
- game.py→ゲーム部分の実装
- dual_network.py→学習を行うネットワーク本体の定義
- pv_mcts.py→モンテカルロ木探索。自己対戦の時に手を探索する
- self_play.py→自己対戦を行い、学習データを蓄積する。一手毎にpv_mcts.pyのサブルーチンを呼び出し、次の手の確率分布を計算する
- train_network.py→自己対戦が終わったら、蓄積した学習データで学習を行う
- evaluate_network.py→学習した結果、前より強くなったかを判定
- evaluate_best_player.py→強くなったら、他の方法と対戦して勝率を算出
- train_cycle.py→上記プログラムを呼び出す。
上記を動かし、強くなったかどうかなどの結果をおいかけることは普通に出来るであろう。作ったモデルの勝率例が以下の通りで、最強のAlphaBetaと互角、モンテカルロ(MCTS)より勝率が上回っているので、ちょっと強くなっている。
Evaluate 10/10
VS_Random 0.85
Evaluate 10/10
VS_AlphaBeta 0.5
Evaluate 10/10
VS_MCTS 0.6
ここで、実際にAlphaZeroの中身に迫ろうとすると、なんども読み返し、何が行われているのかを読み取らなくてはならないと思う。以前読んだ「直感DeepLearning」等も引っ張り出した。特に、モデル(ネットワーク本体)の説明があっさりしていると感じた。自分なりの理解を以下に書いていく。
モデルは、ある局面を入力されると、その時の取るべき方策と価値を出力する。方策とは、次の手の確率分布である。価値は、分かりにくいがその局面から進んだ時に先手が勝ったか後手が勝ったかである。方策と価値の二つを出力するので、デュアルネットワークと読んでいる。AlphaZeroの前身のAlphaGoでは、別のネットワークだったのが、AlphaZeroで一つにまとまったとのこと。なぜデュアルネットワークになっているのかの明確な説明は無い。
モデルの構造は、最初に畳み込み層があってから、残差ブロックと呼ばれる畳み込み等の組み合わせモデルが16回繰り返される長大なものである。ブログ筆者は理解するのに時間がかかったが、このモデルは、一つのネットワークで、どんな局面がinputされてもその局面での最適解を出力するように学習する(局面ごとに最適化されたネットワークがある訳では無いんですね…)。
また、「二人零和有限確定完全情報ゲーム」であるため、ゲームの経緯は関係なく、とある局面での最適解を出力するように学習すれば良い。そのため、学習データは、自己対戦での一局面とその時の方策(取るべき次の手の確率分布)、またゲームが終わった時に勝ったか負けたかというデータになる。一ゲームで8局面ぐらいあるとして、自己対戦で決着がついた時の勝敗をその8局面全部に「価値」として付与するのか、明記されていないが、そうなっていると思われる。
すなわち、自己対戦を500回やって、局面データ(局面+次の手の確率分布+勝ち負け)を集め、それをネットワークに学習させる、自己対戦の時にはモンテカルロ法で解空間を探索し次の手の確率分布を更新する、というのがAlphaZeroの全体像であろう。
まだ本質に迫れていない部分もあり、なぜデュアルネットワークにするのかを始め、もう少し詳しい説明があると良かったように思う。最初の畳み込みでKernelが128個あるのも、入力の次元(3×3×2)と比べると多いなぁという気がするが、128個の理由は記載が無い。kerasのConv2dルーチンが使われているので、それがどういうものなのか、本家をあたりたい。
おまけとして、ソースにTensorboardを呼び出すコールバックを追加して、ネットワーク構造を可視化したものを載せておく。残差ブロックが16個あると縦長すぎるので、途中は省略した。最初に3個、後半に1個残差ブロックが載っており、途中の12個が省略となっている。
最後に、つまらない誤記の指摘。正誤表に載っていなかった。
P213 本文下から4行目のDN_WIDTHは、DN_FILTERS
P213 同、DN_HEIGHTは、DN_RESIDUAL_NUM
がそれぞれ正しいであろう。プログラムコード、サンプルプログラムとも直っているので、実害は無い。