Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
# Model Card for Model ID
|
2 |
|
3 |
Intent classification is the act of classifying customer's in to different pre defined categories.
|
@@ -30,41 +34,41 @@ This is the model card of a 🤗 transformers model that has been pushed on the
|
|
30 |
|
31 |
## How to Get Started with the Model
|
32 |
|
33 |
-
class IntentClassifier:
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
|
39 |
|
40 |
-
def build_prompt(text, prompt="", company_name="", company_specific=""):
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
45 |
|
46 |
-
return f"Company name: {company_name} is doing: {company_specific}\nCustomer: {text}.\nEND MESSAGE\nChoose one topic that matches customer's issue.\n{prompt}\nClass name: "
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
def predict(self, text, prompt_options, company_name, company_portion) -> str:
|
50 |
-
input_text = build_prompt(text, prompt_options, company_name, company_portion)
|
51 |
-
# print(input_text)
|
52 |
-
# Tokenize the concatenated inp_ut text
|
53 |
-
input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
60 |
-
|
61 |
-
return decoded_output
|
62 |
-
|
63 |
-
|
64 |
-
m = IntentClassifier("serj/intent-classifier")
|
65 |
-
print(m.predict("Hey, after recent changes, I want to cancel subscription, please help.",
|
66 |
-
"OPTIONS:\n refund\n cancel subscription\n damaged item\n return item\n", "Company",
|
67 |
-
"Products and subscriptions"))
|
68 |
|
69 |
|
70 |
[More Information Needed]
|
@@ -122,5 +126,4 @@ F1 AVG on the train set is 0.69
|
|
122 |
|
123 |
#### Hardware
|
124 |
|
125 |
-
Nvidia RTX3060 12Gb
|
126 |
-
|
|
|
1 |
+
---
|
2 |
+
tags:
|
3 |
+
- intent, topic-discovery
|
4 |
+
---
|
5 |
# Model Card for Model ID
|
6 |
|
7 |
Intent classification is the act of classifying customer's in to different pre defined categories.
|
|
|
34 |
|
35 |
## How to Get Started with the Model
|
36 |
|
37 |
+
class IntentClassifier:
|
38 |
+
def __init__(self, model_name="serj/intent-classifier", device="cuda"):
|
39 |
+
self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
|
40 |
+
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
|
41 |
+
self.device = device
|
42 |
|
43 |
|
44 |
+
def build_prompt(text, prompt="", company_name="", company_specific=""):
|
45 |
+
if company_name == "Pizza Mia":
|
46 |
+
company_specific = "This company is a pizzeria place."
|
47 |
+
if company_name == "Online Banking":
|
48 |
+
company_specific = "This company is an online banking."
|
49 |
+
|
50 |
+
return f"Company name: {company_name} is doing: {company_specific}\nCustomer: {text}.\nEND MESSAGE\nChoose one topic that matches customer's issue.\n{prompt}\nClass name: "
|
51 |
|
|
|
52 |
|
53 |
+
def predict(self, text, prompt_options, company_name, company_portion) -> str:
|
54 |
+
input_text = build_prompt(text, prompt_options, company_name, company_portion)
|
55 |
+
# print(input_text)
|
56 |
+
# Tokenize the concatenated inp_ut text
|
57 |
+
input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)
|
58 |
+
|
59 |
+
# Generate the output
|
60 |
+
output = self.model.generate(input_ids)
|
61 |
+
|
62 |
+
# Decode the output tokens
|
63 |
+
decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
64 |
+
|
65 |
+
return decoded_output
|
66 |
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
m = IntentClassifier("serj/intent-classifier")
|
69 |
+
print(m.predict("Hey, after recent changes, I want to cancel subscription, please help.",
|
70 |
+
"OPTIONS:\n refund\n cancel subscription\n damaged item\n return item\n", "Company",
|
71 |
+
"Products and subscriptions"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
|
74 |
[More Information Needed]
|
|
|
126 |
|
127 |
#### Hardware
|
128 |
|
129 |
+
Nvidia RTX3060 12Gb
|
|