import gym | |
import torch | |
import yaml | |
from model.agent import Agent | |
from model.qnn import QuantumNet | |
# Environment | |
env_name = 'MountainCar-v0' | |
env = gym.make(env_name) | |
# Network | |
with open('config.yaml', 'r') as f: | |
hparams = yaml.safe_load(f) | |
net = QuantumNet( | |
n_layers=hparams['n_layers'], | |
w_input=hparams['w_input'], | |
w_output=hparams['w_output'], | |
data_reupload=hparams['data_reupload'] | |
) | |
state_dict = torch.load('qdqn-MountainCar-v0.pt', map_location=torch.device('cpu')) | |
net.load_state_dict(state_dict) | |
# Agent | |
agent = Agent(net) | |