gururise commited on
Commit
26fd787
·
1 Parent(s): 53c517b

update gradio interface for iterative outputs

Browse files
Files changed (1) hide show
  1. app.py +99 -85
app.py CHANGED
@@ -1,46 +1,73 @@
1
  import gradio as gr
2
  import threading
3
  import codecs
4
- #from ast import literal_eval
5
  from datetime import datetime
6
-
7
- import os
8
- #os.environ['TRANSFORMERS_CACHE'] = '/data/.modelcache/huggingface/hub/'
9
- #os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:516"
10
-
11
  from transformers import BloomTokenizerFast
12
  from petals.client import DistributedBloomForCausalLM
13
  import torch
14
- import gc
15
 
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
  TORCH_DTYPE = torch.bfloat16
18
  MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]
19
 
20
- models = {}
21
- output = {}
 
22
 
 
 
 
 
23
 
24
- def gen_thread(model_name, inputs, max_new_tokens, min_length, temperature, top_p, repetition_penalty):
25
  global output
26
- n_input_tokens = inputs.shape[1]
27
- outputs = models[model_name][1].generate(inputs,
28
- max_new_tokens=max_new_tokens,
29
- min_length=min_length,
30
- do_sample=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- repetition_penalty=repetition_penalty
34
- )
35
- output[model_name] = models[model_name][0].decode(outputs[0, n_input_tokens:])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def to_md(text):
38
- # return text.replace("\n", "<br />")
39
  return text.replace("\n", "<br />")
40
 
 
 
41
  def infer(
42
  prompt,
43
- min_length=2,
44
  max_new_tokens=10,
45
  temperature=0.1,
46
  top_p=1.0,
@@ -49,27 +76,32 @@ def infer(
49
  num_completions=1,
50
  seed=42,
51
  ):
 
 
 
52
 
53
- #gc.collect()
54
- #torch.cuda.empty_cache()
55
 
56
- if not models:
57
- for model_name in MODEL_NAMES:
 
 
 
58
  tokenizer = BloomTokenizerFast.from_pretrained(model_name)
59
  model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE)
60
  model = model.to(DEVICE)
61
  models[model_name] = tokenizer, model
 
62
 
63
  max_new_tokens = int(max_new_tokens)
64
- num_completions = int(num_completions)
65
  temperature = float(temperature)
66
  top_p = float(top_p)
67
- stop = stop.split(";")
68
  repetition_penalty = float(repetition_penalty)
69
  seed = seed
70
 
71
  assert 1 <= max_new_tokens <= 384
72
- assert 0 <= min_length <= max_new_tokens
73
  assert 1 <= num_completions <= 5
74
  assert 0.0 <= temperature <= 1.0
75
  assert 0.0 <= top_p <= 1.0
@@ -80,45 +112,19 @@ def infer(
80
  if prompt == "":
81
  prompt = " "
82
 
83
- threads = list()
84
  print(f"START -> ({datetime.now()})\n")
85
  print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
86
- for model_name in MODEL_NAMES:
87
- inputs = models[model_name][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
88
- x = threading.Thread(target=gen_thread, args=(model_name, inputs, max_new_tokens, min_length, temperature, top_p, repetition_penalty))
89
  threads.append(x)
90
  x.start()
91
- #n_input_tokens = inputs.shape[1]
92
- # outputs = models[model_name][1].generate(inputs,
93
- # max_new_tokens=max_new_tokens,
94
- # min_length=min_length,
95
- # do_sample=True,
96
- # temperature=temperature,
97
- # top_p=top_p,
98
- # repetition_penalty=repetition_penalty
99
- # )
100
- #output[model_name] = models[model_name][0].decode(outputs[0, n_input_tokens:])
101
-
102
- #output[model_name] = outputs[len(prompt):]
103
 
104
  # Join Threads
105
  for model_name, thread in enumerate(threads):
106
- print(f"waiting on: {model_name}\n")
107
- thread.join()
108
- print(f"{model_name} thread done\n")
109
-
110
-
111
- for model_name in MODEL_NAMES:
112
- stop = codecs.getdecoder("unicode_escape")(stop[0])[0]
113
- stop = [x.strip(' ') for x in stop.split(',')]
114
- for stop_word in stop:
115
- if stop_word != '' and stop_word in output[model_name]:
116
- output[model_name] = output[model_name][:output[model_name].find(stop_word)]
117
-
118
- print(f"--- START: {model_name} --- \n{output[model_name]}\n--- END {model_name} ---\n\n")
119
-
120
- print(f"DONE -> ({datetime.now()})\n")
121
- return output[MODEL_NAMES[0]], output[MODEL_NAMES[1]]
122
 
123
 
124
  examples = [
@@ -126,7 +132,7 @@ examples = [
126
  # Question Answering
127
  '''Please answer the following question:
128
  Question: What is the capital of Germany?
129
- Answer:''', 1, 3, 0.2, 1.0, 1.0, "\\n,</s>"],
130
  [
131
  # Natural Language Interface
132
  '''Given a pair of sentences, choose whether the two sentences agree (entailment)/disagree (contradiction) with each other.
@@ -136,28 +142,36 @@ Label: entailment
136
  Sentence 1: The boy skated down the staircase railing. Sentence 2: The boy is a newbie skater.
137
  Label: contradiction
138
  Sentence 1: Two middle-aged people stand by a golf hole. Sentence 2: A couple riding in a golf cart.
139
- Label:''', 1, 2, 0.2, 1.0, 1.0, "\\n,</s>"]
140
  ]
141
 
142
-
143
-
144
- iface = gr.Interface(
145
- fn=infer,
146
- allow_flagging="never",
147
- inputs=[
148
- gr.Textbox(lines=20), # prompt
149
- gr.Slider(0, 256, value=1), #min_length
150
- gr.Slider(1, 384, value=20), # max_tokens
151
- gr.Slider(0.0, 1.0, value=0.2), # temperature
152
- gr.Slider(0.0, 1.0, value=0.9), # top_p
153
- gr.Slider(0.9, 3.0, value=1.0), # repetition penalty
154
- gr.Textbox(lines=1, value="\\n,</s>") # stop
155
- ],
156
- outputs=[gr.Textbox(lines=7, label="BLOOM OUTPUT:"), gr.Textbox(lines=7,label="BLOOMZ OUTPUT:")],
157
-
158
- examples=examples,
159
- cache_examples=True,
160
- title="BLOOM vs BLOOMZ",
161
- description='''<p>Compare outputs of the BLOOM and BLOOMZ 176 billion parameter models using the [Petals](https://petals.ml/) network. WARNING: Inference may take a long time. Keep the max_tokens low to speed things up.<p>
162
- <p>Please consider [joining](https://github.com/bigscience-workshop/petals) the Petals network to help speed up inference.</p><p>Big thanks to [RFTCapital](https://www.rftcapital.com) for providing initial compute resources.</p>'''
163
- ).launch()
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import threading
3
  import codecs
 
4
  from datetime import datetime
 
 
 
 
 
5
  from transformers import BloomTokenizerFast
6
  from petals.client import DistributedBloomForCausalLM
7
  import torch
8
+ import time
9
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  TORCH_DTYPE = torch.bfloat16
12
  MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]
13
 
14
+ models = {MODEL_NAMES[0]:None,MODEL_NAMES[1]:None}
15
+ output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""}
16
+ kill = threading.Event()
17
 
18
+ def stop_threads():
19
+ global kill
20
+ print("Force stopping threads")
21
+ kill.set()
22
 
23
+ def gen_thread(model_name, prompt, max_tokens, temperature, top_p, repetition_penalty, stop):
24
  global output
25
+
26
+ if kill.is_set():
27
+ return
28
+
29
+ flag = False
30
+ token_cnt = 0
31
+ with models[model_name][1].inference_session(max_length=512) as sess:
32
+ print(f"Thread Start -> {threading.get_ident()}")
33
+ output[model_name] = ""
34
+ inputs = models[model_name][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
35
+ n_input_tokens = inputs.shape[1]
36
+ done = False
37
+ while not done and not kill.is_set():
38
+ outputs = models[model_name][1].generate(
39
+ inputs,
40
+ max_new_tokens=1,
41
+ do_sample=True,
42
+ top_p=top_p,
43
+ temperature=temperature,
44
+ repetition_penalty=repetition_penalty,
45
+ session=sess
46
+ )
47
+ output[model_name] += models[model_name][0].decode(outputs[0, n_input_tokens:])
48
+ token_cnt += 1
49
+ print("\n["+ str(threading.get_ident()) + "]" + output[model_name], end="", flush=True)
50
+
51
+ for stop_word in stop:
52
+ stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
53
+ if stop_word != '' and stop_word in output[model_name]:
54
+ print(f"\nDONE (stop) -> {threading.get_ident()}")
55
+ done = True
56
+ if flag or (token_cnt >= max_tokens):
57
+ print(f"\nDONE (max tokens) -> {threading.get_ident()}")
58
+ done = True
59
+ inputs = None # Prefix is passed only for the 1st token of the bot's response
60
+ n_input_tokens = 0
61
+ print(f"\nThread End -> {threading.get_ident()}")
62
 
63
  def to_md(text):
 
64
  return text.replace("\n", "<br />")
65
 
66
+ threads = list()
67
+
68
  def infer(
69
  prompt,
70
+ model_idx = ["BLOOM","BLOOMZ"],
71
  max_new_tokens=10,
72
  temperature=0.1,
73
  top_p=1.0,
 
76
  num_completions=1,
77
  seed=42,
78
  ):
79
+ global threads
80
+ global output
81
+ global models
82
 
83
+ if len(model_idx) == 0:
84
+ return
85
 
86
+ kill.clear()
87
+ print("Loading Models\n")
88
+ for idx in model_idx:
89
+ model_name = MODEL_NAMES[idx]
90
+ if models[model_name] == None:
91
  tokenizer = BloomTokenizerFast.from_pretrained(model_name)
92
  model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE)
93
  model = model.to(DEVICE)
94
  models[model_name] = tokenizer, model
95
+ output[model_name] = ""
96
 
97
  max_new_tokens = int(max_new_tokens)
 
98
  temperature = float(temperature)
99
  top_p = float(top_p)
100
+ stop = [x.strip(' ') for x in stop.split(',')]
101
  repetition_penalty = float(repetition_penalty)
102
  seed = seed
103
 
104
  assert 1 <= max_new_tokens <= 384
 
105
  assert 1 <= num_completions <= 5
106
  assert 0.0 <= temperature <= 1.0
107
  assert 0.0 <= top_p <= 1.0
 
112
  if prompt == "":
113
  prompt = " "
114
 
 
115
  print(f"START -> ({datetime.now()})\n")
116
  print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
117
+ for idx in model_idx:
118
+ model_name = MODEL_NAMES[idx]
119
+ x = threading.Thread(target=gen_thread, args=(model_name, prompt, max_new_tokens, temperature, top_p, repetition_penalty, stop))
120
  threads.append(x)
121
  x.start()
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  # Join Threads
124
  for model_name, thread in enumerate(threads):
125
+ while thread.is_alive():
126
+ thread.join(timeout=0.2)
127
+ yield output[MODEL_NAMES[0]], output[MODEL_NAMES[1]]
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
 
130
  examples = [
 
132
  # Question Answering
133
  '''Please answer the following question:
134
  Question: What is the capital of Germany?
135
+ Answer:''',["BLOOM","BLOOMZ"] , 3, 0.2, 1.0, 1.0, "\\n,</s>", ["BLOOM","BLOOMZ"]],
136
  [
137
  # Natural Language Interface
138
  '''Given a pair of sentences, choose whether the two sentences agree (entailment)/disagree (contradiction) with each other.
 
142
  Sentence 1: The boy skated down the staircase railing. Sentence 2: The boy is a newbie skater.
143
  Label: contradiction
144
  Sentence 1: Two middle-aged people stand by a golf hole. Sentence 2: A couple riding in a golf cart.
145
+ Label:''',["BLOOM","BLOOMZ"] , 2, 0.2, 1.0, 1.0, "\\n,</s>"]
146
  ]
147
 
148
+ def clear_prompt():
149
+ return "","",""
150
+
151
+ with gr.Blocks() as demo:
152
+ gr.Markdown("Start typing below and then click **Run** to see the output.")
153
+ with gr.Row():
154
+ with gr.Column():
155
+ prompt = gr.Textbox(lines=17,label="Prompt",placeholder="Enter Prompt", interactive=True)
156
+ with gr.Box():
157
+ chk_boxes = gr.CheckboxGroup(choices=["BLOOM","BLOOMZ"],value=["BLOOM","BLOOMZ"], type="index", label="Model")
158
+ #min_length = gr.Slider(minimum=0, maximum=256, value=1, label="Minimum Length") #min_length
159
+ max_tokens = gr.Slider(minimum=1, maximum=256, value=15, label="Max Tokens") # max_tokens
160
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.2, label="Temperature") # temperature
161
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.9, label="Top P") # top_p
162
+ rep_penalty = gr.Slider(minimum=0.9, maximum=3.0, step=0.1, value=1.0, label="Repetition Penalty") # repetition penalty
163
+ stop = gr.Textbox(lines=1, value="\\n,</s>", label="Stop Token") # stop
164
+ with gr.Column():
165
+ bloom_out = gr.Textbox(lines=7, label="BLOOM OUTPUT:")
166
+ bloomz_out = gr.Textbox(lines=7,label="BLOOMZ OUTPUT:")
167
+ with gr.Row():
168
+ btn_clear = gr.Button("Clear", variant="secondary")
169
+ btn_run = gr.Button("Run", variant="primary")
170
+ btn_stop = gr.Button("Stop", variant="stop")
171
+ click_run = btn_run.click(infer, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop], outputs=[bloom_out,bloomz_out])
172
+ btn_clear.click(clear_prompt, outputs=[prompt, bloom_out, bloomz_out])
173
+ btn_stop.click(stop_threads,cancels=click_run)
174
+ gr.Examples(examples, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop])
175
+
176
+ demo.queue(concurrency_count=3)
177
+ demo.launch()