Spaces:
Runtime error
Runtime error
from statistics import mean | |
import random | |
import torch | |
from transformers import BertModel, BertTokenizerFast | |
import numpy as np | |
import torch.nn.functional as F | |
import gradio as gr | |
threshold = 0.4 | |
tokenizer = BertTokenizerFast.from_pretrained("setu4993/LaBSE") | |
model = BertModel.from_pretrained("setu4993/LaBSE") | |
model = model.eval() | |
order_food_ex = [ | |
"food", | |
"I am hungry, I want to order food", | |
"How do I order food", | |
"What are the food options", | |
"I need dinner", | |
"I want lunch", | |
"What are the menu options", | |
"I want a hamburger" | |
] | |
talk_to_human_ex = [ | |
"I need to talk to someone", | |
"Connect me with a human", | |
"I need to speak with a person", | |
"Put me on with a human", | |
"Connect me with customer service", | |
"human" | |
] | |
def embed(text, tokenizer, model): | |
inputs = tokenizer(text, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
return outputs.pooler_output | |
def similarity(embeddings_1, embeddings_2): | |
normalized_embeddings_1 = F.normalize(embeddings_1, p=2) | |
normalized_embeddings_2 = F.normalize(embeddings_2, p=2) | |
return torch.matmul( | |
normalized_embeddings_1, normalized_embeddings_2.transpose(0, 1) | |
) | |
order_food_embed = [embed(x, tokenizer, model) for x in order_food_ex] | |
talk_to_human_embed = [embed(x, tokenizer, model) for x in talk_to_human_ex] | |
def chat(message, history): | |
history = history or [] | |
message_embed = embed(message, tokenizer, model) | |
order_sim = [] | |
for em in order_food_embed: | |
order_sim.append(float(similarity(em, message_embed))) | |
human_sim = [] | |
for em in talk_to_human_embed: | |
human_sim.append(float(similarity(em, message_embed))) | |
if mean(order_sim) > threshold: | |
response = random.choice([ | |
"We have hamburgers or pizza! Which one do you want?", | |
"Do you want a hamburger or a pizza?"]) | |
elif mean(human_sim) > threshold: | |
response = random.choice([ | |
"Sure, a customer service agent will jump into this convo shortly!", | |
"No problem. Let me forward on this conversation to a person that can respond."]) | |
else: | |
response = "Sorry, I didn't catch that. Could your rephrase?" | |
history.append((message, response)) | |
return history, history | |
iface = gr.Interface( | |
chat, | |
["text", "state"], | |
["chatbot", "state"], | |
allow_screenshot=False, | |
allow_flagging="never", | |
) | |
iface.launch() | |