AlphaGo Zero 详解



之前看了AlphaGo Zero 的整个流程,接下来就要了解一下具体怎么实现的。毕设选择做用 AlphaGoZero 做五子棋,也在网上找到了相当不错的前人写的 代码。我要做的是先看懂他写的,然后再试试改进算法的性能。

首先要实现 MCTS 的部分,原版注释用英语写的。现在我要一步一步的分析。

首先创建节点类 TreeNode:

class TreeNode(object): def __init__(self, parent, prior_p): self._parent = parent self._children = {} self._n_visits = 0 self._Q = 0 self._u = 0 self._P = prior_p def select(self, c_puct): def expand(self, action_priors): def update(self, leaf_value): def update_recursive(self, leaf_value): def get_value(self, c_puct): def is_leaf(self): def is_root(self):

TreeNode 类里初始化了一些数值,主要是 父节点,子节点,访问节点的次数,Q值和u值,还有先验概率。他还定义了一些函数:

def select(self, c_puct): return max(self._children.items(), key=lambda act_node: act_node[1].get_value(c_puct)) def get_value(self, c_puct): self._u = c_puct * self._P * np.sqrt(self._parent._n_visits) / (1 + self._n_visits) return self._Q + self._u

select() 的功能:选择 在子节中选择具有 (Q+u)最大的节点,c_puct是需要我们定义的值,在后面会说到。

def expand(self, action_priors): for action, prob in action_priors: if action not in self._children: self._children[action] = TreeNode(self, prob)

expand() 的功能:扩展 输入action_priors 是一个包括的所有合法动作的列表(list),表示在当前局面我可以在哪些地方落子。此函数为当前节点扩展了子节点。

def update(self, leaf_value): # Count visit. self._n_visits += 1 # Update Q, a running average of values for all visits. self._Q += 1.0*(leaf_value - self._Q) / self._n_visits def update_recursive(self, leaf_value): # If it is not root, this node's parent should be updated first. if self._parent: self._parent.update_recursive(-leaf_value) self.update(leaf_value)

update_recursive() 的功能:回溯 从该节点开始,自上而下地 更新 所有 的父节点。


def is_leaf(self): return self._children == {} def is_root(self): return self._parent is None


以上是MCTS中的三个流程(一二四),我们发现还少了一个最重要的第三步:模拟,模拟的步骤写在了 MCTS类中。

class MCTS(object): def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): self._root = TreeNode(None, 1.0) self._policy = policy_value_fn self._c_puct = c_puct self._n_playout = n_playout def _playout(self, state): node = self._root while True: if node.is_leaf(): break # Greedily select next move. action, node = node.select(self._c_puct) state.do_move(action) # Evaluate the leaf using a network which outputs a list of (action, probability) # tuples p and also a score v in [-1, 1] for the current player. action_probs, leaf_value = self._policy(state) # Check for end of game. end, winner = state.game_end() if not end: node.expand(action_probs) else: # for end state,return the "true" leaf_value if winner == -1: # tie leaf_value = 0.0 else: leaf_value = 1.0 if winner == state.get_current_player() else -1.0 # Update value and visit count of nodes in this traversal. node.update_recursive(-leaf_value) def get_move_probs(self, state, temp=1e-3): for n in range(self._n_playout): state_copy = copy.deepcopy(state) self._playout(state_copy) act_visits = [(act, node._n_visits) for act, node in self._root._children.items()] acts, visits = zip(*act_visits) act_probs = softmax(1.0/temp * np.log(visits)) return acts, act_probs def update_with_move(self, last_move): if last_move in self._root._children: self._root = self._root._children[last_move] self._root._parent = None else: self._root = TreeNode(None, 1.0) def __str__(self): return "MCTS"


policy_value_fn:当前采用的策略函数,输入是当前棋盘的状态,输出 (action, prob)元祖和score[-1,1]。c_puct:控制探索和回报的比例,值越大表示越依赖之前的先验概率。n_playout:MCTS的执行次数,值越大,消耗的时间越多,效果也越好。

他还定义了一个根节点 self._root = TreeNode(None, 1.0) 父节点:None,先验概率:1.0

_playout(self, state): 此函数有一个输入参数:state, 它表示当前的状态。 这个函数的功能就是 模拟。它根据当前的状态进行游戏,用贪心算法一条路走到黑,直到叶子节点,再判断游戏结束与否。如果游戏没有结束,则 扩展 节点,否则 回溯 更新叶子节点和所有祖先的值。

get_move_probs(self, state, temp): 之前所有的代码都是为这个函数做铺垫。它的功能是从当前状态开始获得所有可行行动以及它们的概率。也就是说它能根据棋盘的状态,结合之前介绍的代码,告诉你它计算的结果,在棋盘的各个位置落子的胜率是多少。有了它,我们就能让计算机学会下棋。

update_with_move(self, last_move): 自我对弈时,每走一步之后更新MCTS的子树。 与玩家对弈时,每一个回合都要重置子树。


class MCTSPlayer(object): """AI player based on MCTS""" def __init__(self, policy_value_function, c_puct=5, n_playout=2000, is_selfplay=0): self.mcts = MCTS(policy_value_function, c_puct, n_playout) self._is_selfplay = is_selfplay def set_player_ind(self, p): self.player = p def reset_player(self): self.mcts.update_with_move(-1) def get_action(self, board, temp=1e-3, return_prob=0): sensible_moves = board.availables move_probs = np.zeros(board.width * board.height) # the pi vector returned by MCTS as in the alphaGo Zero paper if len(sensible_moves) > 0: acts, probs = self.mcts.get_move_probs(board, temp) move_probs[list(acts)] = probs if self._is_selfplay: # add Dirichlet Noise for exploration (needed for self-play training) move = np.random.choice(acts, p=0.75 * probs + 0.25 * np.random.dirichlet(0.3 * np.ones(len(probs)))) self.mcts.update_with_move(move) # update the root node and reuse the search tree else: # with the default temp=1e-3, this is almost equivalent to choosing the move with the highest prob move = np.random.choice(acts, p=probs) # reset the root node self.mcts.update_with_move(-1) if return_prob: return move, move_probs else: return move else: print("WARNING: the board is full")

MCTSPlayer类的主要功能在函数get_action(self, board, temp=1e-3, return_prob=0)里实现。自我对弈的时候会有一定的探索几率,用来训练。与人类下棋是总是选择最优策略 ,用来检测训练成果。


