Hukuna's picture
Upload 221 files
ce7bf5b verified
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
"""
对3d数据进行归一化
"""
class MaskedBatchNorm1d(nn.Module):
"""A masked version of nn.BatchNorm1d. Only tested for 3D inputs.
Args:
num_features (int): :math:`C` from an expected input of size
:math:`(N, C, L)`
eps (float): a value added to the denominator for numerical stability.
Default: 1e-5
momentum (float): the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine (bool): a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats (bool) : a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
statistics in both training and eval modes. Default: ``True``
Inputs:
x (torch.tensor): of size (batch_size, num_features, sequence_length)
input_mask (torch.tensor or None) : (optional) of dtype (byte) or (bool) of shape (batch_size, 1, sequence_length) zeroes (or False) indicate positions that cannot contribute to computation
Outputs:
y (torch.tensor): of size (batch_size, num_features, sequence_length)
"""
def __init__(
self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
):
super(MaskedBatchNorm1d, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
if affine:
self.weight = nn.Parameter(torch.Tensor(num_features, 1))
self.bias = nn.Parameter(torch.Tensor(num_features, 1))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.track_running_stats = track_running_stats
if self.track_running_stats:
self.register_buffer("running_mean", torch.zeros(num_features, 1))
self.register_buffer("running_var", torch.ones(num_features, 1))
self.register_buffer(
"num_batches_tracked", torch.tensor(0, dtype=torch.long)
)
else:
self.register_parameter("running_mean", None)
self.register_parameter("running_var", None)
self.register_parameter("num_batches_tracked", None)
self.reset_parameters()
def reset_running_stats(self):
if self.track_running_stats:
self.running_mean.zero_()
self.running_var.fill_(1)
self.num_batches_tracked.zero_()
def reset_parameters(self):
self.reset_running_stats()
if self.affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, input, input_mask=None):
# Calculate the masked mean and variance
B, C, L = input.shape
if input_mask is not None and input_mask.shape != (B, 1, L):
raise ValueError("Mask should have shape (B, 1, L).")
if C != self.num_features:
raise ValueError(
"Expected %d channels but input has %d channels"
% (self.num_features, C)
)
if input_mask is not None:
masked = input * input_mask
n = input_mask.sum()
else:
masked = input
n = B * L
# Sum
masked_sum = masked.sum(dim=0, keepdim=True).sum(dim=2, keepdim=True)
# Divide by sum of mask
current_mean = masked_sum / n
current_var = ((masked - current_mean) ** 2).sum(dim=0, keepdim=True).sum(
dim=2, keepdim=True
) / n
# Update running stats
if self.track_running_stats and self.training:
if self.num_batches_tracked == 0:
self.running_mean = current_mean
self.running_var = current_var
else:
self.running_mean = (
1 - self.momentum
) * self.running_mean + self.momentum * current_mean
self.running_var = (
1 - self.momentum
) * self.running_var + self.momentum * current_var
self.num_batches_tracked += 1
# Norm the input
if self.track_running_stats and not self.training:
normed = (masked - self.running_mean) / (
torch.sqrt(self.running_var + self.eps)
)
else:
normed = (masked - current_mean) / (torch.sqrt(current_var + self.eps))
# Apply affine parameters
if self.affine:
normed = normed * self.weight + self.bias
return normed
class MaskedBatchNorm2d(nn.Module):
"""A masked version of nn.BatchNorm1d. Only tested for 3D inputs.
Args:
num_features (int): :math:`C` from an expected input of size
:math:`(N, C, L)`
eps (float): a value added to the denominator for numerical stability.
Default: 1e-5
momentum (float): the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine (bool): a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats (bool) : a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
statistics in both training and eval modes. Default: ``True``
Inputs:
x (torch.tensor): of size (batch_size, num_features, sequence_length)
input_mask (torch.tensor or None) : (optional) of dtype (byte) or (bool) of shape (batch_size, 1, sequence_length) zeroes (or False) indicate positions that cannot contribute to computation
Outputs:
y (torch.tensor): of size (batch_size, num_features, sequence_length)
"""
def __init__(
self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
if affine:
self.weight = nn.Parameter(torch.ones(num_features,))
self.bias = nn.Parameter(torch.zeros(num_features,))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.track_running_stats = track_running_stats
if self.track_running_stats:
self.register_buffer("running_mean", torch.zeros(1, 1, 1, num_features))
self.register_buffer("running_var", torch.ones(1, 1, 1, num_features))
self.register_buffer(
"num_batches_tracked", torch.tensor(0, dtype=torch.long)
)
else:
self.register_parameter("running_mean", None)
self.register_parameter("running_var", None)
self.register_parameter("num_batches_tracked", None)
self.reset_parameters()
def reset_running_stats(self):
if self.track_running_stats:
self.running_mean.zero_()
self.running_var.fill_(1)
self.num_batches_tracked.zero_()
def reset_parameters(self):
self.reset_running_stats()
if self.affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, input, mask=None):
# Calculate the masked mean and variance
B, L, L, C = input.size()
if mask is not None:
if mask.dim() != 4:
raise ValueError(
f"Input mask must have four dimensions, but has {mask.dim()}"
)
b, l, l, d = mask.size()
if (b != B) or (l != L):
raise ValueError(
f"Input mask must have shape {(B, L, L, 1)} or {(B, L, L, C)} to match input."
)
if d == 1:
mask = mask.expand(input.size())
if C != self.num_features:
raise ValueError(
"Expected %d channels but input has %d channels"
% (self.num_features, C)
)
if mask is None:
mask = input.new_ones(input.size())
masked = input * mask
n = mask.sum(dim=(0, 1, 2), keepdim=True)
masked_sum = (masked).sum(dim=(0, 1, 2), keepdim=True)
current_mean = masked_sum / n
current_var = (mask * (masked - current_mean).pow(2)).sum(
dim=(0, 1, 2), keepdim=True
) / n
# Update running stats
with torch.no_grad():
if self.track_running_stats and self.training:
if self.num_batches_tracked == 0:
self.running_mean = current_mean.detach()
self.running_var = current_var.detach()
else:
self.running_mean = (
1 - self.momentum
) * self.running_mean + self.momentum * current_mean.detach()
self.running_var = (
1 - self.momentum
) * self.running_var + self.momentum * current_var.detach()
self.num_batches_tracked += 1
# Norm the input
if self.track_running_stats and not self.training:
normed = (masked - self.running_mean) / (
torch.sqrt(self.running_var + self.eps)
)
else:
normed = (masked - current_mean) / (torch.sqrt(current_var + self.eps))
# Apply affine parameters
if self.affine:
normed = normed * self.weight + self.bias
normed = normed * mask
return normed
class NormedReductionLayer(nn.Module):
"""A ReductionLayer with LayerNorms after the hidden layers."""
def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.0):
super().__init__()
self.d1 = nn.Dropout(p=dropout)
self.d2 = nn.Dropout(p=dropout)
self.hidden = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.output = nn.Linear(hidden_dim, output_dim)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
def reduce(self, x, mask):
masked_x = x * mask
mean_x = masked_x.sum(dim=1) / torch.sum(mask, dim=1)
return mean_x
def forward(self, x, mask):
reduced_x = self.norm1(self.reduce(x, mask))
h = self.norm2(self.hidden(reduced_x))
return self.output(self.relu(h))