|
--- |
|
tags: |
|
- intent, topic-discovery |
|
--- |
|
# Model Card for Model ID |
|
|
|
Intent classification is the act of classifying customer's in to different pre defined categories. |
|
Sometimes intent classification is referred to as topic classification. |
|
By fine tuning a T5 model with prompts containing sythetic data that resembles customer's requests this |
|
model is able to classify intents in a dynamic way by adding all of the categories to the prompt |
|
|
|
|
|
## Model Details |
|
Fine tuned Flan-T5-Base |
|
|
|
### Model Description |
|
|
|
<!-- Provide a longer summary of what this model is. --> |
|
|
|
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated. |
|
|
|
- **Developed by:** Serj Smorodinsky |
|
- **Model type:** Flan-T5-Base |
|
- **Language(s) (NLP):** [More Information Needed] |
|
- **License:** [More Information Needed] |
|
- **Finetuned from model [optional]:** Flan-T5-Base |
|
|
|
### Model Sources [optional] |
|
|
|
<!-- Provide the basic links for the model. --> |
|
|
|
- **Repository:** https://github.com/SerjSmor/intent_classification |
|
|
|
|
|
## How to Get Started with the Model |
|
``` |
|
class IntentClassifier: |
|
def __init__(self, model_name="serj/intent-classifier", device="cuda"): |
|
self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device) |
|
self.tokenizer = T5Tokenizer.from_pretrained(model_name) |
|
self.device = device |
|
|
|
|
|
def build_prompt(text, prompt="", company_name="", company_specific=""): |
|
if company_name == "Pizza Mia": |
|
company_specific = "This company is a pizzeria place." |
|
if company_name == "Online Banking": |
|
company_specific = "This company is an online banking." |
|
|
|
return f"Company name: {company_name} is doing: {company_specific}\nCustomer: {text}.\nEND MESSAGE\nChoose one topic that matches customer's issue.\n{prompt}\nClass name: " |
|
|
|
|
|
def predict(self, text, prompt_options, company_name, company_portion) -> str: |
|
input_text = build_prompt(text, prompt_options, company_name, company_portion) |
|
# print(input_text) |
|
# Tokenize the concatenated inp_ut text |
|
input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device) |
|
|
|
# Generate the output |
|
output = self.model.generate(input_ids) |
|
|
|
# Decode the output tokens |
|
decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
return decoded_output |
|
|
|
|
|
m = IntentClassifier("serj/intent-classifier") |
|
print(m.predict("Hey, after recent changes, I want to cancel subscription, please help.", |
|
"OPTIONS:\n refund\n cancel subscription\n damaged item\n return item\n", "Company", |
|
"Products and subscriptions")) |
|
|
|
``` |
|
[More Information Needed] |
|
|
|
## Training Details |
|
|
|
### Training Data |
|
|
|
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. --> |
|
https://github.com/SerjSmor/intent_classification |
|
HF dataset will be added in the future. |
|
|
|
[More Information Needed] |
|
|
|
### Training Procedure |
|
|
|
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. --> |
|
https://github.com/SerjSmor/intent_classification/blob/main/t5_generator_trainer.py |
|
|
|
Using HF trainer |
|
|
|
training_args = TrainingArguments( |
|
output_dir='./results', |
|
num_train_epochs=epochs, |
|
per_device_train_batch_size=batch_size, |
|
per_device_eval_batch_size=batch_size, |
|
warmup_steps=500, |
|
weight_decay=0.01, |
|
logging_dir='./logs', |
|
logging_steps=10, |
|
evaluation_strategy="epoch" |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=val_dataset, |
|
tokenizer=tokenizer, |
|
# compute_metrics=compute_metrics |
|
) |
|
|
|
|
|
|
|
## Evaluation |
|
|
|
<!-- This section describes the evaluation protocols and provides the results. --> |
|
I've used Atis dataset for evaluation. |
|
F1 AVG on the train set is 0.69 |
|
|
|
|
|
#### Summary |
|
|
|
|
|
|
|
#### Hardware |
|
|
|
Nvidia RTX3060 12Gb |