OlympIA / bart_ft.py
johannoriel's picture
Initial relase. Tested. Working
f34a6fd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BartForSequenceClassification, BartTokenizer, get_linear_schedule_with_warmup
from transformers import AdamW
from tqdm import tqdm
import gc
from plugins.scansite import ScansitePlugin # Assurez-vous que l'import est correct
torch.cuda.empty_cache()
# Définition du dataset personnalisé
class PreferenceDataset(Dataset):
def __init__(self, data, tokenizer, max_length=128):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
_, title, score = self.data[idx]
encoding = self.tokenizer(title, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
# Convertir le score en label binaire (0 ou 1)
label = 1 if score > 0 else 0
return {key: val.squeeze(0) for key, val in encoding.items()}, torch.tensor(label, dtype=torch.long)
# Fonction principale de finetuning
def finetune(num_epochs=2, lr=2e-5, weight_decay=0.01, batch_size=1, model_name='facebook/bart-large-mnli', output_model='./bart-large-ft', num_warmup_steps=0):
print(f"Fine-tuning parameters:\n"
f"num_epochs: {num_epochs}\n"
f"learning rate (lr): {lr}\n"
f"weight_decay: {weight_decay}\n"
f"batch_size: {batch_size}\n"
f"model_name: {model_name}\n"
f"num_warmup_steps: {num_warmup_steps}")
# Récupérer les données de référence
scansite_plugin = ScansitePlugin("scansite", None) # Vous devrez peut-être ajuster ceci en fonction de votre structure
reference_data_valid, reference_data_rejected = scansite_plugin.get_reference_data()
# Combiner les données valides et rejetées
all_data = reference_data_valid + [(url, title, 0) for url, title in reference_data_rejected]
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForSequenceClassification.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True)
dataset = PreferenceDataset(all_data, tokenizer)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
total_steps = len(dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.gradient_checkpointing_enable()
for epoch in range(num_epochs):
model.train()
for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
optimizer.zero_grad()
inputs, labels = batch
inputs = {k: v.to(device) for k, v in inputs.items()}
labels = labels.to(device)
outputs = model(**inputs, labels=labels)
loss = outputs.loss
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
del inputs, outputs, labels
torch.cuda.empty_cache()
gc.collect()
#print(f"Finetuning en cours round {epoch}/{num_epochs}")
print("Finetuning terminé sauvegarde en cours.")
model.save_pretrained(output_model)
tokenizer.save_pretrained(output_model)
print("Finetuning terminé et modèle sauvegardé.")
# Appel par défaut si le script est exécuté directement
if __name__ == "__main__":
finetune()