|
""" |
|
inference_safetensors.py |
|
|
|
Defines the architecture of the fine-tuned embedding model used for Off-Topic classification. |
|
""" |
|
import json |
|
import torch |
|
import sys |
|
import torch.nn as nn |
|
|
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
class Adapter(nn.Module): |
|
def __init__(self, hidden_size): |
|
super(Adapter, self).__init__() |
|
self.down_project = nn.Linear(hidden_size, hidden_size // 2) |
|
self.activation = nn.ReLU() |
|
self.up_project = nn.Linear(hidden_size // 2, hidden_size) |
|
|
|
def forward(self, x): |
|
down = self.down_project(x) |
|
activated = self.activation(down) |
|
up = self.up_project(activated) |
|
return up + x |
|
|
|
|
|
class AttentionPooling(nn.Module): |
|
def __init__(self, hidden_size): |
|
super(AttentionPooling, self).__init__() |
|
self.attention_weights = nn.Parameter(torch.randn(hidden_size)) |
|
|
|
def forward(self, hidden_states): |
|
|
|
scores = torch.matmul(hidden_states, self.attention_weights) |
|
attention_weights = torch.softmax(scores, dim=0) |
|
weighted_sum = torch.sum(attention_weights.unsqueeze(-1) * hidden_states, dim=0) |
|
return weighted_sum |
|
|
|
|
|
class CrossEncoderWithSharedBase(nn.Module): |
|
def __init__(self, base_model, num_labels=2, num_heads=8): |
|
super(CrossEncoderWithSharedBase, self).__init__() |
|
|
|
self.shared_encoder = base_model |
|
hidden_size = self.shared_encoder.config.hidden_size |
|
|
|
self.adapter1 = Adapter(hidden_size) |
|
self.adapter2 = Adapter(hidden_size) |
|
|
|
self.cross_attention_1_to_2 = nn.MultiheadAttention(hidden_size, num_heads) |
|
self.cross_attention_2_to_1 = nn.MultiheadAttention(hidden_size, num_heads) |
|
|
|
self.attn_pooling_1_to_2 = AttentionPooling(hidden_size) |
|
self.attn_pooling_2_to_1 = AttentionPooling(hidden_size) |
|
|
|
self.projection_layer = nn.Sequential( |
|
nn.Linear(hidden_size * 2, hidden_size), |
|
nn.ReLU() |
|
) |
|
|
|
self.classifier = nn.Sequential( |
|
nn.Linear(hidden_size, hidden_size // 2), |
|
nn.ReLU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(hidden_size // 2, hidden_size // 4), |
|
nn.ReLU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(hidden_size // 4, num_labels) |
|
) |
|
def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2): |
|
|
|
outputs1 = self.shared_encoder(input_ids1, attention_mask=attention_mask1) |
|
outputs2 = self.shared_encoder(input_ids2, attention_mask=attention_mask2) |
|
|
|
embeds1 = self.adapter1(outputs1.last_hidden_state) |
|
embeds2 = self.adapter2(outputs2.last_hidden_state) |
|
|
|
embeds1 = embeds1.transpose(0, 1) |
|
embeds2 = embeds2.transpose(0, 1) |
|
|
|
cross_attn_1_to_2, _ = self.cross_attention_1_to_2(embeds1, embeds2, embeds2) |
|
cross_attn_2_to_1, _ = self.cross_attention_2_to_1(embeds2, embeds1, embeds1) |
|
|
|
pooled_1_to_2 = self.attn_pooling_1_to_2(cross_attn_1_to_2) |
|
pooled_2_to_1 = self.attn_pooling_2_to_1(cross_attn_2_to_1) |
|
|
|
combined = torch.cat((pooled_1_to_2, pooled_2_to_1), dim=1) |
|
projected = self.projection_layer(combined) |
|
|
|
logits = self.classifier(projected) |
|
return logits |
|
|
|
|
|
repo_path = "govtech/jina-embeddings-v2-small-en-off-topic" |
|
config_path = hf_hub_download(repo_id=repo_path, filename="config.json") |
|
config_path = "config.json" |
|
|
|
with open(config_path, 'r') as f: |
|
config = json.load(f) |
|
|
|
def predict(sentence1, sentence2): |
|
""" |
|
Predicts the label for a pair of sentences using a fine-tuned model with SafeTensors weights. |
|
|
|
Args: |
|
- sentence1 (str): The first input sentence. |
|
- sentence2 (str): The second input sentence. |
|
|
|
Returns: |
|
tuple: |
|
- predicted_label (int): The predicted label (e.g., 0 or 1). |
|
- probabilities (numpy.ndarray): The probabilities for each class. |
|
""" |
|
|
|
model_name = config['classifier']['embedding']['model_name'] |
|
max_length = config['classifier']['embedding']['max_length'] |
|
model_weights_fp = config['classifier']['embedding']['model_weights_fp'] |
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else "cpu" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
base_model = AutoModel.from_pretrained(model_name) |
|
model = CrossEncoderWithSharedBase(base_model, num_labels=2) |
|
|
|
|
|
weights = load_file(model_weights_fp) |
|
model.load_state_dict(weights) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length) |
|
inputs2 = tokenizer(sentence2, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length) |
|
input_ids1 = inputs1['input_ids'].to(device) |
|
attention_mask1 = inputs1['attention_mask'].to(device) |
|
input_ids2 = inputs2['input_ids'].to(device) |
|
attention_mask2 = inputs2['attention_mask'].to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids1=input_ids1, attention_mask1=attention_mask1, |
|
input_ids2=input_ids2, attention_mask2=attention_mask2) |
|
probabilities = torch.softmax(outputs, dim=1) |
|
predicted_label = torch.argmax(probabilities, dim=1).item() |
|
|
|
return predicted_label, probabilities.cpu().numpy() |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
input_data = sys.argv[1] |
|
sentence_pairs = json.loads(input_data) |
|
|
|
|
|
if not all(isinstance(pair[0], str) and isinstance(pair[1], str) for pair in sentence_pairs): |
|
raise ValueError("Each pair must contain two strings.") |
|
|
|
for idx, (sentence1, sentence2) in enumerate(sentence_pairs): |
|
|
|
|
|
predicted_label, probabilities = predict(sentence1, sentence2) |
|
|
|
|
|
print(f"Pair {idx + 1}:") |
|
print(f" Sentence 1: {sentence1}") |
|
print(f" Sentence 2: {sentence2}") |
|
print(f" Predicted Label: {predicted_label}") |
|
print(f" Probabilities: {probabilities}") |
|
print('-' * 50) |
|
|