Pedro Cuenca commited on
Commit
0e8338d
·
1 Parent(s): 7158e2e

Add a couple of sliders and prevent generating without a prompt

Browse files
Files changed (1) hide show
  1. app/app.py +31 -11
app/app.py CHANGED
@@ -26,7 +26,6 @@ from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
26
 
27
  import streamlit as st
28
 
29
- st.write("List GPU Device", jax.devices("gpu"))
30
  st.write("Loading model...")
31
 
32
  # TODO: set those args in a config file
@@ -82,8 +81,6 @@ model.config.forced_bos_token_id = None
82
  model.config.forced_eos_token_id = None
83
 
84
  vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
85
- st.write("VQModel")
86
- print("Initialize VqModel")
87
 
88
  def custom_to_pil(x):
89
  x = np.clip(x, 0., 1.)
@@ -137,11 +134,11 @@ vqgan_params = replicate(vqgan.params)
137
  from transformers import CLIPProcessor, FlaxCLIPModel
138
 
139
  clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
140
- st.write("FlaxCLIPModel")
141
- print("Initialize FlaxCLIPModel")
142
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
143
- st.write("CLIPProcessor")
144
- print("Initialize CLIPProcessor")
145
 
146
  def hallucinate(prompt, num_images=64):
147
  prompt = [prompt] * jax.device_count()
@@ -169,8 +166,31 @@ def clip_top_k(prompt, images, k=8):
169
  scores = np.array(logits[0]).argsort()[-k:][::-1]
170
  return [images[score] for score in scores]
171
 
172
- prompt = st.text_input("Input prompt", "rice fields by the mediterranean coast")
173
- st.write(f"Generating candidates for: {prompt}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
- images = hallucinate(prompt, num_images=1)
176
- st.image(images[0])
 
26
 
27
  import streamlit as st
28
 
 
29
  st.write("Loading model...")
30
 
31
  # TODO: set those args in a config file
 
81
  model.config.forced_eos_token_id = None
82
 
83
  vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
 
 
84
 
85
  def custom_to_pil(x):
86
  x = np.clip(x, 0., 1.)
 
134
  from transformers import CLIPProcessor, FlaxCLIPModel
135
 
136
  clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
137
+ # st.write("FlaxCLIPModel")
138
+ # print("Initialize FlaxCLIPModel")
139
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
140
+ # st.write("CLIPProcessor")
141
+ # print("Initialize CLIPProcessor")
142
 
143
  def hallucinate(prompt, num_images=64):
144
  prompt = [prompt] * jax.device_count()
 
166
  scores = np.array(logits[0]).argsort()[-k:][::-1]
167
  return [images[score] for score in scores]
168
 
169
+ def captioned_strip(images, caption):
170
+ increased_h = 0 if caption is None else 48
171
+ w, h = images[0].size[0], images[0].size[1]
172
+ img = Image.new("RGB", (len(images)*w, h + increased_h))
173
+ for i, img_ in enumerate(images):
174
+ img.paste(img_, (i*w, increased_h))
175
+
176
+ if caption is not None:
177
+ draw = ImageDraw.Draw(img)
178
+ font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
179
+ draw.text((20, 3), caption, (255,255,255), font=font)
180
+ return img
181
+
182
+ # Controls
183
+
184
+ num_images = st.sidebar.slider("Candidates to generate", 1, 64, 8, 1)
185
+ num_preds = st.sidebar.slider("Best predictions to show", 1, 8, 1, 1)
186
+
187
+
188
+ prompt = st.text_input("What do you want to see?")
189
+
190
+ if prompt != "":
191
+ st.write(f"Generating candidates for: {prompt}")
192
+ images = hallucinate(prompt, num_images=num_images)
193
+ images = clip_top_k(prompt, images, k=num_preds)
194
+ predictions_strip = captioned_strip(images, None)
195
 
196
+ st.image(predictions_strip)