|
import streamlit as st |
|
import numpy as np |
|
|
|
import Arena |
|
|
|
from MCTS import MCTS |
|
|
|
from no_one.NoOneGame import NoOneGame |
|
|
|
|
|
from no_one.pytorch.NNet import NNetWrapper as NNet |
|
from utils import * |
|
|
|
import time |
|
|
|
|
|
square_content = { |
|
-1: "❌", |
|
+0: "·", |
|
+1: "⭕" |
|
} |
|
|
|
players = [{"name": "❌"}, {"name": "⭕"}] |
|
player_types = [":rainbow[AI]", "Human"] |
|
ai_ranks = ["Super", "First Blood", "Random"] |
|
ai_models = { |
|
"Super": "super.pth.tar", |
|
"First Blood": "firstBlood.pth.tar", |
|
"Random": "random.pth.tar", |
|
} |
|
|
|
game = NoOneGame(4) |
|
|
|
def check_clicked(i, j): |
|
if st.session_state.clicked == (i, j): |
|
return True |
|
return False |
|
|
|
class HumanPlayer: |
|
def __init__(self, game): |
|
self.game = game |
|
def encode(self, src, target): |
|
row, col = src[0] - target[0], src[1] - target[1] |
|
if (row, col) == (0, -1): |
|
d = 1 |
|
elif (row, col) == (0, 1): |
|
d = 3 |
|
elif (row, col) == (-1, 0): |
|
d = 0 |
|
elif (row, col) == (1, 0): |
|
d = 2 |
|
else: |
|
return None |
|
return self.game.encodeAction(src, d) |
|
def play(self, board, player, src, target): |
|
action = self.encode(src, target) |
|
if action is None: |
|
return None |
|
valids = self.game.getValidMoves(board, player) |
|
if valids[action] != 1: |
|
return None |
|
return action |
|
|
|
|
|
def ai_player(ai_model): |
|
n1 = NNet(game) |
|
n1.load_checkpoint('./models/', ai_models[ai_model]) |
|
args1 = dotdict({'numMCTSSims': 50, 'cpuct':1.0}) |
|
msts1 = MCTS(game, n1, args1) |
|
return lambda x: np.argmax(msts1.getActionProb(x, temp=0)) |
|
|
|
|
|
def human_step(i, j): |
|
if st.session_state.clicked is None: |
|
if st.session_state.board[i, j] != st.session_state.player: |
|
st.toast('Invalid move!') |
|
return |
|
st.session_state.clicked = (i, j) |
|
return |
|
elif st.session_state.clicked == (i, j): |
|
|
|
st.session_state.clicked = None |
|
return |
|
elif st.session_state.board[i, j] == st.session_state.player: |
|
st.session_state.clicked = (i, j) |
|
return |
|
else: |
|
|
|
player = st.session_state.players[st.session_state.player + 1] |
|
action = player.play(st.session_state.board, st.session_state.player, st.session_state.clicked, (i, j)) |
|
if action is None: |
|
st.toast('Invalid move!') |
|
return |
|
move(action) |
|
st.session_state.clicked = None |
|
|
|
|
|
def ai_step(): |
|
p = st.session_state.player |
|
action = st.session_state.players[p+1](game.getCanonicalForm(st.session_state.board, st.session_state.player)) |
|
|
|
|
|
|
|
move(action) |
|
|
|
|
|
def move(action): |
|
p = st.session_state.player |
|
st.session_state.board, st.session_state.player = game.getNextState( |
|
st.session_state.board, p, action, |
|
) |
|
|
|
res = game.getGameEnded(st.session_state.board, p) |
|
if res != 0: |
|
st.session_state.winner = res*p |
|
st.session_state.win[res*p] += 1 |
|
st.balloons() |
|
|
|
def reinit(): |
|
st.session_state.reinit = True |
|
|
|
def init(post_init=False): |
|
st.session_state.reinit = False |
|
if not post_init: |
|
st.session_state.win = {-1: 0, 1: 0} |
|
|
|
st.session_state.board = game.getInitBoard() |
|
st.session_state.player = -1 |
|
st.session_state.players = [None for i in range(2)] |
|
for i, setting in enumerate(st.session_state.player_settings): |
|
if setting["pt"] == player_types[1]: |
|
st.session_state.players[i] = HumanPlayer(game) |
|
else: |
|
st.session_state.players[i] = ai_player(setting["ai_model"]) |
|
st.session_state.players.insert(1, None) |
|
|
|
st.session_state.winner = None |
|
st.session_state.clicked = None |
|
|
|
|
|
def player_to_index(p): |
|
if p == -1: |
|
return 0 |
|
elif p == 1: |
|
return 1 |
|
|
|
|
|
def main(): |
|
if "player_settings" not in st.session_state: |
|
st.session_state.player_settings = [ |
|
{"pt": player_types[1], "ai_model": ai_ranks[0]}, |
|
{"pt": player_types[0], "ai_model": ai_ranks[1]}, |
|
] |
|
if "reinit" not in st.session_state: |
|
st.session_state.reinit = False |
|
|
|
fire, settings = st.columns([1, 2]) |
|
fire.button('New Game', on_click=init, args=(True,)) |
|
with settings.expander( |
|
'Settings', |
|
expanded=False, |
|
): |
|
st.warning('Any setting changing will restart the game immediately', icon="⚠️") |
|
for i, p in enumerate(st.columns([0.5, 0.5])): |
|
with p: |
|
|
|
st.session_state.player_settings[i]["pt"] = st.radio( |
|
f"Who will play %s" % players[i]["name"], |
|
player_types, |
|
key=f"xp_type_{i}", |
|
horizontal=True, |
|
index=player_types.index(st.session_state.player_settings[i]["pt"]), |
|
on_change=reinit, |
|
) |
|
if st.session_state.player_settings[i]["pt"] == player_types[0]: |
|
st.session_state.player_settings[i]["ai_model"] = st.radio( |
|
"AI rank", |
|
ai_ranks, |
|
key=f"xp_rank_{i}", |
|
index=ai_ranks.index(st.session_state.player_settings[i]["ai_model"]), |
|
on_change=reinit, |
|
) |
|
|
|
st.divider() |
|
st.checkbox("Show me how to play", key="show_how_to_play") |
|
|
|
|
|
if "board" not in st.session_state or st.session_state.reinit: |
|
init() |
|
|
|
if st.session_state.show_how_to_play: |
|
with st.sidebar: |
|
st.title("How to play") |
|
how_to = '''1. 游戏的目标是吃掉对方的棋子 |
|
开局双方各有四枚棋子,被吃剩一枚棋子即可判负 |
|
2. 棋子可以向上下左右移动到空白位置 |
|
3. 如何吃子 |
|
1. 移动后的棋子需要跟本方其他某个棋子,在水平方向,或者垂直方向上连住 |
|
2. 如果连住后的棋子两端有对方一个棋子,那么这个对方的棋子就被吃掉 |
|
3. 如果对方也有两颗棋子,则互相不吃 |
|
4. 如果落子在对方连起来的两枚棋子一端,也不会被吃 |
|
5. 如果选择 AI vs AI,需要手动点击比分下面的回合按钮,触发下一次落子 |
|
''' |
|
st.markdown(how_to) |
|
|
|
|
|
xp, score, op = st.columns([2, 8, 2]) |
|
for i, p in enumerate([xp, op]): |
|
p.title(players[i]["name"], anchor=False) |
|
caption = st.session_state.player_settings[i]["pt"] |
|
if st.session_state.player_settings[i]["pt"] == player_types[0]: |
|
caption = st.session_state.player_settings[i]["ai_model"] + " " + caption |
|
p.caption(caption) |
|
|
|
|
|
|
|
if st.session_state.player_settings[player_to_index(st.session_state.player)]['pt'] != "Human" and st.session_state.winner is None: |
|
ai_step() |
|
|
|
st.divider() |
|
|
|
for i, row in enumerate(st.session_state.board): |
|
cols = st.columns([5, 1, 1, 1, 1, 5]) |
|
for j, field in enumerate(row): |
|
cols[j + 1].button( |
|
square_content[field], |
|
key=f"{i}-{j}", |
|
type=f'{"primary" if check_clicked(i, j) else "secondary"}', |
|
on_click=human_step, |
|
args=(i, j), |
|
) |
|
|
|
s = score.columns([2, 2, 2]) |
|
s[1].title(f'{st.session_state.win[-1]} : {st.session_state.win[1]}', anchor=False) |
|
s[1].button( |
|
f'{"❌" if st.session_state.player == -1 else "⭕"}\'s turn' |
|
if not st.session_state.winner |
|
else f'🏁 Game finished' |
|
) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |