wordle-solver / main.py
santit96's picture
Code refactor
254d61f
raw
history blame
1.07 kB
import sys
import os
import gym
import matplotlib.pyplot as plt
from a3c.discrete_A3C import train, evaluate, evaluate_checkpoints
from wordle_env.wordle import WordleEnvBase
def print_results(global_ep, win_ep, res):
print("Jugadas:", global_ep.value)
print("Ganadas:", win_ep.value)
plt.plot(res)
plt.ylabel('Moving average ep reward')
plt.xlabel('Step')
plt.show()
if __name__ == "__main__":
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
evaluation = True if len(sys.argv) > 3 and sys.argv[3] == 'evaluation' else False
env = gym.make(env_id)
model_checkpoint_dir = os.path.join('checkpoints', env.unwrapped.spec.id)
if not evaluation:
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir)
print_results(global_ep, win_ep, res)
evaluate(gnet, env)
else:
print("Evaluation mode")
results = evaluate_checkpoints(model_checkpoint_dir, env)
print(results)