File size: 9,170 Bytes
70c7861
3c2639a
 
 
 
70c7861
3c2639a
70c7861
3c2639a
 
 
 
 
 
 
 
 
 
70c7861
3c2639a
70c7861
3c2639a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70c7861
3c2639a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70c7861
 
 
 
 
 
 
 
 
 
 
3c2639a
 
 
 
70c7861
 
 
 
 
3c2639a
70c7861
3c2639a
70c7861
3c2639a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70c7861
3c2639a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70c7861
 
 
 
 
 
 
 
 
 
3c2639a
 
 
70c7861
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import json
import torch
from torch import nn
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the model state_dict from safetensors
def load_model_safetensors(model, load_path="model.safetensors"):
    # Load the safetensors file
    state_dict = load_file(load_path)
    # Load the state dict into the model
    model.load_state_dict(state_dict)
    return model

###################
# JINA EMBEDDINGS
###################

# Jina Configs
JINA_CONTEXT_LEN = 1024

# Adapter for embeddings
class Adapter(nn.Module):
    def __init__(self, hidden_size):
        super(Adapter, self).__init__()
        self.down_project = nn.Linear(hidden_size, hidden_size // 2)
        self.activation = nn.ReLU()
        self.up_project = nn.Linear(hidden_size // 2, hidden_size)

    def forward(self, x):
        down = self.down_project(x)
        activated = self.activation(down)
        up = self.up_project(activated)
        return up + x  # Residual connection

# Pool by attention score
class AttentionPooling(nn.Module):
    def __init__(self, hidden_size):
        super(AttentionPooling, self).__init__()
        self.attention_weights = nn.Parameter(torch.randn(hidden_size))

    def forward(self, hidden_states):
        # hidden_states: [seq_len, batch_size, hidden_size]
        scores = torch.matmul(hidden_states, self.attention_weights)
        attention_weights = torch.softmax(scores, dim=0)
        weighted_sum = torch.sum(attention_weights.unsqueeze(-1) * hidden_states, dim=0)
        return weighted_sum

# Custom bi-encoder model with MLP layers for interaction
class CrossEncoderWithSharedBase(nn.Module):
    def __init__(self, base_model, num_labels=2, num_heads=8):
        super(CrossEncoderWithSharedBase, self).__init__()
        # Shared pre-trained model
        self.shared_encoder = base_model
        hidden_size = self.shared_encoder.config.hidden_size
        # Sentence-specific adapters
        self.adapter1 = Adapter(hidden_size)
        self.adapter2 = Adapter(hidden_size)
        # Cross-attention layers
        self.cross_attention_1_to_2 = nn.MultiheadAttention(hidden_size, num_heads)
        self.cross_attention_2_to_1 = nn.MultiheadAttention(hidden_size, num_heads)
        # Attention pooling layers
        self.attn_pooling_1_to_2 = AttentionPooling(hidden_size)
        self.attn_pooling_2_to_1 = AttentionPooling(hidden_size)
        # Projection layer with non-linearity
        self.projection_layer = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU()
        )
        # Classifier with three hidden layers
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 4, num_labels)
        )
    def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):
        # Encode sentences
        outputs1 = self.shared_encoder(input_ids1, attention_mask=attention_mask1)
        outputs2 = self.shared_encoder(input_ids2, attention_mask=attention_mask2)
        # Apply sentence-specific adapters
        embeds1 = self.adapter1(outputs1.last_hidden_state)
        embeds2 = self.adapter2(outputs2.last_hidden_state)
        # Transpose for attention layers
        embeds1 = embeds1.transpose(0, 1)
        embeds2 = embeds2.transpose(0, 1)
        # Cross-attention
        cross_attn_1_to_2, _ = self.cross_attention_1_to_2(embeds1, embeds2, embeds2)
        cross_attn_2_to_1, _ = self.cross_attention_2_to_1(embeds2, embeds1, embeds1)
        # Attention pooling
        pooled_1_to_2 = self.attn_pooling_1_to_2(cross_attn_1_to_2)
        pooled_2_to_1 = self.attn_pooling_2_to_1(cross_attn_2_to_1)
        # Concatenate and project
        combined = torch.cat((pooled_1_to_2, pooled_2_to_1), dim=1)
        projected = self.projection_layer(combined)
        # Classification
        logits = self.classifier(projected)
        return logits

# Prediction function for embeddings relevance
def embeddings_predict_relevance(sentence1, sentence2, model, tokenizer, device):
    model.eval()
    inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
    inputs2 = tokenizer(sentence2, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
    input_ids1 = inputs1['input_ids'].to(device)
    attention_mask1 = inputs1['attention_mask'].to(device)
    input_ids2 = inputs2['input_ids'].to(device)
    attention_mask2 = inputs2['attention_mask'].to(device)
    with torch.no_grad():
        outputs = model(input_ids1=input_ids1, attention_mask1=attention_mask1,
                        input_ids2=input_ids2, attention_mask2=attention_mask2)
        probabilities = torch.softmax(outputs, dim=1)
        predicted_label = torch.argmax(probabilities, dim=1).item()
    return predicted_label, probabilities.cpu().numpy()

# Load configuration file
jina_repo_path = "govtech/jina-embeddings-v2-small-en-off-topic"
jina_config_path = hf_hub_download(repo_id=jina_repo_path, filename="config.json")
with open(jina_config_path, 'r') as f:
    jina_config = json.load(f)

# Load Jina model configuration
JINA_MODEL_NAME = jina_config['classifier']['embedding']['model_name']
jina_model_weights_fp = jina_config['classifier']['embedding']['model_weights_fp']

# Load tokenizer and model
jina_tokenizer = AutoTokenizer.from_pretrained(JINA_MODEL_NAME)
jina_base_model = AutoModel.from_pretrained(JINA_MODEL_NAME)
jina_model = CrossEncoderWithSharedBase(jina_base_model, num_labels=2)

# Load model weights from safetensors
jina_model_weights_path = hf_hub_download(repo_id=jina_repo_path, filename=jina_model_weights_fp)
jina_model = load_model_safetensors(jina_model, jina_model_weights_path)

#################
# CROSS-ENCODER
#################

# STSB Configuration
STSB_CONTEXT_LEN = 512

class CrossEncoderWithMLP(nn.Module):
    def __init__(self, base_model, num_labels=2):
        super(CrossEncoderWithMLP, self).__init__()

        # Existing cross-encoder model
        self.base_model = base_model
        # Hidden size of the base model
        hidden_size = base_model.config.hidden_size
        # MLP layers after combining the cross-encoders
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),  # Input: a single sentence
            nn.ReLU(),
            nn.Linear(hidden_size // 2, hidden_size // 4),  # Reduce the size of the layer
            nn.ReLU()
        )
        # Classifier head
        self.classifier = nn.Linear(hidden_size // 4, num_labels)

    def forward(self, input_ids, attention_mask):
        # Encode the pair of sentences in one pass
        outputs = self.base_model(input_ids, attention_mask)
        pooled_output = outputs.pooler_output
        # Pass the pooled output through mlp layers
        mlp_output = self.mlp(pooled_output)
        # Pass the final MLP output through the classifier
        logits = self.classifier(mlp_output)
        return logits

# Prediction function for cross-encoder
def cross_encoder_predict_relevance(sentence1, sentence2, model, tokenizer, device):
    model.eval()
    # Tokenize the pair of sentences
    encoding = tokenizer(
        sentence1, sentence2,  # Takes in a two sentences as a pair
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=512,
        return_token_type_ids=False
    )
    # Extract the input_ids and attention mask
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )  # Returns logits
        # Convert raw logits into probabilities for each class and get the predicted label
        probabilities = torch.softmax(outputs, dim=1)
        predicted_label = torch.argmax(probabilities, dim=1).item()
    return predicted_label, probabilities.cpu().numpy()

# Load STSB model configuration
stsb_repo_path = "govtech/stsb-roberta-base-off-topic"
stsb_config_path = hf_hub_download(repo_id=stsb_repo_path, filename="config.json")
with open(stsb_config_path, 'r') as f:
    stsb_config = json.load(f)

STSB_MODEL_NAME = stsb_config['classifier']['embedding']['model_name']
stsb_model_weights_fp = stsb_config['classifier']['embedding']['model_weights_fp']

# Load STSB tokenizer and model
stsb_tokenizer = AutoTokenizer.from_pretrained(STSB_MODEL_NAME)
stsb_base_model = AutoModel.from_pretrained(STSB_MODEL_NAME)
stsb_model = CrossEncoderWithMLP(stsb_base_model, num_labels=2)

# Load model weights from safetensors for STSB
stsb_model_weights_path = hf_hub_download(repo_id=stsb_repo_path, filename=stsb_model_weights_fp)
stsb_model = load_model_safetensors(stsb_model, stsb_model_weights_path)