ONNX
English
File size: 6,927 Bytes
ba6803f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
inference_safetensors.py

Defines the architecture of the fine-tuned embedding model used for Off-Topic classification.
"""
import json
import torch
import sys
import torch.nn as nn

from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoTokenizer, AutoModel

# Adapter for embeddings
class Adapter(nn.Module):
    def __init__(self, hidden_size):
        super(Adapter, self).__init__()
        self.down_project = nn.Linear(hidden_size, hidden_size // 2)
        self.activation = nn.ReLU()
        self.up_project = nn.Linear(hidden_size // 2, hidden_size)

    def forward(self, x):
        down = self.down_project(x)
        activated = self.activation(down)
        up = self.up_project(activated)
        return up + x  # Residual connection

# Pool by attention score
class AttentionPooling(nn.Module):
    def __init__(self, hidden_size):
        super(AttentionPooling, self).__init__()
        self.attention_weights = nn.Parameter(torch.randn(hidden_size))

    def forward(self, hidden_states):
        # hidden_states: [seq_len, batch_size, hidden_size]
        scores = torch.matmul(hidden_states, self.attention_weights)
        attention_weights = torch.softmax(scores, dim=0)
        weighted_sum = torch.sum(attention_weights.unsqueeze(-1) * hidden_states, dim=0)
        return weighted_sum

# Custom bi-encoder model with MLP layers for interaction
class CrossEncoderWithSharedBase(nn.Module):
    def __init__(self, base_model, num_labels=2, num_heads=8):
        super(CrossEncoderWithSharedBase, self).__init__()
        # Shared pre-trained model
        self.shared_encoder = base_model
        hidden_size = self.shared_encoder.config.hidden_size
        # Sentence-specific adapters
        self.adapter1 = Adapter(hidden_size)
        self.adapter2 = Adapter(hidden_size)
        # Cross-attention layers
        self.cross_attention_1_to_2 = nn.MultiheadAttention(hidden_size, num_heads)
        self.cross_attention_2_to_1 = nn.MultiheadAttention(hidden_size, num_heads)
        # Attention pooling layers
        self.attn_pooling_1_to_2 = AttentionPooling(hidden_size)
        self.attn_pooling_2_to_1 = AttentionPooling(hidden_size)
        # Projection layer with non-linearity
        self.projection_layer = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU()
        )
        # Classifier with three hidden layers
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 4, num_labels)
        )
    def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):
        # Encode sentences
        outputs1 = self.shared_encoder(input_ids1, attention_mask=attention_mask1)
        outputs2 = self.shared_encoder(input_ids2, attention_mask=attention_mask2)
        # Apply sentence-specific adapters
        embeds1 = self.adapter1(outputs1.last_hidden_state)
        embeds2 = self.adapter2(outputs2.last_hidden_state)
        # Transpose for attention layers
        embeds1 = embeds1.transpose(0, 1)
        embeds2 = embeds2.transpose(0, 1)
        # Cross-attention
        cross_attn_1_to_2, _ = self.cross_attention_1_to_2(embeds1, embeds2, embeds2)
        cross_attn_2_to_1, _ = self.cross_attention_2_to_1(embeds2, embeds1, embeds1)
        # Attention pooling
        pooled_1_to_2 = self.attn_pooling_1_to_2(cross_attn_1_to_2)
        pooled_2_to_1 = self.attn_pooling_2_to_1(cross_attn_2_to_1)
        # Concatenate and project
        combined = torch.cat((pooled_1_to_2, pooled_2_to_1), dim=1)
        projected = self.projection_layer(combined)
        # Classification
        logits = self.classifier(projected)
        return logits

# Load configuration file
repo_path = "govtech/jina-embeddings-v2-small-en-off-topic"
config_path = hf_hub_download(repo_id=repo_path, filename="config.json")
config_path = "config.json"

with open(config_path, 'r') as f:
    config = json.load(f)

def predict(sentence1, sentence2):
    """
    Predicts the label for a pair of sentences using a fine-tuned model with SafeTensors weights.

    Args:
    - sentence1 (str): The first input sentence.
    - sentence2 (str): The second input sentence.

    Returns:
    tuple:
        - predicted_label (int): The predicted label (e.g., 0 or 1).
        - probabilities (numpy.ndarray): The probabilities for each class.
    """
    # Load model configuration
    model_name = config['classifier']['embedding']['model_name']
    max_length = config['classifier']['embedding']['max_length']
    model_weights_fp = config['classifier']['embedding']['model_weights_fp']

    # Load tokenizer and base model
    device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    base_model = AutoModel.from_pretrained(model_name)
    model = CrossEncoderWithSharedBase(base_model, num_labels=2)

    # Load weights into the model
    weights = load_file(model_weights_fp)
    model.load_state_dict(weights)
    model.to(device)
    model.eval()

    # Get inputs
    inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length)
    inputs2 = tokenizer(sentence2, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length)
    input_ids1 = inputs1['input_ids'].to(device)
    attention_mask1 = inputs1['attention_mask'].to(device)
    input_ids2 = inputs2['input_ids'].to(device)
    attention_mask2 = inputs2['attention_mask'].to(device)

    # Get outputs
    with torch.no_grad():
        outputs = model(input_ids1=input_ids1, attention_mask1=attention_mask1,
                        input_ids2=input_ids2, attention_mask2=attention_mask2)
        probabilities = torch.softmax(outputs, dim=1)
        predicted_label = torch.argmax(probabilities, dim=1).item()

    return predicted_label, probabilities.cpu().numpy()

if __name__ == "__main__":

    # Load data
    input_data = sys.argv[1]
    sentence_pairs = json.loads(input_data)

    # Validate input data format
    if not all(isinstance(pair[0], str) and isinstance(pair[1], str) for pair in sentence_pairs):
        raise ValueError("Each pair must contain two strings.")

    for idx, (sentence1, sentence2) in enumerate(sentence_pairs):

        # Generate prediction and scores
        predicted_label, probabilities = predict(sentence1, sentence2)

        # Print the results
        print(f"Pair {idx + 1}:")
        print(f"  Sentence 1: {sentence1}")
        print(f"  Sentence 2: {sentence2}")
        print(f"  Predicted Label: {predicted_label}")
        print(f"  Probabilities: {probabilities}")
        print('-' * 50)