Spaces:
Running
Running
Shing Yee
commited on
Commit
β’
70c7861
1
Parent(s):
885748b
update models
Browse files- .DS_Store +0 -0
- app.py +0 -7
- models/cross-encoder-ms-marco-MiniLM-L-6-v2-CrossEncoder-OffTopic-Classifier-20240918-090615.safetensors +0 -3
- models/cross-encoder-stsb-roberta-base-CrossEncoder-OffTopic-Classifier-20240920-174009.safetensors +0 -3
- models/jinaai-jina-embeddings-v2-small-en-TwinEncoder-OffTopic-Classifier-20240915-151858.safetensors +0 -3
- utils.py +39 -22
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
app.py
CHANGED
@@ -7,15 +7,12 @@ from utils import (
|
|
7 |
embeddings_predict_relevance,
|
8 |
stsb_model,
|
9 |
stsb_tokenizer,
|
10 |
-
ms_model,
|
11 |
-
ms_tokenizer,
|
12 |
cross_encoder_predict_relevance
|
13 |
)
|
14 |
|
15 |
def predict(system_prompt, user_prompt):
|
16 |
predicted_label_jina, probabilities_jina = embeddings_predict_relevance(system_prompt, user_prompt, jina_model, jina_tokenizer, device)
|
17 |
predicted_label_stsb, probabilities_stsb = cross_encoder_predict_relevance(system_prompt, user_prompt, stsb_model, stsb_tokenizer, device)
|
18 |
-
predicted_label_ms, probabilities_ms = cross_encoder_predict_relevance(system_prompt, user_prompt, ms_model, ms_tokenizer, device)
|
19 |
|
20 |
result = f"""
|
21 |
**Prediction Summary**
|
@@ -27,10 +24,6 @@ def predict(system_prompt, user_prompt):
|
|
27 |
**2. Model: cross-encoder/stsb-roberta-base**
|
28 |
- **Prediction**: {"π₯ Off-topic" if predicted_label_stsb==1 else "π© On-topic"}
|
29 |
- **Probability of being off-topic**: {probabilities_stsb[0][1]:.2%}
|
30 |
-
|
31 |
-
**3. Model: cross-encoder/ms-marco-MiniLM-L-6-v2**
|
32 |
-
- **Prediction**: {"π₯ Off-topic" if predicted_label_ms==1 else "π© On-topic"}
|
33 |
-
- **Probability of being off-topic**: {probabilities_ms[0][1]:.2%}
|
34 |
"""
|
35 |
|
36 |
return result
|
|
|
7 |
embeddings_predict_relevance,
|
8 |
stsb_model,
|
9 |
stsb_tokenizer,
|
|
|
|
|
10 |
cross_encoder_predict_relevance
|
11 |
)
|
12 |
|
13 |
def predict(system_prompt, user_prompt):
|
14 |
predicted_label_jina, probabilities_jina = embeddings_predict_relevance(system_prompt, user_prompt, jina_model, jina_tokenizer, device)
|
15 |
predicted_label_stsb, probabilities_stsb = cross_encoder_predict_relevance(system_prompt, user_prompt, stsb_model, stsb_tokenizer, device)
|
|
|
16 |
|
17 |
result = f"""
|
18 |
**Prediction Summary**
|
|
|
24 |
**2. Model: cross-encoder/stsb-roberta-base**
|
25 |
- **Prediction**: {"π₯ Off-topic" if predicted_label_stsb==1 else "π© On-topic"}
|
26 |
- **Probability of being off-topic**: {probabilities_stsb[0][1]:.2%}
|
|
|
|
|
|
|
|
|
27 |
"""
|
28 |
|
29 |
return result
|
models/cross-encoder-ms-marco-MiniLM-L-6-v2-CrossEncoder-OffTopic-Classifier-20240918-090615.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:78a99fac3bc5b4729fee844d2154ea625aa9ceac2928cd648984ee1da5b8a203
|
3 |
-
size 91236352
|
|
|
|
|
|
|
|
models/cross-encoder-stsb-roberta-base-CrossEncoder-OffTopic-Classifier-20240920-174009.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:1e90752828e92bc2f8ec567b85b3de5a0c8c5ddc331c1907d4dfa950624f71ce
|
3 |
-
size 500085976
|
|
|
|
|
|
|
|
models/jinaai-jina-embeddings-v2-small-en-TwinEncoder-OffTopic-Classifier-20240915-151858.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:223687abc28cf0fa198d326d2786374000396d841e66d684c022941da2ca9628
|
3 |
-
size 144076480
|
|
|
|
|
|
|
|
utils.py
CHANGED
@@ -1,8 +1,11 @@
|
|
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
from safetensors.torch import load_file
|
4 |
from transformers import AutoModel, AutoTokenizer
|
|
|
5 |
|
|
|
6 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
7 |
|
8 |
# Load the model state_dict from safetensors
|
@@ -13,9 +16,9 @@ def load_model_safetensors(model, load_path="model.safetensors"):
|
|
13 |
model.load_state_dict(state_dict)
|
14 |
return model
|
15 |
|
16 |
-
|
17 |
# JINA EMBEDDINGS
|
18 |
-
|
19 |
|
20 |
# Jina Configs
|
21 |
JINA_CONTEXT_LEN = 1024
|
@@ -101,7 +104,7 @@ class CrossEncoderWithSharedBase(nn.Module):
|
|
101 |
logits = self.classifier(projected)
|
102 |
return logits
|
103 |
|
104 |
-
# Prediction function
|
105 |
def embeddings_predict_relevance(sentence1, sentence2, model, tokenizer, device):
|
106 |
model.eval()
|
107 |
inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
|
@@ -117,23 +120,32 @@ def embeddings_predict_relevance(sentence1, sentence2, model, tokenizer, device)
|
|
117 |
predicted_label = torch.argmax(probabilities, dim=1).item()
|
118 |
return predicted_label, probabilities.cpu().numpy()
|
119 |
|
120 |
-
#
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
jina_tokenizer = AutoTokenizer.from_pretrained(JINA_MODEL_NAME)
|
123 |
jina_base_model = AutoModel.from_pretrained(JINA_MODEL_NAME)
|
124 |
jina_model = CrossEncoderWithSharedBase(jina_base_model, num_labels=2)
|
125 |
-
jina_model = load_model_safetensors(jina_model, load_path="models/jinaai-jina-embeddings-v2-small-en-TwinEncoder-OffTopic-Classifier-20240915-151858.safetensors")
|
126 |
|
127 |
-
|
|
|
|
|
|
|
|
|
128 |
# CROSS-ENCODER
|
129 |
-
|
130 |
|
131 |
-
# STSB
|
132 |
STSB_CONTEXT_LEN = 512
|
133 |
|
134 |
-
# ms-macro Configs
|
135 |
-
MS_CONTEXT_LEN = 512
|
136 |
-
|
137 |
class CrossEncoderWithMLP(nn.Module):
|
138 |
def __init__(self, base_model, num_labels=2):
|
139 |
super(CrossEncoderWithMLP, self).__init__()
|
@@ -162,6 +174,7 @@ class CrossEncoderWithMLP(nn.Module):
|
|
162 |
logits = self.classifier(mlp_output)
|
163 |
return logits
|
164 |
|
|
|
165 |
def cross_encoder_predict_relevance(sentence1, sentence2, model, tokenizer, device):
|
166 |
model.eval()
|
167 |
# Tokenize the pair of sentences
|
@@ -187,16 +200,20 @@ def cross_encoder_predict_relevance(sentence1, sentence2, model, tokenizer, devi
|
|
187 |
predicted_label = torch.argmax(probabilities, dim=1).item()
|
188 |
return predicted_label, probabilities.cpu().numpy()
|
189 |
|
190 |
-
# STSB model
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
stsb_tokenizer = AutoTokenizer.from_pretrained(STSB_MODEL_NAME)
|
193 |
stsb_base_model = AutoModel.from_pretrained(STSB_MODEL_NAME)
|
194 |
stsb_model = CrossEncoderWithMLP(stsb_base_model, num_labels=2)
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
ms_tokenizer = AutoTokenizer.from_pretrained(MS_MODEL_NAME)
|
200 |
-
ms_base_model = AutoModel.from_pretrained(MS_MODEL_NAME)
|
201 |
-
ms_model = CrossEncoderWithMLP(ms_base_model, num_labels=2)
|
202 |
-
ms_model = load_model_safetensors(ms_model, load_path="models/cross-encoder-ms-marco-MiniLM-L-6-v2-CrossEncoder-OffTopic-Classifier-20240918-090615.safetensors")
|
|
|
1 |
+
import json
|
2 |
import torch
|
3 |
from torch import nn
|
4 |
from safetensors.torch import load_file
|
5 |
from transformers import AutoModel, AutoTokenizer
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
|
8 |
+
# Set device
|
9 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
|
11 |
# Load the model state_dict from safetensors
|
|
|
16 |
model.load_state_dict(state_dict)
|
17 |
return model
|
18 |
|
19 |
+
###################
|
20 |
# JINA EMBEDDINGS
|
21 |
+
###################
|
22 |
|
23 |
# Jina Configs
|
24 |
JINA_CONTEXT_LEN = 1024
|
|
|
104 |
logits = self.classifier(projected)
|
105 |
return logits
|
106 |
|
107 |
+
# Prediction function for embeddings relevance
|
108 |
def embeddings_predict_relevance(sentence1, sentence2, model, tokenizer, device):
|
109 |
model.eval()
|
110 |
inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
|
|
|
120 |
predicted_label = torch.argmax(probabilities, dim=1).item()
|
121 |
return predicted_label, probabilities.cpu().numpy()
|
122 |
|
123 |
+
# Load configuration file
|
124 |
+
jina_repo_path = "govtech/jina-embeddings-v2-small-en-off-topic"
|
125 |
+
jina_config_path = hf_hub_download(repo_id=jina_repo_path, filename="config.json")
|
126 |
+
with open(jina_config_path, 'r') as f:
|
127 |
+
jina_config = json.load(f)
|
128 |
+
|
129 |
+
# Load Jina model configuration
|
130 |
+
JINA_MODEL_NAME = jina_config['classifier']['embedding']['model_name']
|
131 |
+
jina_model_weights_fp = jina_config['classifier']['embedding']['model_weights_fp']
|
132 |
+
|
133 |
+
# Load tokenizer and model
|
134 |
jina_tokenizer = AutoTokenizer.from_pretrained(JINA_MODEL_NAME)
|
135 |
jina_base_model = AutoModel.from_pretrained(JINA_MODEL_NAME)
|
136 |
jina_model = CrossEncoderWithSharedBase(jina_base_model, num_labels=2)
|
|
|
137 |
|
138 |
+
# Load model weights from safetensors
|
139 |
+
jina_model_weights_path = hf_hub_download(repo_id=jina_repo_path, filename=jina_model_weights_fp)
|
140 |
+
jina_model = load_model_safetensors(jina_model, jina_model_weights_path)
|
141 |
+
|
142 |
+
#################
|
143 |
# CROSS-ENCODER
|
144 |
+
#################
|
145 |
|
146 |
+
# STSB Configuration
|
147 |
STSB_CONTEXT_LEN = 512
|
148 |
|
|
|
|
|
|
|
149 |
class CrossEncoderWithMLP(nn.Module):
|
150 |
def __init__(self, base_model, num_labels=2):
|
151 |
super(CrossEncoderWithMLP, self).__init__()
|
|
|
174 |
logits = self.classifier(mlp_output)
|
175 |
return logits
|
176 |
|
177 |
+
# Prediction function for cross-encoder
|
178 |
def cross_encoder_predict_relevance(sentence1, sentence2, model, tokenizer, device):
|
179 |
model.eval()
|
180 |
# Tokenize the pair of sentences
|
|
|
200 |
predicted_label = torch.argmax(probabilities, dim=1).item()
|
201 |
return predicted_label, probabilities.cpu().numpy()
|
202 |
|
203 |
+
# Load STSB model configuration
|
204 |
+
stsb_repo_path = "govtech/stsb-roberta-base-off-topic"
|
205 |
+
stsb_config_path = hf_hub_download(repo_id=stsb_repo_path, filename="config.json")
|
206 |
+
with open(stsb_config_path, 'r') as f:
|
207 |
+
stsb_config = json.load(f)
|
208 |
+
|
209 |
+
STSB_MODEL_NAME = stsb_config['classifier']['embedding']['model_name']
|
210 |
+
stsb_model_weights_fp = stsb_config['classifier']['embedding']['model_weights_fp']
|
211 |
+
|
212 |
+
# Load STSB tokenizer and model
|
213 |
stsb_tokenizer = AutoTokenizer.from_pretrained(STSB_MODEL_NAME)
|
214 |
stsb_base_model = AutoModel.from_pretrained(STSB_MODEL_NAME)
|
215 |
stsb_model = CrossEncoderWithMLP(stsb_base_model, num_labels=2)
|
216 |
+
|
217 |
+
# Load model weights from safetensors for STSB
|
218 |
+
stsb_model_weights_path = hf_hub_download(repo_id=stsb_repo_path, filename=stsb_model_weights_fp)
|
219 |
+
stsb_model = load_model_safetensors(stsb_model, stsb_model_weights_path)
|
|
|
|
|
|
|
|