|
|
|
"""Now that we've built a powerful LLM-based classifier, let's showcase it to the world by creating an interactive demo. In this chapter, we'll learn how to: |
|
- Create a user-friendly web interface using Gradio |
|
- Package our demo for deployment |
|
- Deploy it on Hugging Face Spaces for free |
|
- Use the Hugging Face Inference API for model access |
|
""" |
|
|
|
import json |
|
import time |
|
import os |
|
import sys |
|
from retry import retry |
|
from rich.progress import track |
|
from huggingface_hub import InferenceClient |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.metrics import confusion_matrix, classification_report |
|
import pandas as pd |
|
import gradio as gr |
|
|
|
|
|
api_key = os.getenv("HF_TOKEN") |
|
client = InferenceClient(token=api_key) |
|
|
|
|
|
sample_df = pd.read_csv("sample.csv") |
|
|
|
def get_batch_list(li, n=10): |
|
"""Split the provided list into batches of size `n`.""" |
|
batch_list = [] |
|
for i in range(0, len(li), n): |
|
batch_list.append(li[i : i + n]) |
|
return batch_list |
|
|
|
|
|
training_input, test_input, training_output, test_output = train_test_split( |
|
sample_df[['payee']], |
|
sample_df['category'], |
|
test_size=0.33, |
|
random_state=42 |
|
) |
|
|
|
|
|
def get_fewshots(training_input, training_output, batch_size=10): |
|
"""Convert the training input and output from sklearn's train_test_split into a few-shot prompt""" |
|
|
|
input_batches = get_batch_list(list(training_input.payee), n=batch_size) |
|
|
|
|
|
output_batches = get_batch_list(list(training_output), n=batch_size) |
|
|
|
|
|
fewshot_list = [] |
|
|
|
|
|
batch_count = min(len(input_batches), len(output_batches)) |
|
|
|
|
|
for i in range(batch_count): |
|
fewshot_list.extend([ |
|
|
|
{ |
|
"role": "user", |
|
"content": "\n".join(input_batches[i]), |
|
}, |
|
|
|
{ |
|
"role": "assistant", |
|
"content": json.dumps(output_batches[i]) |
|
} |
|
]) |
|
|
|
|
|
return fewshot_list |
|
|
|
fewshot_list = get_fewshots(training_input, training_output) |
|
|
|
@retry(ValueError, tries=2, delay=2) |
|
def classify_payees(name_list): |
|
prompt = """You are an AI model trained to categorize businesses based on their names. |
|
|
|
You will be given a list of business names, each separated by a new line. |
|
|
|
Your task is to analyze each name and classify it into one of the following categories: Restaurant, Bar, Hotel, or Other. |
|
|
|
It is extremely critical that there is a corresponding category output for each business name provided as an input. |
|
|
|
If a business does not clearly fall into Restaurant, Bar, or Hotel categories, you should classify it as "Other". |
|
|
|
Even if the type of business is not immediately clear from the name, it is essential that you provide your best guess based on the information available to you. If you can't make a good guess, classify it as Other. |
|
|
|
For example, if given the following input: |
|
|
|
"Intercontinental Hotel\nPizza Hut\nCheers\nWelsh's Family Restaurant\nKTLA\nDirect Mailing" |
|
|
|
Your output should be a JSON list in the following format: |
|
|
|
["Hotel", "Restaurant", "Bar", "Restaurant", "Other", "Other"] |
|
|
|
This means that you have classified "Intercontinental Hotel" as a Hotel, "Pizza Hut" as a Restaurant, "Cheers" as a Bar, "Welsh's Family Restaurant" as a Restaurant, and both "KTLA" and "Direct Mailing" as Other. |
|
|
|
If a business name contains both the word "Restaurant" and the word "Bar", you should classify it as a Restaurant. |
|
|
|
Ensure that the number of classifications in your output matches the number of business names in the input. It is very important that the length of the JSON list you return is exactly the same as the number of business names you receive. |
|
""" |
|
try: |
|
response = client.chat.completions.create( |
|
messages=[ |
|
|
|
{ |
|
"role": "system", |
|
"content": prompt, |
|
}, |
|
*fewshot_list, |
|
{ |
|
"role": "user", |
|
"content": "\n".join(name_list), |
|
} |
|
], |
|
model="meta-llama/Llama-3.3-70B-Instruct", |
|
temperature=0, |
|
) |
|
|
|
answer_str = response.choices[0].message.content |
|
answer_list = json.loads(answer_str) |
|
|
|
acceptable_answers = [ |
|
"Restaurant", |
|
"Bar", |
|
"Hotel", |
|
"Other", |
|
] |
|
for answer in answer_list: |
|
if answer not in acceptable_answers: |
|
raise ValueError(f"{answer} not in list of acceptable answers") |
|
|
|
if len(name_list) != len(answer_list): |
|
raise ValueError(f"Number of inputs ({len(name_list)}) does not equal the number of outputs ({len(answer_list)})") |
|
|
|
return dict(zip(name_list, answer_list)) |
|
|
|
except Exception as e: |
|
|
|
raise ValueError(f"Error during classification: {str(e)}") |
|
|
|
def classify_batches(name_list, batch_size=10, wait=2): |
|
|
|
all_results = {} |
|
|
|
|
|
batch_list = get_batch_list(name_list, n=batch_size) |
|
|
|
|
|
for batch in track(batch_list): |
|
try: |
|
|
|
batch_results = classify_payees(batch) |
|
|
|
|
|
all_results.update(batch_results) |
|
|
|
|
|
time.sleep(wait) |
|
except Exception as e: |
|
print(f"Error processing batch: {e}", file=sys.stderr) |
|
|
|
|
|
|
|
return pd.DataFrame( |
|
all_results.items(), |
|
columns=["payee", "category"] |
|
) |
|
|
|
|
|
llm_df = classify_batches(list(test_input.payee)) |
|
|
|
|
|
def classify_business_names(input_text): |
|
|
|
name_list = [line.strip() for line in input_text.splitlines() if line.strip()] |
|
|
|
if not name_list: |
|
return json.dumps({"error": "No business names provided. Please enter at least one business name."}) |
|
|
|
try: |
|
result = classify_payees(name_list) |
|
return json.dumps(result, indent=2) |
|
except Exception as e: |
|
return json.dumps({"error": f"Classification failed: {str(e)}"}) |
|
|
|
|
|
demo = gr.Interface( |
|
fn=classify_business_names, |
|
inputs=gr.Textbox(lines=10, placeholder="Enter business names, one per line"), |
|
outputs=gr.JSON(), |
|
title="Business Category Classifier", |
|
description="Enter business names and get a classification: Restaurant, Bar, Hotel, or Other.", |
|
examples=[ |
|
["Marriott Hotel\nTaco Bell\nThe Tipsy Cow\nStarbucks\nApple Store"] |
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |