gururise commited on
Commit
2812e92
·
1 Parent(s): a74cf5e
Files changed (1) hide show
  1. app.py +99 -121
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
- import threading
3
  import codecs
4
  from datetime import datetime
 
5
  from transformers import BloomTokenizerFast
6
  from petals.client import DistributedBloomForCausalLM
7
  import torch
@@ -11,63 +11,18 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  TORCH_DTYPE = torch.bfloat16
12
  MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]
13
 
14
- models = {MODEL_NAMES[0]:None,MODEL_NAMES[1]:None}
15
  output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""}
16
- kill = threading.Event()
17
 
18
- def stop_threads():
19
- global kill
20
- print("Force stopping threads")
21
- kill.set()
22
 
23
- def gen_thread(model_name, prompt, max_tokens, temperature, top_p, repetition_penalty, stop):
24
- global output
25
-
26
- if kill.is_set():
27
- return
28
-
29
- flag = False
30
- token_cnt = 0
31
- with models[model_name][1].inference_session(max_length=512) as sess:
32
- print(f"Thread Start -> {threading.get_ident()}")
33
- output[model_name] = ""
34
- inputs = models[model_name][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
35
- n_input_tokens = inputs.shape[1]
36
- done = False
37
- while not done and not kill.is_set():
38
- outputs = models[model_name][1].generate(
39
- inputs,
40
- max_new_tokens=1,
41
- do_sample=True,
42
- top_p=top_p,
43
- temperature=temperature,
44
- repetition_penalty=repetition_penalty,
45
- session=sess
46
- )
47
- output[model_name] += models[model_name][0].decode(outputs[0, n_input_tokens:])
48
- token_cnt += 1
49
- print("\n["+ str(threading.get_ident()) + "]" + output[model_name], end="", flush=True)
50
-
51
- for stop_word in stop:
52
- stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
53
- if stop_word != '' and stop_word in output[model_name]:
54
- print(f"\nDONE (stop) -> {threading.get_ident()}")
55
- done = True
56
- if flag or (token_cnt >= max_tokens):
57
- print(f"\nDONE (max tokens) -> {threading.get_ident()}")
58
- done = True
59
- inputs = None # Prefix is passed only for the 1st token of the bot's response
60
- n_input_tokens = 0
61
- print(f"\nThread End -> {threading.get_ident()}")
62
 
63
  def to_md(text):
64
  return text.replace("\n", "<br />")
65
 
66
- threads = list()
67
-
68
  def infer(
69
  prompt,
70
- model_idx = ["BLOOM","BLOOMZ"],
71
  max_new_tokens=10,
72
  temperature=0.1,
73
  top_p=1.0,
@@ -76,24 +31,22 @@ def infer(
76
  num_completions=1,
77
  seed=42,
78
  ):
79
- global threads
80
  global output
81
  global models
82
 
83
- if len(model_idx) == 0:
84
- return
85
-
86
- kill.clear()
87
  print("Loading Models\n")
88
- for idx in model_idx:
89
- model_name = MODEL_NAMES[idx]
90
- if models[model_name] == None:
91
- print ("Initializing " + model_name)
92
- tokenizer = BloomTokenizerFast.from_pretrained(model_name)
93
- model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE)
94
- model = model.to(DEVICE)
95
- models[model_name] = tokenizer, model
96
- output[model_name] = ""
 
 
 
97
 
98
  max_new_tokens = int(max_new_tokens)
99
  temperature = float(temperature)
@@ -115,71 +68,96 @@ def infer(
115
 
116
  print(f"START -> ({datetime.now()})\n")
117
  print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
118
- for idx in model_idx:
119
- model_name = MODEL_NAMES[idx]
120
- x = threading.Thread(target=gen_thread, args=(model_name, prompt, max_new_tokens, temperature, top_p, repetition_penalty, stop))
121
- threads.append(x)
122
- x.start()
123
 
124
- # Join Threads
125
- for model_name, thread in enumerate(threads):
126
- while thread.is_alive():
127
- thread.join(timeout=0.2)
128
- yield output[MODEL_NAMES[0]], output[MODEL_NAMES[1]]
129
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  examples = [
132
  [
133
  # Question Answering
134
  '''Please answer the following question:
135
  Question: What is the capital of Germany?
136
- Answer:''',["BLOOM","BLOOMZ"] , 3, 0.2, 1.0, 1.0, "\\n,</s>", ["BLOOM","BLOOMZ"]],
137
  [
138
- # Natural Language Interface
139
- '''Given a pair of sentences, choose whether the two sentences agree (entailment)/disagree (contradiction) with each other.
140
- Possible labels: 1. entailment 2. contradiction
141
- Sentence 1: The skier was on the edge of the ramp. Sentence 2: The skier was dressed in winter clothes.
142
- Label: entailment
143
- Sentence 1: The boy skated down the staircase railing. Sentence 2: The boy is a newbie skater.
144
- Label: contradiction
145
- Sentence 1: Two middle-aged people stand by a golf hole. Sentence 2: A couple riding in a golf cart.
146
- Label:''',["BLOOM","BLOOMZ"] , 2, 0.2, 1.0, 1.0, "\\n,</s>"]
 
 
 
 
 
 
 
 
 
 
 
 
147
  ]
148
 
149
- def clear_prompt():
150
- return "","",""
151
-
152
- with gr.Blocks() as demo:
153
- gr.Markdown("# <p style='text-align: center;'>BLOOM vs BLOOMZ Comparison</p>")
154
- gr.Markdown("")
155
- gr.Markdown("Test Inference on the [BLOOM](https://huggingface.co/bigscience/bloom) and [BLOOMZ](https://huggingface.co/bigscience/bloomz) 176 Billion Parameter models using Petals. \
156
- Please consider contributing your unused GPU cycles to the [Petals Swarm](https://github.com/bigscience-workshop/petals) to speed up inference. <br />\n \
157
- Due to heavy resource requirements of these large models, token generation can take upwards of 3-5 seconds per token. Try to keep Max Tokens to a minimum.")
158
- gr.Markdown("")
159
- gr.Markdown("Special thanks to [RFT Capital](https://www.rftcapital.com/) for supporting our experiments with compute time dontations.")
160
- gr.Markdown("Type a Prompt and then click **Run** to see the output.")
161
- with gr.Row():
162
- with gr.Column():
163
- prompt = gr.Textbox(lines=17,label="Prompt",placeholder="Enter Prompt", interactive=True)
164
- with gr.Box():
165
- chk_boxes = gr.CheckboxGroup(choices=["BLOOM","BLOOMZ"],value=["BLOOM","BLOOMZ"], type="index", label="Model")
166
- #min_length = gr.Slider(minimum=0, maximum=256, value=1, label="Minimum Length") #min_length
167
- max_tokens = gr.Slider(minimum=1, maximum=256, value=15, label="Max Tokens") # max_tokens
168
- temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.2, label="Temperature") # temperature
169
- top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.9, label="Top P") # top_p
170
- rep_penalty = gr.Slider(minimum=0.9, maximum=3.0, step=0.1, value=1.0, label="Repetition Penalty") # repetition penalty
171
- stop = gr.Textbox(lines=1, value="\\n,</s>", label="Stop Token") # stop
172
- with gr.Column():
173
- bloom_out = gr.Textbox(lines=7, label="BLOOM OUTPUT:")
174
- bloomz_out = gr.Textbox(lines=7,label="BLOOMZ OUTPUT:")
175
- with gr.Row():
176
- btn_clear = gr.Button("Clear", variant="secondary")
177
- btn_run = gr.Button("Run", variant="primary")
178
- btn_stop = gr.Button("Stop", variant="stop")
179
- click_run = btn_run.click(infer, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop], outputs=[bloom_out,bloomz_out])
180
- btn_clear.click(clear_prompt, outputs=[prompt, bloom_out, bloomz_out])
181
- btn_stop.click(stop_threads,cancels=click_run)
182
- gr.Examples(examples, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop])
183
-
184
- demo.queue(concurrency_count=1)
185
- demo.launch()
 
1
  import gradio as gr
 
2
  import codecs
3
  from datetime import datetime
4
+ import gc
5
  from transformers import BloomTokenizerFast
6
  from petals.client import DistributedBloomForCausalLM
7
  import torch
 
11
  TORCH_DTYPE = torch.bfloat16
12
  MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]
13
 
14
+ models = {"model":None,"model_name":None}
15
  output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""}
 
16
 
 
 
 
 
17
 
18
+ print (DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def to_md(text):
21
  return text.replace("\n", "<br />")
22
 
 
 
23
  def infer(
24
  prompt,
25
+ model_idx = 0,
26
  max_new_tokens=10,
27
  temperature=0.1,
28
  top_p=1.0,
 
31
  num_completions=1,
32
  seed=42,
33
  ):
 
34
  global output
35
  global models
36
 
 
 
 
 
37
  print("Loading Models\n")
38
+ model_name = MODEL_NAMES[model_idx]
39
+ if (models["model_name"] == None or models["model_name"] != model_name):
40
+ models = {"model":None,"model_name":None}
41
+ gc.collect()
42
+ if (DEVICE == "cuda"):
43
+ torch.cuda.empty_cache()
44
+ tokenizer = BloomTokenizerFast.from_pretrained(model_name)
45
+ model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE, request_timeout=300)
46
+ model = model.to(DEVICE)
47
+ models["model"] = tokenizer, model
48
+ models["model_name"] = model_name
49
+ output[model_name] = ""
50
 
51
  max_new_tokens = int(max_new_tokens)
52
  temperature = float(temperature)
 
68
 
69
  print(f"START -> ({datetime.now()})\n")
70
  print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
 
 
 
 
 
71
 
72
+ flag = False
73
+ token_cnt = 0
74
+ with models["model"][1].inference_session(max_length=512) as sess:
75
+ print(f"Encode Input Prompt")
76
+ output[model_name] = ""
77
+ inputs = models["model"][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
78
+ n_input_tokens = inputs.shape[1]
79
+ done = False
80
+ print(f"Start Inference ({sess})")
81
+ while not done:
82
+ outputs = models["model"][1].generate(
83
+ inputs,
84
+ max_new_tokens=1,
85
+ do_sample=True,
86
+ top_p=top_p,
87
+ temperature=temperature,
88
+ repetition_penalty=repetition_penalty,
89
+ session=sess
90
+ )
91
+ output[model_name] += models["model"][0].decode(outputs[0, n_input_tokens:])
92
+ token_cnt += 1
93
+ print("\n["+ str(model_name) + "]" + output[model_name], end="", flush=True)
94
+ yield output[model_name]
95
+ for stop_word in stop:
96
+ stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
97
+ if stop_word != '' and stop_word in output[model_name]:
98
+ print(f"\nDONE (stop)")
99
+ done = True
100
+ if flag or (token_cnt >= max_new_tokens):
101
+ print(f"\nDONE (max tokens)")
102
+ done = True
103
+ inputs = None # Prefix is passed only for the 1st token of the bot's response
104
+ n_input_tokens = 0
105
+ print(f"\nEnd")
106
+ yield output[model_name]
107
 
108
  examples = [
109
  [
110
  # Question Answering
111
  '''Please answer the following question:
112
  Question: What is the capital of Germany?
113
+ Answer:''',"BLOOMZ" , 3, 0.2, 1.0, 1.0, "\\n,</s>", ["BLOOM","BLOOMZ"]],
114
  [
115
+ # Chatbot 1
116
+ '''This is a conversation between Alex (an AI based on the 2020 GPT-3 language model), and Fritz (an AI based on the 2021 Jurassic-1 language model). They are exploring each other's capabilities, and trying to ask interesting, complex, and 'ungoogleable' questions of one another, to test the limits of the AI...
117
+ Alex: Good morning, Fritz!
118
+ Fritz:''',"BLOOM" , 160, 0.85, 0.9, 1.0, "\\n\\n,</s>"],
119
+ [
120
+ # Chatbot 1
121
+ '''This is a conversation between Alex (an AI based on the 2020 GPT-3 language model), and Fritz (an AI based on the 2021 Jurassic-1 language model). They are exploring each other's capabilities, and trying to ask interesting, complex, and 'ungoogleable' questions of one another, to test the limits of the AI...
122
+ Alex: Good morning, Fritz!
123
+ Fritz:''',"BLOOMZ" , 160, 0.85, 0.9, 1.0, "\\n\\n,</s>"],
124
+ [
125
+ # Expert Answers
126
+ '''Expert Questions & Helpful Answers
127
+ Ask Research Experts
128
+ Question:
129
+ Are humans good or bad?
130
+
131
+ Full Answer:''',"BLOOM" , 120, 0.85, 0.9, 1.0, "</s>"],
132
+ [
133
+ # G
134
+ '''You are the writing assistant for Stephen King. You have worked in the fiction/horror genre for 30 years. You are a Pulitzer Prize-winning author, and now you are tasked with developing a skeletal outline for his newest novel, set to be completed in the spring of 2024. Create a title and brief description for the first 5 chapters of this work.\n\nTitle:''',"BLOOM" , 120, 0.85, 0.9, 1.0, "</s>"
135
+ ]
136
  ]
137
 
138
+
139
+
140
+ iface = gr.Interface(
141
+ fn=infer,
142
+ allow_flagging="never",
143
+ inputs=[
144
+ gr.Textbox(lines=20,label="Input Prompt", max_lines=10), # prompt
145
+ gr.Radio(["BLOOM","BLOOMZ"], value="BLOOM", type="index", label="Choose 176 billion parameter Model"),
146
+ gr.Slider(1, 256, value=15), # max_tokens
147
+ gr.Slider(0.0, 1.0, value=0.2), # temperature
148
+ gr.Slider(0.0, 1.0, value=0.9), # top_p
149
+ gr.Slider(0.9, 3.0, value=1.0), # repetition penalty
150
+ gr.Textbox(lines=1, value="\\n\\n,</s>") # stop
151
+ ],
152
+ outputs=gr.Textbox(lines=20, label="Generated Output:"),
153
+
154
+ examples=examples,
155
+ #cache_examples=True,
156
+ title="BLOOM vs BLOOMZ",
157
+ description='''<p>Compare outputs of the BLOOM and BLOOMZ 176 billion parameter models using the Petals network. <b>WARNING:</b> Initial inference may take a long time. Keep the input prompt to a minimum size to speed things up.<p>
158
+ <p>Please consider contributing your unused GPU cycles to the <a href='https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity'>petals swarm</a> to help speed up inference. Check the <a href='http://health.petals.ml/'>Health</a> of the Petals Swarm.</p>
159
+ <p>Big thanks to <a href='https://www.rftcapital.com/'>RFT Capital</a> for providing initial compute resources.</p>'''
160
+ )
161
+
162
+ iface.queue(concurrency_count=2)
163
+ iface.launch()