File size: 3,747 Bytes
f34a6fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
from torch.utils.data import Dataset, DataLoader
from sentence_transformers import SentenceTransformer, losses
from tqdm import tqdm
import gc
from plugins.scansite import ScansitePlugin

torch.cuda.empty_cache()

class PreferenceDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        url, title, score = self.data[idx]
        encoded = self.tokenizer(title, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt")
        return {key: val.squeeze(0) for key, val in encoded.items()}, torch.tensor(score, dtype=torch.float)

def collate_fn(batch):
    input_ids = torch.stack([item[0]['input_ids'] for item in batch])
    attention_mask = torch.stack([item[0]['attention_mask'] for item in batch])
    scores = torch.stack([item[1] for item in batch])
    return {'input_ids': input_ids, 'attention_mask': attention_mask}, scores

def finetune(model_name='nomic-ai/nomic-embed-text-v1', output_model_name="embeddings-ft", num_epochs=2, learning_rate=2e-5, weight_decay=0.01, batch_size=8, num_warmup_steps=0):
    print(f"Fine-tuning parameters:\n"
          f"num_epochs: {num_epochs}\n"
          f"learning rate (lr): {learning_rate}\n"
          f"weight_decay: {weight_decay}\n"
          f"batch_size: {batch_size}\n"
          f"model_name: {model_name}\n"
          f"num_warmup_steps: {num_warmup_steps}")

    scansite_plugin = ScansitePlugin("scansite", None)
    reference_data_valid, reference_data_rejected = scansite_plugin.get_reference_data()

    valid_data_with_scores = [(url, title, (score - 1) / 8 + 0.5) for url, title, score in reference_data_valid]
    rejected_data_with_scores = [(url, title, 0.0) for url, title in reference_data_rejected]

    all_data = valid_data_with_scores + rejected_data_with_scores

    model = SentenceTransformer(model_name, trust_remote_code=True)
    tokenizer = model.tokenizer

    dataset = PreferenceDataset(all_data, tokenizer)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

    loss_function = torch.nn.MSELoss()

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    total_steps = len(dataloader) * num_epochs
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=total_steps)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
            input_data, scores = batch
            input_data = {k: v.to(device) for k, v in input_data.items()}
            scores = scores.to(device)

            optimizer.zero_grad()

            embeddings = model(input_data)['sentence_embedding']

            # Calcul de la similarité cosinus
            embeddings_norm = torch.nn.functional.normalize(embeddings, p=2, dim=1)
            cosine_similarities = torch.sum(embeddings_norm, dim=1)

            # Calcul de la perte
            loss = loss_function(cosine_similarities, scores)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            del embeddings, cosine_similarities
            torch.cuda.empty_cache()
            gc.collect()

    model.save(output_model_name)

    print("Finetuning terminé et modèle sauvegardé.")

if __name__ == "__main__":
    finetune()