File size: 4,838 Bytes
9d3538c
 
 
895708c
 
cf80b98
06fd2e7
cf80b98
0e51956
 
cf80b98
a7cbc76
 
38b2959
 
 
 
a7cbc76
 
 
38b2959
a7cbc76
 
 
 
 
 
 
38b2959
 
a7cbc76
 
38b2959
a7cbc76
 
 
 
 
7f73ead
a7cbc76
 
 
64f1612
9d3538c
 
 
 
 
bfb4729
 
9d3538c
 
 
 
 
 
 
bfb4729
 
9d3538c
 
 
 
 
 
 
 
 
 
 
 
 
bfb4729
 
9d3538c
 
 
 
7f73ead
64f1612
a7cbc76
 
 
 
 
 
 
7f73ead
 
a7cbc76
 
 
 
 
 
7f73ead
a7cbc76
7f73ead
a7cbc76
7f73ead
 
 
 
 
 
 
 
 
 
 
a7cbc76
7f73ead
 
 
 
 
 
 
 
a7cbc76
 
 
 
 
 
895708c
 
98226df
895708c
98226df
a7cbc76
 
 
 
 
 
 
 
9d3538c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
---
tags:
- intent, topic-discovery
datasets:
- clinc/clinc_oos
widget:
- text: "Topic %% Customer: How do I get my money back?.\nEND MESSAGE\nChoose one topic that matches customer's issue.\n# renew subscription # account deletion # cancel subscription # resume subscription # refund requests # other # general # item damaged # malfunction # hello # intro # question\nClass name: "
  example_title: "Open Label Intent Classification"
---


# Model Card for Model ID

Intent classification is the act of classifying customer's in to different pre defined categories. 
Sometimes intent classification is referred to as topic classification.
By fine tuning a T5 model with prompts containing sythetic data that resembles customer's requests this 
model is able to classify intents in a dynamic way by adding all of the categories to the prompt


## Model Details
Fine tuned Flan-T5-Base

### Model Description

<!-- Provide a longer summary of what this model is. -->

This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.

- **Developed by:** Serj Smorodinsky
- **Model type:** Flan-T5-Base
- **Language(s) (NLP):** [More Information Needed]
- **License:** [More Information Needed]
- **Finetuned from model [optional]:** Flan-T5-Base

### Model Sources [optional]

<!-- Provide the basic links for the model. -->

- **Repository:** https://github.com/SerjSmor/intent_classification


## How to Get Started with the Model
```
  class IntentClassifier:
      def __init__(self, model_name="serj/intent-classifier", device="cuda"):
          self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
          self.tokenizer = T5Tokenizer.from_pretrained(model_name)
          self.device = device


  def build_prompt(text, prompt="", company_name="", company_specific=""):
      if company_name == "Pizza Mia":
          company_specific = "This company is a pizzeria place."
      if company_name == "Online Banking":
          company_specific = "This company is an online banking."
  
      return f"Company name: {company_name} is doing: {company_specific}\nCustomer: {text}.\nEND MESSAGE\nChoose one topic that matches customer's issue.\n{prompt}\nClass name: "


  def predict(self, text, prompt_options, company_name, company_portion) -> str:
      input_text = build_prompt(text, prompt_options, company_name, company_portion)
      # print(input_text)
      # Tokenize the concatenated inp_ut text
      input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)
  
      # Generate the output
      output = self.model.generate(input_ids)
  
      # Decode the output tokens
      decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
  
      return decoded_output


  m = IntentClassifier("serj/intent-classifier")
  print(m.predict("Hey, after recent changes, I want to cancel subscription, please help.",
                  "OPTIONS:\n refund\n cancel subscription\n damaged item\n return item\n", "Company",
                  "Products and subscriptions"))

```
[More Information Needed]

## Training Details

### Training Data

<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
https://github.com/SerjSmor/intent_classification
HF dataset will be added in the future.

[More Information Needed]

### Training Procedure

<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
https://github.com/SerjSmor/intent_classification/blob/main/t5_generator_trainer.py

Using HF trainer

    training_args = TrainingArguments(
        output_dir='./results',
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir='./logs',
        logging_steps=10,
        evaluation_strategy="epoch"
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        # compute_metrics=compute_metrics
    )



## Evaluation

<!-- This section describes the evaluation protocols and provides the results. -->
The newest version of the model is finetuned on 2 synthetic datasets and 41 first classes of clinc_oos in a few shot manner.
All datasets have 10-20 samples per class. Training data did not include Atis dataset.

Atis zero shot test set evaluation: weighted F1 87%
Clinc test set is next.


#### Summary



#### Hardware

Nvidia RTX3060 12Gb