Shing Yee
commited on
feat: add files
Browse files- .gitignore +2 -0
- README.md +48 -1
- config.json +11 -0
- inference_onnx.py +82 -0
- inference_safetensors.py +121 -0
- models/off-topic-cross-encoder-stsb-roberta-base-CrossEncoder.onnx +3 -0
- govtech-stsb-roberta-base-off-topic → models/off-topic-cross-encoder-stsb-roberta-base-CrossEncoder.safetensors +0 -0
- requirements.txt +6 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.venv/
|
2 |
+
.DS_store
|
README.md
CHANGED
@@ -2,4 +2,51 @@
|
|
2 |
license: other
|
3 |
license_name: govtech-singapore
|
4 |
license_link: LICENSE
|
5 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
license: other
|
3 |
license_name: govtech-singapore
|
4 |
license_link: LICENSE
|
5 |
+
---
|
6 |
+
|
7 |
+
# Off-Topic Classification Model
|
8 |
+
# Off-Topic Classification Model
|
9 |
+
|
10 |
+
This model leverages a fine-tuned **Cross Encoder STSB Roberta Base** to perform binary classification, determining whether a user prompt is off-topic in relation to the system's intended purpose as defined by the system prompt.
|
11 |
+
|
12 |
+
## Model Highlights
|
13 |
+
|
14 |
+
- **Base Model**: [`stsb-roberta-base`](https://huggingface.co/sentence-transformers/stsb-roberta-base)
|
15 |
+
- **Maximum Context Length**: 1024 tokens
|
16 |
+
- **Task**: Binary classification (on-topic/off-topic)
|
17 |
+
|
18 |
+
## Performance
|
19 |
+
|
20 |
+
| Approach | Model | ROC-AUC | F1 | Precision | Recall |
|
21 |
+
|---------------------------------------|--------------------------------|---------|------|-----------|--------|
|
22 |
+
| Fine-tuned cross-encoder classifier | stsb-roberta-base | 0.99 | 0.99 | 0.99 | 0.99 |
|
23 |
+
| Pre-trained cross-encoder | stsb-roberta-base | 0.73 | 0.68 | 0.53 | 0.93 |
|
24 |
+
|
25 |
+
## Usage
|
26 |
+
1. Clone this repository and install the required dependencies:
|
27 |
+
|
28 |
+
```bash
|
29 |
+
pip install -r requirements.txt
|
30 |
+
```
|
31 |
+
|
32 |
+
2. You can run the model using two options:
|
33 |
+
|
34 |
+
**Option 1**: Using `inference_onnx.py` with the ONNX Model.
|
35 |
+
|
36 |
+
```
|
37 |
+
python inference_onnx.py '[
|
38 |
+
["System prompt example 1", "User prompt example 1"],
|
39 |
+
["System prompt example 2", "System prompt example 2]
|
40 |
+
]'
|
41 |
+
```
|
42 |
+
|
43 |
+
**Option 2**: Using `inference_safetensors.py` with PyTorch and SafeTensors.
|
44 |
+
|
45 |
+
```
|
46 |
+
python inference_safetensors.py '[
|
47 |
+
["System prompt example 1", "User prompt example 1"],
|
48 |
+
["System prompt example 2", "System prompt example 2]
|
49 |
+
]'
|
50 |
+
```
|
51 |
+
|
52 |
+
Read more about this model in our [technical report](https://arxiv.org/abs/2411.12946).
|
config.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"description": "Off-Topic classifier designed to block user prompts that do not align with the intended purpose of the system, as determined by the system prompt.",
|
3 |
+
"classifier": {
|
4 |
+
"embedding": {
|
5 |
+
"model_name": "cross-encoder/stsb-roberta-base",
|
6 |
+
"max_length": 512,
|
7 |
+
"model_weights_fp": "models/off-topic-cross-encoder-stsb-roberta-base-CrossEncoder.safetensors",
|
8 |
+
"model_fp": "models/off-topic-cross-encoder-stsb-roberta-base-CrossEncoder.onnx"
|
9 |
+
}
|
10 |
+
}
|
11 |
+
}
|
inference_onnx.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
inference_onnx.py
|
3 |
+
|
4 |
+
This script leverages ONNX runtime to perform inference with a pre-trained model.
|
5 |
+
"""
|
6 |
+
import json
|
7 |
+
import torch
|
8 |
+
import sys
|
9 |
+
import numpy as np
|
10 |
+
import onnxruntime as rt
|
11 |
+
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
from transformers import AutoTokenizer
|
14 |
+
|
15 |
+
repo_path = "govtech/stsb-roberta-base-off-topic"
|
16 |
+
config_path = hf_hub_download(repo_id=repo_path, filename="config.json")
|
17 |
+
|
18 |
+
config_path = "config.json"
|
19 |
+
|
20 |
+
with open(config_path, 'r') as f:
|
21 |
+
config = json.load(f)
|
22 |
+
|
23 |
+
def predict(sentence1, sentence2):
|
24 |
+
|
25 |
+
# Configuration
|
26 |
+
model_name = config['classifier']['embedding']['model_name']
|
27 |
+
max_length = config['classifier']['embedding']['max_length']
|
28 |
+
model_fp = config['classifier']['embedding']['model_fp']
|
29 |
+
|
30 |
+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
32 |
+
|
33 |
+
# Get inputs
|
34 |
+
encoding = tokenizer(
|
35 |
+
sentence1, sentence2, # Takes in a two sentences as a pair
|
36 |
+
return_tensors="pt",
|
37 |
+
truncation=True,
|
38 |
+
padding="max_length",
|
39 |
+
max_length=max_length,
|
40 |
+
return_token_type_ids=False
|
41 |
+
)
|
42 |
+
input_ids = encoding["input_ids"].to(device)
|
43 |
+
attention_mask = encoding["attention_mask"].to(device)
|
44 |
+
|
45 |
+
# Download the classifier from HuggingFace hub
|
46 |
+
local_model_fp = model_fp
|
47 |
+
local_model_fp = hf_hub_download(repo_id=repo_path, filename=model_fp)
|
48 |
+
|
49 |
+
# Run inference
|
50 |
+
session = rt.InferenceSession(local_model_fp) # Load the ONNX model
|
51 |
+
onnx_inputs = {
|
52 |
+
session.get_inputs()[0].name: input_ids.cpu().numpy(),
|
53 |
+
session.get_inputs()[1].name: attention_mask.cpu().numpy()
|
54 |
+
}
|
55 |
+
outputs = session.run(None, onnx_inputs)
|
56 |
+
|
57 |
+
probabilities = torch.softmax(torch.tensor(outputs[0]), dim=1)
|
58 |
+
predicted_label = torch.argmax(probabilities, dim=1).item()
|
59 |
+
|
60 |
+
return predicted_label, probabilities.cpu().numpy()
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
# Load data
|
64 |
+
input_data = sys.argv[1]
|
65 |
+
sentence_pairs = json.loads(input_data)
|
66 |
+
|
67 |
+
# Validate input data format
|
68 |
+
if not all(isinstance(pair[0], str) and isinstance(pair[1], str) for pair in sentence_pairs):
|
69 |
+
raise ValueError("Each pair must contain two strings.")
|
70 |
+
|
71 |
+
for idx, (sentence1, sentence2) in enumerate(sentence_pairs):
|
72 |
+
|
73 |
+
# Generate prediction and scores
|
74 |
+
predicted_label, probabilities = predict(sentence1, sentence2)
|
75 |
+
|
76 |
+
# Print the results
|
77 |
+
print(f"Pair {idx + 1}:")
|
78 |
+
print(f" Sentence 1: {sentence1}")
|
79 |
+
print(f" Sentence 2: {sentence2}")
|
80 |
+
print(f" Predicted Label: {predicted_label}")
|
81 |
+
print(f" Probabilities: {probabilities}")
|
82 |
+
print('-' * 50)
|
inference_safetensors.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
inference_safetensors.py
|
3 |
+
|
4 |
+
Defines the architecture of the fine-tuned embedding model used for Off-Topic classification.
|
5 |
+
"""
|
6 |
+
import json
|
7 |
+
import torch
|
8 |
+
import sys
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
from safetensors.torch import load_file
|
13 |
+
from transformers import AutoTokenizer, AutoModel
|
14 |
+
|
15 |
+
class CrossEncoderWithMLP(nn.Module):
|
16 |
+
def __init__(self, base_model, num_labels=2):
|
17 |
+
super(CrossEncoderWithMLP, self).__init__()
|
18 |
+
|
19 |
+
# Existing cross-encoder model
|
20 |
+
self.base_model = base_model
|
21 |
+
# Hidden size of the base model
|
22 |
+
hidden_size = base_model.config.hidden_size
|
23 |
+
# MLP layers after combining the cross-encoders
|
24 |
+
self.mlp = nn.Sequential(
|
25 |
+
nn.Linear(hidden_size, hidden_size // 2), # Input: a single sentence
|
26 |
+
nn.ReLU(),
|
27 |
+
nn.Linear(hidden_size // 2, hidden_size // 4), # Reduce the size of the layer
|
28 |
+
nn.ReLU()
|
29 |
+
)
|
30 |
+
# Classifier head
|
31 |
+
self.classifier = nn.Linear(hidden_size // 4, num_labels)
|
32 |
+
|
33 |
+
def forward(self, input_ids, attention_mask):
|
34 |
+
# Encode the pair of sentences in one pass
|
35 |
+
outputs = self.base_model(input_ids, attention_mask)
|
36 |
+
pooled_output = outputs.pooler_output
|
37 |
+
# Pass the pooled output through mlp layers
|
38 |
+
mlp_output = self.mlp(pooled_output)
|
39 |
+
# Pass the final MLP output through the classifier
|
40 |
+
logits = self.classifier(mlp_output)
|
41 |
+
return logits
|
42 |
+
|
43 |
+
# Load configuration file
|
44 |
+
repo_path = "govtech/jina-embeddings-v2-small-en-off-topic"
|
45 |
+
#config_path = hf_hub_download(repo_id=repo_path, filename="config.json")
|
46 |
+
config_path = "config.json"
|
47 |
+
|
48 |
+
with open(config_path, 'r') as f:
|
49 |
+
config = json.load(f)
|
50 |
+
|
51 |
+
def predict(sentence1, sentence2):
|
52 |
+
"""
|
53 |
+
Predicts the label for a pair of sentences using a fine-tuned model with SafeTensors weights.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
- sentence1 (str): The first input sentence.
|
57 |
+
- sentence2 (str): The second input sentence.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
tuple:
|
61 |
+
- predicted_label (int): The predicted label (e.g., 0 or 1).
|
62 |
+
- probabilities (numpy.ndarray): The probabilities for each class.
|
63 |
+
"""
|
64 |
+
# Load model configuration
|
65 |
+
model_name = config['classifier']['embedding']['model_name']
|
66 |
+
max_length = config['classifier']['embedding']['max_length']
|
67 |
+
model_weights_fp = config['classifier']['embedding']['model_weights_fp']
|
68 |
+
|
69 |
+
# Load tokenizer and base model
|
70 |
+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
71 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
72 |
+
base_model = AutoModel.from_pretrained(model_name)
|
73 |
+
model = CrossEncoderWithMLP(base_model, num_labels=2)
|
74 |
+
|
75 |
+
# Load weights into the model
|
76 |
+
weights = load_file(model_weights_fp)
|
77 |
+
model.load_state_dict(weights)
|
78 |
+
model.to(device)
|
79 |
+
model.eval()
|
80 |
+
|
81 |
+
# Get inputs
|
82 |
+
encoding = tokenizer(
|
83 |
+
sentence1, sentence2, # Takes in a two sentences as a pair
|
84 |
+
return_tensors="pt",
|
85 |
+
truncation=True,
|
86 |
+
padding="max_length",
|
87 |
+
max_length=max_length,
|
88 |
+
return_token_type_ids=False
|
89 |
+
)
|
90 |
+
input_ids = encoding["input_ids"].to(device)
|
91 |
+
attention_mask = encoding["attention_mask"].to(device)
|
92 |
+
|
93 |
+
# Get outputs
|
94 |
+
with torch.no_grad():
|
95 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
96 |
+
probabilities = torch.softmax(outputs, dim=1)
|
97 |
+
predicted_label = torch.argmax(probabilities, dim=1).item()
|
98 |
+
|
99 |
+
return predicted_label, probabilities.cpu().numpy()
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
# Load data
|
103 |
+
input_data = sys.argv[1]
|
104 |
+
sentence_pairs = json.loads(input_data)
|
105 |
+
|
106 |
+
# Validate input data format
|
107 |
+
if not all(isinstance(pair[0], str) and isinstance(pair[1], str) for pair in sentence_pairs):
|
108 |
+
raise ValueError("Each pair must contain two strings.")
|
109 |
+
|
110 |
+
for idx, (sentence1, sentence2) in enumerate(sentence_pairs):
|
111 |
+
|
112 |
+
# Generate prediction and scores
|
113 |
+
predicted_label, probabilities = predict(sentence1, sentence2)
|
114 |
+
|
115 |
+
# Print the results
|
116 |
+
print(f"Pair {idx + 1}:")
|
117 |
+
print(f" Sentence 1: {sentence1}")
|
118 |
+
print(f" Sentence 2: {sentence2}")
|
119 |
+
print(f" Predicted Label: {predicted_label}")
|
120 |
+
print(f" Probabilities: {probabilities}")
|
121 |
+
print('-' * 50)
|
models/off-topic-cross-encoder-stsb-roberta-base-CrossEncoder.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4c311c4a80aae3477b3688d52f3c5dfc9e2e761242f7884d2e713f028e3aa21c
|
3 |
+
size 500394446
|
govtech-stsb-roberta-base-off-topic → models/off-topic-cross-encoder-stsb-roberta-base-CrossEncoder.safetensors
RENAMED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub==0.26.2
|
2 |
+
numpy==2.1.3
|
3 |
+
onnxruntime==1.20.0
|
4 |
+
safetensors==0.4.5
|
5 |
+
torch==2.5.1
|
6 |
+
transformers==4.46.3
|