Serj commited on
Commit
9d3538c
1 Parent(s): bfb4729

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +34 -31
README.md CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  # Model Card for Model ID
2
 
3
  Intent classification is the act of classifying customer's in to different pre defined categories.
@@ -30,41 +34,41 @@ This is the model card of a 🤗 transformers model that has been pushed on the
30
 
31
  ## How to Get Started with the Model
32
 
33
- class IntentClassifier:
34
- def __init__(self, model_name="serj/intent-classifier", device="cuda"):
35
- self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
36
- self.tokenizer = T5Tokenizer.from_pretrained(model_name)
37
- self.device = device
38
 
39
 
40
- def build_prompt(text, prompt="", company_name="", company_specific=""):
41
- if company_name == "Pizza Mia":
42
- company_specific = "This company is a pizzeria place."
43
- if company_name == "Online Banking":
44
- company_specific = "This company is an online banking."
 
 
45
 
46
- return f"Company name: {company_name} is doing: {company_specific}\nCustomer: {text}.\nEND MESSAGE\nChoose one topic that matches customer's issue.\n{prompt}\nClass name: "
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- def predict(self, text, prompt_options, company_name, company_portion) -> str:
50
- input_text = build_prompt(text, prompt_options, company_name, company_portion)
51
- # print(input_text)
52
- # Tokenize the concatenated inp_ut text
53
- input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)
54
 
55
- # Generate the output
56
- output = self.model.generate(input_ids)
57
-
58
- # Decode the output tokens
59
- decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
60
-
61
- return decoded_output
62
-
63
-
64
- m = IntentClassifier("serj/intent-classifier")
65
- print(m.predict("Hey, after recent changes, I want to cancel subscription, please help.",
66
- "OPTIONS:\n refund\n cancel subscription\n damaged item\n return item\n", "Company",
67
- "Products and subscriptions"))
68
 
69
 
70
  [More Information Needed]
@@ -122,5 +126,4 @@ F1 AVG on the train set is 0.69
122
 
123
  #### Hardware
124
 
125
- Nvidia RTX3060 12Gb
126
-
 
1
+ ---
2
+ tags:
3
+ - intent, topic-discovery
4
+ ---
5
  # Model Card for Model ID
6
 
7
  Intent classification is the act of classifying customer's in to different pre defined categories.
 
34
 
35
  ## How to Get Started with the Model
36
 
37
+ class IntentClassifier:
38
+ def __init__(self, model_name="serj/intent-classifier", device="cuda"):
39
+ self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
40
+ self.tokenizer = T5Tokenizer.from_pretrained(model_name)
41
+ self.device = device
42
 
43
 
44
+ def build_prompt(text, prompt="", company_name="", company_specific=""):
45
+ if company_name == "Pizza Mia":
46
+ company_specific = "This company is a pizzeria place."
47
+ if company_name == "Online Banking":
48
+ company_specific = "This company is an online banking."
49
+
50
+ return f"Company name: {company_name} is doing: {company_specific}\nCustomer: {text}.\nEND MESSAGE\nChoose one topic that matches customer's issue.\n{prompt}\nClass name: "
51
 
 
52
 
53
+ def predict(self, text, prompt_options, company_name, company_portion) -> str:
54
+ input_text = build_prompt(text, prompt_options, company_name, company_portion)
55
+ # print(input_text)
56
+ # Tokenize the concatenated inp_ut text
57
+ input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)
58
+
59
+ # Generate the output
60
+ output = self.model.generate(input_ids)
61
+
62
+ # Decode the output tokens
63
+ decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
64
+
65
+ return decoded_output
66
 
 
 
 
 
 
67
 
68
+ m = IntentClassifier("serj/intent-classifier")
69
+ print(m.predict("Hey, after recent changes, I want to cancel subscription, please help.",
70
+ "OPTIONS:\n refund\n cancel subscription\n damaged item\n return item\n", "Company",
71
+ "Products and subscriptions"))
 
 
 
 
 
 
 
 
 
72
 
73
 
74
  [More Information Needed]
 
126
 
127
  #### Hardware
128
 
129
+ Nvidia RTX3060 12Gb