Spaces:
Runtime error
Runtime error
import sys | |
import numpy as np | |
import torch | |
from torch import nn | |
sys.path.insert(0, sys.path[0]+"/../") | |
from typing import ( | |
Any, | |
Dict, | |
List, | |
Optional, | |
Sequence, | |
Tuple, | |
Type, | |
Union, | |
no_type_check, | |
) | |
import torch.nn as nn | |
from tianshou.utils.net.discrete import NoisyLinear | |
ModuleType = Type[nn.Module] | |
import random | |
from collections import namedtuple, deque | |
from itertools import count | |
import math | |
import torch | |
import torch.optim as optim | |
from transformers import AutoModel, AutoTokenizer | |
import torch.nn.functional as F | |
from tianshou.utils.net.common import ModuleType, Net, MLP | |
def bert_embedding(x, max_length=512, device='cuda'): | |
from transformers import logging | |
logging.set_verbosity_error() | |
model_name = 'bert-base-uncased' | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
bert_model = AutoModel.from_pretrained(model_name) | |
text = x | |
if isinstance(text, np.ndarray): | |
text = list(text) | |
tokens = tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt') | |
input_ids = tokens['input_ids'] | |
attention_mask = tokens['attention_mask'] | |
with torch.no_grad(): | |
outputs = bert_model(input_ids, attention_mask=attention_mask) | |
embeddings = outputs.last_hidden_state | |
return embeddings | |
class Net_GRU(nn.Module): | |
def __init__(self, input_size, n_actions, hidden_dim, n_layers, dropout, bidirectional): | |
super(Net_GRU, self).__init__() | |
self.input_size = input_size | |
self.hidden_dim = hidden_dim | |
self.num_classes = n_actions | |
self.n_layers = n_layers | |
self.dropout = dropout | |
self.bidirectional = bidirectional | |
# Layers | |
self.gru = nn.GRU(self.input_size, self.hidden_dim, self.n_layers, | |
batch_first=True, dropout=self.dropout, bidirectional=self.bidirectional) | |
self.final_layer = nn.Linear(self.hidden_dim*(1 + int(self.bidirectional)), self.num_classes) | |
def forward(self, x): | |
# Input shape: (batch_size, seq_length) | |
batch_size, seq_length, emb_size = x.size() | |
gru_out, hidden = self.gru(x) | |
# Use the final state | |
# hidden -> (num_direction, batch, hidden_size) | |
if self.bidirectional: | |
hidden = hidden.view(self.n_layers, 2, batch_size, self.hidden_dim) | |
final_hidden = torch.cat((hidden[:, -1, :, :].squeeze(0), hidden[:, 0, :, :].squeeze(0)), 1) | |
else: | |
final_hidden = hidden.squeeze(0) | |
# final_hidden -> (batch_size, num_classes) | |
logits = self.final_layer(final_hidden) | |
return logits | |
class MyGRU(nn.Module): | |
def __init__(self, input_size, hidden_dim, n_layers, dropout, bidirectional, output_dim): | |
super(MyGRU, self).__init__() | |
self.input_size = input_size | |
self.hidden_dim = hidden_dim | |
self.n_layers = n_layers | |
self.dropout = dropout | |
self.bidirectional = bidirectional | |
# Layers | |
self.gru = nn.GRU(self.input_size, self.hidden_dim, self.n_layers, | |
batch_first=True, dropout=self.dropout, bidirectional=self.bidirectional) | |
self.final_layer = nn.Linear(self.hidden_dim*(1 + int(self.bidirectional)), output_dim) | |
def forward(self, x): | |
batch_size, seq_length, emb_size = x.size() | |
gru_out, hidden = self.gru(x) | |
# Use the final state | |
# hidden -> (num_direction, batch, hidden_size) | |
if self.bidirectional: | |
hidden = hidden.view(self.n_layers, 2, batch_size, self.hidden_dim) | |
final_hidden = torch.cat((hidden[:, -1, :, :].squeeze(0), hidden[:, 0, :, :].squeeze(0)), 1) | |
else: | |
final_hidden = hidden.squeeze(0) | |
# final_hidden -> (batch_size, num_classes) | |
logits = self.final_layer(final_hidden) | |
return logits | |
class MyCNN(nn.Module): | |
def __init__(self, | |
input_dim: int, | |
output_dim: int = 0, | |
hidden_sizes: Sequence[int] = (), | |
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None, | |
activation: ModuleType = nn.ReLU, | |
device: Optional[Union[str, int, torch.device]] = None, | |
linear_layer: Type[nn.Linear] = nn.Linear, | |
flatten_input: bool = True,) -> None: | |
super().__init__() | |
self.model = [] | |
input_dim_temp = input_dim | |
for h in hidden_sizes: | |
self.model.append(nn.Conv1d(in_channels=input_dim_temp, out_channels=h, kernel_size=3, padding=1)) | |
self.model.append(activation()) | |
self.model.append(nn.MaxPool1d(kernel_size=2)) | |
input_dim_temp = h | |
self.model = nn.Sequential(*self.model) | |
self.fc = nn.Linear(in_features=input_dim_temp, out_features=output_dim) | |
def forward(self, x): | |
x = self.model(x.transpose(1, 2)) | |
x.transpose_(1, 2) | |
x = self.fc(x) | |
return x | |
class Net_GRU_Bert_tianshou(Net): | |
def __init__( | |
self, | |
state_shape: Union[int, Sequence[int]], | |
action_shape: Union[int, Sequence[int]] = 0, | |
hidden_sizes: Sequence[int] = (), | |
norm_layer: Optional[ModuleType] = None, | |
activation: Optional[ModuleType] = nn.ReLU, | |
device: Union[str, int, torch.device] = "cpu", | |
softmax: bool = False, | |
concat: bool = False, | |
num_atoms: int = 1, | |
dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None, | |
linear_layer: Type[nn.Linear] = nn.Linear, | |
hidden_dim: int = 128, | |
bidirectional: bool = True, | |
dropout: float = 0., | |
n_layers: int = 1, | |
max_length: int = 512, | |
trans_model_name: str = 'bert-base-uncased', | |
) -> None: | |
nn.Module.__init__(self) | |
self.device = device | |
self.softmax = softmax | |
self.num_atoms = num_atoms | |
self.hidden_dim = hidden_dim | |
self.bidirectional = bidirectional | |
self.dropout = dropout | |
self.n_layers = n_layers | |
self.trans_model_name = trans_model_name | |
self.max_length = max_length | |
input_dim = int(np.prod(state_shape)) | |
action_dim = int(np.prod(action_shape)) * num_atoms | |
if concat: | |
input_dim += action_dim | |
self.use_dueling = dueling_param is not None | |
output_dim = action_dim if not self.use_dueling and not concat else 0 | |
self.output_dim = output_dim or hidden_dim | |
self.model = MyGRU(768, self.hidden_dim, self.n_layers, | |
self.dropout, self.bidirectional, self.output_dim) | |
if self.use_dueling: # dueling DQN | |
q_kwargs, v_kwargs = dueling_param # type: ignore | |
q_output_dim, v_output_dim = 0, 0 | |
if not concat: | |
q_output_dim, v_output_dim = action_dim, num_atoms | |
q_kwargs: Dict[str, Any] = { | |
**q_kwargs, "input_dim": self.output_dim, | |
"output_dim": q_output_dim, | |
"device": self.device | |
} | |
v_kwargs: Dict[str, Any] = { | |
**v_kwargs, "input_dim": self.output_dim, | |
"output_dim": v_output_dim, | |
"device": self.device | |
} | |
self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs) | |
self.output_dim = self.Q.output_dim | |
self.bert_model = AutoModel.from_pretrained(self.trans_model_name).to(self.device) | |
self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name) | |
from transformers import logging | |
logging.set_verbosity_error() | |
def bert_embedding(self, x, max_length=512): | |
text = x | |
if isinstance(text, np.ndarray): | |
text = list(text) | |
tokens = self.tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt') | |
input_ids = tokens['input_ids'].to(self.device) | |
attention_mask = tokens['attention_mask'].to(self.device) | |
with torch.no_grad(): | |
outputs = self.bert_model(input_ids, attention_mask=attention_mask) | |
embeddings = outputs.last_hidden_state | |
return embeddings | |
def forward( | |
self, | |
obs: Union[np.ndarray, torch.Tensor], | |
state: Any = None, | |
info: Dict[str, Any] = {}, | |
) -> Tuple[torch.Tensor, Any]: | |
"""Mapping: obs -> flatten (inside MLP)-> logits.""" | |
embedding = self.bert_embedding(obs, max_length=self.max_length) | |
logits = self.model(embedding) | |
bsz = logits.shape[0] | |
if self.use_dueling: # Dueling DQN | |
q, v = self.Q(logits), self.V(logits) | |
if self.num_atoms > 1: | |
q = q.view(bsz, -1, self.num_atoms) | |
v = v.view(bsz, -1, self.num_atoms) | |
logits = q - q.mean(dim=1, keepdim=True) + v | |
elif self.num_atoms > 1: | |
logits = logits.view(bsz, -1, self.num_atoms) | |
if self.softmax: | |
logits = torch.softmax(logits, dim=-1) | |
return logits, state | |
class Net_Bert_CLS_tianshou(Net): | |
def __init__( | |
self, | |
state_shape: Union[int, Sequence[int]], | |
action_shape: Union[int, Sequence[int]] = 0, | |
hidden_sizes: Sequence[int] = (), | |
norm_layer: Optional[ModuleType] = None, | |
activation: Optional[ModuleType] = nn.ReLU, | |
device: Union[str, int, torch.device] = "cpu", | |
softmax: bool = False, | |
concat: bool = False, | |
num_atoms: int = 1, | |
dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None, | |
linear_layer: Type[nn.Linear] = nn.Linear, | |
hidden_dim: int = 128, | |
bidirectional: bool = True, | |
dropout: float = 0., | |
n_layers: int = 1, | |
max_length: int = 512, | |
trans_model_name: str = 'bert-base-uncased', | |
) -> None: | |
nn.Module.__init__(self) | |
self.device = device | |
self.softmax = softmax | |
self.num_atoms = num_atoms | |
self.hidden_dim = hidden_dim | |
self.bidirectional = bidirectional | |
self.dropout = dropout | |
self.n_layers = n_layers | |
self.trans_model_name = trans_model_name | |
self.max_length = max_length | |
input_dim = int(np.prod(state_shape)) | |
action_dim = int(np.prod(action_shape)) * num_atoms | |
if concat: | |
input_dim += action_dim | |
self.use_dueling = dueling_param is not None | |
output_dim = action_dim if not self.use_dueling and not concat else 0 | |
self.output_dim = output_dim or hidden_dim | |
self.model = MLP(768, output_dim, hidden_sizes, norm_layer, activation, device, linear_layer) | |
if self.use_dueling: # dueling DQN | |
q_kwargs, v_kwargs = dueling_param # type: ignore | |
q_output_dim, v_output_dim = 0, 0 | |
if not concat: | |
q_output_dim, v_output_dim = action_dim, num_atoms | |
q_kwargs: Dict[str, Any] = { | |
**q_kwargs, "input_dim": self.output_dim, | |
"output_dim": q_output_dim, | |
"device": self.device | |
} | |
v_kwargs: Dict[str, Any] = { | |
**v_kwargs, "input_dim": self.output_dim, | |
"output_dim": v_output_dim, | |
"device": self.device | |
} | |
self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs) | |
self.output_dim = self.Q.output_dim | |
self.bert_model = AutoModel.from_pretrained(self.trans_model_name).to(self.device) | |
self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name) | |
from transformers import logging | |
logging.set_verbosity_error() | |
def bert_CLS_embedding(self, x, max_length=512): | |
text = x | |
if isinstance(text, np.ndarray): | |
text = list(text) | |
tokens = self.tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt') | |
input_ids = tokens['input_ids'].to(self.device) | |
attention_mask = tokens['attention_mask'].to(self.device) | |
with torch.no_grad(): | |
outputs = self.bert_model(input_ids, attention_mask=attention_mask) | |
embeddings = outputs[0][:, 0, :] | |
return embeddings | |
def forward( | |
self, | |
obs: Union[np.ndarray, torch.Tensor], | |
state: Any = None, | |
info: Dict[str, Any] = {}, | |
) -> Tuple[torch.Tensor, Any]: | |
"""Mapping: obs -> flatten (inside MLP)-> logits.""" | |
embedding = self.bert_CLS_embedding(obs, max_length=self.max_length) | |
logits = self.model(embedding) | |
bsz = logits.shape[0] | |
if self.use_dueling: # Dueling DQN | |
q, v = self.Q(logits), self.V(logits) | |
if self.num_atoms > 1: | |
q = q.view(bsz, -1, self.num_atoms) | |
v = v.view(bsz, -1, self.num_atoms) | |
logits = q - q.mean(dim=1, keepdim=True) + v | |
elif self.num_atoms > 1: | |
logits = logits.view(bsz, -1, self.num_atoms) | |
if self.softmax: | |
logits = torch.softmax(logits, dim=-1) | |
return logits, state | |
class Net_Bert_CNN_tianshou(Net_GRU_Bert_tianshou): | |
def __init__( | |
self, | |
state_shape: Union[int, Sequence[int]], | |
action_shape: Union[int, Sequence[int]] = 0, | |
hidden_sizes: Sequence[int] = (), | |
norm_layer: Optional[ModuleType] = None, | |
activation: Optional[ModuleType] = nn.ReLU, | |
device: Union[str, int, torch.device] = "cpu", | |
softmax: bool = False, | |
concat: bool = False, | |
num_atoms: int = 1, | |
dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None, | |
linear_layer: Type[nn.Linear] = nn.Linear, | |
hidden_dim: int = 128, | |
bidirectional: bool = True, | |
dropout: float = 0., | |
n_layers: int = 1, | |
max_length: int = 512, | |
trans_model_name: str = 'bert-base-uncased', | |
) -> None: | |
nn.Module.__init__(self) | |
self.device = device | |
self.softmax = softmax | |
self.num_atoms = num_atoms | |
self.hidden_dim = hidden_dim | |
self.bidirectional = bidirectional | |
self.dropout = dropout | |
self.n_layers = n_layers | |
self.trans_model_name = trans_model_name | |
self.max_length = max_length | |
input_dim = int(np.prod(state_shape)) | |
action_dim = int(np.prod(action_shape)) * num_atoms | |
if concat: | |
input_dim += action_dim | |
self.use_dueling = dueling_param is not None | |
output_dim = action_dim if not self.use_dueling and not concat else 0 | |
self.output_dim = output_dim or hidden_dim | |
self.model = MyCNN(768, output_dim, hidden_sizes, norm_layer, activation, device, linear_layer, flatten_input=False) | |
if self.use_dueling: # dueling DQN | |
q_kwargs, v_kwargs = dueling_param # type: ignore | |
q_output_dim, v_output_dim = 0, 0 | |
if not concat: | |
q_output_dim, v_output_dim = action_dim, num_atoms | |
q_kwargs: Dict[str, Any] = { | |
**q_kwargs, "input_dim": self.output_dim, | |
"output_dim": q_output_dim, | |
"device": self.device | |
} | |
v_kwargs: Dict[str, Any] = { | |
**v_kwargs, "input_dim": self.output_dim, | |
"output_dim": v_output_dim, | |
"device": self.device | |
} | |
self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs) | |
self.output_dim = self.Q.output_dim | |
self.bert_model = AutoModel.from_pretrained(self.trans_model_name).to(self.device) | |
self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name) | |
from transformers import logging | |
logging.set_verbosity_error() | |
class DQN_GRU(nn.Module): | |
"""Reference: Human-level control through deep reinforcement learning. | |
""" | |
def __init__( | |
self, | |
state_shape: Union[int, Sequence[int]], | |
action_shape: Sequence[int], | |
device: Union[str, int, torch.device] = "cpu", | |
features_only: bool = False, | |
output_dim: Optional[int] = None, | |
hidden_dim: int = 128, | |
n_layers: int = 1, | |
dropout: float = 0., | |
bidirectional: bool = True, | |
trans_model_name: str = 'bert-base-uncased', | |
max_length: int = 512, | |
) -> None: | |
super().__init__() | |
self.device = device | |
self.max_length = max_length | |
action_dim = int(np.prod(action_shape)) | |
self.net = MyGRU(768, hidden_dim, n_layers, dropout, bidirectional, | |
hidden_dim) | |
if not features_only: | |
self.net = MyGRU(768, hidden_dim, n_layers, dropout, bidirectional, | |
action_dim) | |
self.output_dim = action_dim | |
elif output_dim is not None: | |
self.net = MyGRU(768, hidden_dim, n_layers, dropout, bidirectional, | |
output_dim) | |
self.output_dim = output_dim | |
else: | |
self.net = MyGRU(768, hidden_dim, n_layers, dropout, bidirectional, | |
hidden_dim) | |
self.output_dim = hidden_dim | |
self.trans_model_name = trans_model_name | |
self.bert_model = AutoModel.from_pretrained(self.trans_model_name).to(self.device) | |
self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name) | |
from transformers import logging | |
logging.set_verbosity_error() | |
def bert_embedding(self, x, max_length=512): | |
text = x | |
if isinstance(text, np.ndarray): | |
text = list(text) | |
tokens = self.tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt') | |
input_ids = tokens['input_ids'].to(self.device) | |
attention_mask = tokens['attention_mask'].to(self.device) | |
with torch.no_grad(): | |
outputs = self.bert_model(input_ids, attention_mask=attention_mask) | |
embeddings = outputs.last_hidden_state | |
return embeddings | |
def forward( | |
self, | |
obs: Union[np.ndarray, torch.Tensor], | |
state: Optional[Any] = None, | |
info: Dict[str, Any] = {}, | |
) -> Tuple[torch.Tensor, Any]: | |
r"""Mapping: s -> Q(s, \*).""" | |
embedding = self.bert_embedding(obs, max_length=self.max_length) | |
return self.net(embedding), state | |
class Rainbow_GRU(DQN_GRU): | |
"""Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning. | |
""" | |
def __init__( | |
self, | |
state_shape: Union[int, Sequence[int]], | |
action_shape: Sequence[int], | |
num_atoms: int = 51, | |
noisy_std: float = 0.5, | |
device: Union[str, int, torch.device] = "cpu", | |
is_dueling: bool = True, | |
is_noisy: bool = True, | |
output_dim: Optional[int] = None, | |
hidden_dim: int = 128, | |
n_layers: int = 1, | |
dropout: float = 0., | |
bidirectional: bool = True, | |
trans_model_name: str = 'bert-base-uncased', | |
max_length: int = 512, | |
) -> None: | |
super().__init__(state_shape, action_shape, device, features_only=True, | |
output_dim=output_dim, hidden_dim=hidden_dim, n_layers=n_layers, | |
dropout=dropout, bidirectional=bidirectional, trans_model_name=trans_model_name) | |
self.action_num = np.prod(action_shape) | |
self.num_atoms = num_atoms | |
def linear(x, y): | |
if is_noisy: | |
return NoisyLinear(x, y, noisy_std) | |
else: | |
return nn.Linear(x, y) | |
self.Q = nn.Sequential( | |
linear(self.output_dim, 512), nn.ReLU(inplace=True), | |
linear(512, self.action_num * self.num_atoms) | |
) | |
self._is_dueling = is_dueling | |
if self._is_dueling: | |
self.V = nn.Sequential( | |
linear(self.output_dim, 512), nn.ReLU(inplace=True), | |
linear(512, self.num_atoms) | |
) | |
self.output_dim = self.action_num * self.num_atoms | |
def forward( | |
self, | |
obs: Union[np.ndarray, torch.Tensor], | |
state: Optional[Any] = None, | |
info: Dict[str, Any] = {}, | |
) -> Tuple[torch.Tensor, Any]: | |
r"""Mapping: x -> Z(x, \*).""" | |
obs, state = super().forward(obs) | |
q = self.Q(obs) | |
q = q.view(-1, self.action_num, self.num_atoms) | |
if self._is_dueling: | |
v = self.V(obs) | |
v = v.view(-1, 1, self.num_atoms) | |
logits = q - q.mean(dim=1, keepdim=True) + v | |
else: | |
logits = q | |
probs = logits.softmax(dim=2) | |
return probs, state | |
class Net_GRU_nn_emb_tianshou(Net): | |
def __init__( | |
self, | |
action_shape: Union[int, Sequence[int]] = 0, | |
hidden_sizes: Sequence[int] = (), | |
norm_layer: Optional[ModuleType] = None, | |
activation: Optional[ModuleType] = nn.ReLU, | |
device: Union[str, int, torch.device] = "cpu", | |
softmax: bool = False, | |
concat: bool = False, | |
num_atoms: int = 1, | |
dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None, | |
linear_layer: Type[nn.Linear] = nn.Linear, | |
hidden_dim: int = 128, | |
bidirectional: bool = True, | |
dropout: float = 0., | |
n_layers: int = 1, | |
max_length: int = 512, | |
trans_model_name: str = 'bert-base-uncased', | |
word_emb_dim: int = 128, | |
) -> None: | |
nn.Module.__init__(self) | |
self.device = device | |
self.softmax = softmax | |
self.num_atoms = num_atoms | |
self.hidden_dim = hidden_dim | |
self.bidirectional = bidirectional | |
self.dropout = dropout | |
self.n_layers = n_layers | |
self.trans_model_name = trans_model_name | |
self.max_length = max_length | |
action_dim = int(np.prod(action_shape)) * num_atoms | |
self.use_dueling = dueling_param is not None | |
output_dim = action_dim if not self.use_dueling and not concat else 0 | |
self.output_dim = output_dim or hidden_dim | |
self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name) | |
from transformers import logging | |
logging.set_verbosity_error() | |
self.vocab_size = self.tokenizer.vocab_size | |
self.embedding = nn.Embedding(self.vocab_size, word_emb_dim) | |
self.model = MyGRU(word_emb_dim, self.hidden_dim, self.n_layers, | |
self.dropout, self.bidirectional, self.output_dim) | |
if self.use_dueling: # dueling DQN | |
q_kwargs, v_kwargs = dueling_param # type: ignore | |
q_output_dim, v_output_dim = 0, 0 | |
if not concat: | |
q_output_dim, v_output_dim = action_dim, num_atoms | |
q_kwargs: Dict[str, Any] = { | |
**q_kwargs, "input_dim": self.output_dim, | |
"output_dim": q_output_dim, | |
"device": self.device | |
} | |
v_kwargs: Dict[str, Any] = { | |
**v_kwargs, "input_dim": self.output_dim, | |
"output_dim": v_output_dim, | |
"device": self.device | |
} | |
self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs) | |
self.output_dim = self.Q.output_dim | |
def forward( | |
self, | |
obs: Union[np.ndarray, torch.Tensor], | |
state: Any = None, | |
info: Dict[str, Any] = {}, | |
) -> Tuple[torch.Tensor, Any]: | |
"""Mapping: obs -> flatten (inside MLP)-> logits.""" | |
if isinstance(obs, np.ndarray): | |
text = list(obs) | |
else: | |
text = obs | |
tokens = self.tokenizer(text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt') | |
input_ids = tokens['input_ids'].to(self.device) | |
attention_mask = tokens['attention_mask'].to(self.device) | |
embedding = self.embedding(input_ids) | |
mask = attention_mask.unsqueeze(-1).expand(embedding.size()).float() | |
embedding = embedding * mask | |
logits = self.model(embedding) | |
bsz = logits.shape[0] | |
if self.use_dueling: # Dueling DQN | |
q, v = self.Q(logits), self.V(logits) | |
if self.num_atoms > 1: | |
q = q.view(bsz, -1, self.num_atoms) | |
v = v.view(bsz, -1, self.num_atoms) | |
logits = q - q.mean(dim=1, keepdim=True) + v | |
elif self.num_atoms > 1: | |
logits = logits.view(bsz, -1, self.num_atoms) | |
if self.softmax: | |
logits = torch.softmax(logits, dim=-1) | |
return logits, state | |