File size: 6,927 Bytes
ba6803f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
"""
inference_safetensors.py
Defines the architecture of the fine-tuned embedding model used for Off-Topic classification.
"""
import json
import torch
import sys
import torch.nn as nn
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoTokenizer, AutoModel
# Adapter for embeddings
class Adapter(nn.Module):
def __init__(self, hidden_size):
super(Adapter, self).__init__()
self.down_project = nn.Linear(hidden_size, hidden_size // 2)
self.activation = nn.ReLU()
self.up_project = nn.Linear(hidden_size // 2, hidden_size)
def forward(self, x):
down = self.down_project(x)
activated = self.activation(down)
up = self.up_project(activated)
return up + x # Residual connection
# Pool by attention score
class AttentionPooling(nn.Module):
def __init__(self, hidden_size):
super(AttentionPooling, self).__init__()
self.attention_weights = nn.Parameter(torch.randn(hidden_size))
def forward(self, hidden_states):
# hidden_states: [seq_len, batch_size, hidden_size]
scores = torch.matmul(hidden_states, self.attention_weights)
attention_weights = torch.softmax(scores, dim=0)
weighted_sum = torch.sum(attention_weights.unsqueeze(-1) * hidden_states, dim=0)
return weighted_sum
# Custom bi-encoder model with MLP layers for interaction
class CrossEncoderWithSharedBase(nn.Module):
def __init__(self, base_model, num_labels=2, num_heads=8):
super(CrossEncoderWithSharedBase, self).__init__()
# Shared pre-trained model
self.shared_encoder = base_model
hidden_size = self.shared_encoder.config.hidden_size
# Sentence-specific adapters
self.adapter1 = Adapter(hidden_size)
self.adapter2 = Adapter(hidden_size)
# Cross-attention layers
self.cross_attention_1_to_2 = nn.MultiheadAttention(hidden_size, num_heads)
self.cross_attention_2_to_1 = nn.MultiheadAttention(hidden_size, num_heads)
# Attention pooling layers
self.attn_pooling_1_to_2 = AttentionPooling(hidden_size)
self.attn_pooling_2_to_1 = AttentionPooling(hidden_size)
# Projection layer with non-linearity
self.projection_layer = nn.Sequential(
nn.Linear(hidden_size * 2, hidden_size),
nn.ReLU()
)
# Classifier with three hidden layers
self.classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size // 2, hidden_size // 4),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size // 4, num_labels)
)
def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):
# Encode sentences
outputs1 = self.shared_encoder(input_ids1, attention_mask=attention_mask1)
outputs2 = self.shared_encoder(input_ids2, attention_mask=attention_mask2)
# Apply sentence-specific adapters
embeds1 = self.adapter1(outputs1.last_hidden_state)
embeds2 = self.adapter2(outputs2.last_hidden_state)
# Transpose for attention layers
embeds1 = embeds1.transpose(0, 1)
embeds2 = embeds2.transpose(0, 1)
# Cross-attention
cross_attn_1_to_2, _ = self.cross_attention_1_to_2(embeds1, embeds2, embeds2)
cross_attn_2_to_1, _ = self.cross_attention_2_to_1(embeds2, embeds1, embeds1)
# Attention pooling
pooled_1_to_2 = self.attn_pooling_1_to_2(cross_attn_1_to_2)
pooled_2_to_1 = self.attn_pooling_2_to_1(cross_attn_2_to_1)
# Concatenate and project
combined = torch.cat((pooled_1_to_2, pooled_2_to_1), dim=1)
projected = self.projection_layer(combined)
# Classification
logits = self.classifier(projected)
return logits
# Load configuration file
repo_path = "govtech/jina-embeddings-v2-small-en-off-topic"
config_path = hf_hub_download(repo_id=repo_path, filename="config.json")
config_path = "config.json"
with open(config_path, 'r') as f:
config = json.load(f)
def predict(sentence1, sentence2):
"""
Predicts the label for a pair of sentences using a fine-tuned model with SafeTensors weights.
Args:
- sentence1 (str): The first input sentence.
- sentence2 (str): The second input sentence.
Returns:
tuple:
- predicted_label (int): The predicted label (e.g., 0 or 1).
- probabilities (numpy.ndarray): The probabilities for each class.
"""
# Load model configuration
model_name = config['classifier']['embedding']['model_name']
max_length = config['classifier']['embedding']['max_length']
model_weights_fp = config['classifier']['embedding']['model_weights_fp']
# Load tokenizer and base model
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModel.from_pretrained(model_name)
model = CrossEncoderWithSharedBase(base_model, num_labels=2)
# Load weights into the model
weights = load_file(model_weights_fp)
model.load_state_dict(weights)
model.to(device)
model.eval()
# Get inputs
inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length)
inputs2 = tokenizer(sentence2, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length)
input_ids1 = inputs1['input_ids'].to(device)
attention_mask1 = inputs1['attention_mask'].to(device)
input_ids2 = inputs2['input_ids'].to(device)
attention_mask2 = inputs2['attention_mask'].to(device)
# Get outputs
with torch.no_grad():
outputs = model(input_ids1=input_ids1, attention_mask1=attention_mask1,
input_ids2=input_ids2, attention_mask2=attention_mask2)
probabilities = torch.softmax(outputs, dim=1)
predicted_label = torch.argmax(probabilities, dim=1).item()
return predicted_label, probabilities.cpu().numpy()
if __name__ == "__main__":
# Load data
input_data = sys.argv[1]
sentence_pairs = json.loads(input_data)
# Validate input data format
if not all(isinstance(pair[0], str) and isinstance(pair[1], str) for pair in sentence_pairs):
raise ValueError("Each pair must contain two strings.")
for idx, (sentence1, sentence2) in enumerate(sentence_pairs):
# Generate prediction and scores
predicted_label, probabilities = predict(sentence1, sentence2)
# Print the results
print(f"Pair {idx + 1}:")
print(f" Sentence 1: {sentence1}")
print(f" Sentence 2: {sentence2}")
print(f" Predicted Label: {predicted_label}")
print(f" Probabilities: {probabilities}")
print('-' * 50)
|