|
import gradio as gr |
|
from gradio import ChatInterface, Request |
|
import anyio |
|
import os |
|
import threading |
|
import sys |
|
from itertools import chain |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
|
|
LOG_LEVEL = "INFO" |
|
TIMEOUT = 60 |
|
|
|
|
|
model_name = "gpt2" |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
def generate_response(message, history): |
|
inputs = tokenizer(message, return_tensors="pt") |
|
outputs = model.generate(**inputs, max_length=150, pad_token_id=tokenizer.eos_token_id) |
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return response |
|
|
|
class myChatInterface(ChatInterface): |
|
async def _submit_fn( |
|
self, |
|
message: str, |
|
history_with_input: list[list[str | None]], |
|
request: Request, |
|
*args, |
|
) -> tuple[list[list[str | None]], list[list[str | None]]]: |
|
history = history_with_input[:-1] |
|
response = generate_response(message, history) |
|
history.append([message, response]) |
|
return history, history |
|
|
|
with gr.Blocks() as demo: |
|
def flatten_chain(list_of_lists): |
|
return list(chain.from_iterable(list_of_lists)) |
|
|
|
class thread_with_trace(threading.Thread): |
|
def __init__(self, *args, **keywords): |
|
threading.Thread.__init__(self, *args, **keywords) |
|
self.killed = False |
|
self._return = None |
|
|
|
def start(self): |
|
self.__run_backup = self.run |
|
self.run = self.__run |
|
threading.Thread.start(self) |
|
|
|
def __run(self): |
|
sys.settrace(self.globaltrace) |
|
self.__run_backup() |
|
self.run = self.__run_backup |
|
|
|
def run(self): |
|
if self._target is not None: |
|
self._return = self._target(*self._args, **self._kwargs) |
|
|
|
def globaltrace(self, frame, event, arg): |
|
if event == "call": |
|
return self.localtrace |
|
else: |
|
return None |
|
|
|
def localtrace(self, frame, event, arg): |
|
if self.killed: |
|
if event == "line": |
|
raise SystemExit() |
|
return self.localtrace |
|
|
|
def kill(self): |
|
self.killed = True |
|
|
|
def join(self, timeout=0): |
|
threading.Thread.join(self, timeout) |
|
return self._return |
|
|
|
def get_description_text(): |
|
return """ |
|
# Hugging Face Model Chatbot Demo |
|
|
|
This demo shows how to build a chatbot using models available on Hugging Face. |
|
""" |
|
|
|
description = gr.Markdown(get_description_text()) |
|
|
|
with gr.Row() as params: |
|
txt_model = gr.Dropdown( |
|
label="Model", |
|
choices=[ |
|
"gpt2", |
|
"gpt-2-medium", |
|
"gpt-2-large", |
|
"gpt-2-xl", |
|
], |
|
allow_custom_value=True, |
|
value="gpt2", |
|
container=True, |
|
) |
|
|
|
chatbot = gr.Chatbot( |
|
[], |
|
elem_id="chatbot", |
|
bubble_full_width=False, |
|
avatar_images=( |
|
"human.png", |
|
(os.path.join(os.path.dirname(__file__), "autogen.png")), |
|
), |
|
render=False, |
|
height=600, |
|
) |
|
|
|
txt_input = gr.Textbox( |
|
scale=4, |
|
show_label=False, |
|
placeholder="Enter text and press enter", |
|
container=False, |
|
render=False, |
|
autofocus=True, |
|
) |
|
|
|
chatiface = myChatInterface( |
|
respond=None, |
|
chatbot=chatbot, |
|
textbox=txt_input, |
|
additional_inputs=[txt_model], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True, server_name="0.0.0.0") |
|
|