File size: 555 Bytes
5cb9176 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import gym
import torch
import yaml
from model.agent import Agent
from model.qnn import QuantumNet
# Environment
env_name = 'MountainCar-v0'
env = gym.make(env_name)
# Network
with open('config.yaml', 'r') as f:
hparams = yaml.safe_load(f)
net = QuantumNet(
n_layers=hparams['n_layers'],
w_input=hparams['w_input'],
w_output=hparams['w_output'],
data_reupload=hparams['data_reupload']
)
state_dict = torch.load('qdqn-MountainCar-v0.pt', map_location=torch.device('cpu'))
net.load_state_dict(state_dict)
# Agent
agent = Agent(net)
|