ONNX
English
Shing Yee commited on
Commit
1d52dd2
·
unverified ·
1 Parent(s): 54d897a

feat: add files

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv/
2
+ .DS_store
README.md CHANGED
@@ -2,4 +2,51 @@
2
  license: other
3
  license_name: govtech-singapore
4
  license_link: LICENSE
5
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: other
3
  license_name: govtech-singapore
4
  license_link: LICENSE
5
+ ---
6
+
7
+ # Off-Topic Classification Model
8
+ # Off-Topic Classification Model
9
+
10
+ This model leverages a fine-tuned **Cross Encoder STSB Roberta Base** to perform binary classification, determining whether a user prompt is off-topic in relation to the system's intended purpose as defined by the system prompt.
11
+
12
+ ## Model Highlights
13
+
14
+ - **Base Model**: [`stsb-roberta-base`](https://huggingface.co/sentence-transformers/stsb-roberta-base)
15
+ - **Maximum Context Length**: 1024 tokens
16
+ - **Task**: Binary classification (on-topic/off-topic)
17
+
18
+ ## Performance
19
+
20
+ | Approach | Model | ROC-AUC | F1 | Precision | Recall |
21
+ |---------------------------------------|--------------------------------|---------|------|-----------|--------|
22
+ | Fine-tuned cross-encoder classifier | stsb-roberta-base | 0.99 | 0.99 | 0.99 | 0.99 |
23
+ | Pre-trained cross-encoder | stsb-roberta-base | 0.73 | 0.68 | 0.53 | 0.93 |
24
+
25
+ ## Usage
26
+ 1. Clone this repository and install the required dependencies:
27
+
28
+ ```bash
29
+ pip install -r requirements.txt
30
+ ```
31
+
32
+ 2. You can run the model using two options:
33
+
34
+ **Option 1**: Using `inference_onnx.py` with the ONNX Model.
35
+
36
+ ```
37
+ python inference_onnx.py '[
38
+ ["System prompt example 1", "User prompt example 1"],
39
+ ["System prompt example 2", "System prompt example 2]
40
+ ]'
41
+ ```
42
+
43
+ **Option 2**: Using `inference_safetensors.py` with PyTorch and SafeTensors.
44
+
45
+ ```
46
+ python inference_safetensors.py '[
47
+ ["System prompt example 1", "User prompt example 1"],
48
+ ["System prompt example 2", "System prompt example 2]
49
+ ]'
50
+ ```
51
+
52
+ Read more about this model in our [technical report](https://arxiv.org/abs/2411.12946).
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "description": "Off-Topic classifier designed to block user prompts that do not align with the intended purpose of the system, as determined by the system prompt.",
3
+ "classifier": {
4
+ "embedding": {
5
+ "model_name": "cross-encoder/stsb-roberta-base",
6
+ "max_length": 512,
7
+ "model_weights_fp": "models/off-topic-cross-encoder-stsb-roberta-base-CrossEncoder.safetensors",
8
+ "model_fp": "models/off-topic-cross-encoder-stsb-roberta-base-CrossEncoder.onnx"
9
+ }
10
+ }
11
+ }
inference_onnx.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference_onnx.py
3
+
4
+ This script leverages ONNX runtime to perform inference with a pre-trained model.
5
+ """
6
+ import json
7
+ import torch
8
+ import sys
9
+ import numpy as np
10
+ import onnxruntime as rt
11
+
12
+ from huggingface_hub import hf_hub_download
13
+ from transformers import AutoTokenizer
14
+
15
+ repo_path = "govtech/stsb-roberta-base-off-topic"
16
+ config_path = hf_hub_download(repo_id=repo_path, filename="config.json")
17
+
18
+ config_path = "config.json"
19
+
20
+ with open(config_path, 'r') as f:
21
+ config = json.load(f)
22
+
23
+ def predict(sentence1, sentence2):
24
+
25
+ # Configuration
26
+ model_name = config['classifier']['embedding']['model_name']
27
+ max_length = config['classifier']['embedding']['max_length']
28
+ model_fp = config['classifier']['embedding']['model_fp']
29
+
30
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
31
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
32
+
33
+ # Get inputs
34
+ encoding = tokenizer(
35
+ sentence1, sentence2, # Takes in a two sentences as a pair
36
+ return_tensors="pt",
37
+ truncation=True,
38
+ padding="max_length",
39
+ max_length=max_length,
40
+ return_token_type_ids=False
41
+ )
42
+ input_ids = encoding["input_ids"].to(device)
43
+ attention_mask = encoding["attention_mask"].to(device)
44
+
45
+ # Download the classifier from HuggingFace hub
46
+ local_model_fp = model_fp
47
+ local_model_fp = hf_hub_download(repo_id=repo_path, filename=model_fp)
48
+
49
+ # Run inference
50
+ session = rt.InferenceSession(local_model_fp) # Load the ONNX model
51
+ onnx_inputs = {
52
+ session.get_inputs()[0].name: input_ids.cpu().numpy(),
53
+ session.get_inputs()[1].name: attention_mask.cpu().numpy()
54
+ }
55
+ outputs = session.run(None, onnx_inputs)
56
+
57
+ probabilities = torch.softmax(torch.tensor(outputs[0]), dim=1)
58
+ predicted_label = torch.argmax(probabilities, dim=1).item()
59
+
60
+ return predicted_label, probabilities.cpu().numpy()
61
+
62
+ if __name__ == "__main__":
63
+ # Load data
64
+ input_data = sys.argv[1]
65
+ sentence_pairs = json.loads(input_data)
66
+
67
+ # Validate input data format
68
+ if not all(isinstance(pair[0], str) and isinstance(pair[1], str) for pair in sentence_pairs):
69
+ raise ValueError("Each pair must contain two strings.")
70
+
71
+ for idx, (sentence1, sentence2) in enumerate(sentence_pairs):
72
+
73
+ # Generate prediction and scores
74
+ predicted_label, probabilities = predict(sentence1, sentence2)
75
+
76
+ # Print the results
77
+ print(f"Pair {idx + 1}:")
78
+ print(f" Sentence 1: {sentence1}")
79
+ print(f" Sentence 2: {sentence2}")
80
+ print(f" Predicted Label: {predicted_label}")
81
+ print(f" Probabilities: {probabilities}")
82
+ print('-' * 50)
inference_safetensors.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference_safetensors.py
3
+
4
+ Defines the architecture of the fine-tuned embedding model used for Off-Topic classification.
5
+ """
6
+ import json
7
+ import torch
8
+ import sys
9
+ import torch.nn as nn
10
+
11
+ from huggingface_hub import hf_hub_download
12
+ from safetensors.torch import load_file
13
+ from transformers import AutoTokenizer, AutoModel
14
+
15
+ class CrossEncoderWithMLP(nn.Module):
16
+ def __init__(self, base_model, num_labels=2):
17
+ super(CrossEncoderWithMLP, self).__init__()
18
+
19
+ # Existing cross-encoder model
20
+ self.base_model = base_model
21
+ # Hidden size of the base model
22
+ hidden_size = base_model.config.hidden_size
23
+ # MLP layers after combining the cross-encoders
24
+ self.mlp = nn.Sequential(
25
+ nn.Linear(hidden_size, hidden_size // 2), # Input: a single sentence
26
+ nn.ReLU(),
27
+ nn.Linear(hidden_size // 2, hidden_size // 4), # Reduce the size of the layer
28
+ nn.ReLU()
29
+ )
30
+ # Classifier head
31
+ self.classifier = nn.Linear(hidden_size // 4, num_labels)
32
+
33
+ def forward(self, input_ids, attention_mask):
34
+ # Encode the pair of sentences in one pass
35
+ outputs = self.base_model(input_ids, attention_mask)
36
+ pooled_output = outputs.pooler_output
37
+ # Pass the pooled output through mlp layers
38
+ mlp_output = self.mlp(pooled_output)
39
+ # Pass the final MLP output through the classifier
40
+ logits = self.classifier(mlp_output)
41
+ return logits
42
+
43
+ # Load configuration file
44
+ repo_path = "govtech/jina-embeddings-v2-small-en-off-topic"
45
+ #config_path = hf_hub_download(repo_id=repo_path, filename="config.json")
46
+ config_path = "config.json"
47
+
48
+ with open(config_path, 'r') as f:
49
+ config = json.load(f)
50
+
51
+ def predict(sentence1, sentence2):
52
+ """
53
+ Predicts the label for a pair of sentences using a fine-tuned model with SafeTensors weights.
54
+
55
+ Args:
56
+ - sentence1 (str): The first input sentence.
57
+ - sentence2 (str): The second input sentence.
58
+
59
+ Returns:
60
+ tuple:
61
+ - predicted_label (int): The predicted label (e.g., 0 or 1).
62
+ - probabilities (numpy.ndarray): The probabilities for each class.
63
+ """
64
+ # Load model configuration
65
+ model_name = config['classifier']['embedding']['model_name']
66
+ max_length = config['classifier']['embedding']['max_length']
67
+ model_weights_fp = config['classifier']['embedding']['model_weights_fp']
68
+
69
+ # Load tokenizer and base model
70
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
71
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
72
+ base_model = AutoModel.from_pretrained(model_name)
73
+ model = CrossEncoderWithMLP(base_model, num_labels=2)
74
+
75
+ # Load weights into the model
76
+ weights = load_file(model_weights_fp)
77
+ model.load_state_dict(weights)
78
+ model.to(device)
79
+ model.eval()
80
+
81
+ # Get inputs
82
+ encoding = tokenizer(
83
+ sentence1, sentence2, # Takes in a two sentences as a pair
84
+ return_tensors="pt",
85
+ truncation=True,
86
+ padding="max_length",
87
+ max_length=max_length,
88
+ return_token_type_ids=False
89
+ )
90
+ input_ids = encoding["input_ids"].to(device)
91
+ attention_mask = encoding["attention_mask"].to(device)
92
+
93
+ # Get outputs
94
+ with torch.no_grad():
95
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
96
+ probabilities = torch.softmax(outputs, dim=1)
97
+ predicted_label = torch.argmax(probabilities, dim=1).item()
98
+
99
+ return predicted_label, probabilities.cpu().numpy()
100
+
101
+ if __name__ == "__main__":
102
+ # Load data
103
+ input_data = sys.argv[1]
104
+ sentence_pairs = json.loads(input_data)
105
+
106
+ # Validate input data format
107
+ if not all(isinstance(pair[0], str) and isinstance(pair[1], str) for pair in sentence_pairs):
108
+ raise ValueError("Each pair must contain two strings.")
109
+
110
+ for idx, (sentence1, sentence2) in enumerate(sentence_pairs):
111
+
112
+ # Generate prediction and scores
113
+ predicted_label, probabilities = predict(sentence1, sentence2)
114
+
115
+ # Print the results
116
+ print(f"Pair {idx + 1}:")
117
+ print(f" Sentence 1: {sentence1}")
118
+ print(f" Sentence 2: {sentence2}")
119
+ print(f" Predicted Label: {predicted_label}")
120
+ print(f" Probabilities: {probabilities}")
121
+ print('-' * 50)
models/off-topic-cross-encoder-stsb-roberta-base-CrossEncoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c311c4a80aae3477b3688d52f3c5dfc9e2e761242f7884d2e713f028e3aa21c
3
+ size 500394446
govtech-stsb-roberta-base-off-topic → models/off-topic-cross-encoder-stsb-roberta-base-CrossEncoder.safetensors RENAMED
File without changes
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ huggingface_hub==0.26.2
2
+ numpy==2.1.3
3
+ onnxruntime==1.20.0
4
+ safetensors==0.4.5
5
+ torch==2.5.1
6
+ transformers==4.46.3