File size: 3,553 Bytes
812f70a
 
 
0e02ca5
 
22b3942
0e02ca5
 
 
22b3942
c6bb1bc
0e02ca5
1a6b000
ba913e7
0e02ca5
764f34e
5b2ea5a
764f34e
 
22b3942
 
bd17394
22b3942
764f34e
ba913e7
0e02ca5
22b3942
89eb83f
22b3942
 
0e02ca5
 
 
 
 
22b3942
82d8190
0e02ca5
 
22b3942
 
 
 
89eb83f
 
 
c48336f
89eb83f
22b3942
 
 
 
 
0e02ca5
22b3942
f00ac1d
22b3942
f00ac1d
 
 
 
0e02ca5
 
 
 
 
d259634
0e02ca5
 
22b3942
0e02ca5
9dfa458
 
 
 
 
0e02ca5
 
 
 
 
 
 
 
 
 
 
 
 
22b3942
 
 
0e02ca5
 
 
 
 
1a6b000
0e02ca5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# https://www.gradio.app/guides/using-hugging-face-integrations

import gradio as gr
import logging
import html
from   pprint import pprint
import time
import torch
from   threading import Thread
from   transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer

# Model
model_name = "augmxnt/shisa-7b-v1"

# UI Settings
title = "Shisa 7B"
description = "Test out <a href='https://huggingface.co/augmxnt/shisa-7b-v1'>Shisa 7B</a> in either English or Japanese. If you aren't getting the right language outputs, you can try changing the system prompt to the appropriate language. Note, we are running `load_in_4bit` to fit in 16GB of VRAM."
placeholder = "Type Here / ここに入力してください" 
examples = [
    ["What are the best slices of pizza in New York City?"],
    ["東京でおすすめのラーメン屋ってどこ?"],
    ['How do I program a simple "hello world" in Python?'],
    ["Pythonでシンプルな「ハローワールド」をプログラムするにはどうすればいいですか?"],
]

# LLM Settings
# Initial
system_prompt = 'You are a helpful, bilingual assistant. Reply in same language as the user.'
default_prompt = system_prompt

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    # load_in_8bit=True,
    load_in_4bit=True,
)

def chat(message, history, system_prompt):
    if not system_prompt:
        system_prompt = default_prompt

    print('---')
    print('Prompt:', system_prompt)
    pprint(history)
    print(message)

    # Let's just rebuild every time it's easier
    chat_history = [{"role": "system", "content": system_prompt}]
    for h in history:
        chat_history.append({"role": "user", "content": h[0]})
        chat_history.append({"role": "assistant", "content": h[1]})
    chat_history.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, return_tensors="pt")

    # for multi-gpu, find the device of the first parameter of the model
    first_param_device = next(model.parameters()).device
    input_ids = input_ids.to(first_param_device)

    generate_kwargs = dict(
        inputs=input_ids,
        max_new_tokens=200,
        do_sample=True,
        temperature=0.7,
        repetition_penalty=1.15,
        top_p=0.95,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
    )

    output_ids = model.generate(**generate_kwargs)
    new_tokens = output_ids[0, input_ids.size(1):]
    response = tokenizer.decode(new_tokens, skip_special_tokens=True) 
    return response


chat_interface = gr.ChatInterface(
    chat,
    chatbot=gr.Chatbot(height=400),
    textbox=gr.Textbox(placeholder=placeholder, container=False, scale=7),
    title=title,
    description=description,
    theme="soft",
    examples=examples,
    cache_examples=False,
    undo_btn="Delete Previous",
    clear_btn="Clear",
    additional_inputs=[
        gr.Textbox(system_prompt, label="System Prompt (Change the language of the prompt for better replies)"),
    ],
)

# https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI/blob/main/app.py#L219 - we use this with construction b/c Gradio barfs on autoreload otherwise
with gr.Blocks() as demo:
    chat_interface.render()
    gr.Markdown("You can try asking this question in Japanese or English. We limit output to 200 tokens.")

demo.queue().launch()