File size: 4,025 Bytes
6e81596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

import torch
import wandb
from transformers import Trainer


class ORPOTrainer(Trainer):
    def __init__(self, alpha, pad, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pad = pad
        self.alpha = alpha
        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
        print("Pad Token ID: ", self.pad)
        
    def compute_custom_loss(self, logits, labels):
        
        logits = logits.contiguous()
        
        if labels is not None:
            # move labels to correct device to enable model parallelism
            labels = labels.to(logits.device)
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            # Flatten the tokens
            loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(dim=-1)
            
        return loss
    
    def compute_logps(self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits):
        mask = chosen_attention_mask[:, :-1] - prompt_attention_mask[:, 1:]
        per_token_logps = torch.gather(logits[:, :-1, :].log_softmax(-1), dim=2, 
                                       index=(mask * chosen_inputs[:, 1:]).unsqueeze(2)).squeeze(2)
        return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(dtype=torch.float64) / mask.sum(dim=1).to(dtype=torch.float64)
        
    def compute_loss(self, model, inputs, return_outputs=False):
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        
        # Generate the hidden states for 'chosen' and 'reject'
        neg_labels = inputs['negative_input_ids'].clone()
        pos_labels = inputs['positive_input_ids'].clone()

        neg_labels[neg_labels == self.pad] = -100
        pos_labels[pos_labels == self.pad] = -100

        outputs_neg = model(**{'input_ids': inputs['negative_input_ids'],
                               'attention_mask': inputs['negative_attention_mask'],
                               'labels': neg_labels,}, output_hidden_states=True)      
        outputs_pos = model(**{'input_ids': inputs['positive_input_ids'],
                               'attention_mask': inputs['positive_attention_mask'],
                               'labels': pos_labels,}, output_hidden_states=True)
            
        # Calculate NLL loss
        pos_loss = self.compute_custom_loss(logits=outputs_pos.logits, labels=inputs['positive_input_ids']) 

        # Calculate Log Probability
        pos_prob = self.compute_logps(prompt_attention_mask=inputs['attention_mask'], 
                                      chosen_inputs=inputs['positive_input_ids'], 
                                      chosen_attention_mask=inputs['positive_attention_mask'], 
                                      logits=outputs_pos.logits)
        neg_prob = self.compute_logps(prompt_attention_mask=inputs['attention_mask'], 
                                      chosen_inputs=inputs['negative_input_ids'], 
                                      chosen_attention_mask=inputs['negative_attention_mask'], 
                                      logits=outputs_neg.logits)

        # Calculate log odds
        log_odds = (pos_prob - neg_prob) - (torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)))
        sig_ratio = torch.nn.functional.sigmoid(log_odds)
        ratio = torch.log(sig_ratio)
        
        # Calculate the Final Loss
        loss = torch.mean(pos_loss - self.alpha * ratio).to(dtype=torch.bfloat16)
        
        wandb.log({'Positive Geometric Mean': torch.mean(pos_prob).item(),
                   'Negative Geometric Mean': torch.mean(neg_prob).item(),
                   'Log Odds Ratio': torch.mean(ratio).item(),
                   'Log Odds': torch.mean(log_odds).item()})
        
        return (loss, outputs_pos) if return_outputs else loss