Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License. | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from nncore.nn import Parameter | |
class Permute(nn.Module): | |
def forward(self, x): | |
return x.transpose(-1, -2) | |
class LearnableEmbedding(nn.Module): | |
def __init__(self, dims): | |
super().__init__() | |
self.weights = Parameter(1, 1, dims) | |
def forward(self, x): | |
return x + self.weights.expand_as(x) | |
class ConvPyramid(nn.Module): | |
def __init__(self, dims, strides, act_cls=nn.ReLU): | |
super().__init__() | |
self.blocks = nn.ModuleList() | |
for s in strides: | |
p = int(math.log2(s)) | |
if p == 0: | |
layers = act_cls() | |
else: | |
conv_cls = nn.Conv1d if p > 0 else nn.ConvTranspose1d | |
layers = nn.Sequential() | |
for _ in range(abs(p)): | |
module = [Permute(), conv_cls(dims, dims, 2, stride=2), Permute(), nn.LayerNorm(dims), act_cls()] | |
layers.extend(module) | |
self.blocks.append(layers) | |
self.strides = strides | |
def forward(self, x, mask, return_mask=False): | |
pymid, pymid_msk = [], [] | |
for s, blk in zip(self.strides, self.blocks): | |
if x.size(1) < s: | |
continue | |
pymid.append(blk(x)) | |
if return_mask: | |
if s > 1: | |
msk = F.max_pool1d(mask.float(), s, stride=s).long() | |
elif s < 1: | |
msk = mask.repeat_interleave(int(1 / s), dim=1) | |
else: | |
msk = mask | |
pymid_msk.append(msk) | |
return (pymid, pymid_msk) if return_mask else pymid | |
class Scale(nn.Module): | |
def __init__(self, strides): | |
super().__init__() | |
self.scale = nn.Parameter(torch.ones(len(strides))) | |
def forward(self, x, i): | |
return x * self.scale[i] | |
class ConvHead(nn.Module): | |
def __init__(self, dims, out_dims, kernal_size=3, act_cls=nn.ReLU): | |
super().__init__() | |
# yapf:disable | |
self.module = nn.Sequential( | |
Permute(), | |
nn.Conv1d(dims, dims, kernal_size, padding=kernal_size // 2), | |
act_cls(), | |
nn.Conv1d(dims, out_dims, kernal_size, padding=kernal_size // 2), | |
Permute()) | |
# yapf:enable | |
def forward(self, x): | |
return self.module(x) | |