dance-classifier / train.py
waidhoferj's picture
first commit
c914273
raw
history blame
7.22 kB
import datetime
import os
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from tqdm import tqdm
import pandas as pd
import numpy as np
from torch.utils.data import random_split, SubsetRandomSampler
import json
from sklearn.model_selection import KFold
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from preprocessing.dataset import SongDataset
from preprocessing.preprocess import get_examples
from dancer_net.dancer_net import ShortChunkCNN
DEVICE = "mps"
SEED = 42
def get_timestamp() -> str:
return datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
class EarlyStopping:
def __init__(self, patience=0):
self.patience = patience
self.last_measure = np.inf
self.consecutive_increase = 0
def step(self, val) -> bool:
if self.last_measure <= val:
self.consecutive_increase +=1
else:
self.consecutive_increase = 0
self.last_measure = val
return self.patience < self.consecutive_increase
def calculate_metrics(pred, target, threshold=0.5, prefix=""):
target = target.detach().cpu().numpy()
pred = pred.detach().cpu().numpy()
pred = np.array(pred > threshold, dtype=float)
metrics= {
'precision': precision_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
'recall': recall_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
'f1': f1_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
'accuracy': accuracy_score(y_true=target, y_pred=pred),
}
if prefix != "":
metrics = {prefix + k : v for k, v in metrics.items()}
return metrics
def evaluate(model:nn.Module, data_loader:DataLoader, criterion, device="mps") -> pd.Series:
val_metrics = []
for features, labels in (prog_bar := tqdm(data_loader)):
features = features.to(device)
labels = labels.to(device)
with torch.no_grad():
outputs = model(features)
loss = criterion(outputs, labels)
batch_metrics = calculate_metrics(outputs, labels, prefix="val_")
batch_metrics["val_loss"] = loss.item()
prog_bar.set_description(f'Validation - Loss: {batch_metrics["val_loss"]:.2f}, Accuracy: {batch_metrics["val_accuracy"]:.2f}')
val_metrics.append(batch_metrics)
return pd.DataFrame(val_metrics).mean()
def train(
model: nn.Module,
data_loader: DataLoader,
val_loader=None,
epochs=3,
lr=1e-3,
device="mps"):
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=lr)
early_stop = EarlyStopping(1)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr,
steps_per_epoch=int(len(data_loader)),
epochs=epochs,
anneal_strategy='linear')
metrics = []
for epoch in range(1,epochs+1):
train_metrics = []
prog_bar = tqdm(data_loader)
for features, labels in prog_bar:
features = features.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(features)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
scheduler.step()
batch_metrics = calculate_metrics(outputs, labels)
batch_metrics["loss"] = loss.item()
train_metrics.append(batch_metrics)
prog_bar.set_description(f'Training - Epoch: {epoch}/{epochs}, Loss: {batch_metrics["loss"]:.2f}, Accuracy: {batch_metrics["accuracy"]:.2f}')
train_metrics = pd.DataFrame(train_metrics).mean()
if val_loader is not None:
val_metrics = evaluate(model, val_loader, criterion)
if early_stop.step(val_metrics["val_f1"]):
break
epoch_metrics = pd.concat([train_metrics, val_metrics], axis=0)
else:
epoch_metrics = train_metrics
metrics.append(dict(epoch_metrics))
return model, metrics
def cross_validation(seed=42, batch_size=64, k=5, device="mps"):
target_classes = ['ATN',
'BBA',
'BCH',
'BLU',
'CHA',
'CMB',
'CSG',
'ECS',
'HST',
'JIV',
'LHP',
'QST',
'RMB',
'SFT',
'SLS',
'SMB',
'SWZ',
'TGO',
'VWZ',
'WCS']
df = pd.read_csv("data/songs.csv")
x,y = get_examples(df, "data/samples",class_list=target_classes)
dataset = SongDataset(x,y)
splits=KFold(n_splits=k,shuffle=True,random_state=seed)
metrics = []
for fold, (train_idx,val_idx) in enumerate(splits.split(x,y)):
print(f"Fold {fold+1}")
train_sampler = SubsetRandomSampler(train_idx)
test_sampler = SubsetRandomSampler(val_idx)
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
n_classes = len(y[0])
model = ShortChunkCNN(n_class=n_classes).to(device)
model, _ = train(model,train_loader, epochs=2, device=device)
val_metrics = evaluate(model, test_loader, nn.BCELoss())
metrics.append(val_metrics)
metrics = pd.DataFrame(metrics)
log_dir = os.path.join(
"logs", get_timestamp()
)
os.makedirs(log_dir, exist_ok=True)
metrics.to_csv(model.state_dict(), os.path.join(log_dir, "cross_val.csv"))
def train_model():
target_classes = ['ATN',
'BBA',
'BCH',
'BLU',
'CHA',
'CMB',
'CSG',
'ECS',
'HST',
'JIV',
'LHP',
'QST',
'RMB',
'SFT',
'SLS',
'SMB',
'SWZ',
'TGO',
'VWZ',
'WCS']
df = pd.read_csv("data/songs.csv")
x,y = get_examples(df, "data/samples",class_list=target_classes)
dataset = SongDataset(x,y)
train_count = int(len(dataset) * 0.9)
datasets = random_split(dataset, [train_count, len(dataset) - train_count], torch.Generator().manual_seed(SEED))
data_loaders = [DataLoader(data, batch_size=64, shuffle=True) for data in datasets]
train_data, val_data = data_loaders
example_spec, example_label = dataset[0]
n_classes = len(example_label)
model = ShortChunkCNN(n_class=n_classes).to(DEVICE)
model, metrics = train(model,train_data, val_data, epochs=3, device=DEVICE)
log_dir = os.path.join(
"logs", get_timestamp()
)
os.makedirs(log_dir, exist_ok=True)
torch.save(model.state_dict(), os.path.join(log_dir, "dancer_net.pt"))
metrics = pd.DataFrame(metrics)
metrics.to_csv(os.path.join(log_dir, "metrics.csv"))
config = {
"classes": target_classes
}
with open(os.path.join(log_dir, "config.json")) as f:
json.dump(config, f)
print("Training information saved!")
if __name__ == "__main__":
cross_validation()