Spaces:
Sleeping
Sleeping
from typing import List, Optional, Tuple, Dict | |
from mlagents.torch_utils import torch, nn | |
from mlagents.trainers.torch_entities.layers import LinearEncoder, Initialization | |
import numpy as np | |
from mlagents.trainers.torch_entities.encoders import ( | |
SimpleVisualEncoder, | |
ResNetVisualEncoder, | |
NatureVisualEncoder, | |
SmallVisualEncoder, | |
FullyConnectedVisualEncoder, | |
VectorInput, | |
) | |
from mlagents.trainers.settings import EncoderType, ScheduleType | |
from mlagents.trainers.torch_entities.attention import ( | |
EntityEmbedding, | |
ResidualSelfAttention, | |
) | |
from mlagents.trainers.exception import UnityTrainerException | |
from mlagents_envs.base_env import ObservationSpec, DimensionProperty | |
class ModelUtils: | |
# Minimum supported side for each encoder type. If refactoring an encoder, please | |
# adjust these also. | |
MIN_RESOLUTION_FOR_ENCODER = { | |
EncoderType.FULLY_CONNECTED: 1, | |
EncoderType.MATCH3: 5, | |
EncoderType.SIMPLE: 20, | |
EncoderType.NATURE_CNN: 36, | |
EncoderType.RESNET: 15, | |
} | |
VALID_VISUAL_PROP = frozenset( | |
[ | |
( | |
DimensionProperty.TRANSLATIONAL_EQUIVARIANCE, | |
DimensionProperty.TRANSLATIONAL_EQUIVARIANCE, | |
DimensionProperty.NONE, | |
), | |
(DimensionProperty.UNSPECIFIED,) * 3, | |
] | |
) | |
VALID_VECTOR_PROP = frozenset( | |
[(DimensionProperty.NONE,), (DimensionProperty.UNSPECIFIED,)] | |
) | |
VALID_VAR_LEN_PROP = frozenset( | |
[(DimensionProperty.VARIABLE_SIZE, DimensionProperty.NONE)] | |
) | |
def update_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None: | |
""" | |
Apply a learning rate to a torch optimizer. | |
:param optim: Optimizer | |
:param lr: Learning rate | |
""" | |
for param_group in optim.param_groups: | |
param_group["lr"] = lr | |
class DecayedValue: | |
def __init__( | |
self, | |
schedule: ScheduleType, | |
initial_value: float, | |
min_value: float, | |
max_step: int, | |
): | |
""" | |
Object that represnets value of a parameter that should be decayed, assuming it is a function of | |
global_step. | |
:param schedule: Type of learning rate schedule. | |
:param initial_value: Initial value before decay. | |
:param min_value: Decay value to this value by max_step. | |
:param max_step: The final step count where the return value should equal min_value. | |
:param global_step: The current step count. | |
:return: The value. | |
""" | |
self.schedule = schedule | |
self.initial_value = initial_value | |
self.min_value = min_value | |
self.max_step = max_step | |
def get_value(self, global_step: int) -> float: | |
""" | |
Get the value at a given global step. | |
:param global_step: Step count. | |
:returns: Decayed value at this global step. | |
""" | |
if self.schedule == ScheduleType.CONSTANT: | |
return self.initial_value | |
elif self.schedule == ScheduleType.LINEAR: | |
return ModelUtils.polynomial_decay( | |
self.initial_value, self.min_value, self.max_step, global_step | |
) | |
else: | |
raise UnityTrainerException(f"The schedule {self.schedule} is invalid.") | |
def polynomial_decay( | |
initial_value: float, | |
min_value: float, | |
max_step: int, | |
global_step: int, | |
power: float = 1.0, | |
) -> float: | |
""" | |
Get a decayed value based on a polynomial schedule, with respect to the current global step. | |
:param initial_value: Initial value before decay. | |
:param min_value: Decay value to this value by max_step. | |
:param max_step: The final step count where the return value should equal min_value. | |
:param global_step: The current step count. | |
:param power: Power of polynomial decay. 1.0 (default) is a linear decay. | |
:return: The current decayed value. | |
""" | |
global_step = min(global_step, max_step) | |
decayed_value = (initial_value - min_value) * ( | |
1 - float(global_step) / max_step | |
) ** (power) + min_value | |
return decayed_value | |
def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module: | |
ENCODER_FUNCTION_BY_TYPE = { | |
EncoderType.SIMPLE: SimpleVisualEncoder, | |
EncoderType.NATURE_CNN: NatureVisualEncoder, | |
EncoderType.RESNET: ResNetVisualEncoder, | |
EncoderType.MATCH3: SmallVisualEncoder, | |
EncoderType.FULLY_CONNECTED: FullyConnectedVisualEncoder, | |
} | |
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type) | |
def _check_resolution_for_encoder( | |
height: int, width: int, vis_encoder_type: EncoderType | |
) -> None: | |
min_res = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[vis_encoder_type] | |
if height < min_res or width < min_res: | |
raise UnityTrainerException( | |
f"Visual observation resolution ({width}x{height}) is too small for" | |
f"the provided EncoderType ({vis_encoder_type.value}). The min dimension is {min_res}" | |
) | |
def get_encoder_for_obs( | |
obs_spec: ObservationSpec, | |
normalize: bool, | |
h_size: int, | |
attention_embedding_size: int, | |
vis_encode_type: EncoderType, | |
) -> Tuple[nn.Module, int]: | |
""" | |
Returns the encoder and the size of the appropriate encoder. | |
:param shape: Tuples that represent the observation dimension. | |
:param normalize: Normalize all vector inputs. | |
:param h_size: Number of hidden units per layer excluding attention layers. | |
:param attention_embedding_size: Number of hidden units per attention layer. | |
:param vis_encode_type: Type of visual encoder to use. | |
""" | |
shape = obs_spec.shape | |
dim_prop = obs_spec.dimension_property | |
# VISUAL | |
if dim_prop in ModelUtils.VALID_VISUAL_PROP: | |
visual_encoder_class = ModelUtils.get_encoder_for_type(vis_encode_type) | |
ModelUtils._check_resolution_for_encoder( | |
shape[0], shape[1], vis_encode_type | |
) | |
return (visual_encoder_class(shape[0], shape[1], shape[2], h_size), h_size) | |
# VECTOR | |
if dim_prop in ModelUtils.VALID_VECTOR_PROP: | |
return (VectorInput(shape[0], normalize), shape[0]) | |
# VARIABLE LENGTH | |
if dim_prop in ModelUtils.VALID_VAR_LEN_PROP: | |
return ( | |
EntityEmbedding( | |
entity_size=shape[1], | |
entity_num_max_elements=shape[0], | |
embedding_size=attention_embedding_size, | |
), | |
0, | |
) | |
# OTHER | |
raise UnityTrainerException(f"Unsupported Sensor with specs {obs_spec}") | |
def create_input_processors( | |
observation_specs: List[ObservationSpec], | |
h_size: int, | |
vis_encode_type: EncoderType, | |
attention_embedding_size: int, | |
normalize: bool = False, | |
) -> Tuple[nn.ModuleList, List[int]]: | |
""" | |
Creates visual and vector encoders, along with their normalizers. | |
:param observation_specs: List of ObservationSpec that represent the observation dimensions. | |
:param action_size: Number of additional un-normalized inputs to each vector encoder. Used for | |
conditioning network on other values (e.g. actions for a Q function) | |
:param h_size: Number of hidden units per layer excluding attention layers. | |
:param attention_embedding_size: Number of hidden units per attention layer. | |
:param vis_encode_type: Type of visual encoder to use. | |
:param unnormalized_inputs: Vector inputs that should not be normalized, and added to the vector | |
obs. | |
:param normalize: Normalize all vector inputs. | |
:return: Tuple of : | |
- ModuleList of the encoders | |
- A list of embedding sizes (0 if the input requires to be processed with a variable length | |
observation encoder) | |
""" | |
encoders: List[nn.Module] = [] | |
embedding_sizes: List[int] = [] | |
for obs_spec in observation_specs: | |
encoder, embedding_size = ModelUtils.get_encoder_for_obs( | |
obs_spec, normalize, h_size, attention_embedding_size, vis_encode_type | |
) | |
encoders.append(encoder) | |
embedding_sizes.append(embedding_size) | |
x_self_size = sum(embedding_sizes) # The size of the "self" embedding | |
if x_self_size > 0: | |
for enc in encoders: | |
if isinstance(enc, EntityEmbedding): | |
enc.add_self_embedding(attention_embedding_size) | |
return (nn.ModuleList(encoders), embedding_sizes) | |
def list_to_tensor( | |
ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = torch.float32 | |
) -> torch.Tensor: | |
""" | |
Converts a list of numpy arrays into a tensor. MUCH faster than | |
calling as_tensor on the list directly. | |
""" | |
return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype) | |
def list_to_tensor_list( | |
ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = torch.float32 | |
) -> torch.Tensor: | |
""" | |
Converts a list of numpy arrays into a list of tensors. MUCH faster than | |
calling as_tensor on the list directly. | |
""" | |
return [ | |
torch.as_tensor(np.asanyarray(_arr), dtype=dtype) for _arr in ndarray_list | |
] | |
def to_numpy(tensor: torch.Tensor) -> np.ndarray: | |
""" | |
Converts a Torch Tensor to a numpy array. If the Tensor is on the GPU, it will | |
be brought to the CPU. | |
""" | |
return tensor.detach().cpu().numpy() | |
def break_into_branches( | |
concatenated_logits: torch.Tensor, action_size: List[int] | |
) -> List[torch.Tensor]: | |
""" | |
Takes a concatenated set of logits that represent multiple discrete action branches | |
and breaks it up into one Tensor per branch. | |
:param concatenated_logits: Tensor that represents the concatenated action branches | |
:param action_size: List of ints containing the number of possible actions for each branch. | |
:return: A List of Tensors containing one tensor per branch. | |
""" | |
action_idx = [0] + list(np.cumsum(action_size)) | |
branched_logits = [ | |
concatenated_logits[:, action_idx[i] : action_idx[i + 1]] | |
for i in range(len(action_size)) | |
] | |
return branched_logits | |
def actions_to_onehot( | |
discrete_actions: torch.Tensor, action_size: List[int] | |
) -> List[torch.Tensor]: | |
""" | |
Takes a tensor of discrete actions and turns it into a List of onehot encoding for each | |
action. | |
:param discrete_actions: Actions in integer form. | |
:param action_size: List of branch sizes. Should be of same size as discrete_actions' | |
last dimension. | |
:return: List of one-hot tensors, one representing each branch. | |
""" | |
onehot_branches = [ | |
torch.nn.functional.one_hot(_act.T, action_size[i]).float() | |
for i, _act in enumerate(discrete_actions.long().T) | |
] | |
return onehot_branches | |
def dynamic_partition( | |
data: torch.Tensor, partitions: torch.Tensor, num_partitions: int | |
) -> List[torch.Tensor]: | |
""" | |
Torch implementation of dynamic_partition : | |
https://www.tensorflow.org/api_docs/python/tf/dynamic_partition | |
Splits the data Tensor input into num_partitions Tensors according to the indices in | |
partitions. | |
:param data: The Tensor data that will be split into partitions. | |
:param partitions: An indices tensor that determines in which partition each element | |
of data will be in. | |
:param num_partitions: The number of partitions to output. Corresponds to the | |
maximum possible index in the partitions argument. | |
:return: A list of Tensor partitions (Their indices correspond to their partition index). | |
""" | |
res: List[torch.Tensor] = [] | |
for i in range(num_partitions): | |
res += [data[(partitions == i).nonzero().squeeze(1)]] | |
return res | |
def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: | |
""" | |
Returns the mean of the tensor but ignoring the values specified by masks. | |
Used for masking out loss functions. | |
:param tensor: Tensor which needs mean computation. | |
:param masks: Boolean tensor of masks with same dimension as tensor. | |
""" | |
if tensor.ndim == 0: | |
return (tensor * masks).sum() / torch.clamp( | |
(torch.ones_like(tensor) * masks).float().sum(), min=1.0 | |
) | |
else: | |
return ( | |
tensor.permute(*torch.arange(tensor.ndim - 1, -1, -1)) * masks | |
).sum() / torch.clamp( | |
( | |
torch.ones_like( | |
tensor.permute(*torch.arange(tensor.ndim - 1, -1, -1)) | |
) | |
* masks | |
) | |
.float() | |
.sum(), | |
min=1.0, | |
) | |
def soft_update(source: nn.Module, target: nn.Module, tau: float) -> None: | |
""" | |
Performs an in-place polyak update of the target module based on the source, | |
by a ratio of tau. Note that source and target modules must have the same | |
parameters, where: | |
target = tau * source + (1-tau) * target | |
:param source: Source module whose parameters will be used. | |
:param target: Target module whose parameters will be updated. | |
:param tau: Percentage of source parameters to use in average. Setting tau to | |
1 will copy the source parameters to the target. | |
""" | |
with torch.no_grad(): | |
for source_param, target_param in zip( | |
source.parameters(), target.parameters() | |
): | |
target_param.data.mul_(1.0 - tau) | |
torch.add( | |
target_param.data, | |
source_param.data, | |
alpha=tau, | |
out=target_param.data, | |
) | |
def create_residual_self_attention( | |
input_processors: nn.ModuleList, embedding_sizes: List[int], hidden_size: int | |
) -> Tuple[Optional[ResidualSelfAttention], Optional[LinearEncoder]]: | |
""" | |
Creates an RSA if there are variable length observations found in the input processors. | |
:param input_processors: A ModuleList of input processors as returned by the function | |
create_input_processors(). | |
:param embedding sizes: A List of embedding sizes as returned by create_input_processors(). | |
:param hidden_size: The hidden size to use for the RSA. | |
:returns: A Tuple of the RSA itself, a self encoder, and the embedding size after the RSA. | |
Returns None for the RSA and encoder if no var len inputs are detected. | |
""" | |
rsa, x_self_encoder = None, None | |
entity_num_max: int = 0 | |
var_processors = [p for p in input_processors if isinstance(p, EntityEmbedding)] | |
for processor in var_processors: | |
entity_max: int = processor.entity_num_max_elements | |
# Only adds entity max if it was known at construction | |
if entity_max > 0: | |
entity_num_max += entity_max | |
if len(var_processors) > 0: | |
if sum(embedding_sizes): | |
x_self_encoder = LinearEncoder( | |
sum(embedding_sizes), | |
1, | |
hidden_size, | |
kernel_init=Initialization.Normal, | |
kernel_gain=(0.125 / hidden_size) ** 0.5, | |
) | |
rsa = ResidualSelfAttention(hidden_size, entity_num_max) | |
return rsa, x_self_encoder | |
def trust_region_value_loss( | |
values: Dict[str, torch.Tensor], | |
old_values: Dict[str, torch.Tensor], | |
returns: Dict[str, torch.Tensor], | |
epsilon: float, | |
loss_masks: torch.Tensor, | |
) -> torch.Tensor: | |
""" | |
Evaluates value loss, clipping to stay within a trust region of old value estimates. | |
Used for PPO and POCA. | |
:param values: Value output of the current network. | |
:param old_values: Value stored with experiences in buffer. | |
:param returns: Computed returns. | |
:param epsilon: Clipping value for value estimate. | |
:param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences. | |
""" | |
value_losses = [] | |
for name, head in values.items(): | |
old_val_tensor = old_values[name] | |
returns_tensor = returns[name] | |
clipped_value_estimate = old_val_tensor + torch.clamp( | |
head - old_val_tensor, -1 * epsilon, epsilon | |
) | |
v_opt_a = (returns_tensor - head) ** 2 | |
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2 | |
value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks) | |
value_losses.append(value_loss) | |
value_loss = torch.mean(torch.stack(value_losses)) | |
return value_loss | |
def trust_region_policy_loss( | |
advantages: torch.Tensor, | |
log_probs: torch.Tensor, | |
old_log_probs: torch.Tensor, | |
loss_masks: torch.Tensor, | |
epsilon: float, | |
) -> torch.Tensor: | |
""" | |
Evaluate policy loss clipped to stay within a trust region. Used for PPO and POCA. | |
:param advantages: Computed advantages. | |
:param log_probs: Current policy probabilities | |
:param old_log_probs: Past policy probabilities | |
:param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences. | |
""" | |
advantage = advantages.unsqueeze(-1) | |
r_theta = torch.exp(log_probs - old_log_probs) | |
p_opt_a = r_theta * advantage | |
p_opt_b = torch.clamp(r_theta, 1.0 - epsilon, 1.0 + epsilon) * advantage | |
policy_loss = -1 * ModelUtils.masked_mean( | |
torch.min(p_opt_a, p_opt_b), loss_masks | |
) | |
return policy_loss | |