Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import pytorch_lightning as pl | |
from torch.utils.data import Dataset, DataLoader | |
from typing import List, Tuple, Optional | |
import numpy as np | |
from pathlib import Path | |
from deepspeed.ops.adam import FusedAdam | |
class MusicAudioClassifier(pl.LightningModule): | |
def __init__(self, | |
input_dim: int, | |
hidden_dim: int = 256, | |
learning_rate: float = 1e-4, | |
emb_model: Optional[nn.Module] = None, | |
is_emb: bool = False, | |
mode: str = 'both', | |
share_parameter: bool = False): | |
super().__init__() | |
self.save_hyperparameters() | |
self.model = SegmentTransformer( | |
input_dim=input_dim, | |
hidden_dim=hidden_dim, | |
mode = mode, | |
share_parameter = share_parameter | |
) | |
self.emb_model = emb_model | |
self.learning_rate = learning_rate | |
self.is_emb = is_emb | |
def _process_audio_batch(self, x: torch.Tensor) -> torch.Tensor: | |
B, S = x.shape[:2] # [B, S, C, M, T] or [B, S, C, T] for wav, [B, S, 1?, embsize] for emb | |
x = x.view(B*S, *x.shape[2:]) # [B*S, C, M, T] | |
if self.is_emb == False: | |
embeddings = self.emb_model(x) # [B*S, emb_dim] | |
else: | |
embeddings = x | |
if embeddings.dim() == 3: | |
pooled_features = embeddings.mean(dim=1) # transformer | |
else: | |
pooled_features = embeddings # CCV..? no need to pooling | |
return pooled_features.view(B, S, -1) # [B, S, emb_dim] | |
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
x = self._process_audio_batch(x) # ์ด๊ฑธ freezeํ๊ณ ์ฐ๋๊ฒ ์ฌ์ค์ ์๋ฒ์ ์ | |
x = x.half() | |
return self.model(x, mask) | |
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: | |
x, y, mask = batch | |
x = x.half() | |
y_hat = self(x, mask) | |
# ๋ฐฐ์น ํฌ๊ธฐ๊ฐ 1์ธ ๊ฒฝ์ฐ ์์ธ์ฒ๋ฆฌ | |
if y_hat.size(0) == 1: | |
loss = F.binary_cross_entropy_with_logits(y_hat.flatten(), y.float().flatten()) | |
probs = torch.sigmoid(y_hat.flatten()) | |
y_true = y.float().flatten() | |
else: | |
loss = F.binary_cross_entropy_with_logits(y_hat.squeeze(), y.float()) | |
probs = torch.sigmoid(y_hat.squeeze()) | |
y_true = y.float() | |
# ๊ฐ๋จํ ๋ฐฐ์น ์์ค๋ง ๋ก๊น (step ์์ค) | |
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) | |
# ์ ์ฒด ์ํญ์ ๋ํ ๋ฉํธ๋ฆญ ๊ณ์ฐ์ ์ํด ์์ธก๊ณผ ์ค์ ๊ฐ ์ ์ฅ | |
self.training_step_outputs.append({'preds': probs, 'targets': y_true}) | |
return loss | |
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> None: | |
x, y, mask = batch | |
x = x.half() | |
y_hat = self(x, mask) | |
# ๋ฐฐ์น ํฌ๊ธฐ๊ฐ 1์ธ ๊ฒฝ์ฐ ์์ธ์ฒ๋ฆฌ | |
if y_hat.size(0) == 1: | |
loss = F.binary_cross_entropy_with_logits(y_hat.flatten(), y.float().flatten()) | |
probs = torch.sigmoid(y_hat.flatten()) | |
y_true = y.float().flatten() | |
else: | |
loss = F.binary_cross_entropy_with_logits(y_hat.squeeze(), y.float()) | |
probs = torch.sigmoid(y_hat.squeeze()) | |
y_true = y.float() | |
# ๊ฐ๋จํ ๋ฐฐ์น ์์ค๋ง ๋ก๊น (step ์์ค) | |
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) | |
# ์ ์ฒด ์ํญ์ ๋ํ ๋ฉํธ๋ฆญ ๊ณ์ฐ์ ์ํด ์์ธก๊ณผ ์ค์ ๊ฐ ์ ์ฅ | |
self.validation_step_outputs.append({'preds': probs, 'targets': y_true}) | |
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> None: | |
x, y, mask = batch | |
x = x.half() | |
y_hat = self(x, mask) | |
# ๋ฐฐ์น ํฌ๊ธฐ๊ฐ 1์ธ ๊ฒฝ์ฐ ์์ธ์ฒ๋ฆฌ | |
if y_hat.size(0) == 1: | |
loss = F.binary_cross_entropy_with_logits(y_hat.flatten(), y.float().flatten()) | |
probs = torch.sigmoid(y_hat.flatten()) | |
y_true = y.float().flatten() | |
else: | |
loss = F.binary_cross_entropy_with_logits(y_hat.squeeze(), y.float()) | |
probs = torch.sigmoid(y_hat.squeeze()) | |
y_true = y.float() | |
# ๊ฐ๋จํ ๋ฐฐ์น ์์ค๋ง ๋ก๊น (step ์์ค) | |
self.log('test_loss', loss, on_epoch=True, prog_bar=True) | |
# ์ ์ฒด ์ํญ์ ๋ํ ๋ฉํธ๋ฆญ ๊ณ์ฐ์ ์ํด ์์ธก๊ณผ ์ค์ ๊ฐ ์ ์ฅ | |
self.test_step_outputs.append({'preds': probs, 'targets': y_true}) | |
def on_train_epoch_start(self): | |
# ์ํญ ์์ ์ ๊ฒฐ๊ณผ ์ ์ฅ์ฉ ๋ฆฌ์คํธ ์ด๊ธฐํ | |
self.training_step_outputs = [] | |
def on_validation_epoch_start(self): | |
# ์ํญ ์์ ์ ๊ฒฐ๊ณผ ์ ์ฅ์ฉ ๋ฆฌ์คํธ ์ด๊ธฐํ | |
self.validation_step_outputs = [] | |
def on_test_epoch_start(self): | |
# ์ํญ ์์ ์ ๊ฒฐ๊ณผ ์ ์ฅ์ฉ ๋ฆฌ์คํธ ์ด๊ธฐํ | |
self.test_step_outputs = [] | |
def on_train_epoch_end(self): | |
# ์ํญ์ด ๋๋ ๋ ์ ์ฒด ๋ฐ์ดํฐ์ ๋ํ ๋ฉํธ๋ฆญ ๊ณ์ฐ | |
if not hasattr(self, 'training_step_outputs') or not self.training_step_outputs: | |
return | |
all_preds = torch.cat([x['preds'] for x in self.training_step_outputs]) | |
all_targets = torch.cat([x['targets'] for x in self.training_step_outputs]) | |
# ์ ์ฒด ๋ฐ์ดํฐ์ ๋ํ ๋ฉํธ๋ฆญ ๊ณ์ฐ | |
binary_preds = (all_preds > 0.5).float() | |
# ์ ํ๋ ๊ณ์ฐ | |
acc = (binary_preds == all_targets).float().mean() | |
# ํผ๋ ํ๋ ฌ ์์ ๊ณ์ฐ | |
tp = torch.sum((binary_preds == 1) & (all_targets == 1)).float() | |
fp = torch.sum((binary_preds == 1) & (all_targets == 0)).float() | |
tn = torch.sum((binary_preds == 0) & (all_targets == 0)).float() | |
fn = torch.sum((binary_preds == 0) & (all_targets == 1)).float() | |
# ๋ฉํธ๋ฆญ ๊ณ์ฐ | |
precision = tp / (tp + fp) if (tp + fp) > 0 else torch.tensor(0.0).to(tp.device) | |
recall = tp / (tp + fn) if (tp + fn) > 0 else torch.tensor(0.0).to(tp.device) | |
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(0.0).to(tp.device) | |
specificity = tn / (tn + fp) if (tn + fp) > 0 else torch.tensor(0.0).to(tn.device) | |
# ๋ก๊น - ์ผ๊ด๋ ์ด๋ฆ ์ฌ์ฉ | |
self.log('train_acc', acc, on_epoch=True, prog_bar=True, sync_dist=True) | |
self.log('train_precision', precision, on_epoch=True, sync_dist=True) | |
self.log('train_recall', recall, on_epoch=True, sync_dist=True) | |
self.log('train_f1', f1, on_epoch=True, prog_bar=True, sync_dist=True) | |
self.log('train_specificity', specificity, on_epoch=True, sync_dist=True) | |
def on_validation_epoch_end(self): | |
# ์ํญ์ด ๋๋ ๋ ์ ์ฒด ๋ฐ์ดํฐ์ ๋ํ ๋ฉํธ๋ฆญ ๊ณ์ฐ | |
if not hasattr(self, 'validation_step_outputs') or not self.validation_step_outputs: | |
return | |
all_preds = torch.cat([x['preds'] for x in self.validation_step_outputs]) | |
all_targets = torch.cat([x['targets'] for x in self.validation_step_outputs]) | |
# ROC-AUC ๊ณ์ฐ (๊ฐ๋จํ ๊ทผ์ฌ) | |
sorted_indices = torch.argsort(all_preds, descending=True) | |
sorted_targets = all_targets[sorted_indices] | |
n_pos = torch.sum(all_targets) | |
n_neg = len(all_targets) - n_pos | |
if n_pos > 0 and n_neg > 0: | |
# TPR๊ณผ FPR์ ๋์ ํฉ์ผ๋ก ๊ณ์ฐ | |
tpr_curve = torch.cumsum(sorted_targets, dim=0) / n_pos | |
fpr_curve = torch.cumsum(1 - sorted_targets, dim=0) / n_neg | |
# AUC ๊ณ์ฐ (์ฌ๋ค๋ฆฌ๊ผด ๋ฒ์น) | |
width = fpr_curve[1:] - fpr_curve[:-1] | |
height = (tpr_curve[1:] + tpr_curve[:-1]) / 2 | |
auc_approx = torch.sum(width * height) | |
self.log('val_auc', auc_approx, on_epoch=True) | |
# ์ ์ฒด ๋ฐ์ดํฐ์ ๋ํ ๋ฉํธ๋ฆญ ๊ณ์ฐ | |
binary_preds = (all_preds > 0.5).float() | |
# ์ ํ๋ ๊ณ์ฐ | |
acc = (binary_preds == all_targets).float().mean() | |
# ํผ๋ ํ๋ ฌ ์์ ๊ณ์ฐ | |
tp = torch.sum((binary_preds == 1) & (all_targets == 1)).float() | |
fp = torch.sum((binary_preds == 1) & (all_targets == 0)).float() | |
tn = torch.sum((binary_preds == 0) & (all_targets == 0)).float() | |
fn = torch.sum((binary_preds == 0) & (all_targets == 1)).float() | |
# ๋ฉํธ๋ฆญ ๊ณ์ฐ | |
precision = tp / (tp + fp) if (tp + fp) > 0 else torch.tensor(0.0).to(tp.device) | |
recall = tp / (tp + fn) if (tp + fn) > 0 else torch.tensor(0.0).to(tp.device) | |
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(0.0).to(tp.device) | |
specificity = tn / (tn + fp) if (tn + fp) > 0 else torch.tensor(0.0).to(tn.device) | |
# ๋ก๊น - ์ผ๊ด๋ ์ด๋ฆ ์ฌ์ฉ (val_epoch_f1 ๋์ val_f1 ์ฌ์ฉ) | |
self.log('val_acc', acc, on_epoch=True, prog_bar=True, sync_dist=True) | |
self.log('val_precision', precision, on_epoch=True, sync_dist=True) | |
self.log('val_recall', recall, on_epoch=True, sync_dist=True) | |
self.log('val_f1', f1, on_epoch=True, prog_bar=True, sync_dist=True) | |
self.log('val_specificity', specificity, on_epoch=True, sync_dist=True) | |
def on_test_epoch_end(self): | |
# ์ํญ์ด ๋๋ ๋ ์ ์ฒด ํ ์คํธ ๋ฐ์ดํฐ์ ๋ํ ๋ฉํธ๋ฆญ ๊ณ์ฐ | |
if not hasattr(self, 'test_step_outputs') or not self.test_step_outputs: | |
return | |
all_preds = torch.cat([x['preds'] for x in self.test_step_outputs]) | |
all_targets = torch.cat([x['targets'] for x in self.test_step_outputs]) | |
# ROC-AUC ๊ณ์ฐ (๊ฐ๋จํ ๊ทผ์ฌ) | |
sorted_indices = torch.argsort(all_preds, descending=True) | |
sorted_targets = all_targets[sorted_indices] | |
n_pos = torch.sum(all_targets) | |
n_neg = len(all_targets) - n_pos | |
if n_pos > 0 and n_neg > 0: | |
# TPR๊ณผ FPR์ ๋์ ํฉ์ผ๋ก ๊ณ์ฐ | |
tpr_curve = torch.cumsum(sorted_targets, dim=0) / n_pos | |
fpr_curve = torch.cumsum(1 - sorted_targets, dim=0) / n_neg | |
# AUC ๊ณ์ฐ (์ฌ๋ค๋ฆฌ๊ผด ๋ฒ์น) | |
width = fpr_curve[1:] - fpr_curve[:-1] | |
height = (tpr_curve[1:] + tpr_curve[:-1]) / 2 | |
auc_approx = torch.sum(width * height) | |
self.log('test_auc', auc_approx, on_epoch=True, sync_dist=True) | |
# ์ ์ฒด ๋ฐ์ดํฐ์ ๋ํ ๋ฉํธ๋ฆญ ๊ณ์ฐ | |
binary_preds = (all_preds > 0.5).float() | |
# ์ ํ๋ ๊ณ์ฐ | |
acc = (binary_preds == all_targets).float().mean() | |
# ํผ๋ ํ๋ ฌ ์์ ๊ณ์ฐ | |
tp = torch.sum((binary_preds == 1) & (all_targets == 1)).float() | |
fp = torch.sum((binary_preds == 1) & (all_targets == 0)).float() | |
tn = torch.sum((binary_preds == 0) & (all_targets == 0)).float() | |
fn = torch.sum((binary_preds == 0) & (all_targets == 1)).float() | |
# ๋ฉํธ๋ฆญ ๊ณ์ฐ | |
precision = tp / (tp + fp) if (tp + fp) > 0 else torch.tensor(0.0).to(tp.device) | |
recall = tp / (tp + fn) if (tp + fn) > 0 else torch.tensor(0.0).to(tp.device) | |
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(0.0).to(tp.device) | |
specificity = tn / (tn + fp) if (tn + fp) > 0 else torch.tensor(0.0).to(tn.device) | |
balanced_acc = (recall + specificity) / 2 | |
# ๋ก๊น - ์ผ๊ด๋ ์ด๋ฆ ์ฌ์ฉ | |
self.log('test_acc', acc, on_epoch=True, prog_bar=True) | |
self.log('test_precision', precision, on_epoch=True) | |
self.log('test_recall', recall, on_epoch=True) | |
self.log('test_f1', f1, on_epoch=True, prog_bar=True) | |
self.log('test_specificity', specificity, on_epoch=True) | |
self.log('test_balanced_acc', balanced_acc, on_epoch=True) | |
def configure_optimizers(self): | |
optimizer = FusedAdam(self.parameters(),lr=self.learning_rate, | |
weight_decay=0.01) | |
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
optimizer, | |
T_max=100, # Adjust based on your training epochs | |
eta_min=1e-6 | |
) | |
return { | |
'optimizer': optimizer, | |
'lr_scheduler': scheduler, | |
'monitor': 'val_loss', | |
} | |
def pad_sequence_with_mask(batch: List[Tuple[torch.Tensor, int]]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
"""Collate function for DataLoader that creates padded sequences and attention masks with fixed length (48).""" | |
embeddings, labels = zip(*batch) | |
fixed_len = 48 # ๊ณ ์ ๊ธธ์ด | |
batch_size = len(embeddings) | |
feat_dim = embeddings[0].shape[-1] | |
padded = torch.zeros((batch_size, fixed_len, feat_dim)) # ๊ณ ์ ๊ธธ์ด๋ก ํจ๋ฉ๋ ํ ์ | |
mask = torch.ones((batch_size, fixed_len), dtype=torch.bool) # True๋ padding์ ์๋ฏธ | |
for i, emb in enumerate(embeddings): | |
length = emb.shape[0] | |
# ๊ธธ์ด๊ฐ ๊ณ ์ ๊ธธ์ด๋ณด๋ค ๊ธธ๋ฉด ์๋ฅด๊ณ , ์งง์ผ๋ฉด ํจ๋ฉ | |
if length > fixed_len: | |
padded[i, :] = emb[:fixed_len] # fixed_len๋ณด๋ค ๊ธด ๋ถ๋ถ์ ์๋ผ์ ์ฑ์ด๋ค. | |
mask[i, :] = False | |
else: | |
padded[i, :length] = emb # ์ค์ ๋ฐ์ดํฐ ๊ธธ์ด์ ๋ง๊ฒ ์ฑ์ด๋ค. | |
mask[i, :length] = False # ํจ๋ฉ์ด ์๋ ๋ถ๋ถ์ False๋ก ์ค์ | |
return padded, torch.tensor(labels), mask | |
class SegmentTransformer(nn.Module): | |
def __init__(self, | |
input_dim: int, | |
hidden_dim: int = 256, | |
num_heads: int = 8, | |
num_layers: int = 4, | |
dropout: float = 0.1, | |
max_sequence_length: int = 1000, | |
mode: str = 'only_emb', | |
share_parameter: bool = False): | |
super().__init__() | |
# Original sequence processing | |
self.input_projection = nn.Linear(input_dim, hidden_dim) | |
self.mode = mode | |
self.share_parameter = share_parameter | |
# Positional encoding | |
position = torch.arange(max_sequence_length).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, hidden_dim, 2) * (-np.log(10000.0) / hidden_dim)) | |
pos_encoding = torch.zeros(max_sequence_length, hidden_dim) | |
pos_encoding[:, 0::2] = torch.sin(position * div_term) | |
pos_encoding[:, 1::2] = torch.cos(position * div_term) | |
self.register_buffer('pos_encoding', pos_encoding) | |
# Transformer for original sequence | |
encoder_layer = nn.TransformerEncoderLayer( | |
d_model=hidden_dim, | |
nhead=num_heads, | |
dim_feedforward=hidden_dim * 4, | |
dropout=dropout, | |
batch_first=True | |
) | |
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
# Self-similarity stream processing | |
self.similarity_projection = nn.Sequential( | |
nn.Conv1d(1, hidden_dim // 2, kernel_size=3, padding=1), | |
nn.ReLU(), | |
nn.Conv1d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1), | |
nn.ReLU(), | |
nn.Dropout(dropout) | |
) | |
# Transformer for similarity stream | |
self.similarity_transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
# Final classification head | |
self.classification_head_dim = hidden_dim * 2 if mode == 'both' else hidden_dim | |
self.classification_head = nn.Sequential( | |
nn.Linear(self.classification_head_dim, hidden_dim), | |
nn.LayerNorm(hidden_dim), | |
nn.ReLU(), | |
nn.Dropout(dropout), | |
nn.Linear(hidden_dim, hidden_dim // 2), | |
nn.LayerNorm(hidden_dim // 2), | |
nn.ReLU(), | |
nn.Dropout(dropout), | |
nn.Linear(hidden_dim // 2, 1) | |
) | |
def forward(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
batch_size, seq_len, _ = x.shape | |
# 1. Process original sequence | |
x1 = self.input_projection(x) | |
x1 = x1 + self.pos_encoding[:seq_len].unsqueeze(0) | |
x1 = self.transformer(x1, src_key_padding_mask=padding_mask) # padding_mask ์ฌ์ฉ | |
# 2. Calculate and process self-similarity | |
x_expanded = x.unsqueeze(2) | |
x_transposed = x.unsqueeze(1) | |
distances = torch.mean((x_expanded - x_transposed) ** 2, dim=-1) | |
similarity_matrix = torch.exp(-distances) # (batch_size, seq_len, seq_len) | |
# ์๊ธฐ ์ ์ฌ๋ ๋ง์คํฌ ์์ฑ ๋ฐ ์ ์ฉ (๊ฐ ์์ ์ ๋ํ ๋ง์คํฌ ๊ฐ๋ณ ์ ์ฉ) | |
if padding_mask is not None: | |
similarity_mask = padding_mask.unsqueeze(1) | padding_mask.unsqueeze(2) # (batch_size, seq_len, seq_len) | |
similarity_matrix = similarity_matrix.masked_fill(similarity_mask, 0.0) | |
# Process similarity matrix row by row using Conv1d | |
x2 = similarity_matrix.unsqueeze(1) # (batch_size, 1, seq_len, seq_len) | |
x2 = x2.view(batch_size * seq_len, 1, seq_len) # Reshape for Conv1d | |
x2 = self.similarity_projection(x2) # (batch_size * seq_len, hidden_dim, seq_len) | |
x2 = x2.mean(dim=2) # Pool across sequence dimension | |
x2 = x2.view(batch_size, seq_len, -1) # Reshape back | |
x2 = x2 + self.pos_encoding[:seq_len].unsqueeze(0) | |
if self.share_parameter: | |
x2 = self.transformer(x2, src_key_padding_mask=padding_mask) | |
else: | |
x2 = self.transformer(x2, src_key_padding_mask=padding_mask) # padding_mask ์ฌ์ฉ | |
# 3. Global average pooling for both streams | |
if padding_mask is not None: | |
mask_expanded = (~padding_mask).float().unsqueeze(-1) | |
x1 = (x1 * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1) | |
x2 = (x2 * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1) | |
else: | |
x1 = x1.mean(dim=1) | |
x2 = x2.mean(dim=1) | |
# 4. Combine both streams and classify | |
#x = x1 # only emb | |
#x = x2 # only structure | |
#x = torch.cat([x1, x2], dim=-1) | |
if self.mode == 'only_emb': | |
x = x1 | |
elif self.mode == 'only_structure': | |
x = x2 | |
elif self.mode == 'both': | |
x = torch.cat([x1, x2], dim=-1) | |
x = x.half() | |
return self.classification_head(x) | |