Spaces:
Running
Running
from mlagents.torch_utils import torch | |
from typing import List | |
import math | |
from mlagents.trainers.torch_entities.layers import ( | |
linear_layer, | |
Swish, | |
Initialization, | |
LayerNorm, | |
) | |
class ConditionalEncoder(torch.nn.Module): | |
def __init__( | |
self, | |
input_size: int, | |
goal_size: int, | |
hidden_size: int, | |
num_layers: int, | |
num_conditional_layers: int, | |
kernel_init: Initialization = Initialization.KaimingHeNormal, | |
kernel_gain: float = 1.0, | |
): | |
""" | |
ConditionalEncoder module. A fully connected network of which some of the | |
weights are generated by a goal conditioning. Uses the HyperNetwork module to | |
generate the weights of the network. Only the weights of the last | |
"num_conditional_layers" layers will be generated by HyperNetworks, the others | |
will use regular parameters. | |
:param input_size: The size of the input of the encoder | |
:param goal_size: The size of the goal tensor that will condition the encoder | |
:param hidden_size: The number of hidden units in the encoder | |
:param num_layers: The total number of layers of the encoder (both regular and | |
generated by HyperNetwork) | |
:param num_conditional_layers: The number of layers generated with hypernetworks | |
:param kernel_init: The Initialization to use for the weights of the layer | |
:param kernel_gain: The multiplier for the weights of the kernel. | |
""" | |
super().__init__() | |
layers: List[torch.nn.Module] = [] | |
prev_size = input_size | |
for i in range(num_layers): | |
if num_layers - i <= num_conditional_layers: | |
# This means layer i is a conditional layer since the conditional | |
# leyers are the last num_conditional_layers | |
layers.append( | |
HyperNetwork(prev_size, hidden_size, goal_size, hidden_size, 2) | |
) | |
else: | |
layers.append( | |
linear_layer( | |
prev_size, | |
hidden_size, | |
kernel_init=kernel_init, | |
kernel_gain=kernel_gain, | |
) | |
) | |
layers.append(Swish()) | |
prev_size = hidden_size | |
self.layers = torch.nn.ModuleList(layers) | |
def forward( | |
self, input_tensor: torch.Tensor, goal_tensor: torch.Tensor | |
) -> torch.Tensor: # type: ignore | |
activation = input_tensor | |
for layer in self.layers: | |
if isinstance(layer, HyperNetwork): | |
activation = layer(activation, goal_tensor) | |
else: | |
activation = layer(activation) | |
return activation | |
class HyperNetwork(torch.nn.Module): | |
def __init__( | |
self, input_size, output_size, hyper_input_size, layer_size, num_layers | |
): | |
""" | |
Hyper Network module. This module will use the hyper_input tensor to generate | |
the weights of the main network. The main network is a single fully connected | |
layer. | |
:param input_size: The size of the input of the main network | |
:param output_size: The size of the output of the main network | |
:param hyper_input_size: The size of the input of the hypernetwork that will | |
generate the main network. | |
:param layer_size: The number of hidden units in the layers of the hypernetwork | |
:param num_layers: The number of layers of the hypernetwork | |
""" | |
super().__init__() | |
self.input_size = input_size | |
self.output_size = output_size | |
layer_in_size = hyper_input_size | |
layers = [] | |
for _ in range(num_layers): | |
layers.append( | |
linear_layer( | |
layer_in_size, | |
layer_size, | |
kernel_init=Initialization.KaimingHeNormal, | |
kernel_gain=1.0, | |
bias_init=Initialization.Zero, | |
) | |
) | |
layers.append(Swish()) | |
layer_in_size = layer_size | |
flat_output = linear_layer( | |
layer_size, | |
input_size * output_size, | |
kernel_init=Initialization.KaimingHeNormal, | |
kernel_gain=0.1, | |
bias_init=Initialization.Zero, | |
) | |
# Re-initializing the weights of the last layer of the hypernetwork | |
bound = math.sqrt(1 / (layer_size * self.input_size)) | |
flat_output.weight.data.uniform_(-bound, bound) | |
self.hypernet = torch.nn.Sequential(*layers, LayerNorm(), flat_output) | |
# The hypernetwork will not generate the bias of the main network layer | |
self.bias = torch.nn.Parameter(torch.zeros(output_size)) | |
def forward(self, input_activation, hyper_input): | |
output_weights = self.hypernet(hyper_input) | |
output_weights = output_weights.view(-1, self.input_size, self.output_size) | |
result = ( | |
torch.bmm(input_activation.unsqueeze(1), output_weights).squeeze(1) | |
+ self.bias | |
) | |
return result | |