目次
機械学習と深層学習 Pythonによるシミュレーション の第2章のコード qlearning.py
について、 わかりやすく書き換えてみました。
環境
- Python 3.6.5 Anaconda
問題の設定
2.2.3 の迷路抜けの学習プログラムです。
書籍と同じように、 リスト(q_value
)にQ値を格納します。 1階層目の値1
1階層目 |
0 | |||||||
---|---|---|---|---|---|---|---|---|
2階層目 | 1 | 2 | ||||||
3階層目 | 3 | 4 | 5 | 6 | ||||
4階層目 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |
そして書籍では インデックス 14 がゴールとなっていました。 今回は GOAL_POSITION
という定数にゴールのインデックスを格納します。 ゴールを変更して実験することも可能です。
コード
変数名や条件分岐の記述を少し変えて、理解しやすくしたつもりです。 階層構造などを柔軟に変えられるようにすることもできなくはないですが、コード量が膨れあがるのでやっていません。
THRESHOLD
という定数を導入しており、 この閾値の回数だけ q_value
が変化しなければ収束したものとして学習を終了し、 収束値に最初になったときのイテレーションインデックスを出力します。 ALPHA
, GAMMA
, EPSILON
の値によって収束する回数が変化します。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import math import random ALPHA = 0.1 # learning rate GAMMA = 0.9 EPSILON = 0.8 # randomize value LEVEL = 3 GOAL_POSITION = 14 THRESHOLD = 100 def select_a(previous_position, q_value): """ Select activity """ if random.random() < EPSILON: if random.randint(0, 1) == 0: return 2 * previous_position + 1 else: if q_value[2 * previous_position + 1] > q_value[2 * previous_position + 2]: return 2 * previous_position + 1 return 2 * previous_position + 2 def update_q(position, q_value): """ Calculate updated Q Value """ if position < 2 ** LEVEL - 1: # if not in last level qmax = max( q_value[2 * position + 1], q_value[2 * position + 2]) return q_value[position] + int( ALPHA * (GAMMA * qmax - q_value[position])) # if in last level if position == GOAL_POSITION: return q_value[position] + int( ALPHA * (1000 - q_value[position])) #elif position == 11: # return q_value[position] + int( # ALPHA * (500 - q_value[position])) return q_value[position] def solve(seed, time): random.seed(seed) q_value = [ random.randint(0, 100) for i in range(2 ** (LEVEL + 1) - 1) ] convergence_time = 0 eq_count = 0 print(q_value) for i in range(time): previous_q_value = q_value position = 0 # initial state for j in range(3): position = select_a(position, q_value) q_value[position] = update_q(position, q_value) print(q_value) if q_value == previous_q_value: if convergence_time == 0: convergence_time = i eq_count += 1 if eq_count > THRESHOLD: print(convergence_time) break else: eq_count = 0 convergence_time = 0 if __name__ == "__main__": solve(32767, 1000) |