Hukuna's picture
Upload 221 files
ce7bf5b verified
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import torch
import torch.nn.functional as F
MACHINE = platform.machine()
"""
一维线性衰减滤波器
"""
def filter1D_linear_decay(Z, B):
"""Apply a low-pass filter with batch-heterogeneous coefficients.
Computes `x_i = z_i + b * x_{i-1}` where `b` varies per batch member.
Args:
Z (torch.Tensor): Batch of one-dimensional signals with shape `(N, W)`.
B (torch.Tensor): Batch of coefficients with shape `(N)`.
Returns:
X (torch.Tensor): Result of applying linear recurrence with shape `(N, W)`.
"""
# Build filter coefficients as powers of B
N, W = Z.shape
k = (W - 1) - torch.arange(W, device=Z.device)
kernel = B[:, None, None] ** k[None, None, :]
# Pad on left to convolve from backwards in time
Z_pad = F.pad(Z, (W - 1, 0))[None, ...]
# Group convolution can effectively do one filter per batch
while True:
X = F.conv1d(Z_pad, kernel, stride=1, padding=0, groups=N)[0, :, :]
# on arm64 (M1 Mac) this convolution erroneously sometimes produces NaNs
if (
(MACHINE == "arm64")
and torch.isnan(X).any()
and (not torch.isnan(Z_pad).any())
and (not torch.isnan(kernel).any())
):
continue
break
return X