cnmoro commited on
Commit
370cbd8
·
verified ·
1 Parent(s): 3075b79

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +54 -5
README.md CHANGED
@@ -1,9 +1,58 @@
1
  ---
2
- library_name: transformers.js
3
- base_model:
4
- - cnmoro/bert-tiny-question-classifier
 
 
 
 
 
 
 
 
 
 
 
 
5
  ---
6
 
7
- # bert-tiny-question-classifier (ONNX)
8
 
9
- This is an ONNX version of [cnmoro/bert-tiny-question-classifier](https://huggingface.co/cnmoro/bert-tiny-question-classifier). It was automatically converted and uploaded using [this space](https://huggingface.co/spaces/onnx-community/convert-to-onnx).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - cnmoro/QuestionClassification
5
+ tags:
6
+ - classification
7
+ - questioning
8
+ - directed
9
+ - generic
10
+ language:
11
+ - en
12
+ - pt
13
+ library_name: transformers
14
+ pipeline_tag: text-classification
15
+ widget:
16
+ - text: "What is the summary of the text?"
17
  ---
18
 
19
+ (This model has a v2, use it instead: https://huggingface.co/cnmoro/granite-question-classifier)
20
 
21
+ A finetuned version of prajjwal1/bert-tiny.
22
+
23
+ The goal is to classify questions into "Directed" or "Generic".
24
+
25
+ If a question is not directed, we would change the actions we perform on a RAG pipeline (if it is generic, semantic search wouldn't be useful directly; e.g. asking for a summary).
26
+
27
+ (Class 0 is Generic; Class 1 is Directed)
28
+
29
+ The accuracy on the training dataset is around 87.5%
30
+
31
+ ```python
32
+ from transformers import BertForSequenceClassification, BertTokenizerFast
33
+ import torch
34
+
35
+ # Load the model and tokenizer
36
+ model = BertForSequenceClassification.from_pretrained("cnmoro/bert-tiny-question-classifier")
37
+ tokenizer = BertTokenizerFast.from_pretrained("cnmoro/bert-tiny-question-classifier")
38
+
39
+ def is_question_generic(question):
40
+ # Tokenize the sentence and convert to PyTorch tensors
41
+ inputs = tokenizer(
42
+ question.lower(),
43
+ truncation=True,
44
+ padding=True,
45
+ return_tensors="pt",
46
+ max_length=512
47
+ )
48
+
49
+ # Get the model's predictions
50
+ with torch.no_grad():
51
+ outputs = model(**inputs)
52
+
53
+ # Extract the prediction
54
+ predictions = outputs.logits
55
+ predicted_class = torch.argmax(predictions).item()
56
+
57
+ return int(predicted_class) == 0
58
+ ```