Spaces:
Runtime error
Runtime error
from transformers import TFAutoModelForCausalLM, AutoTokenizer | |
import tensorflow as tf | |
import gradio as gr | |
TITLE = "DialoGPT -- Chatbot" | |
DESCRIPTION = """<center>This application allows you to talk with a machine. | |
In the back-end is using the DialoGPT model from Microsoft.<br> | |
This model extends GPT2 towards the conversational neural response generetion domain.<br> | |
You can also see the <a href="https://arxiv.org/abs/1911.00536">ArXiv paper</a><br></center>""" | |
EXAMPLES = [ | |
["What is your favorite videogame?"], | |
["What do you do for work?"], | |
["What are your hobbies?"], | |
["What is your favorite food?"], | |
] | |
ARTICLE = r"""<center> | |
Done by dr. Gabriel Lopez<br> | |
For more please visit: <a href='https://sites.google.com/view/dr-gabriel-lopez/home'>My Page</a><br> | |
</center>""" | |
# checkpoint = "ericzhou/DialoGPT-Medium-Rick_v2" #pytorch | |
# checkpoint = "epeicher/DialoGPT-medium-homer" #pytorch | |
checkpoint = "microsoft/DialoGPT-medium" #tf | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
model = TFAutoModelForCausalLM.from_pretrained(checkpoint) | |
# interaction function | |
def chat_with_bot(user_input, chat_history_and_input=[]): | |
emb_user_input = tokenizer.encode( | |
user_input + tokenizer.eos_token, return_tensors="tf" | |
) | |
if chat_history_and_input == []: | |
bot_input_ids = emb_user_input # first iteration | |
else: | |
bot_input_ids = tf.concat( | |
[chat_history_and_input, emb_user_input], axis=-1 | |
) # other iterations | |
chat_history_and_input = model.generate( | |
bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id | |
).numpy() | |
bot_response = tokenizer.decode( | |
chat_history_and_input[:, bot_input_ids.shape[-1] :][0], | |
skip_special_tokens=True, | |
) | |
return bot_response, chat_history_and_input | |
# gradio interface | |
in_text = gr.Textbox(value="How was the class?", label="Start chatting!") | |
out_text = gr.Textbox(value="", label="Chatbot response:") | |
gr.Interface( | |
inputs=[in_text, "state"], | |
outputs=[out_text, "state"], | |
examples=EXAMPLES, | |
title=TITLE, | |
description=DESCRIPTION, | |
article=ARTICLE, | |
fn=chat_with_bot, | |
allow_flagging=False, | |
).launch() | |