File size: 3,531 Bytes
812f70a
 
 
0e02ca5
 
 
 
 
 
c6bb1bc
0e02ca5
 
f00ac1d
ba913e7
0e02ca5
764f34e
 
 
 
f00ac1d
 
 
764f34e
ba913e7
0e02ca5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f00ac1d
 
 
 
 
0e02ca5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f00ac1d
0e02ca5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# https://www.gradio.app/guides/using-hugging-face-integrations

import gradio as gr
import logging
import html
import time
import torch
from   threading import Thread
from   transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

# Model
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v0.3"
model_name = "mistralai/Mistral-7B-Instruct-v0.1"

# UI Settings
title = "Shisa 7B"
description = "Test out Shisa 7B in either English or Japanese."
placeholder = "Type Here / ここに入力してください" 
examples = [
    "What's the best ramen in Tokyo?",
    "東京でおすすめのラーメン屋さんを教えていただけますか。",
    "東京でおすすめのラーメン屋ってどこ?",
]

# LLM Settings
system_prompt = 'You are a helpful, friendly assistant.'
chat_history = [{"role": "system", "content": system_prompt}]
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.chat_template = "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <<SYS>>\\n' + messages[idx]['content'] + '\\n<</SYS>>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' '  + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}\n"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    load_in_8bit=True,
)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

def chat(message, history):
    chat_history.append({"role": "user", "content": message})
    input_ids = tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, return_tensors="pt")
    # for multi-gpu, find the device of the first parameter of the model
    first_param_device = next(model.parameters()).device
    input_ids = input_ids.to(first_param_device)

    generate_kwargs = dict(
        inputs=input_ids,
        streamer=streamer,
        max_new_tokens=200,
        do_sample=True,
        temperature=0.7,
        top_p=0.95,
        eos_token_id=tokenizer.eos_token_id,
    )
    # https://www.gradio.app/main/guides/creating-a-chatbot-fast#example-using-a-local-open-source-llm-with-hugging-face
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    partial_message = ""
    for new_token in streamer:
        partial_message += new_token # html.escape(new_token)
        yield partial_message


chat_interface = gr.ChatInterface(
    chat,
    chatbot=gr.Chatbot(height=400),
    textbox=gr.Textbox(placeholder=placeholder, container=False, scale=7),
    title=title,
    description=description,
    theme="soft",
    examples=examples,
    cache_examples=False,
    undo_btn="Delete Previous",
    clear_btn="Clear",
)

# https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI/blob/main/app.py#L219 - we use this with construction b/c Gradio barfs on autoreload otherwise
with gr.Blocks() as demo:
    chat_interface.render()
    gr.Markdown("You can try asking this question in English, formal Japanese, and informal Japanese. You might need to ask it to reply informally with something like もっと友達みたいに話そうよ。あんまり堅苦しくなくて。to get informal replies. We limit output to 200 tokens.")

demo.queue().launch()