File size: 555 Bytes
5cb9176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import gym
import torch
import yaml

from model.agent import Agent
from model.qnn import QuantumNet

# Environment
env_name = 'MountainCar-v0'
env = gym.make(env_name)

# Network
with open('config.yaml', 'r') as f:
    hparams = yaml.safe_load(f)

net = QuantumNet(
    n_layers=hparams['n_layers'],
    w_input=hparams['w_input'],
    w_output=hparams['w_output'],
    data_reupload=hparams['data_reupload']
)
state_dict = torch.load('qdqn-MountainCar-v0.pt', map_location=torch.device('cpu'))
net.load_state_dict(state_dict)

# Agent
agent = Agent(net)