Serj commited on
Commit
bfb4729
1 Parent(s): 7f73ead

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +35 -0
README.md CHANGED
@@ -30,6 +30,41 @@ This is the model card of a 🤗 transformers model that has been pushed on the
30
 
31
  ## How to Get Started with the Model
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  [More Information Needed]
 
30
 
31
  ## How to Get Started with the Model
32
 
33
+ class IntentClassifier:
34
+ def __init__(self, model_name="serj/intent-classifier", device="cuda"):
35
+ self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
36
+ self.tokenizer = T5Tokenizer.from_pretrained(model_name)
37
+ self.device = device
38
+
39
+
40
+ def build_prompt(text, prompt="", company_name="", company_specific=""):
41
+ if company_name == "Pizza Mia":
42
+ company_specific = "This company is a pizzeria place."
43
+ if company_name == "Online Banking":
44
+ company_specific = "This company is an online banking."
45
+
46
+ return f"Company name: {company_name} is doing: {company_specific}\nCustomer: {text}.\nEND MESSAGE\nChoose one topic that matches customer's issue.\n{prompt}\nClass name: "
47
+
48
+
49
+ def predict(self, text, prompt_options, company_name, company_portion) -> str:
50
+ input_text = build_prompt(text, prompt_options, company_name, company_portion)
51
+ # print(input_text)
52
+ # Tokenize the concatenated inp_ut text
53
+ input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)
54
+
55
+ # Generate the output
56
+ output = self.model.generate(input_ids)
57
+
58
+ # Decode the output tokens
59
+ decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
60
+
61
+ return decoded_output
62
+
63
+
64
+ m = IntentClassifier("serj/intent-classifier")
65
+ print(m.predict("Hey, after recent changes, I want to cancel subscription, please help.",
66
+ "OPTIONS:\n refund\n cancel subscription\n damaged item\n return item\n", "Company",
67
+ "Products and subscriptions"))
68
 
69
 
70
  [More Information Needed]