Spaces:
Running
Running
File size: 9,170 Bytes
70c7861 3c2639a 70c7861 3c2639a 70c7861 3c2639a 70c7861 3c2639a 70c7861 3c2639a 70c7861 3c2639a 70c7861 3c2639a 70c7861 3c2639a 70c7861 3c2639a 70c7861 3c2639a 70c7861 3c2639a 70c7861 3c2639a 70c7861 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
import json
import torch
from torch import nn
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the model state_dict from safetensors
def load_model_safetensors(model, load_path="model.safetensors"):
# Load the safetensors file
state_dict = load_file(load_path)
# Load the state dict into the model
model.load_state_dict(state_dict)
return model
###################
# JINA EMBEDDINGS
###################
# Jina Configs
JINA_CONTEXT_LEN = 1024
# Adapter for embeddings
class Adapter(nn.Module):
def __init__(self, hidden_size):
super(Adapter, self).__init__()
self.down_project = nn.Linear(hidden_size, hidden_size // 2)
self.activation = nn.ReLU()
self.up_project = nn.Linear(hidden_size // 2, hidden_size)
def forward(self, x):
down = self.down_project(x)
activated = self.activation(down)
up = self.up_project(activated)
return up + x # Residual connection
# Pool by attention score
class AttentionPooling(nn.Module):
def __init__(self, hidden_size):
super(AttentionPooling, self).__init__()
self.attention_weights = nn.Parameter(torch.randn(hidden_size))
def forward(self, hidden_states):
# hidden_states: [seq_len, batch_size, hidden_size]
scores = torch.matmul(hidden_states, self.attention_weights)
attention_weights = torch.softmax(scores, dim=0)
weighted_sum = torch.sum(attention_weights.unsqueeze(-1) * hidden_states, dim=0)
return weighted_sum
# Custom bi-encoder model with MLP layers for interaction
class CrossEncoderWithSharedBase(nn.Module):
def __init__(self, base_model, num_labels=2, num_heads=8):
super(CrossEncoderWithSharedBase, self).__init__()
# Shared pre-trained model
self.shared_encoder = base_model
hidden_size = self.shared_encoder.config.hidden_size
# Sentence-specific adapters
self.adapter1 = Adapter(hidden_size)
self.adapter2 = Adapter(hidden_size)
# Cross-attention layers
self.cross_attention_1_to_2 = nn.MultiheadAttention(hidden_size, num_heads)
self.cross_attention_2_to_1 = nn.MultiheadAttention(hidden_size, num_heads)
# Attention pooling layers
self.attn_pooling_1_to_2 = AttentionPooling(hidden_size)
self.attn_pooling_2_to_1 = AttentionPooling(hidden_size)
# Projection layer with non-linearity
self.projection_layer = nn.Sequential(
nn.Linear(hidden_size * 2, hidden_size),
nn.ReLU()
)
# Classifier with three hidden layers
self.classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size // 2, hidden_size // 4),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size // 4, num_labels)
)
def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):
# Encode sentences
outputs1 = self.shared_encoder(input_ids1, attention_mask=attention_mask1)
outputs2 = self.shared_encoder(input_ids2, attention_mask=attention_mask2)
# Apply sentence-specific adapters
embeds1 = self.adapter1(outputs1.last_hidden_state)
embeds2 = self.adapter2(outputs2.last_hidden_state)
# Transpose for attention layers
embeds1 = embeds1.transpose(0, 1)
embeds2 = embeds2.transpose(0, 1)
# Cross-attention
cross_attn_1_to_2, _ = self.cross_attention_1_to_2(embeds1, embeds2, embeds2)
cross_attn_2_to_1, _ = self.cross_attention_2_to_1(embeds2, embeds1, embeds1)
# Attention pooling
pooled_1_to_2 = self.attn_pooling_1_to_2(cross_attn_1_to_2)
pooled_2_to_1 = self.attn_pooling_2_to_1(cross_attn_2_to_1)
# Concatenate and project
combined = torch.cat((pooled_1_to_2, pooled_2_to_1), dim=1)
projected = self.projection_layer(combined)
# Classification
logits = self.classifier(projected)
return logits
# Prediction function for embeddings relevance
def embeddings_predict_relevance(sentence1, sentence2, model, tokenizer, device):
model.eval()
inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
inputs2 = tokenizer(sentence2, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
input_ids1 = inputs1['input_ids'].to(device)
attention_mask1 = inputs1['attention_mask'].to(device)
input_ids2 = inputs2['input_ids'].to(device)
attention_mask2 = inputs2['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids1=input_ids1, attention_mask1=attention_mask1,
input_ids2=input_ids2, attention_mask2=attention_mask2)
probabilities = torch.softmax(outputs, dim=1)
predicted_label = torch.argmax(probabilities, dim=1).item()
return predicted_label, probabilities.cpu().numpy()
# Load configuration file
jina_repo_path = "govtech/jina-embeddings-v2-small-en-off-topic"
jina_config_path = hf_hub_download(repo_id=jina_repo_path, filename="config.json")
with open(jina_config_path, 'r') as f:
jina_config = json.load(f)
# Load Jina model configuration
JINA_MODEL_NAME = jina_config['classifier']['embedding']['model_name']
jina_model_weights_fp = jina_config['classifier']['embedding']['model_weights_fp']
# Load tokenizer and model
jina_tokenizer = AutoTokenizer.from_pretrained(JINA_MODEL_NAME)
jina_base_model = AutoModel.from_pretrained(JINA_MODEL_NAME)
jina_model = CrossEncoderWithSharedBase(jina_base_model, num_labels=2)
# Load model weights from safetensors
jina_model_weights_path = hf_hub_download(repo_id=jina_repo_path, filename=jina_model_weights_fp)
jina_model = load_model_safetensors(jina_model, jina_model_weights_path)
#################
# CROSS-ENCODER
#################
# STSB Configuration
STSB_CONTEXT_LEN = 512
class CrossEncoderWithMLP(nn.Module):
def __init__(self, base_model, num_labels=2):
super(CrossEncoderWithMLP, self).__init__()
# Existing cross-encoder model
self.base_model = base_model
# Hidden size of the base model
hidden_size = base_model.config.hidden_size
# MLP layers after combining the cross-encoders
self.mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2), # Input: a single sentence
nn.ReLU(),
nn.Linear(hidden_size // 2, hidden_size // 4), # Reduce the size of the layer
nn.ReLU()
)
# Classifier head
self.classifier = nn.Linear(hidden_size // 4, num_labels)
def forward(self, input_ids, attention_mask):
# Encode the pair of sentences in one pass
outputs = self.base_model(input_ids, attention_mask)
pooled_output = outputs.pooler_output
# Pass the pooled output through mlp layers
mlp_output = self.mlp(pooled_output)
# Pass the final MLP output through the classifier
logits = self.classifier(mlp_output)
return logits
# Prediction function for cross-encoder
def cross_encoder_predict_relevance(sentence1, sentence2, model, tokenizer, device):
model.eval()
# Tokenize the pair of sentences
encoding = tokenizer(
sentence1, sentence2, # Takes in a two sentences as a pair
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=512,
return_token_type_ids=False
)
# Extract the input_ids and attention mask
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
with torch.no_grad():
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
) # Returns logits
# Convert raw logits into probabilities for each class and get the predicted label
probabilities = torch.softmax(outputs, dim=1)
predicted_label = torch.argmax(probabilities, dim=1).item()
return predicted_label, probabilities.cpu().numpy()
# Load STSB model configuration
stsb_repo_path = "govtech/stsb-roberta-base-off-topic"
stsb_config_path = hf_hub_download(repo_id=stsb_repo_path, filename="config.json")
with open(stsb_config_path, 'r') as f:
stsb_config = json.load(f)
STSB_MODEL_NAME = stsb_config['classifier']['embedding']['model_name']
stsb_model_weights_fp = stsb_config['classifier']['embedding']['model_weights_fp']
# Load STSB tokenizer and model
stsb_tokenizer = AutoTokenizer.from_pretrained(STSB_MODEL_NAME)
stsb_base_model = AutoModel.from_pretrained(STSB_MODEL_NAME)
stsb_model = CrossEncoderWithMLP(stsb_base_model, num_labels=2)
# Load model weights from safetensors for STSB
stsb_model_weights_path = hf_hub_download(repo_id=stsb_repo_path, filename=stsb_model_weights_fp)
stsb_model = load_model_safetensors(stsb_model, stsb_model_weights_path)
|