Edit model card

Prompt Structure

Topic %% Customer: text. END MESSAGE OPTIONS: each class separated by % Choose one topic that matches customer's issue. Class name:

You have to have a period after the end of the text, otherwise you'll get funky results. That's how the model was trained.

Model Card for Model ID

Intent classification is the act of classifying customer's in to different pre defined categories. Sometimes intent classification is referred to as topic classification. By fine tuning a T5 model with prompts containing sythetic data that resembles customer's requests this model is able to classify intents in a dynamic way by adding all of the categories to the prompt

Model Details

Fine tuned Flan-T5-Base

Model Description

This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.

  • Developed by: Serj Smorodinsky
  • Model type: Flan-T5-Base
  • Language(s) (NLP): [More Information Needed]
  • License: [More Information Needed]
  • Finetuned from model [optional]: Flan-T5-Base

Model Sources [optional]

How to Get Started with the Model

  class IntentClassifier:
      def __init__(self, model_name="serj/intent-classifier", device="cuda"):
          self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
          self.tokenizer = T5Tokenizer.from_pretrained(model_name)
          self.device = device


  def build_prompt(text, prompt="", company_name="", company_specific=""):
      if company_name == "Pizza Mia":
          company_specific = "This company is a pizzeria place."
      if company_name == "Online Banking":
          company_specific = "This company is an online banking."
  
      return f"Company name: {company_name} is doing: {company_specific}\nCustomer: {text}.\nEND MESSAGE\nChoose one topic that matches customer's issue.\n{prompt}\nClass name: "


  def predict(self, text, prompt_options, company_name, company_portion) -> str:
      input_text = build_prompt(text, prompt_options, company_name, company_portion)
      # print(input_text)
      # Tokenize the concatenated inp_ut text
      input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)
  
      # Generate the output
      output = self.model.generate(input_ids)
  
      # Decode the output tokens
      decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
  
      return decoded_output


  m = IntentClassifier("serj/intent-classifier")
  print(m.predict("Hey, after recent changes, I want to cancel subscription, please help.",
                  "OPTIONS:\n refund\n cancel subscription\n damaged item\n return item\n", "Company",
                  "Products and subscriptions"))

[More Information Needed]

Training Details

Training Data

https://github.com/SerjSmor/intent_classification HF dataset will be added in the future.

[More Information Needed]

Training Procedure

https://github.com/SerjSmor/intent_classification/blob/main/t5_generator_trainer.py

Using HF trainer

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="epoch"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    # compute_metrics=compute_metrics
)

Evaluation

The newest version of the model is finetuned on 2 synthetic datasets and 41 first classes of clinc_oos in a few shot manner. All datasets have 10-20 samples per class. Training data did not include Atis dataset.

Atis zero shot test set evaluation: weighted F1 87% Clinc test set is next.

Summary

Hardware

Nvidia RTX3060 12Gb

Downloads last month
221
Safetensors
Model size
248M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train Serj/intent-classifier

Space using Serj/intent-classifier 1