|
import torch |
|
import torch.optim as optim |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader |
|
from sklearn.metrics import accuracy_score |
|
from utils.dataset import load_audio_data |
|
from utils.preprocessing import preprocess_audio, prepare_features, collate_fn |
|
from model.emotion_classifier import EmotionClassifier |
|
from config import DEVICE, NUM_LABELS, BEST_MODEL_NAME |
|
import os |
|
|
|
|
|
data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "dataset")) |
|
ds = load_audio_data(data_dir) |
|
|
|
|
|
ds["train"] = ds["train"].map(preprocess_audio).map(lambda batch: prepare_features(batch, max_length=128)) |
|
ds["test"] = ds["test"].map(preprocess_audio).map(lambda batch: prepare_features(batch, max_length=128)) |
|
|
|
|
|
train_loader = DataLoader(ds["train"], batch_size=8, shuffle=True, collate_fn=collate_fn) |
|
test_loader = DataLoader(ds["test"], batch_size=8, shuffle=False, collate_fn=collate_fn) |
|
|
|
|
|
classifier = EmotionClassifier(feature_dim=40, num_labels=NUM_LABELS).to(DEVICE) |
|
|
|
|
|
def train_classifier(classifier, train_loader, test_loader, epochs=20): |
|
optimizer = optim.AdamW(classifier.parameters(), lr=2e-5, weight_decay=0.01) |
|
loss_fn = nn.CrossEntropyLoss() |
|
best_accuracy = 0.0 |
|
|
|
for epoch in range(epochs): |
|
classifier.train() |
|
total_loss, correct = 0, 0 |
|
|
|
for inputs, labels in train_loader: |
|
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) |
|
optimizer.zero_grad() |
|
|
|
logits = classifier(inputs) |
|
loss = loss_fn(logits, labels) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
total_loss += loss.item() |
|
correct += (logits.argmax(dim=-1) == labels).sum().item() |
|
|
|
train_acc = correct / len(train_loader.dataset) |
|
|
|
if train_acc > best_accuracy: |
|
best_accuracy = train_acc |
|
torch.save(classifier.state_dict(), BEST_MODEL_NAME) |
|
print(f"✔️ Nouveau meilleur modèle sauvegardé ! Accuracy: {best_accuracy:.4f}") |
|
|
|
print(f"📢 Epoch {epoch+1}/{epochs} - Loss: {total_loss:.4f} - Accuracy: {train_acc:.4f}") |
|
|
|
return classifier |
|
|
|
|
|
def evaluate(model, test_loader): |
|
model.eval() |
|
all_preds, all_labels = [], [] |
|
|
|
with torch.no_grad(): |
|
for inputs, labels in test_loader: |
|
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) |
|
|
|
logits = model(inputs) |
|
preds = torch.argmax(logits, dim=-1).cpu().numpy() |
|
all_preds.extend(preds) |
|
all_labels.extend(labels.cpu().numpy()) |
|
|
|
return accuracy_score(all_labels, all_preds) |
|
|
|
|
|
trained_classifier = train_classifier(classifier, train_loader, test_loader, epochs=20) |
|
|
|
print("✅ Entraînement terminé, le meilleur modèle a été sauvegardé !") |
|
|