Spaces:
Paused
Paused
import numpy as np | |
import torch.nn.functional as F | |
from torch import nn | |
from .model import MLPLayers | |
class LinearProbe(nn.Module): | |
def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): | |
""" | |
Args: | |
model: nn.Module | |
mlp: bool, if True, then use the MLP layer as the linear probe module | |
freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe | |
in_ch: int, the output channel from CLAP model | |
out_ch: int, the output channel from linear probe (class_num) | |
act: torch.nn.functional, the activation function before the loss function | |
""" | |
super().__init__() | |
in_ch = 512 | |
self.clap_model = model | |
self.clap_model.text_branch = None # to save memory | |
self.freeze = freeze | |
if mlp: | |
self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) | |
else: | |
self.lp_layer = nn.Linear(in_ch, out_ch) | |
if self.freeze: | |
for param in self.clap_model.parameters(): | |
param.requires_grad = False | |
if act == "None": | |
self.act = None | |
elif act == "relu": | |
self.act = nn.ReLU() | |
elif act == "elu": | |
self.act = nn.ELU() | |
elif act == "prelu": | |
self.act = nn.PReLU(num_parameters=in_ch) | |
elif act == "softmax": | |
self.act = nn.Softmax(dim=-1) | |
elif act == "sigmoid": | |
self.act = nn.Sigmoid() | |
def forward(self, x, mix_lambda=None, device=None): | |
""" | |
Args: | |
x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list | |
mix_lambda: torch.tensor [batch], the mixup lambda | |
Returns: | |
class_prob: torch.tensor [batch, class_num] | |
""" | |
# batchnorm cancel grandient | |
if self.freeze: | |
self.clap_model.eval() | |
x = self.clap_model.audio_projection( | |
self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[ | |
"embedding" | |
] | |
) | |
out = self.lp_layer(x) | |
if self.act is not None: | |
out = self.act(out) | |
return out | |