Spaces:
Paused
Paused
from typing import Any, Callable, Dict | |
import random | |
import lightning.pytorch as pl | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.optim.lr_scheduler import LambdaLR | |
class AudioSep(pl.LightningModule): | |
def __init__( | |
self, | |
ss_model: nn.Module, | |
waveform_mixer, | |
query_encoder, | |
loss_function, | |
optimizer_type: str, | |
learning_rate: float, | |
lr_lambda_func, | |
use_text_ratio=1.0, | |
): | |
r"""Pytorch Lightning wrapper of PyTorch model, including forward, | |
optimization of model, etc. | |
Args: | |
ss_model: nn.Module | |
anchor_segment_detector: nn.Module | |
loss_function: function or object | |
learning_rate: float | |
lr_lambda: function | |
""" | |
super().__init__() | |
self.ss_model = ss_model | |
self.waveform_mixer = waveform_mixer | |
self.query_encoder = query_encoder | |
self.query_encoder_type = self.query_encoder.encoder_type | |
self.use_text_ratio = use_text_ratio | |
self.loss_function = loss_function | |
self.optimizer_type = optimizer_type | |
self.learning_rate = learning_rate | |
self.lr_lambda_func = lr_lambda_func | |
def forward(self, x): | |
pass | |
def training_step(self, batch_data_dict, batch_idx): | |
r"""Forward a mini-batch data to model, calculate loss function, and | |
train for one step. A mini-batch data is evenly distributed to multiple | |
devices (if there are) for parallel training. | |
Args: | |
batch_data_dict: e.g. | |
'audio_text': { | |
'text': ['a sound of dog', ...] | |
'waveform': (batch_size, 1, samples) | |
} | |
batch_idx: int | |
Returns: | |
loss: float, loss function of this mini-batch | |
""" | |
# [important] fix random seeds across devices | |
random.seed(batch_idx) | |
batch_audio_text_dict = batch_data_dict['audio_text'] | |
batch_text = batch_audio_text_dict['text'] | |
batch_audio = batch_audio_text_dict['waveform'] | |
device = batch_audio.device | |
mixtures, segments = self.waveform_mixer( | |
waveforms=batch_audio | |
) | |
# calculate text embed for audio-text data | |
if self.query_encoder_type == 'CLAP': | |
conditions = self.query_encoder.get_query_embed( | |
modality='hybird', | |
text=batch_text, | |
audio=segments.squeeze(1), | |
use_text_ratio=self.use_text_ratio, | |
) | |
input_dict = { | |
'mixture': mixtures[:, None, :].squeeze(1), | |
'condition': conditions, | |
} | |
target_dict = { | |
'segment': segments.squeeze(1), | |
} | |
self.ss_model.train() | |
sep_segment = self.ss_model(input_dict)['waveform'] | |
sep_segment = sep_segment.squeeze() | |
# (batch_size, 1, segment_samples) | |
output_dict = { | |
'segment': sep_segment, | |
} | |
# Calculate loss. | |
loss = self.loss_function(output_dict, target_dict) | |
self.log_dict({"train_loss": loss}) | |
return loss | |
def test_step(self, batch, batch_idx): | |
pass | |
def configure_optimizers(self): | |
r"""Configure optimizer. | |
""" | |
if self.optimizer_type == "AdamW": | |
optimizer = optim.AdamW( | |
params=self.ss_model.parameters(), | |
lr=self.learning_rate, | |
betas=(0.9, 0.999), | |
eps=1e-08, | |
weight_decay=0.0, | |
amsgrad=True, | |
) | |
else: | |
raise NotImplementedError | |
scheduler = LambdaLR(optimizer, self.lr_lambda_func) | |
output_dict = { | |
"optimizer": optimizer, | |
"lr_scheduler": { | |
'scheduler': scheduler, | |
'interval': 'step', | |
'frequency': 1, | |
} | |
} | |
return output_dict | |
def get_model_class(model_type): | |
if model_type == 'ResUNet30': | |
from models.resunet import ResUNet30 | |
return ResUNet30 | |
else: | |
raise NotImplementedError | |