Spaces:
Runtime error
Runtime error
leonardlin
commited on
Commit
•
22b3942
1
Parent(s):
82777d3
better chat rebuild, added system prompt, load_in_4bit
Browse files
app.py
CHANGED
@@ -3,40 +3,61 @@
|
|
3 |
import gradio as gr
|
4 |
import logging
|
5 |
import html
|
|
|
6 |
import time
|
7 |
import torch
|
8 |
from threading import Thread
|
9 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
10 |
|
11 |
# Model
|
12 |
model_name = "augmxnt/shisa-7b-v1"
|
13 |
|
14 |
# UI Settings
|
15 |
title = "Shisa 7B"
|
16 |
-
description = "Test out Shisa 7B in either English or Japanese."
|
17 |
placeholder = "Type Here / ここに入力してください"
|
18 |
examples = [
|
19 |
-
"What
|
20 |
-
"
|
21 |
-
"東京でおすすめのラーメン屋ってどこ?",
|
|
|
22 |
]
|
23 |
|
24 |
# LLM Settings
|
25 |
-
|
26 |
-
|
|
|
|
|
27 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
28 |
model = AutoModelForCausalLM.from_pretrained(
|
29 |
model_name,
|
30 |
torch_dtype=torch.bfloat16,
|
31 |
device_map="auto",
|
32 |
-
load_in_8bit=True,
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
34 |
)
|
35 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
36 |
|
37 |
-
def chat(message, history):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
chat_history.append({"role": "user", "content": message})
|
|
|
39 |
input_ids = tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, return_tensors="pt")
|
|
|
40 |
# for multi-gpu, find the device of the first parameter of the model
|
41 |
first_param_device = next(model.parameters()).device
|
42 |
input_ids = input_ids.to(first_param_device)
|
@@ -50,6 +71,7 @@ def chat(message, history):
|
|
50 |
repetition_penalty=1.15,
|
51 |
top_p=0.95,
|
52 |
eos_token_id=tokenizer.eos_token_id,
|
|
|
53 |
)
|
54 |
# https://www.gradio.app/main/guides/creating-a-chatbot-fast#example-using-a-local-open-source-llm-with-hugging-face
|
55 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
@@ -71,6 +93,9 @@ chat_interface = gr.ChatInterface(
|
|
71 |
cache_examples=False,
|
72 |
undo_btn="Delete Previous",
|
73 |
clear_btn="Clear",
|
|
|
|
|
|
|
74 |
)
|
75 |
|
76 |
# https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI/blob/main/app.py#L219 - we use this with construction b/c Gradio barfs on autoreload otherwise
|
|
|
3 |
import gradio as gr
|
4 |
import logging
|
5 |
import html
|
6 |
+
from pprint import pprint
|
7 |
import time
|
8 |
import torch
|
9 |
from threading import Thread
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
|
11 |
|
12 |
# Model
|
13 |
model_name = "augmxnt/shisa-7b-v1"
|
14 |
|
15 |
# UI Settings
|
16 |
title = "Shisa 7B"
|
17 |
+
description = "Test out Shisa 7B in either English or Japanese. If you aren't getting the right language outputs, you can try changing the system prompt to the appropriate language. Note, we are running `load_in_4bit` to fit in 16GB of VRAM"
|
18 |
placeholder = "Type Here / ここに入力してください"
|
19 |
examples = [
|
20 |
+
["What are the best slices of pizza in New York City?"],
|
21 |
+
['How do I program a simple "hello world" in Python?'],
|
22 |
+
["東京でおすすめのラーメン屋ってどこ?"],
|
23 |
+
["Pythonでシンプルな「ハローワールド」をプログラムするにはどうすればいいですか?"],
|
24 |
]
|
25 |
|
26 |
# LLM Settings
|
27 |
+
# Initial
|
28 |
+
system_prompt = 'You are a helpful, bilingual assistant. Reply in the same language as the user.'
|
29 |
+
default_prompt = system_prompt
|
30 |
+
|
31 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
32 |
model = AutoModelForCausalLM.from_pretrained(
|
33 |
model_name,
|
34 |
torch_dtype=torch.bfloat16,
|
35 |
device_map="auto",
|
36 |
+
# load_in_8bit=True,
|
37 |
+
quantization_config = BitsAndBytesConfig(
|
38 |
+
load_in_4bit=True,
|
39 |
+
bnb_4bit_quant_type='nf4',
|
40 |
+
bnb_4bit_use_double_quant=True,
|
41 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
42 |
+
),
|
43 |
)
|
44 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
45 |
|
46 |
+
def chat(message, history, system_prompt):
|
47 |
+
print('---')
|
48 |
+
pprint(history)
|
49 |
+
if not system_prompt:
|
50 |
+
system_prompt = default_prompt
|
51 |
+
|
52 |
+
# Let's just rebuild every time it's easier
|
53 |
+
chat_history = [{"role": "system", "content": system_prompt}]
|
54 |
+
for h in history:
|
55 |
+
chat_history.append({"role": "user", "content": h[0]})
|
56 |
+
chat_history.append({"role": "assistant", "content": h[1]})
|
57 |
chat_history.append({"role": "user", "content": message})
|
58 |
+
|
59 |
input_ids = tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, return_tensors="pt")
|
60 |
+
|
61 |
# for multi-gpu, find the device of the first parameter of the model
|
62 |
first_param_device = next(model.parameters()).device
|
63 |
input_ids = input_ids.to(first_param_device)
|
|
|
71 |
repetition_penalty=1.15,
|
72 |
top_p=0.95,
|
73 |
eos_token_id=tokenizer.eos_token_id,
|
74 |
+
pad_token_id=tokenizer.eos_token_id,
|
75 |
)
|
76 |
# https://www.gradio.app/main/guides/creating-a-chatbot-fast#example-using-a-local-open-source-llm-with-hugging-face
|
77 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
|
|
93 |
cache_examples=False,
|
94 |
undo_btn="Delete Previous",
|
95 |
clear_btn="Clear",
|
96 |
+
additional_inputs=[
|
97 |
+
gr.Textbox(system_prompt, label="System Prompt (Change the language of the prompt for better replies)"),
|
98 |
+
],
|
99 |
)
|
100 |
|
101 |
# https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI/blob/main/app.py#L219 - we use this with construction b/c Gradio barfs on autoreload otherwise
|