Spaces:
Sleeping
Sleeping
# Copyright Generate Biomedicines, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import torch.nn as nn | |
""" | |
对3d数据进行归一化 | |
""" | |
class MaskedBatchNorm1d(nn.Module): | |
"""A masked version of nn.BatchNorm1d. Only tested for 3D inputs. | |
Args: | |
num_features (int): :math:`C` from an expected input of size | |
:math:`(N, C, L)` | |
eps (float): a value added to the denominator for numerical stability. | |
Default: 1e-5 | |
momentum (float): the value used for the running_mean and running_var | |
computation. Can be set to ``None`` for cumulative moving average | |
(i.e. simple average). Default: 0.1 | |
affine (bool): a boolean value that when set to ``True``, this module has | |
learnable affine parameters. Default: ``True`` | |
track_running_stats (bool) : a boolean value that when set to ``True``, this | |
module tracks the running mean and variance, and when set to ``False``, | |
this module does not track such statistics and always uses batch | |
statistics in both training and eval modes. Default: ``True`` | |
Inputs: | |
x (torch.tensor): of size (batch_size, num_features, sequence_length) | |
input_mask (torch.tensor or None) : (optional) of dtype (byte) or (bool) of shape (batch_size, 1, sequence_length) zeroes (or False) indicate positions that cannot contribute to computation | |
Outputs: | |
y (torch.tensor): of size (batch_size, num_features, sequence_length) | |
""" | |
def __init__( | |
self, | |
num_features, | |
eps=1e-5, | |
momentum=0.1, | |
affine=True, | |
track_running_stats=True, | |
): | |
super(MaskedBatchNorm1d, self).__init__() | |
self.num_features = num_features | |
self.eps = eps | |
self.momentum = momentum | |
self.affine = affine | |
if affine: | |
self.weight = nn.Parameter(torch.Tensor(num_features, 1)) | |
self.bias = nn.Parameter(torch.Tensor(num_features, 1)) | |
else: | |
self.register_parameter("weight", None) | |
self.register_parameter("bias", None) | |
self.track_running_stats = track_running_stats | |
if self.track_running_stats: | |
self.register_buffer("running_mean", torch.zeros(num_features, 1)) | |
self.register_buffer("running_var", torch.ones(num_features, 1)) | |
self.register_buffer( | |
"num_batches_tracked", torch.tensor(0, dtype=torch.long) | |
) | |
else: | |
self.register_parameter("running_mean", None) | |
self.register_parameter("running_var", None) | |
self.register_parameter("num_batches_tracked", None) | |
self.reset_parameters() | |
def reset_running_stats(self): | |
if self.track_running_stats: | |
self.running_mean.zero_() | |
self.running_var.fill_(1) | |
self.num_batches_tracked.zero_() | |
def reset_parameters(self): | |
self.reset_running_stats() | |
if self.affine: | |
nn.init.ones_(self.weight) | |
nn.init.zeros_(self.bias) | |
def forward(self, input, input_mask=None): | |
# Calculate the masked mean and variance | |
B, C, L = input.shape | |
if input_mask is not None and input_mask.shape != (B, 1, L): | |
raise ValueError("Mask should have shape (B, 1, L).") | |
if C != self.num_features: | |
raise ValueError( | |
"Expected %d channels but input has %d channels" | |
% (self.num_features, C) | |
) | |
if input_mask is not None: | |
masked = input * input_mask | |
n = input_mask.sum() | |
else: | |
masked = input | |
n = B * L | |
# Sum | |
masked_sum = masked.sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) | |
# Divide by sum of mask | |
current_mean = masked_sum / n | |
current_var = ((masked - current_mean) ** 2).sum(dim=0, keepdim=True).sum( | |
dim=2, keepdim=True | |
) / n | |
# Update running stats | |
if self.track_running_stats and self.training: | |
if self.num_batches_tracked == 0: | |
self.running_mean = current_mean | |
self.running_var = current_var | |
else: | |
self.running_mean = ( | |
1 - self.momentum | |
) * self.running_mean + self.momentum * current_mean | |
self.running_var = ( | |
1 - self.momentum | |
) * self.running_var + self.momentum * current_var | |
self.num_batches_tracked += 1 | |
# Norm the input | |
if self.track_running_stats and not self.training: | |
normed = (masked - self.running_mean) / ( | |
torch.sqrt(self.running_var + self.eps) | |
) | |
else: | |
normed = (masked - current_mean) / (torch.sqrt(current_var + self.eps)) | |
# Apply affine parameters | |
if self.affine: | |
normed = normed * self.weight + self.bias | |
return normed | |
class MaskedBatchNorm2d(nn.Module): | |
"""A masked version of nn.BatchNorm1d. Only tested for 3D inputs. | |
Args: | |
num_features (int): :math:`C` from an expected input of size | |
:math:`(N, C, L)` | |
eps (float): a value added to the denominator for numerical stability. | |
Default: 1e-5 | |
momentum (float): the value used for the running_mean and running_var | |
computation. Can be set to ``None`` for cumulative moving average | |
(i.e. simple average). Default: 0.1 | |
affine (bool): a boolean value that when set to ``True``, this module has | |
learnable affine parameters. Default: ``True`` | |
track_running_stats (bool) : a boolean value that when set to ``True``, this | |
module tracks the running mean and variance, and when set to ``False``, | |
this module does not track such statistics and always uses batch | |
statistics in both training and eval modes. Default: ``True`` | |
Inputs: | |
x (torch.tensor): of size (batch_size, num_features, sequence_length) | |
input_mask (torch.tensor or None) : (optional) of dtype (byte) or (bool) of shape (batch_size, 1, sequence_length) zeroes (or False) indicate positions that cannot contribute to computation | |
Outputs: | |
y (torch.tensor): of size (batch_size, num_features, sequence_length) | |
""" | |
def __init__( | |
self, | |
num_features, | |
eps=1e-5, | |
momentum=0.1, | |
affine=True, | |
track_running_stats=True, | |
): | |
super().__init__() | |
self.num_features = num_features | |
self.eps = eps | |
self.momentum = momentum | |
self.affine = affine | |
if affine: | |
self.weight = nn.Parameter(torch.ones(num_features,)) | |
self.bias = nn.Parameter(torch.zeros(num_features,)) | |
else: | |
self.register_parameter("weight", None) | |
self.register_parameter("bias", None) | |
self.track_running_stats = track_running_stats | |
if self.track_running_stats: | |
self.register_buffer("running_mean", torch.zeros(1, 1, 1, num_features)) | |
self.register_buffer("running_var", torch.ones(1, 1, 1, num_features)) | |
self.register_buffer( | |
"num_batches_tracked", torch.tensor(0, dtype=torch.long) | |
) | |
else: | |
self.register_parameter("running_mean", None) | |
self.register_parameter("running_var", None) | |
self.register_parameter("num_batches_tracked", None) | |
self.reset_parameters() | |
def reset_running_stats(self): | |
if self.track_running_stats: | |
self.running_mean.zero_() | |
self.running_var.fill_(1) | |
self.num_batches_tracked.zero_() | |
def reset_parameters(self): | |
self.reset_running_stats() | |
if self.affine: | |
nn.init.ones_(self.weight) | |
nn.init.zeros_(self.bias) | |
def forward(self, input, mask=None): | |
# Calculate the masked mean and variance | |
B, L, L, C = input.size() | |
if mask is not None: | |
if mask.dim() != 4: | |
raise ValueError( | |
f"Input mask must have four dimensions, but has {mask.dim()}" | |
) | |
b, l, l, d = mask.size() | |
if (b != B) or (l != L): | |
raise ValueError( | |
f"Input mask must have shape {(B, L, L, 1)} or {(B, L, L, C)} to match input." | |
) | |
if d == 1: | |
mask = mask.expand(input.size()) | |
if C != self.num_features: | |
raise ValueError( | |
"Expected %d channels but input has %d channels" | |
% (self.num_features, C) | |
) | |
if mask is None: | |
mask = input.new_ones(input.size()) | |
masked = input * mask | |
n = mask.sum(dim=(0, 1, 2), keepdim=True) | |
masked_sum = (masked).sum(dim=(0, 1, 2), keepdim=True) | |
current_mean = masked_sum / n | |
current_var = (mask * (masked - current_mean).pow(2)).sum( | |
dim=(0, 1, 2), keepdim=True | |
) / n | |
# Update running stats | |
with torch.no_grad(): | |
if self.track_running_stats and self.training: | |
if self.num_batches_tracked == 0: | |
self.running_mean = current_mean.detach() | |
self.running_var = current_var.detach() | |
else: | |
self.running_mean = ( | |
1 - self.momentum | |
) * self.running_mean + self.momentum * current_mean.detach() | |
self.running_var = ( | |
1 - self.momentum | |
) * self.running_var + self.momentum * current_var.detach() | |
self.num_batches_tracked += 1 | |
# Norm the input | |
if self.track_running_stats and not self.training: | |
normed = (masked - self.running_mean) / ( | |
torch.sqrt(self.running_var + self.eps) | |
) | |
else: | |
normed = (masked - current_mean) / (torch.sqrt(current_var + self.eps)) | |
# Apply affine parameters | |
if self.affine: | |
normed = normed * self.weight + self.bias | |
normed = normed * mask | |
return normed | |
class NormedReductionLayer(nn.Module): | |
"""A ReductionLayer with LayerNorms after the hidden layers.""" | |
def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.0): | |
super().__init__() | |
self.d1 = nn.Dropout(p=dropout) | |
self.d2 = nn.Dropout(p=dropout) | |
self.hidden = nn.Linear(input_dim, hidden_dim) | |
self.relu = nn.ReLU() | |
self.output = nn.Linear(hidden_dim, output_dim) | |
self.norm1 = nn.LayerNorm(input_dim) | |
self.norm2 = nn.LayerNorm(hidden_dim) | |
def reduce(self, x, mask): | |
masked_x = x * mask | |
mean_x = masked_x.sum(dim=1) / torch.sum(mask, dim=1) | |
return mean_x | |
def forward(self, x, mask): | |
reduced_x = self.norm1(self.reduce(x, mask)) | |
h = self.norm2(self.hidden(reduced_x)) | |
return self.output(self.relu(h)) | |