from transformers import AutoModel, AutoTokenizer, AutoConfig from transformers import PreTrainedModel, PretrainedConfig from transformers import CONFIG_MAPPING, MODEL_MAPPING import torch import torch.nn.functional as F import torch.nn as nn class JinaJudgeConfig(PretrainedConfig): model_type = "jina-judge" def __init__(self, n_classes=3, hidden_dim=512, num_decoder_layers=5, nhead=8, dropout_prob=0.2, **kwargs): super().__init__(**kwargs) self.n_classes = n_classes self.hidden_dim = hidden_dim self.num_decoder_layers = num_decoder_layers self.nhead = nhead self.dropout_prob = dropout_prob class JinaJudge(PreTrainedModel): config_class = JinaJudgeConfig def __init__(self, config: JinaJudgeConfig): super().__init__(config) self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True) jina_config = AutoConfig.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True) self.encoder = AutoModel.from_config(jina_config, trust_remote_code=True, torch_dtype=torch.bfloat16) self.encoder.lora_main_params_trainable = True self.projection = nn.Linear(self.encoder.config.hidden_size, config.hidden_dim) # Transformer Decoder Layer decoder_layer = nn.TransformerDecoderLayer( d_model=config.hidden_dim, nhead=config.nhead, dim_feedforward=config.hidden_dim * 2, dropout=config.dropout_prob ) # Transformer Decoder self.decoder = nn.TransformerDecoder( decoder_layer, num_layers=config.num_decoder_layers ) # Embedding for a single token as the initial input to the decoder self.decoder_input_embedding = nn.Parameter( torch.randn(1, 1, config.hidden_dim,) ) # Classification head self.classification_head = nn.Linear(config.hidden_dim, config.n_classes) def forward(self, prompts): inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(self.device) encoder_outputs = self.encoder(**inputs) encoder_hidden_states = encoder_outputs.last_hidden_state.float() encoder_hidden_states = self.projection(encoder_hidden_states) encoder_padding_mask = (inputs["attention_mask"] == 0).to(self.device) batch_size = encoder_hidden_states.size(0) decoder_input = self.decoder_input_embedding.expand(1, batch_size, -1).to(self.device) decoder_output = self.decoder( tgt=decoder_input, memory=encoder_hidden_states.transpose(0, 1), memory_key_padding_mask=encoder_padding_mask ).squeeze(0) logits = self.classification_head(decoder_output) return logits AutoConfig.register("jina-judge", JinaJudgeConfig) AutoModel.register(JinaJudgeConfig, JinaJudge)