Spaces:
Sleeping
Sleeping
from easydict import EasyDict | |
import copy | |
class Node(): | |
""" | |
Overview: | |
Alpha-Beta-Pruning Search Node. | |
https://mathspp.com/blog/minimax-algorithm-and-alpha-beta-pruning | |
Arguments: | |
env: Class Env, such as | |
zoo.board_games.tictactoe.envs.tictactoe_env.TicTacToeEnv, | |
zoo.board_games.gomoku.envs.gomoku_env.GomokuEnv | |
""" | |
def __init__(self, board, legal_actions, start_player_index=0, parent=None, prev_action=None, env=None): | |
super().__init__() | |
self.env = env | |
self.board = board | |
self.legal_actions = copy.deepcopy(legal_actions) | |
self.children = [] | |
self.parent = parent | |
self.prev_action = prev_action | |
self.start_player_index = start_player_index | |
self.tree_expanded = False | |
def __str__(self): | |
return f"Tree({', '.join(str(child) for child in self.children)})" | |
def expand(self): | |
if self.start_player_index == 0: | |
next_start_player_index = 1 | |
else: | |
next_start_player_index = 0 | |
if self.is_terminal_node is False: | |
# Ensure self.legal_actions is valid before the loop | |
# self.legal_actions = self.env.get_legal_actions(self.board, self.start_player_index) | |
while len(self.legal_actions) > 0: | |
action = self.legal_actions.pop(0) | |
board, legal_actions = self.env.simulate_action_v2(self.board, self.start_player_index, action) | |
child_node = Node( | |
board, | |
legal_actions, | |
start_player_index=next_start_player_index, | |
parent=self, | |
prev_action=action, | |
env=self.env | |
) | |
# print('add one edge') | |
self.children.append(child_node) | |
self.tree_expanded = True | |
def expanded(self): | |
# return len(self.children) > 0 | |
return self.tree_expanded | |
def is_fully_expanded(self): | |
return len(self.children) == len(self.legal_actions) | |
def is_terminal_node(self): | |
self.env.reset_v2(self.start_player_index, init_state=self.board) # index | |
return self.env.get_done_reward()[0] | |
def value(self): | |
""" | |
def get_done_reward(self): | |
Overview: | |
To judge game whether over, and get reward | |
Returns: | |
[game_over, reward] | |
if winner = 1 reward = 1 | |
if winner = 2 reward = -1 | |
if winner = -1 reward = 0 | |
""" | |
self.env.reset_v2(self.start_player_index, init_state=self.board) # index | |
return self.env.get_done_reward()[1] | |
def estimated_value(self): | |
return 0 | |
def state(self): | |
return self.board | |
def pruning(tree, maximising_player, alpha=float("-inf"), beta=float("+inf"), depth=999, first_level=True): | |
if tree.is_terminal_node is True: | |
return tree.value | |
# TODO(pu): use a limited search depth | |
if depth == 0: | |
return tree.estimated_value | |
# print(ctree) | |
if tree.expanded is False: | |
tree.expand() | |
# print('expand one node!') | |
# for debug | |
# if (ctree.state == np.array([[0, 0, 0], [0, 0, 0], [0, 0, 1]])).all(): | |
# print('p1') | |
# if (ctree.state == np.array([[0, 0, 1], [2, 1, 2], [1, 2, 1]])).all(): | |
# print('p2') | |
val = float("-inf") if maximising_player else float("+inf") | |
for subtree in tree.children: | |
sub_val = pruning(subtree, not maximising_player, alpha, beta, depth - 1, first_level=False) | |
if maximising_player: | |
val = max(sub_val, val) | |
if val > alpha: | |
best_subtree = subtree | |
alpha = val | |
else: | |
val = min(sub_val, val) | |
if val < beta: | |
best_subtree = subtree | |
beta = val | |
if beta <= alpha: | |
break | |
if first_level is True: | |
return val, best_subtree | |
else: | |
return val | |
class AlphaBetaPruningBot: | |
def __init__(self, ENV, cfg, bot_name): | |
self.name = bot_name | |
self.ENV = ENV | |
self.cfg = cfg | |
def get_best_action(self, board, player_index, depth=999): | |
try: | |
simulator_env = copy.deepcopy(self.ENV(EasyDict(self.cfg))) | |
except: | |
simulator_env = copy.deepcopy(self.ENV) | |
simulator_env.reset(start_player_index=player_index, init_state=board) | |
root = Node(board, simulator_env.legal_actions, start_player_index=player_index, env=simulator_env) | |
if player_index == 0: | |
val, best_subtree = pruning(root, True, depth=depth, first_level=True) | |
else: | |
val, best_subtree = pruning(root, False, depth=depth, first_level=True) | |
# print(f'player_index: {player_index}, alpha-beta searched best_action: {best_subtree.prev_action}, its val: {val}') | |
return best_subtree.prev_action | |
if __name__ == "__main__": | |
import time | |
##### TicTacToe ##### | |
from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv | |
cfg = dict( | |
prob_random_agent=0, | |
prob_expert_agent=0, | |
battle_mode='self_play_mode', | |
agent_vs_human=False, | |
bot_action_type='alpha_beta_pruning', # {'v0', 'alpha_beta_pruning'} | |
channel_last=True, | |
scale=True, | |
) | |
env = TicTacToeEnv(EasyDict(cfg)) | |
player_0 = AlphaBetaPruningBot(TicTacToeEnv, cfg, 'player 1') # player_index = 0, player = 1 | |
player_1 = AlphaBetaPruningBot(TicTacToeEnv, cfg, 'player 2') # player_index = 1, player = 2 | |
### test from the init empty board ### | |
player_index = 0 # player 1 fist | |
env.reset() | |
### test from the init specified board ### | |
# player_index = 0 # player 1 fist | |
# init_state = [[1, 0, 1], | |
# [0, 0, 2], | |
# [2, 0, 1]] | |
# env.reset(player_index, init_state) | |
state = env.board | |
print('-' * 15) | |
print(state) | |
while not env.get_done_reward()[0]: | |
if player_index == 0: | |
start = time.time() | |
action = player_0.get_best_action(state, player_index=player_index) | |
print('player 1 action time: ', time.time() - start) | |
player_index = 1 | |
else: | |
start = time.time() | |
action = player_1.get_best_action(state, player_index=player_index) | |
print('player 2 action time: ', time.time() - start) | |
player_index = 0 | |
env.step(action) | |
state = env.board | |
print('-' * 15) | |
print(state) | |
row, col = env.action_to_coord(action) | |
### test from the init empty board ### | |
assert env.get_done_winner()[0] is False, env.get_done_winner()[1] == -1 | |
### test from the init specified board ### | |
# assert (row == 0, col == 1) or (row == 1, col == 1) | |
# assert env.get_done_winner()[0] is True, env.get_done_winner()[1] == 1 | |
""" | |
##### Gomoku ##### | |
from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv | |
cfg = dict( | |
board_size=5, | |
prob_random_agent=0, | |
prob_expert_agent=0, | |
battle_mode='self_play_mode', | |
scale=True, | |
channel_last=True, | |
agent_vs_human=False, | |
bot_action_type='alpha_beta_pruning', # {'v0', 'alpha_beta_pruning'} | |
prob_random_action_in_bot=0., | |
check_action_to_connect4_in_bot_v0=False, | |
) | |
env = GomokuEnv(EasyDict(cfg)) | |
player_0 = AlphaBetaPruningBot(GomokuEnv, cfg, 'player 1') # player_index = 0, player = 1 | |
player_1 = AlphaBetaPruningBot(GomokuEnv, cfg, 'player 2') # player_index = 1, player = 2 | |
### test from the init empty board ### | |
player_index = 0 # player 1 fist | |
env.reset() | |
### test from the init specified board ### | |
# player_index = 1 # player 2 fist | |
# init_state = [[1, 1, 1, 1, 0], | |
# [1, 0, 0, 0, 2], | |
# [0, 0, 2, 0, 2], | |
# [0, 2, 0, 0, 2], | |
# [2, 1, 1, 0, 0], ] | |
# # init_state = [[1, 1, 1, 1, 2], | |
# # [1, 1, 2, 1, 2], | |
# # [2, 1, 2, 2, 2], | |
# # [0, 0, 0, 2, 2], | |
# # [2, 1, 1, 1, 0], ] | |
# env.reset(player_index, init_state) | |
state = env.board | |
print('-' * 15) | |
print(state) | |
while not env.get_done_reward()[0]: | |
if player_index == 0: | |
start = time.time() | |
action = player_0.get_best_action(state, player_index=player_index) | |
print('player 1 action time: ', time.time() - start) | |
player_index = 1 | |
else: | |
start = time.time() | |
action = player_1.get_best_action(state, player_index=player_index) | |
print('player 2 action time: ', time.time() - start) | |
player_index = 0 | |
env.step(action) | |
state = env.board | |
print('-' * 15) | |
print(state) | |
assert env.get_done_winner()[0] is False, env.get_done_winner()[1] == -1 | |
# assert env.get_done_winner()[0] is True, env.get_done_winner()[1] == 2 | |
""" | |