Spaces:
Running
Running
from typing import Tuple, Optional, Union | |
from mlagents.trainers.torch_entities.layers import linear_layer, Initialization, Swish | |
from mlagents.torch_utils import torch, nn | |
from mlagents.trainers.torch_entities.model_serialization import exporting_to_onnx | |
class Normalizer(nn.Module): | |
def __init__(self, vec_obs_size: int): | |
super().__init__() | |
self.register_buffer("normalization_steps", torch.tensor(1)) | |
self.register_buffer("running_mean", torch.zeros(vec_obs_size)) | |
self.register_buffer("running_variance", torch.ones(vec_obs_size)) | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
normalized_state = torch.clamp( | |
(inputs - self.running_mean) | |
/ torch.sqrt(self.running_variance / self.normalization_steps), | |
-5, | |
5, | |
) | |
return normalized_state | |
def update(self, vector_input: torch.Tensor) -> None: | |
with torch.no_grad(): | |
steps_increment = vector_input.size()[0] | |
total_new_steps = self.normalization_steps + steps_increment | |
input_to_old_mean = vector_input - self.running_mean | |
new_mean: torch.Tensor = self.running_mean + ( | |
input_to_old_mean / total_new_steps | |
).sum(0) | |
input_to_new_mean = vector_input - new_mean | |
new_variance = self.running_variance + ( | |
input_to_new_mean * input_to_old_mean | |
).sum(0) | |
# Update references. This is much faster than in-place data update. | |
self.running_mean: torch.Tensor = new_mean | |
self.running_variance: torch.Tensor = new_variance | |
self.normalization_steps: torch.Tensor = total_new_steps | |
def copy_from(self, other_normalizer: "Normalizer") -> None: | |
self.normalization_steps.data.copy_(other_normalizer.normalization_steps.data) | |
self.running_mean.data.copy_(other_normalizer.running_mean.data) | |
self.running_variance.copy_(other_normalizer.running_variance.data) | |
def conv_output_shape( | |
h_w: Tuple[int, int], | |
kernel_size: Union[int, Tuple[int, int]] = 1, | |
stride: int = 1, | |
padding: int = 0, | |
dilation: int = 1, | |
) -> Tuple[int, int]: | |
""" | |
Calculates the output shape (height and width) of the output of a convolution layer. | |
kernel_size, stride, padding and dilation correspond to the inputs of the | |
torch.nn.Conv2d layer (https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html) | |
:param h_w: The height and width of the input. | |
:param kernel_size: The size of the kernel of the convolution (can be an int or a | |
tuple [width, height]) | |
:param stride: The stride of the convolution | |
:param padding: The padding of the convolution | |
:param dilation: The dilation of the convolution | |
""" | |
from math import floor | |
if not isinstance(kernel_size, tuple): | |
kernel_size = (int(kernel_size), int(kernel_size)) | |
h = floor( | |
((h_w[0] + (2 * padding) - (dilation * (kernel_size[0] - 1)) - 1) / stride) + 1 | |
) | |
w = floor( | |
((h_w[1] + (2 * padding) - (dilation * (kernel_size[1] - 1)) - 1) / stride) + 1 | |
) | |
return h, w | |
def pool_out_shape(h_w: Tuple[int, int], kernel_size: int) -> Tuple[int, int]: | |
""" | |
Calculates the output shape (height and width) of the output of a max pooling layer. | |
kernel_size corresponds to the inputs of the | |
torch.nn.MaxPool2d layer (https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html) | |
:param kernel_size: The size of the kernel of the convolution | |
""" | |
height = (h_w[0] - kernel_size) // 2 + 1 | |
width = (h_w[1] - kernel_size) // 2 + 1 | |
return height, width | |
class VectorInput(nn.Module): | |
def __init__(self, input_size: int, normalize: bool = False): | |
super().__init__() | |
self.normalizer: Optional[Normalizer] = None | |
if normalize: | |
self.normalizer = Normalizer(input_size) | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
if self.normalizer is not None: | |
inputs = self.normalizer(inputs) | |
return inputs | |
def copy_normalization(self, other_input: "VectorInput") -> None: | |
if self.normalizer is not None and other_input.normalizer is not None: | |
self.normalizer.copy_from(other_input.normalizer) | |
def update_normalization(self, inputs: torch.Tensor) -> None: | |
if self.normalizer is not None: | |
self.normalizer.update(inputs) | |
class FullyConnectedVisualEncoder(nn.Module): | |
def __init__( | |
self, height: int, width: int, initial_channels: int, output_size: int | |
): | |
super().__init__() | |
self.output_size = output_size | |
self.input_size = height * width * initial_channels | |
self.dense = nn.Sequential( | |
linear_layer( | |
self.input_size, | |
self.output_size, | |
kernel_init=Initialization.KaimingHeNormal, | |
kernel_gain=1.41, # Use ReLU gain | |
), | |
nn.LeakyReLU(), | |
) | |
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor: | |
if not exporting_to_onnx.is_exporting(): | |
visual_obs = visual_obs.permute([0, 3, 1, 2]) | |
hidden = visual_obs.reshape(-1, self.input_size) | |
return self.dense(hidden) | |
class SmallVisualEncoder(nn.Module): | |
""" | |
CNN architecture used by King in their Candy Crush predictor | |
https://www.researchgate.net/publication/328307928_Human-Like_Playtesting_with_Deep_Learning | |
""" | |
def __init__( | |
self, height: int, width: int, initial_channels: int, output_size: int | |
): | |
super().__init__() | |
self.h_size = output_size | |
conv_1_hw = conv_output_shape((height, width), 3, 1) | |
conv_2_hw = conv_output_shape(conv_1_hw, 3, 1) | |
self.final_flat = conv_2_hw[0] * conv_2_hw[1] * 144 | |
self.conv_layers = nn.Sequential( | |
nn.Conv2d(initial_channels, 35, [3, 3], [1, 1]), | |
nn.LeakyReLU(), | |
nn.Conv2d(35, 144, [3, 3], [1, 1]), | |
nn.LeakyReLU(), | |
) | |
self.dense = nn.Sequential( | |
linear_layer( | |
self.final_flat, | |
self.h_size, | |
kernel_init=Initialization.KaimingHeNormal, | |
kernel_gain=1.41, # Use ReLU gain | |
), | |
nn.LeakyReLU(), | |
) | |
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor: | |
if not exporting_to_onnx.is_exporting(): | |
visual_obs = visual_obs.permute([0, 3, 1, 2]) | |
hidden = self.conv_layers(visual_obs) | |
hidden = hidden.reshape(-1, self.final_flat) | |
return self.dense(hidden) | |
class SimpleVisualEncoder(nn.Module): | |
def __init__( | |
self, height: int, width: int, initial_channels: int, output_size: int | |
): | |
super().__init__() | |
self.h_size = output_size | |
conv_1_hw = conv_output_shape((height, width), 8, 4) | |
conv_2_hw = conv_output_shape(conv_1_hw, 4, 2) | |
self.final_flat = conv_2_hw[0] * conv_2_hw[1] * 32 | |
self.conv_layers = nn.Sequential( | |
nn.Conv2d(initial_channels, 16, [8, 8], [4, 4]), | |
nn.LeakyReLU(), | |
nn.Conv2d(16, 32, [4, 4], [2, 2]), | |
nn.LeakyReLU(), | |
) | |
self.dense = nn.Sequential( | |
linear_layer( | |
self.final_flat, | |
self.h_size, | |
kernel_init=Initialization.KaimingHeNormal, | |
kernel_gain=1.41, # Use ReLU gain | |
), | |
nn.LeakyReLU(), | |
) | |
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor: | |
if not exporting_to_onnx.is_exporting(): | |
visual_obs = visual_obs.permute([0, 3, 1, 2]) | |
hidden = self.conv_layers(visual_obs) | |
hidden = hidden.reshape(-1, self.final_flat) | |
return self.dense(hidden) | |
class NatureVisualEncoder(nn.Module): | |
def __init__( | |
self, height: int, width: int, initial_channels: int, output_size: int | |
): | |
super().__init__() | |
self.h_size = output_size | |
conv_1_hw = conv_output_shape((height, width), 8, 4) | |
conv_2_hw = conv_output_shape(conv_1_hw, 4, 2) | |
conv_3_hw = conv_output_shape(conv_2_hw, 3, 1) | |
self.final_flat = conv_3_hw[0] * conv_3_hw[1] * 64 | |
self.conv_layers = nn.Sequential( | |
nn.Conv2d(initial_channels, 32, [8, 8], [4, 4]), | |
nn.LeakyReLU(), | |
nn.Conv2d(32, 64, [4, 4], [2, 2]), | |
nn.LeakyReLU(), | |
nn.Conv2d(64, 64, [3, 3], [1, 1]), | |
nn.LeakyReLU(), | |
) | |
self.dense = nn.Sequential( | |
linear_layer( | |
self.final_flat, | |
self.h_size, | |
kernel_init=Initialization.KaimingHeNormal, | |
kernel_gain=1.41, # Use ReLU gain | |
), | |
nn.LeakyReLU(), | |
) | |
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor: | |
if not exporting_to_onnx.is_exporting(): | |
visual_obs = visual_obs.permute([0, 3, 1, 2]) | |
hidden = self.conv_layers(visual_obs) | |
hidden = hidden.reshape([-1, self.final_flat]) | |
return self.dense(hidden) | |
class ResNetBlock(nn.Module): | |
def __init__(self, channel: int): | |
""" | |
Creates a ResNet Block. | |
:param channel: The number of channels in the input (and output) tensors of the | |
convolutions | |
""" | |
super().__init__() | |
self.layers = nn.Sequential( | |
Swish(), | |
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), | |
Swish(), | |
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), | |
) | |
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: | |
return input_tensor + self.layers(input_tensor) | |
class ResNetVisualEncoder(nn.Module): | |
def __init__( | |
self, height: int, width: int, initial_channels: int, output_size: int | |
): | |
super().__init__() | |
n_channels = [16, 32, 32] # channel for each stack | |
n_blocks = 2 # number of residual blocks | |
layers = [] | |
last_channel = initial_channels | |
for _, channel in enumerate(n_channels): | |
layers.append(nn.Conv2d(last_channel, channel, [3, 3], [1, 1], padding=1)) | |
layers.append(nn.MaxPool2d([3, 3], [2, 2])) | |
height, width = pool_out_shape((height, width), 3) | |
for _ in range(n_blocks): | |
layers.append(ResNetBlock(channel)) | |
last_channel = channel | |
layers.append(Swish()) | |
self.final_flat_size = n_channels[-1] * height * width | |
self.dense = linear_layer( | |
self.final_flat_size, | |
output_size, | |
kernel_init=Initialization.KaimingHeNormal, | |
kernel_gain=1.41, # Use ReLU gain | |
) | |
self.sequential = nn.Sequential(*layers) | |
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor: | |
if not exporting_to_onnx.is_exporting(): | |
visual_obs = visual_obs.permute([0, 3, 1, 2]) | |
hidden = self.sequential(visual_obs) | |
before_out = hidden.reshape(-1, self.final_flat_size) | |
return torch.relu(self.dense(before_out)) | |