import json import torch from torch import nn from safetensors.torch import load_file from transformers import AutoModel, AutoTokenizer from huggingface_hub import hf_hub_download # Set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load the model state_dict from safetensors def load_model_safetensors(model, load_path="model.safetensors"): # Load the safetensors file state_dict = load_file(load_path) # Load the state dict into the model model.load_state_dict(state_dict) return model ################### # JINA EMBEDDINGS ################### # Jina Configs JINA_CONTEXT_LEN = 1024 # Adapter for embeddings class Adapter(nn.Module): def __init__(self, hidden_size): super(Adapter, self).__init__() self.down_project = nn.Linear(hidden_size, hidden_size // 2) self.activation = nn.ReLU() self.up_project = nn.Linear(hidden_size // 2, hidden_size) def forward(self, x): down = self.down_project(x) activated = self.activation(down) up = self.up_project(activated) return up + x # Residual connection # Pool by attention score class AttentionPooling(nn.Module): def __init__(self, hidden_size): super(AttentionPooling, self).__init__() self.attention_weights = nn.Parameter(torch.randn(hidden_size)) def forward(self, hidden_states): # hidden_states: [seq_len, batch_size, hidden_size] scores = torch.matmul(hidden_states, self.attention_weights) attention_weights = torch.softmax(scores, dim=0) weighted_sum = torch.sum(attention_weights.unsqueeze(-1) * hidden_states, dim=0) return weighted_sum # Custom bi-encoder model with MLP layers for interaction class CrossEncoderWithSharedBase(nn.Module): def __init__(self, base_model, num_labels=2, num_heads=8): super(CrossEncoderWithSharedBase, self).__init__() # Shared pre-trained model self.shared_encoder = base_model hidden_size = self.shared_encoder.config.hidden_size # Sentence-specific adapters self.adapter1 = Adapter(hidden_size) self.adapter2 = Adapter(hidden_size) # Cross-attention layers self.cross_attention_1_to_2 = nn.MultiheadAttention(hidden_size, num_heads) self.cross_attention_2_to_1 = nn.MultiheadAttention(hidden_size, num_heads) # Attention pooling layers self.attn_pooling_1_to_2 = AttentionPooling(hidden_size) self.attn_pooling_2_to_1 = AttentionPooling(hidden_size) # Projection layer with non-linearity self.projection_layer = nn.Sequential( nn.Linear(hidden_size * 2, hidden_size), nn.ReLU() ) # Classifier with three hidden layers self.classifier = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size // 2, hidden_size // 4), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size // 4, num_labels) ) def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2): # Encode sentences outputs1 = self.shared_encoder(input_ids1, attention_mask=attention_mask1) outputs2 = self.shared_encoder(input_ids2, attention_mask=attention_mask2) # Apply sentence-specific adapters embeds1 = self.adapter1(outputs1.last_hidden_state) embeds2 = self.adapter2(outputs2.last_hidden_state) # Transpose for attention layers embeds1 = embeds1.transpose(0, 1) embeds2 = embeds2.transpose(0, 1) # Cross-attention cross_attn_1_to_2, _ = self.cross_attention_1_to_2(embeds1, embeds2, embeds2) cross_attn_2_to_1, _ = self.cross_attention_2_to_1(embeds2, embeds1, embeds1) # Attention pooling pooled_1_to_2 = self.attn_pooling_1_to_2(cross_attn_1_to_2) pooled_2_to_1 = self.attn_pooling_2_to_1(cross_attn_2_to_1) # Concatenate and project combined = torch.cat((pooled_1_to_2, pooled_2_to_1), dim=1) projected = self.projection_layer(combined) # Classification logits = self.classifier(projected) return logits # Prediction function for embeddings relevance def embeddings_predict_relevance(sentence1, sentence2, model, tokenizer, device): model.eval() inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=1024) inputs2 = tokenizer(sentence2, return_tensors="pt", truncation=True, padding="max_length", max_length=1024) input_ids1 = inputs1['input_ids'].to(device) attention_mask1 = inputs1['attention_mask'].to(device) input_ids2 = inputs2['input_ids'].to(device) attention_mask2 = inputs2['attention_mask'].to(device) with torch.no_grad(): outputs = model(input_ids1=input_ids1, attention_mask1=attention_mask1, input_ids2=input_ids2, attention_mask2=attention_mask2) probabilities = torch.softmax(outputs, dim=1) predicted_label = torch.argmax(probabilities, dim=1).item() return predicted_label, probabilities.cpu().numpy() # Load configuration file jina_repo_path = "govtech/jina-embeddings-v2-small-en-off-topic" jina_config_path = hf_hub_download(repo_id=jina_repo_path, filename="config.json") with open(jina_config_path, 'r') as f: jina_config = json.load(f) # Load Jina model configuration JINA_MODEL_NAME = jina_config['classifier']['embedding']['model_name'] jina_model_weights_fp = jina_config['classifier']['embedding']['model_weights_fp'] # Load tokenizer and model jina_tokenizer = AutoTokenizer.from_pretrained(JINA_MODEL_NAME) jina_base_model = AutoModel.from_pretrained(JINA_MODEL_NAME) jina_model = CrossEncoderWithSharedBase(jina_base_model, num_labels=2) # Load model weights from safetensors jina_model_weights_path = hf_hub_download(repo_id=jina_repo_path, filename=jina_model_weights_fp) jina_model = load_model_safetensors(jina_model, jina_model_weights_path) ################# # CROSS-ENCODER ################# # STSB Configuration STSB_CONTEXT_LEN = 512 class CrossEncoderWithMLP(nn.Module): def __init__(self, base_model, num_labels=2): super(CrossEncoderWithMLP, self).__init__() # Existing cross-encoder model self.base_model = base_model # Hidden size of the base model hidden_size = base_model.config.hidden_size # MLP layers after combining the cross-encoders self.mlp = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), # Input: a single sentence nn.ReLU(), nn.Linear(hidden_size // 2, hidden_size // 4), # Reduce the size of the layer nn.ReLU() ) # Classifier head self.classifier = nn.Linear(hidden_size // 4, num_labels) def forward(self, input_ids, attention_mask): # Encode the pair of sentences in one pass outputs = self.base_model(input_ids, attention_mask) pooled_output = outputs.pooler_output # Pass the pooled output through mlp layers mlp_output = self.mlp(pooled_output) # Pass the final MLP output through the classifier logits = self.classifier(mlp_output) return logits # Prediction function for cross-encoder def cross_encoder_predict_relevance(sentence1, sentence2, model, tokenizer, device): model.eval() # Tokenize the pair of sentences encoding = tokenizer( sentence1, sentence2, # Takes in a two sentences as a pair return_tensors="pt", truncation=True, padding="max_length", max_length=512, return_token_type_ids=False ) # Extract the input_ids and attention mask input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) with torch.no_grad(): outputs = model( input_ids=input_ids, attention_mask=attention_mask ) # Returns logits # Convert raw logits into probabilities for each class and get the predicted label probabilities = torch.softmax(outputs, dim=1) predicted_label = torch.argmax(probabilities, dim=1).item() return predicted_label, probabilities.cpu().numpy() # Load STSB model configuration stsb_repo_path = "govtech/stsb-roberta-base-off-topic" stsb_config_path = hf_hub_download(repo_id=stsb_repo_path, filename="config.json") with open(stsb_config_path, 'r') as f: stsb_config = json.load(f) STSB_MODEL_NAME = stsb_config['classifier']['embedding']['model_name'] stsb_model_weights_fp = stsb_config['classifier']['embedding']['model_weights_fp'] # Load STSB tokenizer and model stsb_tokenizer = AutoTokenizer.from_pretrained(STSB_MODEL_NAME) stsb_base_model = AutoModel.from_pretrained(STSB_MODEL_NAME) stsb_model = CrossEncoderWithMLP(stsb_base_model, num_labels=2) # Load model weights from safetensors for STSB stsb_model_weights_path = hf_hub_download(repo_id=stsb_repo_path, filename=stsb_model_weights_fp) stsb_model = load_model_safetensors(stsb_model, stsb_model_weights_path)