注:今回の記事は完全にプログラマ向けの解説記事です
ソースコードの閲覧、ダウンロードは此方からどうぞ
GitHub - qhapaq-49/tf_reinforcement: tensorflowを使った簡単(300行弱)なreinforcement learning
【今回作りたいもの】
囲碁やポーカーのAIで度々注目されているディープラーニングを使った強化学習。時代の先端を走るゲームAI開発者的には是非覚えておきたいスキルの一つです。といっても、強化学習の動作原理自体は下記の図のようにシンプルなものです。本稿では下記図の流れを一通り搭載したスタンドアロンで動く強化学習ルーチンを紹介します(上述のgithubのコードを見ながら読まれることをオススメします)。
【本稿で扱うゲームのルール】
本稿ではニューラルネットで動く競りゲームのAIを作ります。競りゲームとは
・初期所持金10のプレイヤー2人が4回の競りを行う
・各競りでは1〜4の勝利点がランダムで出てくる。同じ勝利点は二度出ない。
・競りで買値を宣言できるのは一度だけ。高値を宣言した側が勝利点を得られる。引き分けの場合の再試合はなし
・4回の競りの後で勝利点が高いほうが勝利
というゲームです。私が風呂に入りながら勝手に考えたゲームです。
【使い方】
Step1. ランダムムーブで教師データを作る
playout = seriGame("","")
playout.playout(30000,"hoge",True,True)
によってランダムムーブの対局を行わせて教師データを作ります。
なお、教師データは勝った方の手を良い手としてその手の採択率を上げるというシンプルな作りになっています。
Step2. 学習させる
sa = seriAgent()
sa.loadnofile()
sa.learnbycsv("hoge",200,"model/test_model")
Step1で作られた棋譜を元に勝った時の手の採択率を上げるように学習します。
Step3. 学習させたモデル同士を戦わせて棋譜を作る(以下、Step1-3を繰り返す)
playout = seriGame("model/test_model","model/test_model")
playout.playout(30000,"hoge2",False,False)
【コード中の特に重要(かつ、tensorflowにちなんだ)部分の解説】
1.ネットワークの定義
def design_model(self):
graph_t = tf.Graph()
with graph_t.as_default():
X = tf.placeholder(tf.float32, [None, glob_inpXdim])
W1 = tf.Variable(tf.truncated_normal([glob_inpXdim, self.hiddendim], stddev=0.01), name='W1')
B1 = tf.Variable(tf.zeros([self.hiddendim]), name='B1')
H1 = tf.nn.relu(tf.matmul(X, W1) + B1)
W2 = tf.Variable(tf.random_normal([self.hiddendim, self.outdim], stddev=0.01), name='W2')
B2 = tf.Variable(tf.zeros([self.outdim]), name='B2')
Y = tf.nn.softmax(tf.matmul(H1, W2) + B2)
tf.add_to_collection('vars', W1)
tf.add_to_collection('vars', B1)
tf.add_to_collection('vars', W2)
tf.add_to_collection('vars', B2)
t = tf.placeholder(tf.float32, shape=[None, self.outdim])
entropy = -tf.reduce_sum(t*tf.log(tf.clip_by_value(Y,1e-10,1.0)))
learnfunc = tf.train.AdamOptimizer(0.05).minimize(entropy)
model = {'X': X, 'Y': Y, 't' : t, 'ent' : entropy, 'learnfunc' : learnfunc}
self.sess = tf.Session()
self.saver = tf.train.Saver()
self.sess.run(tf.global_variables_initializer()[f:id:qhapaq:20170719231154p:plain][f:id:qhapaq:20170818173131p:plain])
この処理によってネットワークの構築、ネットワークを使って実際に計算する部分(session)と
計算結果を保存するためのsaverが定義できます。
この辺の処理は解りにくい(そして私の理解も怪しい)ですが、graphが住所、variablesが家の設計図、sessionが建築士と考えると良いでしょう。上記関数は特定の住所(graph)に対し、どういったネットワーク(variable)が入るかを指定したものを、建築士(self.sess)に渡したと解釈できます。graphにぶら下げる形でvariableとsessを定義するのがミソ(with graph_t.as_default())のようです。建築士に住所を指定して教えこむことで、複数のsessを動かすときに変な衝突の発生を避けられるようです。
2.ネットワークの書き込み、読みこみ
TensorFlowでmodelを保存して復元する - test.py
を合わせて読むことをオススメします(本作のコードの元にもなってます)
def learnbycsv(self, fname, stepnum,outname):
data = pd.read_csv(fname,header=None, dtype=float)
x = data.iloc[:,0:glob_inpXdim]
y = data.iloc[:,glob_inpXdim:63]
for i in range(stepnum):
self.sess.run(self.ln, feed_dict={self.X : x, self.t : y})
if i%20 == 0:
ent = self.sess.run(self.ent, feed_dict={self.X: x, self.t : y})
print("epoch " + str(i)+" ,entropy = " + str(ent))
self.saver.save(self.sess, outname)
saverはsessを呼び出し、sessからgraphの内部変数を引っ張りだすことで学習結果を保存してくれます。
逆にデータの読み出しは
def loadfile(self,fname):
self.model = self.design_model()
self.X, self.Y, self.t, self.ent, self.ln = self.model['X'], self.model['Y'], self.model['t'], self.model['ent'], self.model['learnfunc']
print("load agent from : " + fname)
self.saver.restore(self.sess, fname)
でいけます。
3.雑なまとめ
これらの関係を図にするとこんな感じになります。
【あとがき:なんでこんなものを態々公開したか】
一言で言えば、このコードを書いた際に教えて欲しかったことを教えてくれる記事がなかったからです。tensorflow+reinforcement leaningの記事は沢山あるのですが、その多くはgymなどの洗練されすぎたパッケージを使っていて、「結局、これを自前のゲームに適用するにはどうすんのよ」という問の答えは自力で探す必要がありました。更に悪いことに、多くの教材はmnistのように外部から教師を引っ張ってきて、一度学習させたらオシマイとなっており、複数の学習モデルの比較や、強化学習といった学習結果を次の学習にフィードバックさせる類の学習への適用は困難でした。特に、学習データの読み書きや、複数のsessionの切り替え、誤差関数の自作については悲しいくらいに記事が少なく(または、サンプルコードやっつけたらオシマイの記事に埋もれてしまい)、数学に関係ない実装面でのエラー潰しに酷く苛々させられたものです。
これに対し、本コードは300行未満のシンプル、かつ、スタンドアロン(tensorflowやnumpy以外の外部パッケージは使ってない)なコードでありながら、複数セッションの呼び出しやデータの読み書き、思考部と対戦部の切り分けとクラス化といった基本的な機能が搭載されています。本コードを動かし、読み解くことで、皆様もtensorflowを使った強化学習に取り組めるのではないかと期待しています。
# batch normalizationやcnnなどの高度な機能は一切積んでいませんが、そのへんの技はググればすぐ出てくるので、ggrksなのです
しかし同時に、本稿はtensorflow歴一週間にも満たない野良開発者の雑コードの解説であり、至るところに間違いもあると思います。悪く言えば未完成なのですが、よく言えばインタラクティブな記事だと思っています。間違いやアドバイスがありましたら、教えていただけると幸いです。
【謝辞】
本稿の執筆に当たり、コンピュータ囲碁開発者の方々にアドバイスをいただきました。
これらのアドバイスがなかったら、恐らくこの記事は完成しなかったでしょう。心よりお礼申し上げます。