import logging import math import numpy as np # from no_one.NoOnePlayes import HumanPlayer EPS = 1e-8 log = logging.getLogger(__name__) class MCTS(): """ This class handles the MCTS tree. """ def __init__(self, game, nnet, args): self.game = game self.nnet = nnet self.args = args self.Qsa = {} # stores Q values for s,a (as defined in the paper) self.Nsa = {} # stores #times edge s,a was visited self.Ns = {} # stores #times board s was visited self.Ps = {} # stores initial policy (returned by neural net) self.Es = {} # stores game.getGameEnded ended for board s self.Vs = {} # stores game.getValidMoves for board s def getActionProb(self, canonicalBoard, temp=1): """ This function performs numMCTSSims simulations of MCTS starting from canonicalBoard. Returns: probs: a policy vector where the probability of the ith action is proportional to Nsa[(s,a)]**(1./temp) """ for i in range(self.args.numMCTSSims): # self.search(canonicalBoard) self.game.reset_steps() self.search(canonicalBoard) s = self.game.stringRepresentation(canonicalBoard) counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.getActionSize())] if temp == 0: bestAs = np.array(np.argwhere(counts == np.max(counts))).flatten() bestA = np.random.choice(bestAs) probs = [0] * len(counts) probs[bestA] = 1 return probs counts = [x ** (1. / temp) for x in counts] counts_sum = float(sum(counts)) if counts_sum == 0: print(len(counts)) probs = [x / counts_sum for x in counts] return probs def search_iterate(self, canonicalBoard): stack = [(0, (canonicalBoard,))] # Stack of (state, depth, parent_index) results = [] # To store the results of leaf or terminal nodes while stack: st, sv = stack.pop() if st == 0: result, ns = self.search_iterate_st0(sv[0]) if result is not None: results.append(result) if ns is not None: stack.append((1, (ns[1], ns[2]))) stack.append((0, (ns[0],))) elif st == 1: v = results.pop() v = self.search_iterate_update(v, sv[0], sv[1]) results.append(v) else: raise ValueError("Invalid state") return results.pop() def search_iterate_st0(self, canonicalBoard): s = self.game.stringRepresentation(canonicalBoard) if s not in self.Es: self.Es[s] = self.game.getGameEnded(canonicalBoard, 1) if self.Es[s] != 0: result = -self.Es[s] return result, None if s not in self.Ps: # leaf node self.Ps[s], v = self.nnet.predict(canonicalBoard) valids = self.game.getValidMoves(canonicalBoard, 1) self.Ps[s] = self.Ps[s] * valids sum_Ps_s = np.sum(self.Ps[s]) if sum_Ps_s > 0: self.Ps[s] /= sum_Ps_s else: self.Ps[s] = self.Ps[s] + valids self.Ps[s] /= np.sum(self.Ps[s]) self.Vs[s] = valids self.Ns[s] = 0 return -v, None valids = self.Vs[s] cur_best = -float('inf') best_act = -1 for a in range(self.game.getActionSize()): if valids[a]: if (s, a) in self.Qsa: u = self.Qsa[(s, a)] + self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / (1 + self.Nsa[(s, a)]) else: u = self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s] + EPS) if u > cur_best: cur_best = u best_act = a next_s, next_player = self.game.getNextState(canonicalBoard, 1, best_act) next_s = self.game.getCanonicalForm(next_s, next_player) return None, (next_s, s, best_act) # index = len(results) # Current index in results # stack.append((next_s, depth + 1, index)) # # Backpropagate results # for v, parent_index in reversed(results): # if parent_index is not None: # parent_v, _ = results[parent_index] # results[parent_index] = ((parent_v * self.Ns[s] + v) / (self.Ns[s] + 1), _) # self.Ns[s] += 1 # # Update Qsa and Nsa based on backpropagation # for i, (v, parent_index) in enumerate(results): # if parent_index is not None: # Ignore root # _, action = stack[i] # Assuming we also pushed actions to stack # if (s, action) in self.Qsa: # self.Qsa[(s, action)] = (self.Nsa[(s, action)] * self.Qsa[(s, action)] + v) / (self.Nsa[(s, action)] + 1) # self.Nsa[(s, action)] += 1 # else: # self.Qsa[(s, action)] = v # self.Nsa[(s, action)] = 1 # return -results[0][0] # Return the negated value of the root node def search_iterate_update(self, v, s, a): if (s, a) in self.Qsa: self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1) self.Nsa[(s, a)] += 1 else: self.Qsa[(s, a)] = v self.Nsa[(s, a)] = 1 self.Ns[s] += 1 return -v def search(self, canonicalBoard, depth=0): """ This function performs one iteration of MCTS. It is recursively called till a leaf node is found. The action chosen at each node is one that has the maximum upper confidence bound as in the paper. Once a leaf node is found, the neural network is called to return an initial policy P and a value v for the state. This value is propagated up the search path. In case the leaf node is a terminal state, the outcome is propagated up the search path. The values of Ns, Nsa, Qsa are updated. NOTE: the return values are the negative of the value of the current state. This is done since v is in [-1,1] and if v is the value of a state for the current player, then its value is -v for the other player. Returns: v: the negative of the value of the current canonicalBoard """ s = self.game.stringRepresentation(canonicalBoard) if s not in self.Es: self.Es[s] = self.game.getGameEnded(canonicalBoard, 1) if self.Es[s] != 0: # terminal node return -self.Es[s] if s not in self.Ps: # leaf node self.Ps[s], v = self.nnet.predict(canonicalBoard) valids = self.game.getValidMoves(canonicalBoard, 1) self.Ps[s] = self.Ps[s] * valids # masking invalid moves sum_Ps_s = np.sum(self.Ps[s]) if sum_Ps_s > 0: self.Ps[s] /= sum_Ps_s # renormalize else: # if all valid moves were masked make all valid moves equally probable # NB! All valid moves may be masked if either your NNet architecture is insufficient or you've get overfitting or something else. # If you have got dozens or hundreds of these messages you should pay attention to your NNet and/or training process. log.error("All valid moves were masked, doing a workaround.") self.Ps[s] = self.Ps[s] + valids self.Ps[s] /= np.sum(self.Ps[s]) self.Vs[s] = valids self.Ns[s] = 0 return -v valids = self.Vs[s] cur_best = -float('inf') best_act = -1 # pick the action with the highest upper confidence bound for a in range(self.game.getActionSize()): if valids[a]: if (s, a) in self.Qsa: u = self.Qsa[(s, a)] + self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / ( 1 + self.Nsa[(s, a)]) else: u = self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s] + EPS) # Q = 0 ? if u > cur_best: cur_best = u best_act = a a = best_act if depth > 100: candidates = self.game.getValidMoves(canonicalBoard, 1) a = np.random.choice([i for i in range(len(candidates)) if candidates[i] == 1]) # self.game.display(canonicalBoard) # human_player = HumanPlayer(self.game) # a = human_player.play(canonicalBoard) depth = 80 next_s, next_player = self.game.getNextState(canonicalBoard, 1, a) next_s = self.game.getCanonicalForm(next_s, next_player) # print("*", end="") # self.game.display(next_s) v = self.search(next_s, depth=depth + 1) if (s, a) in self.Qsa: self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1) self.Nsa[(s, a)] += 1 else: self.Qsa[(s, a)] = v self.Nsa[(s, a)] = 1 self.Ns[s] += 1 return -v