Spaces:
Runtime error
Runtime error
add inference logic and package requirements
Browse files- app.py +205 -4
- requirements.txt +6 -0
app.py
CHANGED
@@ -1,7 +1,208 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
import os
|
4 |
+
import threading
|
5 |
+
import time
|
6 |
+
import argparse
|
7 |
+
import logging
|
8 |
+
from dataclasses import dataclass
|
9 |
|
10 |
+
import torch
|
11 |
+
import sentencepiece as spm
|
12 |
+
from transformers import GPTNeoXForCausalLM, GPTNeoXConfig
|
13 |
+
from transformers.generation.streamers import BaseStreamer
|
14 |
+
from huggingface_hub import hf_hub_download, login
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.getLogger()
|
18 |
+
logger.setLevel("INFO")
|
19 |
+
|
20 |
+
gr_interface = None
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class DefaultArgs:
|
24 |
+
hf_model_name_or_path: str = None
|
25 |
+
spm_model_path: str = None
|
26 |
+
env: str = "dev"
|
27 |
+
port: int = 7860
|
28 |
+
make_public: bool = False
|
29 |
+
|
30 |
+
if os.getenv("RUNNING_ON_HF_SPACE"):
|
31 |
+
login(token=os.getenv("HF_TOKEN"))
|
32 |
+
hf_repo = "leemeng/stablelm-jp-alpha"
|
33 |
+
|
34 |
+
args = DefaultArgs()
|
35 |
+
args.hf_model_name_or_path = hf_repo
|
36 |
+
args.spm_model_path = hf_hub_download(repo_id=hf_repo, filename="sentencepiece.model")
|
37 |
+
else:
|
38 |
+
parser = argparse.ArgumentParser(description="")
|
39 |
+
parser.add_argument("--hf_model_name_or_path", type=str, required=True)
|
40 |
+
parser.add_argument("--spm_model_path", type=str, required=True)
|
41 |
+
parser.add_argument("--env", type=str, default="dev")
|
42 |
+
parser.add_argument("--port", type=int, default=7860)
|
43 |
+
parser.add_argument("--make_public", action='store_true')
|
44 |
+
args = parser.parse_args()
|
45 |
+
|
46 |
+
def load_model(
|
47 |
+
model_dir,
|
48 |
+
):
|
49 |
+
config = GPTNeoXConfig.from_pretrained(model_dir)
|
50 |
+
config.is_decoder = True
|
51 |
+
model = GPTNeoXForCausalLM.from_pretrained(model_dir, config=config, torch_dtype=torch.bfloat16)
|
52 |
+
if torch.cuda.is_available():
|
53 |
+
model = model.to("cuda:0")
|
54 |
+
return model
|
55 |
+
|
56 |
+
logging.info("Loading model")
|
57 |
+
model = load_model(args.hf_model_name_or_path)
|
58 |
+
sp = spm.SentencePieceProcessor(model_file=args.spm_model_path)
|
59 |
+
logging.info("Finished loading model")
|
60 |
+
|
61 |
+
class SentencePieceStreamer(BaseStreamer):
|
62 |
+
def __init__(self, sp: spm.SentencePieceProcessor):
|
63 |
+
self.sp = sp
|
64 |
+
self.num_invoked = 0
|
65 |
+
self.prompt = ""
|
66 |
+
self.generated_text = ""
|
67 |
+
self.ended = False
|
68 |
+
|
69 |
+
|
70 |
+
def put(self, t: torch.Tensor):
|
71 |
+
d = t.dim()
|
72 |
+
if d == 1:
|
73 |
+
pass
|
74 |
+
elif d == 2:
|
75 |
+
t = t[0]
|
76 |
+
else:
|
77 |
+
raise NotImplementedError
|
78 |
+
|
79 |
+
t = [int(x) for x in t.numpy()]
|
80 |
+
|
81 |
+
text = self.sp.decode_ids(t)
|
82 |
+
|
83 |
+
if self.num_invoked == 0:
|
84 |
+
self.prompt = text
|
85 |
+
self.num_invoked += 1
|
86 |
+
return
|
87 |
+
|
88 |
+
self.generated_text += text
|
89 |
+
# print(f"[streamer]: {self.generated_text}")
|
90 |
+
# yield text
|
91 |
+
|
92 |
+
def end(self):
|
93 |
+
self.ended = True
|
94 |
+
|
95 |
+
def user(user_message, history):
|
96 |
+
logging.debug(f"[user] user_message: {user_message}")
|
97 |
+
logging.debug(f"[user] history: {history}")
|
98 |
+
|
99 |
+
res = ("", history + [[user_message, None]])
|
100 |
+
return res
|
101 |
+
|
102 |
+
def bot(
|
103 |
+
history,
|
104 |
+
temperature,
|
105 |
+
max_new_tokens,
|
106 |
+
):
|
107 |
+
logging.debug(f"[bot] history: {history}")
|
108 |
+
logging.debug(f"temperature: {temperature}")
|
109 |
+
|
110 |
+
# TODO: modify `<br>` back to `\n` based on the original user prinpt
|
111 |
+
prompt = history[-1][0]
|
112 |
+
|
113 |
+
tokens = sp.encode(prompt)
|
114 |
+
input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(model.device)
|
115 |
+
|
116 |
+
# TODO: parametrize setting on UI
|
117 |
+
do_sample = True
|
118 |
+
|
119 |
+
streamer = SentencePieceStreamer(sp=sp)
|
120 |
+
|
121 |
+
max_possilbe_new_tokens = model.config.max_position_embeddings - len(tokens)
|
122 |
+
max_possilbe_new_tokens = min(max_possilbe_new_tokens, max_new_tokens)
|
123 |
+
|
124 |
+
|
125 |
+
thr = threading.Thread(target=model.generate, args=(), kwargs=dict(
|
126 |
+
input_ids=input_ids,
|
127 |
+
temperature=temperature,
|
128 |
+
max_new_tokens=max_possilbe_new_tokens,
|
129 |
+
do_sample=do_sample,
|
130 |
+
streamer=streamer,
|
131 |
+
# max_length=4096,
|
132 |
+
# top_k=100,
|
133 |
+
# top_p=0.9,
|
134 |
+
# repetition_penalty=1.0,
|
135 |
+
# num_return_sequences=2,
|
136 |
+
# num_beams=2,
|
137 |
+
))
|
138 |
+
thr.start()
|
139 |
+
|
140 |
+
history[-1][1] = ""
|
141 |
+
while not streamer.ended:
|
142 |
+
history[-1][1] = streamer.generated_text
|
143 |
+
time.sleep(0.05)
|
144 |
+
yield history
|
145 |
+
|
146 |
+
# TODO: optimize for final few tokens
|
147 |
+
history[-1][1] = streamer.generated_text
|
148 |
+
yield history
|
149 |
+
|
150 |
+
if gr_interface:
|
151 |
+
gr_interface.close(verbose=False)
|
152 |
+
|
153 |
+
with gr.Blocks() as gr_interface:
|
154 |
+
chatbot = gr.Chatbot(label="StableLM JP Alpha").style(height=500)
|
155 |
+
|
156 |
+
# generation params
|
157 |
+
with gr.Row():
|
158 |
+
temperature = gr.Slider(0, 1, value=0.7, step=0.05, label="Temperature")
|
159 |
+
max_new_tokens = gr.Slider(
|
160 |
+
128,
|
161 |
+
model.config.max_position_embeddings,
|
162 |
+
value=128, step=64, label="Max New Tokens")
|
163 |
+
|
164 |
+
# prompt
|
165 |
+
# TODO: add more options
|
166 |
+
# prompt_options = gr.Dropdown(
|
167 |
+
# choices=[
|
168 |
+
# "運が良かったのか悪かったのか日本に帰ってきたタイミングでコロナが猛威を振るい始め、",
|
169 |
+
# """[問題]に対する[答え]を[選択肢]の中から選んでください。
|
170 |
+
|
171 |
+
# [問題]: ある場所の周辺地域を指す言葉は?
|
172 |
+
# [選択肢]: [空, オレゴン州, 街, 歩道橋, 近辺]
|
173 |
+
# [答え]: 近辺
|
174 |
+
|
175 |
+
# [問題]: 若くて世間に慣れていないことを何という?
|
176 |
+
# [選択肢]: [青っぽい, 若い, ベテラン, 生々しい, 玄人]
|
177 |
+
# [答え]: """
|
178 |
+
# ],
|
179 |
+
# label="Prompt Options",
|
180 |
+
# info="Select 1 option for quick start",
|
181 |
+
# allow_custom_value=False,
|
182 |
+
# )
|
183 |
+
prompt = gr.Textbox(label="Prompt", info="Pro tip: press Enter to submit directly")
|
184 |
+
|
185 |
+
|
186 |
+
# def on_prompt_options_change(pmt_opts, pmt):
|
187 |
+
# return pmt_opts
|
188 |
+
|
189 |
+
# prompt_options.change(on_prompt_options_change, [prompt_options, prompt], prompt)
|
190 |
+
|
191 |
+
with gr.Row():
|
192 |
+
submit = gr.Button("Submit")
|
193 |
+
stop = gr.Button("Stop")
|
194 |
+
|
195 |
+
clear = gr.Button("Clear History")
|
196 |
+
|
197 |
+
# event handling
|
198 |
+
submit_event = prompt.submit(user, [prompt, chatbot], [prompt, chatbot], queue=False)\
|
199 |
+
.then(bot, [chatbot, temperature, max_new_tokens], chatbot, queue=True)
|
200 |
+
|
201 |
+
submit_click_event = submit.click(user, [prompt, chatbot], [prompt, chatbot], queue=False)\
|
202 |
+
.then(bot, [chatbot, temperature, max_new_tokens], chatbot, queue=True)
|
203 |
+
|
204 |
+
stop.click(None, None, None, cancels=[submit_event, submit_click_event], queue=False)
|
205 |
+
clear.click(lambda: None, None, chatbot, queue=False)
|
206 |
+
|
207 |
+
gr_interface.queue()
|
208 |
+
gr_interface.launch(server_port=args.port, share=args.make_public)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==3.28.1
|
2 |
+
gradio_client==0.1.4
|
3 |
+
torch==2.0.0
|
4 |
+
sentencepiece==0.1.97
|
5 |
+
transformers==4.28.1
|
6 |
+
huggingface-hub==0.14.1
|