qdqn-MountainCar-v0 / example.py
Arnas
Add QDQN Mountain Car agent trained for 600 episodes
5cb9176
raw
history blame contribute delete
555 Bytes
import gym
import torch
import yaml
from model.agent import Agent
from model.qnn import QuantumNet
# Environment
env_name = 'MountainCar-v0'
env = gym.make(env_name)
# Network
with open('config.yaml', 'r') as f:
hparams = yaml.safe_load(f)
net = QuantumNet(
n_layers=hparams['n_layers'],
w_input=hparams['w_input'],
w_output=hparams['w_output'],
data_reupload=hparams['data_reupload']
)
state_dict = torch.load('qdqn-MountainCar-v0.pt', map_location=torch.device('cpu'))
net.load_state_dict(state_dict)
# Agent
agent = Agent(net)