no_one / MCTS.py
nullne's picture
Add application file
eb34cec
import logging
import math
import numpy as np
# from no_one.NoOnePlayes import HumanPlayer
EPS = 1e-8
log = logging.getLogger(__name__)
class MCTS():
"""
This class handles the MCTS tree.
"""
def __init__(self, game, nnet, args):
self.game = game
self.nnet = nnet
self.args = args
self.Qsa = {} # stores Q values for s,a (as defined in the paper)
self.Nsa = {} # stores #times edge s,a was visited
self.Ns = {} # stores #times board s was visited
self.Ps = {} # stores initial policy (returned by neural net)
self.Es = {} # stores game.getGameEnded ended for board s
self.Vs = {} # stores game.getValidMoves for board s
def getActionProb(self, canonicalBoard, temp=1):
"""
This function performs numMCTSSims simulations of MCTS starting from
canonicalBoard.
Returns:
probs: a policy vector where the probability of the ith action is
proportional to Nsa[(s,a)]**(1./temp)
"""
for i in range(self.args.numMCTSSims):
# self.search(canonicalBoard)
self.game.reset_steps()
self.search(canonicalBoard)
s = self.game.stringRepresentation(canonicalBoard)
counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.getActionSize())]
if temp == 0:
bestAs = np.array(np.argwhere(counts == np.max(counts))).flatten()
bestA = np.random.choice(bestAs)
probs = [0] * len(counts)
probs[bestA] = 1
return probs
counts = [x ** (1. / temp) for x in counts]
counts_sum = float(sum(counts))
if counts_sum == 0:
print(len(counts))
probs = [x / counts_sum for x in counts]
return probs
def search_iterate(self, canonicalBoard):
stack = [(0, (canonicalBoard,))] # Stack of (state, depth, parent_index)
results = [] # To store the results of leaf or terminal nodes
while stack:
st, sv = stack.pop()
if st == 0:
result, ns = self.search_iterate_st0(sv[0])
if result is not None:
results.append(result)
if ns is not None:
stack.append((1, (ns[1], ns[2])))
stack.append((0, (ns[0],)))
elif st == 1:
v = results.pop()
v = self.search_iterate_update(v, sv[0], sv[1])
results.append(v)
else:
raise ValueError("Invalid state")
return results.pop()
def search_iterate_st0(self, canonicalBoard):
s = self.game.stringRepresentation(canonicalBoard)
if s not in self.Es:
self.Es[s] = self.game.getGameEnded(canonicalBoard, 1)
if self.Es[s] != 0:
result = -self.Es[s]
return result, None
if s not in self.Ps:
# leaf node
self.Ps[s], v = self.nnet.predict(canonicalBoard)
valids = self.game.getValidMoves(canonicalBoard, 1)
self.Ps[s] = self.Ps[s] * valids
sum_Ps_s = np.sum(self.Ps[s])
if sum_Ps_s > 0:
self.Ps[s] /= sum_Ps_s
else:
self.Ps[s] = self.Ps[s] + valids
self.Ps[s] /= np.sum(self.Ps[s])
self.Vs[s] = valids
self.Ns[s] = 0
return -v, None
valids = self.Vs[s]
cur_best = -float('inf')
best_act = -1
for a in range(self.game.getActionSize()):
if valids[a]:
if (s, a) in self.Qsa:
u = self.Qsa[(s, a)] + self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / (1 + self.Nsa[(s, a)])
else:
u = self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s] + EPS)
if u > cur_best:
cur_best = u
best_act = a
next_s, next_player = self.game.getNextState(canonicalBoard, 1, best_act)
next_s = self.game.getCanonicalForm(next_s, next_player)
return None, (next_s, s, best_act)
# index = len(results) # Current index in results
# stack.append((next_s, depth + 1, index))
# # Backpropagate results
# for v, parent_index in reversed(results):
# if parent_index is not None:
# parent_v, _ = results[parent_index]
# results[parent_index] = ((parent_v * self.Ns[s] + v) / (self.Ns[s] + 1), _)
# self.Ns[s] += 1
# # Update Qsa and Nsa based on backpropagation
# for i, (v, parent_index) in enumerate(results):
# if parent_index is not None: # Ignore root
# _, action = stack[i] # Assuming we also pushed actions to stack
# if (s, action) in self.Qsa:
# self.Qsa[(s, action)] = (self.Nsa[(s, action)] * self.Qsa[(s, action)] + v) / (self.Nsa[(s, action)] + 1)
# self.Nsa[(s, action)] += 1
# else:
# self.Qsa[(s, action)] = v
# self.Nsa[(s, action)] = 1
# return -results[0][0] # Return the negated value of the root node
def search_iterate_update(self, v, s, a):
if (s, a) in self.Qsa:
self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1)
self.Nsa[(s, a)] += 1
else:
self.Qsa[(s, a)] = v
self.Nsa[(s, a)] = 1
self.Ns[s] += 1
return -v
def search(self, canonicalBoard, depth=0):
"""
This function performs one iteration of MCTS. It is recursively called
till a leaf node is found. The action chosen at each node is one that
has the maximum upper confidence bound as in the paper.
Once a leaf node is found, the neural network is called to return an
initial policy P and a value v for the state. This value is propagated
up the search path. In case the leaf node is a terminal state, the
outcome is propagated up the search path. The values of Ns, Nsa, Qsa are
updated.
NOTE: the return values are the negative of the value of the current
state. This is done since v is in [-1,1] and if v is the value of a
state for the current player, then its value is -v for the other player.
Returns:
v: the negative of the value of the current canonicalBoard
"""
s = self.game.stringRepresentation(canonicalBoard)
if s not in self.Es:
self.Es[s] = self.game.getGameEnded(canonicalBoard, 1)
if self.Es[s] != 0:
# terminal node
return -self.Es[s]
if s not in self.Ps:
# leaf node
self.Ps[s], v = self.nnet.predict(canonicalBoard)
valids = self.game.getValidMoves(canonicalBoard, 1)
self.Ps[s] = self.Ps[s] * valids # masking invalid moves
sum_Ps_s = np.sum(self.Ps[s])
if sum_Ps_s > 0:
self.Ps[s] /= sum_Ps_s # renormalize
else:
# if all valid moves were masked make all valid moves equally probable
# NB! All valid moves may be masked if either your NNet architecture is insufficient or you've get overfitting or something else.
# If you have got dozens or hundreds of these messages you should pay attention to your NNet and/or training process.
log.error("All valid moves were masked, doing a workaround.")
self.Ps[s] = self.Ps[s] + valids
self.Ps[s] /= np.sum(self.Ps[s])
self.Vs[s] = valids
self.Ns[s] = 0
return -v
valids = self.Vs[s]
cur_best = -float('inf')
best_act = -1
# pick the action with the highest upper confidence bound
for a in range(self.game.getActionSize()):
if valids[a]:
if (s, a) in self.Qsa:
u = self.Qsa[(s, a)] + self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / (
1 + self.Nsa[(s, a)])
else:
u = self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s] + EPS) # Q = 0 ?
if u > cur_best:
cur_best = u
best_act = a
a = best_act
if depth > 100:
candidates = self.game.getValidMoves(canonicalBoard, 1)
a = np.random.choice([i for i in range(len(candidates)) if candidates[i] == 1])
# self.game.display(canonicalBoard)
# human_player = HumanPlayer(self.game)
# a = human_player.play(canonicalBoard)
depth = 80
next_s, next_player = self.game.getNextState(canonicalBoard, 1, a)
next_s = self.game.getCanonicalForm(next_s, next_player)
# print("*", end="")
# self.game.display(next_s)
v = self.search(next_s, depth=depth + 1)
if (s, a) in self.Qsa:
self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1)
self.Nsa[(s, a)] += 1
else:
self.Qsa[(s, a)] = v
self.Nsa[(s, a)] = 1
self.Ns[s] += 1
return -v