File size: 3,651 Bytes
f34a6fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BartForSequenceClassification, BartTokenizer, get_linear_schedule_with_warmup
from transformers import AdamW
from tqdm import tqdm
import gc
from plugins.scansite import ScansitePlugin  # Assurez-vous que l'import est correct

torch.cuda.empty_cache()

# Définition du dataset personnalisé
class PreferenceDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        _, title, score = self.data[idx]
        encoding = self.tokenizer(title, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
        # Convertir le score en label binaire (0 ou 1)
        label = 1 if score > 0 else 0
        return {key: val.squeeze(0) for key, val in encoding.items()}, torch.tensor(label, dtype=torch.long)

# Fonction principale de finetuning
def finetune(num_epochs=2, lr=2e-5, weight_decay=0.01, batch_size=1, model_name='facebook/bart-large-mnli', output_model='./bart-large-ft', num_warmup_steps=0):
    print(f"Fine-tuning parameters:\n"
          f"num_epochs: {num_epochs}\n"
          f"learning rate (lr): {lr}\n"
          f"weight_decay: {weight_decay}\n"
          f"batch_size: {batch_size}\n"
          f"model_name: {model_name}\n"
          f"num_warmup_steps: {num_warmup_steps}")

    # Récupérer les données de référence
    scansite_plugin = ScansitePlugin("scansite", None)  # Vous devrez peut-être ajuster ceci en fonction de votre structure
    reference_data_valid, reference_data_rejected = scansite_plugin.get_reference_data()

    # Combiner les données valides et rejetées
    all_data = reference_data_valid + [(url, title, 0) for url, title in reference_data_rejected]

    tokenizer = BartTokenizer.from_pretrained(model_name)
    model = BartForSequenceClassification.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True)

    dataset = PreferenceDataset(all_data, tokenizer)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    total_steps = len(dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.gradient_checkpointing_enable()

    for epoch in range(num_epochs):
        model.train()
        for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
            optimizer.zero_grad()

            inputs, labels = batch
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)

            outputs = model(**inputs, labels=labels)
            loss = outputs.loss

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            del inputs, outputs, labels
            torch.cuda.empty_cache()
            gc.collect()
            #print(f"Finetuning en cours round {epoch}/{num_epochs}")
    print("Finetuning terminé sauvegarde en cours.")
    model.save_pretrained(output_model)
    tokenizer.save_pretrained(output_model)

    print("Finetuning terminé et modèle sauvegardé.")

# Appel par défaut si le script est exécuté directement
if __name__ == "__main__":
    finetune()