leemeng commited on
Commit
54ac395
·
1 Parent(s): c74831d

add inference logic and package requirements

Browse files
Files changed (2) hide show
  1. app.py +205 -4
  2. requirements.txt +6 -0
app.py CHANGED
@@ -1,7 +1,208 @@
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+ import os
4
+ import threading
5
+ import time
6
+ import argparse
7
+ import logging
8
+ from dataclasses import dataclass
9
 
10
+ import torch
11
+ import sentencepiece as spm
12
+ from transformers import GPTNeoXForCausalLM, GPTNeoXConfig
13
+ from transformers.generation.streamers import BaseStreamer
14
+ from huggingface_hub import hf_hub_download, login
15
+
16
+
17
+ logger = logging.getLogger()
18
+ logger.setLevel("INFO")
19
+
20
+ gr_interface = None
21
+
22
+ @dataclass
23
+ class DefaultArgs:
24
+ hf_model_name_or_path: str = None
25
+ spm_model_path: str = None
26
+ env: str = "dev"
27
+ port: int = 7860
28
+ make_public: bool = False
29
+
30
+ if os.getenv("RUNNING_ON_HF_SPACE"):
31
+ login(token=os.getenv("HF_TOKEN"))
32
+ hf_repo = "leemeng/stablelm-jp-alpha"
33
+
34
+ args = DefaultArgs()
35
+ args.hf_model_name_or_path = hf_repo
36
+ args.spm_model_path = hf_hub_download(repo_id=hf_repo, filename="sentencepiece.model")
37
+ else:
38
+ parser = argparse.ArgumentParser(description="")
39
+ parser.add_argument("--hf_model_name_or_path", type=str, required=True)
40
+ parser.add_argument("--spm_model_path", type=str, required=True)
41
+ parser.add_argument("--env", type=str, default="dev")
42
+ parser.add_argument("--port", type=int, default=7860)
43
+ parser.add_argument("--make_public", action='store_true')
44
+ args = parser.parse_args()
45
+
46
+ def load_model(
47
+ model_dir,
48
+ ):
49
+ config = GPTNeoXConfig.from_pretrained(model_dir)
50
+ config.is_decoder = True
51
+ model = GPTNeoXForCausalLM.from_pretrained(model_dir, config=config, torch_dtype=torch.bfloat16)
52
+ if torch.cuda.is_available():
53
+ model = model.to("cuda:0")
54
+ return model
55
+
56
+ logging.info("Loading model")
57
+ model = load_model(args.hf_model_name_or_path)
58
+ sp = spm.SentencePieceProcessor(model_file=args.spm_model_path)
59
+ logging.info("Finished loading model")
60
+
61
+ class SentencePieceStreamer(BaseStreamer):
62
+ def __init__(self, sp: spm.SentencePieceProcessor):
63
+ self.sp = sp
64
+ self.num_invoked = 0
65
+ self.prompt = ""
66
+ self.generated_text = ""
67
+ self.ended = False
68
+
69
+
70
+ def put(self, t: torch.Tensor):
71
+ d = t.dim()
72
+ if d == 1:
73
+ pass
74
+ elif d == 2:
75
+ t = t[0]
76
+ else:
77
+ raise NotImplementedError
78
+
79
+ t = [int(x) for x in t.numpy()]
80
+
81
+ text = self.sp.decode_ids(t)
82
+
83
+ if self.num_invoked == 0:
84
+ self.prompt = text
85
+ self.num_invoked += 1
86
+ return
87
+
88
+ self.generated_text += text
89
+ # print(f"[streamer]: {self.generated_text}")
90
+ # yield text
91
+
92
+ def end(self):
93
+ self.ended = True
94
+
95
+ def user(user_message, history):
96
+ logging.debug(f"[user] user_message: {user_message}")
97
+ logging.debug(f"[user] history: {history}")
98
+
99
+ res = ("", history + [[user_message, None]])
100
+ return res
101
+
102
+ def bot(
103
+ history,
104
+ temperature,
105
+ max_new_tokens,
106
+ ):
107
+ logging.debug(f"[bot] history: {history}")
108
+ logging.debug(f"temperature: {temperature}")
109
+
110
+ # TODO: modify `<br>` back to `\n` based on the original user prinpt
111
+ prompt = history[-1][0]
112
+
113
+ tokens = sp.encode(prompt)
114
+ input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(model.device)
115
+
116
+ # TODO: parametrize setting on UI
117
+ do_sample = True
118
+
119
+ streamer = SentencePieceStreamer(sp=sp)
120
+
121
+ max_possilbe_new_tokens = model.config.max_position_embeddings - len(tokens)
122
+ max_possilbe_new_tokens = min(max_possilbe_new_tokens, max_new_tokens)
123
+
124
+
125
+ thr = threading.Thread(target=model.generate, args=(), kwargs=dict(
126
+ input_ids=input_ids,
127
+ temperature=temperature,
128
+ max_new_tokens=max_possilbe_new_tokens,
129
+ do_sample=do_sample,
130
+ streamer=streamer,
131
+ # max_length=4096,
132
+ # top_k=100,
133
+ # top_p=0.9,
134
+ # repetition_penalty=1.0,
135
+ # num_return_sequences=2,
136
+ # num_beams=2,
137
+ ))
138
+ thr.start()
139
+
140
+ history[-1][1] = ""
141
+ while not streamer.ended:
142
+ history[-1][1] = streamer.generated_text
143
+ time.sleep(0.05)
144
+ yield history
145
+
146
+ # TODO: optimize for final few tokens
147
+ history[-1][1] = streamer.generated_text
148
+ yield history
149
+
150
+ if gr_interface:
151
+ gr_interface.close(verbose=False)
152
+
153
+ with gr.Blocks() as gr_interface:
154
+ chatbot = gr.Chatbot(label="StableLM JP Alpha").style(height=500)
155
+
156
+ # generation params
157
+ with gr.Row():
158
+ temperature = gr.Slider(0, 1, value=0.7, step=0.05, label="Temperature")
159
+ max_new_tokens = gr.Slider(
160
+ 128,
161
+ model.config.max_position_embeddings,
162
+ value=128, step=64, label="Max New Tokens")
163
+
164
+ # prompt
165
+ # TODO: add more options
166
+ # prompt_options = gr.Dropdown(
167
+ # choices=[
168
+ # "運が良かったのか悪かったのか日本に帰ってきたタイミングでコロナが猛威を振るい始め、",
169
+ # """[問題]に対する[答え]を[選択肢]の中から選んでください。
170
+
171
+ # [問題]: ある場所の周辺地域を指す言葉は?
172
+ # [選択肢]: [空, オレゴン州, 街, 歩道橋, 近辺]
173
+ # [答え]: 近辺
174
+
175
+ # [問題]: 若くて世間に慣れていないことを何という?
176
+ # [選択肢]: [青っぽい, 若い, ベテラン, 生々しい, 玄人]
177
+ # [答え]: """
178
+ # ],
179
+ # label="Prompt Options",
180
+ # info="Select 1 option for quick start",
181
+ # allow_custom_value=False,
182
+ # )
183
+ prompt = gr.Textbox(label="Prompt", info="Pro tip: press Enter to submit directly")
184
+
185
+
186
+ # def on_prompt_options_change(pmt_opts, pmt):
187
+ # return pmt_opts
188
+
189
+ # prompt_options.change(on_prompt_options_change, [prompt_options, prompt], prompt)
190
+
191
+ with gr.Row():
192
+ submit = gr.Button("Submit")
193
+ stop = gr.Button("Stop")
194
+
195
+ clear = gr.Button("Clear History")
196
+
197
+ # event handling
198
+ submit_event = prompt.submit(user, [prompt, chatbot], [prompt, chatbot], queue=False)\
199
+ .then(bot, [chatbot, temperature, max_new_tokens], chatbot, queue=True)
200
+
201
+ submit_click_event = submit.click(user, [prompt, chatbot], [prompt, chatbot], queue=False)\
202
+ .then(bot, [chatbot, temperature, max_new_tokens], chatbot, queue=True)
203
+
204
+ stop.click(None, None, None, cancels=[submit_event, submit_click_event], queue=False)
205
+ clear.click(lambda: None, None, chatbot, queue=False)
206
+
207
+ gr_interface.queue()
208
+ gr_interface.launch(server_port=args.port, share=args.make_public)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==3.28.1
2
+ gradio_client==0.1.4
3
+ torch==2.0.0
4
+ sentencepiece==0.1.97
5
+ transformers==4.28.1
6
+ huggingface-hub==0.14.1