Spaces:
Runtime error
Runtime error
# https://www.gradio.app/guides/using-hugging-face-integrations | |
import gradio as gr | |
import logging | |
import html | |
import time | |
import torch | |
from threading import Thread | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
# Model | |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v0.3" | |
model_name = "mistralai/Mistral-7B-Instruct-v0.1" | |
# UI Settings | |
title = "Shisa 7B" | |
description = "Test out Shisa 7B in either English or Japanese." | |
placeholder = "Type Here / ここに入力してください" | |
examples = [ | |
"What's the best ramen in Tokyo?", | |
"東京でおすすめのラーメン屋さんを教えていただけますか。", | |
"東京でおすすめのラーメン屋ってどこ?", | |
] | |
# LLM Settings | |
system_prompt = 'You are a helpful, friendly assistant.' | |
chat_history = [{"role": "system", "content": system_prompt}] | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.chat_template = "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <<SYS>>\\n' + messages[idx]['content'] + '\\n<</SYS>>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}\n" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
load_in_8bit=True, | |
) | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
def chat(message, history): | |
chat_history.append({"role": "user", "content": message}) | |
input_ids = tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, return_tensors="pt") | |
# for multi-gpu, find the device of the first parameter of the model | |
first_param_device = next(model.parameters()).device | |
input_ids = input_ids.to(first_param_device) | |
generate_kwargs = dict( | |
inputs=input_ids, | |
streamer=streamer, | |
max_new_tokens=200, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.95, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
# https://www.gradio.app/main/guides/creating-a-chatbot-fast#example-using-a-local-open-source-llm-with-hugging-face | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
partial_message = "" | |
for new_token in streamer: | |
partial_message += new_token # html.escape(new_token) | |
yield partial_message | |
chat_interface = gr.ChatInterface( | |
chat, | |
chatbot=gr.Chatbot(height=400), | |
textbox=gr.Textbox(placeholder=placeholder, container=False, scale=7), | |
title=title, | |
description=description, | |
theme="soft", | |
examples=examples, | |
cache_examples=False, | |
undo_btn="Delete Previous", | |
clear_btn="Clear", | |
) | |
# https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI/blob/main/app.py#L219 - we use this with construction b/c Gradio barfs on autoreload otherwise | |
with gr.Blocks() as demo: | |
chat_interface.render() | |
gr.Markdown("You can try asking this question in English, formal Japanese, and informal Japanese. You might need to ask it to reply informally with something like もっと友達みたいに話そうよ。あんまり堅苦しくなくて。to get informal replies. We limit output to 200 tokens.") | |
demo.queue().launch() | |