AudioSep / models /audiosep.py
badayvedat's picture
Initial commit
ae29df4
raw
history blame
4.23 kB
from typing import Any, Callable, Dict
import random
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
class AudioSep(pl.LightningModule):
def __init__(
self,
ss_model: nn.Module,
waveform_mixer,
query_encoder,
loss_function,
optimizer_type: str,
learning_rate: float,
lr_lambda_func,
use_text_ratio=1.0,
):
r"""Pytorch Lightning wrapper of PyTorch model, including forward,
optimization of model, etc.
Args:
ss_model: nn.Module
anchor_segment_detector: nn.Module
loss_function: function or object
learning_rate: float
lr_lambda: function
"""
super().__init__()
self.ss_model = ss_model
self.waveform_mixer = waveform_mixer
self.query_encoder = query_encoder
self.query_encoder_type = self.query_encoder.encoder_type
self.use_text_ratio = use_text_ratio
self.loss_function = loss_function
self.optimizer_type = optimizer_type
self.learning_rate = learning_rate
self.lr_lambda_func = lr_lambda_func
def forward(self, x):
pass
def training_step(self, batch_data_dict, batch_idx):
r"""Forward a mini-batch data to model, calculate loss function, and
train for one step. A mini-batch data is evenly distributed to multiple
devices (if there are) for parallel training.
Args:
batch_data_dict: e.g.
'audio_text': {
'text': ['a sound of dog', ...]
'waveform': (batch_size, 1, samples)
}
batch_idx: int
Returns:
loss: float, loss function of this mini-batch
"""
# [important] fix random seeds across devices
random.seed(batch_idx)
batch_audio_text_dict = batch_data_dict['audio_text']
batch_text = batch_audio_text_dict['text']
batch_audio = batch_audio_text_dict['waveform']
device = batch_audio.device
mixtures, segments = self.waveform_mixer(
waveforms=batch_audio
)
# calculate text embed for audio-text data
if self.query_encoder_type == 'CLAP':
conditions = self.query_encoder.get_query_embed(
modality='hybird',
text=batch_text,
audio=segments.squeeze(1),
use_text_ratio=self.use_text_ratio,
)
input_dict = {
'mixture': mixtures[:, None, :].squeeze(1),
'condition': conditions,
}
target_dict = {
'segment': segments.squeeze(1),
}
self.ss_model.train()
sep_segment = self.ss_model(input_dict)['waveform']
sep_segment = sep_segment.squeeze()
# (batch_size, 1, segment_samples)
output_dict = {
'segment': sep_segment,
}
# Calculate loss.
loss = self.loss_function(output_dict, target_dict)
self.log_dict({"train_loss": loss})
return loss
def test_step(self, batch, batch_idx):
pass
def configure_optimizers(self):
r"""Configure optimizer.
"""
if self.optimizer_type == "AdamW":
optimizer = optim.AdamW(
params=self.ss_model.parameters(),
lr=self.learning_rate,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0.0,
amsgrad=True,
)
else:
raise NotImplementedError
scheduler = LambdaLR(optimizer, self.lr_lambda_func)
output_dict = {
"optimizer": optimizer,
"lr_scheduler": {
'scheduler': scheduler,
'interval': 'step',
'frequency': 1,
}
}
return output_dict
def get_model_class(model_type):
if model_type == 'ResUNet30':
from models.resunet import ResUNet30
return ResUNet30
else:
raise NotImplementedError