Q学習-最良経路を学習するスクリプト書いた (powered by Python)

概要

講義の課題でQ学習について実装してみたので、スクリプト等を晒してみる.

          #   #   #   #   #   #   #
          #   S   0   0 -10   0   #
          #   0 -10   0   0   0   #
          #   0 -10   0 -10   0   #
          #   0   0   0 -10   0   #
          #   0 -10   0   0 100   #
          #   #   #   #   #   #   #

こんな感じの迷路において、S(start地点)からより良い報酬("100")までの経路をQ学習を用いて学習させるという話.


Q学習-概要

Q学習(-がくしゅう、英: Q-learning)は、機械学習分野における強化学習の一種である。Q学習は機械学習手法の方策オフ型TD学習の一つである。Q学習は有限マルコフ決定過程において全ての状態が十分にサンプリングできるようなエピソードを無限回試行した場合、最適な評価値に収束することが理論的に証明されている。実際の問題に対してこの条件を満たすことは困難ではあるが、この証明はQ学習の有効性を示す要素の一つとして挙げられる。

http://ja.wikipedia.org/wiki/Q%E5%AD%A6%E7%BF%92

強化学習で有名な手法らしい. TD(Temporal difference:時間差分)学習の一つらしい.

参考

キモはQ値の更新式!

Q(s_t, a) \leftarrow Q(s_t, a) + \alpha\left[r_{t+1} + \gamma\max_pQ(s_{t+1}, p) - Q(s_t,a)\right]
この式を掴めれば、Q学習はものにできたようなものだと勝手に思っている.

簡単に、この式がどういったことを表しているのかを書いてみる. ( \alpha が学習率、 \gamma が割引率と言われている.)
この式の目的は、ある状態 s_t における最良な行動 a を選ぶための基準を、各状態における各行動での評価値 Q(s_t, a) を更新するような処理をしたい. 行動の重みづけをするようなイメージ.
基本的なアイディアは、良い報酬につながるような行動を選ぶようにしたいので、良い報酬を得られる行動 a は良い行動. その報酬を得られる行動ができる状態に行けるための行動もちょっぴり良い行動といったような感じで良い報酬に近づいていく行動に対して、良い重みづけをしていくようなイメージ. 悪い行動に対しても同様に悪い重みを与える.

学習率 \alpha

\alphaが大きくなるほど、更新時の影響が強くなる.

割引率 \gamma

\gamma次の行動の影響度を表す. 値が大きいほど、次の状態での行動が効いてくる.


学習させた結果

1エピソードを報酬が得られるまでの行動遷移として、1000エピソードの全行動分、Q値を更新させた.
その時の行動選択アルゴリズムはε-greedy法を用いた. (スクリプトは長くなってしまったので記事の後ろの方に載せました.)

以下の図は、その1000エピソード分学習したQ値を用いて、今度はgreedy法を適用して行動選択させてみた様子.
ちゃんとスタート地点から最短距離で最高の報酬100ポイントに向かっているのがわかる.
ちなみに実行のたびに経路は変わる.(ε-greedyのランダム性のため。)

--- 盤面情報 "#": 壁, "S": スタート地点, "@": 今いる座標, "数値": その座標での報酬 ---

----- Dump Field:  -----
          #   #   #   #   #   #   #
          #   S   0   0 -10   0   #
          #   0 -10   0   0   0   #
          #   0 -10   0 -10   0   #
          #   0   0   0 -10   0   #
          #   0 -10   0   0 100   #
          #   #   #   #   #   #   #

----- Dump Field: (2, 1) -----
          #   #   #   #   #   #   #
          #   S   @   0 -10   0   #
          #   0 -10   0   0   0   #
          #   0 -10   0 -10   0   #
          #   0   0   0 -10   0   #
          #   0 -10   0   0 100   #
          #   #   #   #   #   #   #
        state: (1, 1) -> action:(2, 1)

----- Dump Field: (3, 1) -----
          #   #   #   #   #   #   #
          #   S   0   @ -10   0   #
          #   0 -10   0   0   0   #
          #   0 -10   0 -10   0   #
          #   0   0   0 -10   0   #
          #   0 -10   0   0 100   #
          #   #   #   #   #   #   #
        state: (2, 1) -> action:(3, 1)

----- Dump Field: (3, 2) -----
          #   #   #   #   #   #   #
          #   S   0   0 -10   0   #
          #   0 -10   @   0   0   #
          #   0 -10   0 -10   0   #
          #   0   0   0 -10   0   #
          #   0 -10   0   0 100   #
          #   #   #   #   #   #   #
        state: (3, 1) -> action:(3, 2)

----- Dump Field: (3, 3) -----
          #   #   #   #   #   #   #
          #   S   0   0 -10   0   #
          #   0 -10   0   0   0   #
          #   0 -10   @ -10   0   #
          #   0   0   0 -10   0   #
          #   0 -10   0   0 100   #
          #   #   #   #   #   #   #
        state: (3, 2) -> action:(3, 3)

----- Dump Field: (3, 4) -----
          #   #   #   #   #   #   #
          #   S   0   0 -10   0   #
          #   0 -10   0   0   0   #
          #   0 -10   0 -10   0   #
          #   0   0   @ -10   0   #
          #   0 -10   0   0 100   #
          #   #   #   #   #   #   #
        state: (3, 3) -> action:(3, 4)

----- Dump Field: (3, 5) -----
          #   #   #   #   #   #   #
          #   S   0   0 -10   0   #
          #   0 -10   0   0   0   #
          #   0 -10   0 -10   0   #
          #   0   0   0 -10   0   #
          #   0 -10   @   0 100   #
          #   #   #   #   #   #   #
        state: (3, 4) -> action:(3, 5)

----- Dump Field: (4, 5) -----
          #   #   #   #   #   #   #
          #   S   0   0 -10   0   #
          #   0 -10   0   0   0   #
          #   0 -10   0 -10   0   #
          #   0   0   0 -10   0   #
          #   0 -10   0   @ 100   #
          #   #   #   #   #   #   #
        state: (3, 5) -> action:(4, 5)

----- Dump Field: (5, 5) -----
          #   #   #   #   #   #   #
          #   S   0   0 -10   0   #
          #   0 -10   0   0   0   #
          #   0 -10   0 -10   0   #
          #   0   0   0 -10   0   #
          #   0 -10   0   0   @   #
          #   #   #   #   #   #   #
        state: (4, 5) -> action:(5, 5)

スクリプトを晒す

スクリプトの処理内容:
  1. 最初に盤面情報を出力
  2. LEARNING_COUNT(1000回)分 エピソード経過させ、Q値を学習.
  3. 学習後、Q値の情報を出力.
  4. greedy法で学習したQ値の観点で1 エピソードの行動選択をさせ、その遷移を出力.
スクリプト

最後にスクリプト170行ほど。冗長になってしまった感は否めない。もっとスマートに書けるアドバイス等していただけると嬉しいです。



#!/usr/bin/env python
# coding: utf-8

import sys
import copy
import random

# "S": Start地点, "#": 壁, "数値": 報酬
RAW_Field = """
#,#,#,#,#,#,#
#,S,0,0,-10,0,#
#,0,-10,0,0,0,#
#,0,-10,0,-10,0,#
#,0,0,0,-10,0,#
#,0,-10,0,0,100,#
#,#,#,#,#,#,#
"""


# 定数
ALPHA = 0.2 # LEARNING RATIO
GAMMA  = 0.9 # DISCOUNT RATIO
E_GREEDY_RATIO = 0.2
LEARNING_COUNT = 1000


class Field(object):
        """ Fieldに関するクラス """

        def __init__(self, raw_field=RAW_Field):
                self.raw_field = raw_field
                self.set_field_data()
                self.start_point = self.get_start_point()

        def set_field_data(self):
                """ 文字列のfieldデータ(raw_field)を2次元配列(field_data)に格納 """
                self.field_data = []
                for line in self.raw_field.split("\n"):
                        if line.strip() != "":
                                self.field_data.append(line.split(","))

        def display(self, point=None):

                """ Fieldの情報を出力する. """
                field_data = copy.deepcopy(self.field_data)
                if not point is None:
                        x, y = point
                        field_data[y][x] = "@"
                else:
                        point = ""
                print "----- Dump Field: %s -----" % str(point)
                for line in field_data:
                        print "\t" + "%3s " * len(line) % tuple(line)

        def get_actions(self, point):
                """ 引数で指定した座標から移動できる座標リストを獲得する. """
                x, y = point
                if self.field_data[y][x] == "#": sys.exit("Field.get_actions() ERROR: 壁を指定している.(x, y)=(%d, %d)" % (x, y))
                around_map = [(x, y-1), (x, y+1), (x-1 , y), (x+1, y)]
                return [(_x, _y) for _x, _y in around_map if self.field_data[_y][_x] != "#"]

        def get_val(self, point):
                """ 指定した座標のfieldの値を返す. エピソード終了判定もする. """
                x, y = point
                try:
                         v = float(self.field_data[y][x])
                         if v == 0.0: return v, False
                         else: return v, True
                except ValueError:
                        if self.field_data[y][x] == "S": return 0.0, False # start地点の時
                        sys.exit("Field.get_val() ERROR: 壁を指定している.(x, y)=(%d, %d)" % (x, y))

        def get_start_point(self):
                """ Field中の Start地点:"S" の座標を返す """
                for y, line in enumerate(self.field_data):
                        for x, v in enumerate(line):
                                if v == "S":
                                        return (x, y)
                sys.exit("Field.set_start_point() ERROR: FieldにStart地点がありません.")


class QLearning(object):
        """ class for Q Learning """

        def __init__(self, map_obj):
                self.Qvalue = {}
                self.Field = map_obj

        def learn(self, greedy_flg=False):
                """ 1エピソードを消化する. """
                state = self.Field.start_point
                #print "----- Episode -----"
                while True:
                        if greedy_flg:
                                action = self.choose_action_greedy(state)
                                self.Field.display(action)
                                print "\tstate: %s -> action:%s\n" % (state, action)
                        else: #default (Learning Mode)
                                action = self.choose_action(state)
                        if self.update_Qvalue(state, action):
                                break # finish this episode
                        else:
                                state = action # continue

        def update_Qvalue(self, state, action):
                """ Q値の更新を行う. """
                # 更新式:
                #       Q(s, a) <- Q(s, a) + alpha * {r(s, a) + gamma max{Q(s`, a`)} -  Q(s,a)}
                #               Q(s, a): 状態sにおける行動aを取った時のQ値      Q_s_a
                #               r(s, a): 状態sにおける報酬      r_s_a
                #               max{Q(s`, a`) 次の状態s`が取りうる行動a`の中で最大のQ値 mQ_s_a)
                Q_s_a = self.get_Qvalue(state, action)
                mQ_s_a = max([self.get_Qvalue(action, n_action) for n_action in self.Field.get_actions(action)])
                r_s_a, finish_flg = self.Field.get_val(action)
                # calculate
                q_value = Q_s_a + ALPHA * ( r_s_a +  GAMMA * mQ_s_a - Q_s_a)
                # update
                self.set_Qvalue(state, action, q_value)
                return finish_flg


        def get_Qvalue(self, state, action):
                """ Q(s,a)を取得する. s:state, a:action """
                try:
                        return self.Qvalue[state][action]
                except KeyError:
                        return 0.0

        def set_Qvalue(self, state, action, q_value):
                """ Q値に値を代入する. """
                self.Qvalue.setdefault(state,{})
                self.Qvalue[state][action] = q_value

        def choose_action(self, state):
                """ e-greedy法で行動を決める. """
                if E_GREEDY_RATIO < random.random():
                        #ランダムに行動選択
                        return random.choice(self.Field.get_actions(state))
                else:
                        # greedy法を適用する
                        return self.choose_action_greedy(state)

        def choose_action_greedy(self, state):
                """ greedy法で行動を決める. Q(s,a)の観点から選択. """
                best_actions = []
                max_q_value = -1
                for a in self.Field.get_actions(state):
                        q_value = self.get_Qvalue(state, a)
                        if q_value > max_q_value:
                                best_actions = [a,]
                                max_q_value = q_value
                        elif q_value == max_q_value:
                                best_actions.append(a)
                return random.choice(best_actions) # Q値の最大値が複数存在する場合はその中からランダムに選択

        def dump_Qvalue(self):
                """ Q値をdumpする. """
                print "##### Dump Qvalue #####"
                for i, s in enumerate(self.Qvalue.keys()):
                        for a in self.Qvalue[s].keys():
                                print "\t\tQ(s, a): Q(%s, %s): %s" % (str(s), str(a), str(self.Qvalue[s][a]))
                        if i != len(self.Qvalue.keys())-1: print '\t----- next state -----'


if __name__ == "__main__":

        # display Field information
        Field().display()
        # create QLearning object
        QL = QLearning(Field())
        # Learning Phase
        for i in range(LEARNING_COUNT):
                QL.learn() # Learning 1 episode
        # After Learning
        QL.dump_Qvalue() # Q値出力
        QL.learn(greedy_flg=True) # 学習結果をgreedy法で行動選択させてみる


# End of File #

感想

コーディングしていて、想定よりしていたよりさっとかけた自分に驚いた。
少し実装力ついてきたかも。日々のアリ本の鍛錬や勉強会のための写経などがいろいろと力になっているのを感じた.
今後もアリ本など日々の鍛練を続けていきたい。Pythonももっと詳しくなりたいな〜。

おわり.

まだ「よいお年を」とは言わないでおこう。自分への制約:年内に最低一本エントリー書くことをここに宣言する。