Spaces:
Running
Running
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import BartForSequenceClassification, BartTokenizer, get_linear_schedule_with_warmup | |
from transformers import AdamW | |
from tqdm import tqdm | |
import gc | |
from plugins.scansite import ScansitePlugin # Assurez-vous que l'import est correct | |
torch.cuda.empty_cache() | |
# Définition du dataset personnalisé | |
class PreferenceDataset(Dataset): | |
def __init__(self, data, tokenizer, max_length=128): | |
self.data = data | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
_, title, score = self.data[idx] | |
encoding = self.tokenizer(title, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt') | |
# Convertir le score en label binaire (0 ou 1) | |
label = 1 if score > 0 else 0 | |
return {key: val.squeeze(0) for key, val in encoding.items()}, torch.tensor(label, dtype=torch.long) | |
# Fonction principale de finetuning | |
def finetune(num_epochs=2, lr=2e-5, weight_decay=0.01, batch_size=1, model_name='facebook/bart-large-mnli', output_model='./bart-large-ft', num_warmup_steps=0): | |
print(f"Fine-tuning parameters:\n" | |
f"num_epochs: {num_epochs}\n" | |
f"learning rate (lr): {lr}\n" | |
f"weight_decay: {weight_decay}\n" | |
f"batch_size: {batch_size}\n" | |
f"model_name: {model_name}\n" | |
f"num_warmup_steps: {num_warmup_steps}") | |
# Récupérer les données de référence | |
scansite_plugin = ScansitePlugin("scansite", None) # Vous devrez peut-être ajuster ceci en fonction de votre structure | |
reference_data_valid, reference_data_rejected = scansite_plugin.get_reference_data() | |
# Combiner les données valides et rejetées | |
all_data = reference_data_valid + [(url, title, 0) for url, title in reference_data_rejected] | |
tokenizer = BartTokenizer.from_pretrained(model_name) | |
model = BartForSequenceClassification.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True) | |
dataset = PreferenceDataset(all_data, tokenizer) | |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) | |
total_steps = len(dataloader) * num_epochs | |
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
model.gradient_checkpointing_enable() | |
for epoch in range(num_epochs): | |
model.train() | |
for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"): | |
optimizer.zero_grad() | |
inputs, labels = batch | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
labels = labels.to(device) | |
outputs = model(**inputs, labels=labels) | |
loss = outputs.loss | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
optimizer.step() | |
scheduler.step() | |
del inputs, outputs, labels | |
torch.cuda.empty_cache() | |
gc.collect() | |
#print(f"Finetuning en cours round {epoch}/{num_epochs}") | |
print("Finetuning terminé sauvegarde en cours.") | |
model.save_pretrained(output_model) | |
tokenizer.save_pretrained(output_model) | |
print("Finetuning terminé et modèle sauvegardé.") | |
# Appel par défaut si le script est exécuté directement | |
if __name__ == "__main__": | |
finetune() | |