File size: 2,986 Bytes
9f606aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from transformers import AutoModel, AutoTokenizer, AutoConfig
from transformers import PreTrainedModel, PretrainedConfig
from transformers import CONFIG_MAPPING, MODEL_MAPPING
import torch
import torch.nn.functional as F
import torch.nn as nn


class JinaJudgeConfig(PretrainedConfig):
    model_type = "jina-judge"

    def __init__(self, n_classes=3, hidden_dim=512, num_decoder_layers=5, nhead=8, dropout_prob=0.2, **kwargs):
        super().__init__(**kwargs)
        self.n_classes = n_classes
        self.hidden_dim = hidden_dim
        self.num_decoder_layers = num_decoder_layers
        self.nhead = nhead
        self.dropout_prob = dropout_prob


class JinaJudge(PreTrainedModel):
    config_class = JinaJudgeConfig

    def __init__(self, config: JinaJudgeConfig):
        super().__init__(config)
        self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True)
        jina_config = AutoConfig.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True)
        self.encoder = AutoModel.from_config(jina_config, trust_remote_code=True, torch_dtype=torch.bfloat16)
        self.encoder.lora_main_params_trainable = True

        self.projection = nn.Linear(self.encoder.config.hidden_size, config.hidden_dim)
        # Transformer Decoder Layer
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=config.hidden_dim,
            nhead=config.nhead,
            dim_feedforward=config.hidden_dim * 2,
            dropout=config.dropout_prob
        )
        
        # Transformer Decoder
        self.decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=config.num_decoder_layers
        )
        
        # Embedding for a single token as the initial input to the decoder
        self.decoder_input_embedding = nn.Parameter(
            torch.randn(1, 1, config.hidden_dim,)
        )

        # Classification head
        self.classification_head = nn.Linear(config.hidden_dim, config.n_classes)
    
    def forward(self, prompts):
        inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(self.device)
        encoder_outputs = self.encoder(**inputs)
        encoder_hidden_states = encoder_outputs.last_hidden_state.float()
        encoder_hidden_states = self.projection(encoder_hidden_states)

        encoder_padding_mask = (inputs["attention_mask"] == 0).to(self.device)
        
        batch_size = encoder_hidden_states.size(0)
        decoder_input = self.decoder_input_embedding.expand(1, batch_size, -1).to(self.device)

        decoder_output = self.decoder(
            tgt=decoder_input,
            memory=encoder_hidden_states.transpose(0, 1),
            memory_key_padding_mask=encoder_padding_mask
        ).squeeze(0)
        
        logits = self.classification_head(decoder_output)
        return logits


AutoConfig.register("jina-judge", JinaJudgeConfig)
AutoModel.register(JinaJudgeConfig, JinaJudge)