Spaces:
Sleeping
Sleeping
""" | |
Overview: | |
This code implements an MCTSbot that uses MCTS to make decisions. | |
The MCTSnode is an abstract base class that specifies the basic methods that a Monte Carlo Tree node should have. | |
The TwoPlayersMCTSnode class inherits from this base class and implements the specific methods. | |
MCTS implements the search function, which takes in a root node and performs a search to obtain the optimal action. | |
MCTSbot integrates the above functions and can create a root node based on the current game environment, | |
and then calls MCTS to perform a search and make a decision. | |
For more details, you can refer to: https://github.com/int8/monte-carlo-tree-search. | |
""" | |
import time | |
from abc import ABC, abstractmethod | |
from collections import defaultdict | |
import numpy as np | |
import copy | |
class MCTSNode(ABC): | |
""" | |
Overview: | |
This is an abstract base class that outlines the fundamental methods for a Monte Carlo Tree node. | |
Each specific method must be implemented in the subclasses for specific use-cases. | |
""" | |
def __init__(self, env, parent=None): | |
""" | |
Arguments: | |
- env (:obj:`BaseEnv`): The game environment of the current node. | |
The properties of this object contain information about the current game environment. | |
For instance, in a game of tictactoe: | |
- env.board: A (3,3) array representing the game board, e.g., | |
[[0,2,0], | |
[1,1,0], | |
[2,0,0]] | |
Here, 0 denotes an unplayed position, 1 represents a position occupied by player 1, and 2 indicates a position taken by player 2. | |
- env.players: A list [1,2] representing the two players, player 1 and player 2 respectively. | |
- env._current_player: Denotes the player who is to make a move in the current turn, which is alterating in each turn not only in the reset phase. | |
The methods of this object implement functionalities such as game state transitions and retrieving game results. | |
- parent (:obj:`MCTSNode`): The parent node of the current node. The parent node is primarily used for backpropagation during the Monte Carlo Tree Search. | |
For the root node, this parent returns None as it does not have a parent node. | |
""" | |
self.env = env | |
self.parent = parent | |
self.children = [] | |
self.expanded_actions = [] | |
self.best_action = -1 | |
def legal_actions(self): | |
pass | |
def value(self): | |
pass | |
def visit_count(self): | |
pass | |
def expand(self): | |
pass | |
def is_terminal_node(self): | |
pass | |
def rollout(self): | |
pass | |
def backpropagate(self, reward): | |
pass | |
def is_fully_expanded(self): | |
""" | |
Overview: | |
This method checks if the node is fully expanded. | |
A node is considered fully expanded when all of its child nodes have been selected at least once. | |
Whenever a new child node is selected for the first time, a corresponding action is removed from the list of legal actions. | |
Once the list of legal actions is depleted, it signifies that all child nodes have been selected, | |
thereby indicating that the parent node is fully expanded. | |
""" | |
return len(self.legal_actions) == 0 | |
def best_child(self, c_param=1.4): | |
""" | |
Overview: | |
This function finds the best child node which has the highest UCB (Upper Confidence Bound) score. | |
The UCB formula is: | |
{UCT}(v_i, v) = \frac{Q(v_i)}{N(v_i)} + c \sqrt{\frac{\log(N(v))}{N(v_i)}} | |
- Q(v_i) is the estimated value of the child node v_i. | |
- N(v_i) is a counter of how many times the child node v_i has been on the backpropagation path. | |
- N(v) is a counter of how many times the parent (current) node v has been on the backpropagation path. | |
- c is a parameter which balances exploration and exploitation. | |
Arguments: | |
- c_param (:obj:`float`): a parameter which controls the balance between exploration and exploitation. Default value is 1.4. | |
Returns: | |
- node (:obj:`MCTSnode`)The child node which has the highest UCB score. | |
""" | |
# Calculate the ucb score for every child node in the list. | |
choices_weights = [(child_node.value / child_node.visit_count) + c_param * np.sqrt( | |
(2 * np.log(self.visit_count) / child_node.visit_count)) for child_node in self.children] | |
# Save the best action based on the highest UCB score. | |
self.best_action = self.expanded_actions[np.argmax(choices_weights)] | |
# Choose the child node which has the highest ucb score. | |
return self.children[np.argmax(choices_weights)] | |
def rollout_policy(self, possible_actions): | |
""" | |
Overview: | |
This method implements the rollout policy for a node during the Monte Carlo Tree Search. | |
The rollout policy is used to determine the action taken during the simulation phase of the MCTS. | |
In this case, the policy is to randomly choose an action from the list of possible actions. | |
Arguments: | |
- possible_actions(:obj:`list`): A list of all possible actions that can be taken from the current state. | |
Return: | |
- action(:obj:`int`): A randomly chosen action from the list of possible actions. | |
""" | |
return possible_actions[np.random.randint(len(possible_actions))] | |
class TwoPlayersMCTSNode(MCTSNode): | |
""" | |
Overview: | |
This subclass inherits from the abstract base class and implements the specific methods required for a two players' Monte Carlo Tree node. | |
""" | |
def __init__(self, env, parent=None): | |
""" | |
Overview: | |
This function initializes a new instance of the class. It sets the parent node, environment, and initializes the number of visits, results, and legal actions. | |
Arguments: | |
- env (:obj:`BaseEnv`): the environment object which contains information about the current game state. | |
- parent (:obj:`MCTSNode`): the parent node of this node. If None, then this node is the root node. | |
""" | |
super().__init__(env, parent) | |
self._number_of_visits = 0. | |
# A default dictionary which sets the value to 0 for undefined keys. | |
self._results = defaultdict(int) | |
self._legal_actions = None | |
# Get all legal actions in current state from the environment object. | |
def legal_actions(self): | |
if self._legal_actions is None: | |
self._legal_actions = copy.deepcopy(self.env.legal_actions) | |
return self._legal_actions | |
def value(self): | |
""" | |
Overview: | |
This property represents the estimated value (Q-value) of the current node. | |
self._results[1] represents the number of wins for player 1. | |
self._results[-1] represents the number of wins for player 2. | |
The Q-value is calculated depends on which player is the current player at the parent node, | |
and is computed as the difference between the wins of the current player and the opponent. | |
If the parent's current player is player 1, Q-value is the difference of player 1's wins and player 2's wins. | |
If the parent's current player is player 2, Q-value is the difference of player 2's wins and player 1's wins. | |
For example, if self._results[1] = 10 (player 1's wins) and self._results[-1] = 5 (player 2's wins): | |
- If the parent's current player is player 1, then Q-value = 10 - 5 = 5. | |
- If the parent's current player is player 2, then Q-value = 5 - 10 = -5. | |
This way, a higher Q-value for a node indicates a higher win rate for the parent's current player. | |
""" | |
# Determine the number of wins and losses based on the current player at the parent node. | |
wins, loses = (self._results[1], self._results[-1]) if self.parent.env.current_player == 1 else ( | |
self._results[-1], self._results[1]) | |
# Calculate and return the Q-value as the difference between wins and losses. | |
return wins - loses | |
def visit_count(self): | |
""" | |
Overview: | |
This property represents the number of times the node has been visited during the search. | |
""" | |
return self._number_of_visits | |
def expand(self): | |
""" | |
Overview: | |
This method expands the current node by creating a new child node. | |
It pops an action from the list of legal actions, simulates the action to get the next game state, | |
and creates a new child node with that state. The new child node is then added to the list of children nodes. | |
Returns: | |
- node(:obj:`TwoPlayersMCTSNode`): The child node object that has been created. | |
""" | |
# Choose an untried action from the list of legal actions and pop it out. Only untried actions are left in the list. | |
action = self.legal_actions.pop() | |
# The simulate_action() function returns a new environment which resets the board and the current player flag. | |
next_simulator_env = self.env.simulate_action(action) | |
# Create a new node object for the child node and append it to the children list. | |
child_node = TwoPlayersMCTSNode(next_simulator_env, parent=self) | |
self.children.append(child_node) | |
# Add the action that has been tried to the expanded_actions list. | |
self.expanded_actions.append(action) | |
# Return the child node object. | |
return child_node | |
def is_terminal_node(self): | |
""" | |
Overview: | |
This function checks whether the current node is a terminal node. | |
It uses the game environment's get_done_reward method to check if the game has ended. | |
Returns: | |
- A bool flag representing whether the game is over. | |
""" | |
# The get_done_reward() returns a tuple (done, reward). | |
# reward = ±1 when player 1 wins/loses the game. | |
# reward = 0 when it is a tie. | |
# reward = None when current node is not a teminal node. | |
# done is the bool flag representing whether the game is over. | |
return self.env.get_done_reward()[0] | |
def rollout(self): | |
""" | |
Overview: | |
This method performs a rollout (simulation) from the current node. | |
It repeatedly selects an action based on the rollout policy and simulates the action until the game ends. | |
The method then returns the reward of the game's final state. | |
Returns: | |
-reward (:obj:`int`): reward = ±1 when player 1 wins/loses the game, reward = 0 when it is a tie, reward = None when current node is not a teminal node. | |
""" | |
# print('simulation begin') | |
current_rollout_env = self.env | |
# print(current_rollout_env.board) | |
while not current_rollout_env.get_done_reward()[0]: | |
possible_actions = current_rollout_env.legal_actions | |
action = self.rollout_policy(possible_actions) | |
current_rollout_env = current_rollout_env.simulate_action(action) | |
# print('\n') | |
# print(current_rollout_env.board) | |
# print('simulation end \n') | |
return current_rollout_env.get_done_reward()[1] | |
def backpropagate(self, result): | |
""" | |
Overview: | |
This method performs backpropagation from the current node. | |
It increments the number of times the node has been visited and the number of wins for the result. | |
If the current node has a parent, the method recursively backpropagates the result to the parent. | |
""" | |
self._number_of_visits += 1. | |
# result is the index of the self._results list. | |
# result = ±1 when player 1 wins/loses the game. | |
self._results[result] += 1. | |
if self.parent: | |
self.parent.backpropagate(result) | |
class MCTS(object): | |
""" | |
Overview: | |
This class implements Monte Carlo Tree Search from the root node, whose environment is the real environment of the game at the current moment. | |
After the tree search and rollout simulation, every child node of the root node has a UCB value. | |
Then the decision for the root node is to choose the child node with the highest UCB value. | |
""" | |
def __init__(self, node): | |
""" | |
Overview: | |
This function initializes a new instance of the MCTS class with the given root node object. | |
Parameters: | |
- node (:obj:`TwoPlayersMCTSNode`): The root node object for the MCTS. | |
""" | |
self.root = node | |
def best_action(self, simulations_number=None, total_simulation_seconds=None, best_action_type="UCB"): | |
""" | |
Overview: | |
This function simulates the game by constantly selecting the best child node and backpropagating the result. | |
Arguments: | |
- simulations_number (:obj:`int`): The number of simulations performed to get the best action. | |
- total_simulation_seconds (:obj:`float`): The amount of time the algorithm has to run. Specified in seconds. | |
- best_action_type (:obj:`str`): The type of best action selection to use. Either "UCB" or "most visited". | |
Returns: | |
- node(:obj:`TwoPlayersMCTSNode`): The best children node object, which contains the best action to take. | |
""" | |
# The search cost is determined by either the maximum number of simulations or the longest simulation time. | |
# If no simulations number is provided, run simulations for the specified time. | |
if simulations_number is None: | |
assert (total_simulation_seconds is not None) | |
end_time = time.time() + total_simulation_seconds | |
while True: | |
# Get the leaf node. | |
leaf_node = self._tree_policy() | |
# Rollout from the leaf node. | |
reward = leaf_node.rollout() | |
# Backpropagate from the leaf node to the root node. | |
leaf_node.backpropagate(reward) | |
if time.time() > end_time: | |
break | |
# If a simulation number is provided, run the specified number of simulations. | |
else: | |
for i in range(0, simulations_number): | |
# print('****simlulation-{}****'.format(i)) | |
# Get the leaf node. | |
leaf_node = self._tree_policy() | |
# Rollout from the leaf node. | |
reward = leaf_node.rollout() | |
# print('reward={}\n'.format(reward)) | |
# Backpropagate from the leaf node to the root node. | |
leaf_node.backpropagate(reward) | |
# To select the best child go for exploitation only. | |
if best_action_type == "UCB": | |
return self.root.best_child(c_param=0.) | |
else: | |
children_visit_counts = [child_node.visit_count for child_node in self.root.children] | |
self.root.best_action = self.root.expanded_actions[np.argmax(children_visit_counts)] | |
return self.root.children[np.argmax(children_visit_counts)] | |
# | |
def _tree_policy(self): | |
""" | |
Overview: | |
This function implements the tree search from the root node to the leaf node, which is either visited for the first time or is the terminal node. | |
At each step, if the current node is not fully expanded, it expands. | |
If it is fully expanded, it moves to the best child according to the tree policy. | |
Returns: | |
- node(:obj:`TwoPlayersMCTSNode`): The leaf node object that has been reached by the tree search. | |
""" | |
current_node = self.root | |
while not current_node.is_terminal_node(): | |
if not current_node.is_fully_expanded(): | |
# choose a child node which has not been visited before | |
return current_node.expand() | |
else: | |
current_node = current_node.best_child() | |
return current_node | |
class MCTSBot: | |
""" | |
Overview: | |
A robot which can use MCTS to make decision, choose an action to take. | |
""" | |
def __init__(self, env, bot_name, num_simulation=50): | |
""" | |
Overview: | |
This function initializes a new instance of the MCTSBot class. | |
Arguments: | |
- env (:obj:`BaseEnv`): The environment object for the game. | |
- bot_name (:obj:`str`): The name of the MCTS Bot. | |
- num_simulation (:obj:`int`): The number of simulations to perform during the MCTS. | |
""" | |
self.name = bot_name | |
self.num_simulation = num_simulation | |
self.simulator_env = env | |
def get_actions(self, state, player_index, best_action_type="UCB"): | |
""" | |
Overview: | |
This function gets the actions that the MCTS Bot will take. | |
The environment is reset to the given state. | |
Then, MCTS is performed with the specified number of simulations to find the best action. | |
Arguments: | |
- state (:obj:`list`): The current game state. | |
- player_index (:obj:`int`): The index of the current player. | |
- best_action_type (:obj:`str`): The type of best action selection to use. Either "UCB" or "most visited". | |
Returns: | |
- action (:obj:`int`): The best action that the MCTS Bot will take. | |
""" | |
# Every time before make a decision, reset the environment to the current environment of the game. | |
self.simulator_env.reset(start_player_index=player_index, init_state=state) | |
root = TwoPlayersMCTSNode(self.simulator_env) | |
# Do the MCTS to find the best action to take. | |
mcts = MCTS(root) | |
mcts.best_action(self.num_simulation, best_action_type=best_action_type) | |
return root.best_action | |