pcuenca commited on
Commit
686197a
·
unverified ·
2 Parent(s): 18f5a29 6783773

Merge pull request #55 from abidlabs/main

Browse files
app/{app_gradio.py → gradio/app_gradio.py} RENAMED
@@ -18,12 +18,16 @@ from PIL import Image
18
  import numpy as np
19
  import matplotlib.pyplot as plt
20
 
21
-
22
  from vqgan_jax.modeling_flax_vqgan import VQModel
23
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
24
 
 
 
 
25
  import gradio as gr
26
 
 
 
27
 
28
  DALLE_REPO = 'flax-community/dalle-mini'
29
  DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'
@@ -58,34 +62,12 @@ def generate(input, rng, params):
58
  def get_images(indices, params):
59
  return vqgan.decode_code(indices, params=params)
60
 
61
- def plot_images(images):
62
- fig = plt.figure(figsize=(40, 20))
63
- columns = 4
64
- rows = 2
65
- plt.subplots_adjust(hspace=0, wspace=0)
66
-
67
- for i in range(1, columns*rows +1):
68
- fig.add_subplot(rows, columns, i)
69
- plt.imshow(images[i-1])
70
- plt.gca().axes.get_yaxis().set_visible(False)
71
- plt.show()
72
-
73
- def stack_reconstructions(images):
74
- w, h = images[0].size[0], images[0].size[1]
75
- img = Image.new("RGB", (len(images)*w, h))
76
- for i, img_ in enumerate(images):
77
- img.paste(img_, (i*w,0))
78
- return img
79
-
80
  p_generate = jax.pmap(generate, "batch")
81
  p_get_images = jax.pmap(get_images, "batch")
82
 
83
  bart_params = replicate(model.params)
84
  vqgan_params = replicate(vqgan.params)
85
 
86
- # ## CLIP Scoring
87
- from transformers import CLIPProcessor, FlaxCLIPModel
88
-
89
  clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
90
  print("Initialize FlaxCLIPModel")
91
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
@@ -137,48 +119,30 @@ def top_k_predictions(prompt, num_candidates=32, k=8):
137
 
138
  def run_inference(prompt, num_images=32, num_preds=8):
139
  images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
140
- predictions = compose_predictions(images)
141
  output_title = f"""
142
- <p style="font-size:22px; font-style:bold">Best predictions</p>
143
- <p>We asked our model to generate 32 candidates for your prompt:</p>
144
-
145
- <pre>
146
-
147
  <b>{prompt}</b>
148
- </pre>
149
- <p>We then used a pre-trained <a href="https://huggingface.co/openai/clip-vit-base-patch32">CLIP model</a> to score them according to the
150
- similarity of the text and the image representations.</p>
151
-
152
- <p>This is the result:</p>
153
  """
154
- output_description = """
155
- <p>Read more about the process <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA">in our report</a>.<p>
156
- <p style='text-align: center'>Created with <a href="https://github.com/borisdayma/dalle-mini">DALLE·mini</a></p>
157
- """
158
- return (output_title, predictions, output_description)
159
 
160
  outputs = [
161
  gr.outputs.HTML(label=""), # To be used as title
162
  gr.outputs.Image(label=''),
163
- gr.outputs.HTML(label=""), # Additional text that appears in the screenshot
164
  ]
165
 
166
  description = """
167
- Welcome to our demo of DALL·E-mini. This project was created on TPU v3-8s during the 🤗 Flax / JAX Community Week.
168
- It reproduces the essential characteristics of OpenAI's DALL·E, at a fraction of the size.
169
-
170
- Please, write what you would like the model to generate, or select one of the examples below.
171
  """
172
  gr.Interface(run_inference,
173
- inputs=[gr.inputs.Textbox(label='Prompt')], #, gr.inputs.Slider(1,64,1,8, label='Candidates to generate'), gr.inputs.Slider(1,8,1,1, label='Best predictions to show')],
174
  outputs=outputs,
175
  title='DALL·E mini',
176
  description=description,
177
- article="<p style='text-align: center'> DALLE·mini by Boris Dayma et al. | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a></p>",
178
  layout='vertical',
179
  theme='huggingface',
180
  examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
181
  allow_flagging=False,
182
  live=False,
183
  # server_port=8999
184
- ).launch()
 
18
  import numpy as np
19
  import matplotlib.pyplot as plt
20
 
 
21
  from vqgan_jax.modeling_flax_vqgan import VQModel
22
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
23
 
24
+ # ## CLIP Scoring
25
+ from transformers import CLIPProcessor, FlaxCLIPModel
26
+
27
  import gradio as gr
28
 
29
+ from dalle_mini.helpers import captioned_strip
30
+
31
 
32
  DALLE_REPO = 'flax-community/dalle-mini'
33
  DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'
 
62
  def get_images(indices, params):
63
  return vqgan.decode_code(indices, params=params)
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  p_generate = jax.pmap(generate, "batch")
66
  p_get_images = jax.pmap(get_images, "batch")
67
 
68
  bart_params = replicate(model.params)
69
  vqgan_params = replicate(vqgan.params)
70
 
 
 
 
71
  clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
72
  print("Initialize FlaxCLIPModel")
73
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
119
 
120
  def run_inference(prompt, num_images=32, num_preds=8):
121
  images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
122
+ predictions = captioned_strip(images)
123
  output_title = f"""
 
 
 
 
 
124
  <b>{prompt}</b>
 
 
 
 
 
125
  """
126
+ return (output_title, predictions)
 
 
 
 
127
 
128
  outputs = [
129
  gr.outputs.HTML(label=""), # To be used as title
130
  gr.outputs.Image(label=''),
 
131
  ]
132
 
133
  description = """
134
+ DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
 
 
 
135
  """
136
  gr.Interface(run_inference,
137
+ inputs=[gr.inputs.Textbox(label='What do you want to see?')],
138
  outputs=outputs,
139
  title='DALL·E mini',
140
  description=description,
141
+ article="<p style='text-align: center'> Created by Boris Dayma et al. 2021 | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a> | <a href='https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA'>Report</a></p>",
142
  layout='vertical',
143
  theme='huggingface',
144
  examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
145
  allow_flagging=False,
146
  live=False,
147
  # server_port=8999
148
+ ).launch(share=True)
app/{app_gradio_ngrok.py → gradio/app_gradio_ngrok.py} RENAMED
@@ -7,25 +7,15 @@ import numpy as np
7
  import matplotlib.pyplot as plt
8
  from io import BytesIO
9
  import base64
 
10
 
11
  import gradio as gr
12
 
13
- # If we use streamlit, this would be exported as a streamlit secret
14
- import os
15
- backend_url = os.environ["BACKEND_SERVER"]
16
 
17
- def compose_predictions(images, caption=None):
18
- increased_h = 0 if caption is None else 48
19
- w, h = images[0].size[0], images[0].size[1]
20
- img = Image.new("RGB", (len(images)*w, h + increased_h))
21
- for i, img_ in enumerate(images):
22
- img.paste(img_, (i*w, increased_h))
23
 
24
- if caption is not None:
25
- draw = ImageDraw.Draw(img)
26
- font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
27
- draw.text((20, 3), caption, (255,255,255), font=font)
28
- return img
29
 
30
  class ServiceError(Exception):
31
  def __init__(self, status_code):
@@ -46,7 +36,7 @@ def get_images_from_ngrok(prompt):
46
  def run_inference(prompt):
47
  try:
48
  images = get_images_from_ngrok(prompt)
49
- predictions = compose_predictions(images)
50
  output_title = f"""
51
  <p style="font-size:22px; font-style:bold">Best predictions</p>
52
  <p>We asked our model to generate 128 candidates for your prompt:</p>
 
7
  import matplotlib.pyplot as plt
8
  from io import BytesIO
9
  import base64
10
+ import os
11
 
12
  import gradio as gr
13
 
14
+ from dalle_mini.helpers import captioned_strip
 
 
15
 
 
 
 
 
 
 
16
 
17
+ backend_url = os.environ["BACKEND_SERVER"]
18
+
 
 
 
19
 
20
  class ServiceError(Exception):
21
  def __init__(self, status_code):
 
36
  def run_inference(prompt):
37
  try:
38
  images = get_images_from_ngrok(prompt)
39
+ predictions = captioned_strip(images)
40
  output_title = f"""
41
  <p style="font-size:22px; font-style:bold">Best predictions</p>
42
  <p>We asked our model to generate 128 candidates for your prompt:</p>
app/gradio/dalle_mini ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../dalle_mini/
app/gradio/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Requirements for huggingface spaces
2
+ gradio>=2.2.3
3
+ flax
4
+ transformers
app/sample_images/image_0.jpg DELETED
Binary file (9.02 kB)
 
app/sample_images/image_1.jpg DELETED
Binary file (9.71 kB)
 
app/sample_images/image_2.jpg DELETED
Binary file (14.1 kB)
 
app/sample_images/image_3.jpg DELETED
Binary file (9.38 kB)
 
app/sample_images/image_4.jpg DELETED
Binary file (9.97 kB)
 
app/sample_images/image_5.jpg DELETED
Binary file (15.3 kB)
 
app/sample_images/image_6.jpg DELETED
Binary file (11.1 kB)
 
app/sample_images/image_7.jpg DELETED
Binary file (8.55 kB)
 
app/sample_images/readme.txt DELETED
@@ -1 +0,0 @@
1
- These images were generated by one of our checkpoints, as responses to the prompt "snowy mountains by the sea".
 
 
app/ui_gradio.py DELETED
@@ -1,91 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- from PIL import Image
5
- import gradio as gr
6
-
7
- def compose_predictions(images, caption=None):
8
- increased_h = 0 if caption is None else 48
9
- w, h = images[0].size[0], images[0].size[1]
10
- img = Image.new("RGB", (len(images)*w, h + increased_h))
11
- for i, img_ in enumerate(images):
12
- img.paste(img_, (i*w, increased_h))
13
-
14
- if caption is not None:
15
- draw = ImageDraw.Draw(img)
16
- font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
17
- draw.text((20, 3), caption, (255,255,255), font=font)
18
- return img
19
-
20
- def compose_predictions_grid(images):
21
- cols = 4
22
- rows = len(images) // cols
23
- w, h = images[0].size[0], images[0].size[1]
24
- img = Image.new("RGB", (w * cols, h * rows))
25
- for i, img_ in enumerate(images):
26
- row = i // cols
27
- col = i % cols
28
- img.paste(img_, (w * col, h * row))
29
- return img
30
-
31
- def top_k_predictions_real(prompt, num_candidates=32, k=8):
32
- images = hallucinate(prompt, num_images=num_candidates)
33
- images = clip_top_k(prompt, images, k=num_preds)
34
- return images
35
-
36
- def top_k_predictions(prompt, num_candidates=32, k=8):
37
- images = []
38
- for i in range(k):
39
- image = Image.open(f"sample_images/image_{i}.jpg")
40
- images.append(image)
41
- return images
42
-
43
- def run_inference(prompt, num_images=32, num_preds=8):
44
- images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
45
- predictions = compose_predictions(images)
46
- output_title = f"""
47
- <p style="font-size:22px; font-style:bold">Best predictions</p>
48
- <p>We asked our model to generate 32 candidates for your prompt:</p>
49
-
50
- <pre>
51
-
52
- <b>{prompt}</b>
53
- </pre>
54
- <p>We then used a pre-trained <a href="https://huggingface.co/openai/clip-vit-base-patch32">CLIP model</a> to score them according to the
55
- similarity of the text and the image representations.</p>
56
-
57
- <p>This is the result:</p>
58
- """
59
- output_description = """
60
- <p>Read more about the process <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA">in our report</a>.<p>
61
- <p style='text-align: center'>Created with <a href="https://github.com/borisdayma/dalle-mini">DALLE·mini</a></p>
62
- """
63
- return (output_title, predictions, output_description)
64
-
65
- outputs = [
66
- gr.outputs.HTML(label=""), # To be used as title
67
- gr.outputs.Image(label=''),
68
- gr.outputs.HTML(label=""), # Additional text that appears in the screenshot
69
- ]
70
-
71
- description = """
72
- Welcome to our demo of DALL·E-mini. This project was created on TPU v3-8s during the 🤗 Flax / JAX Community Week.
73
- It reproduces the essential characteristics of OpenAI's DALL·E, at a fraction of the size.
74
-
75
- Please, write what you would like the model to generate, or select one of the examples below.
76
- """
77
- gr.Interface(run_inference,
78
- inputs=[gr.inputs.Textbox(label='Prompt')], #, gr.inputs.Slider(1,64,1,8, label='Candidates to generate'), gr.inputs.Slider(1,8,1,1, label='Best predictions to show')],
79
- outputs=outputs,
80
- title='DALL·E mini',
81
- description=description,
82
- article="<p style='text-align: center'> DALLE·mini by Boris Dayma et al. | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a></p>",
83
- layout='vertical',
84
- theme='huggingface',
85
- examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
86
- allow_flagging=False,
87
- live=False,
88
- server_port=8999
89
- ).launch(
90
- share=True # Creates temporary public link if true
91
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
  # Requirements for huggingface spaces
2
- streamlit>=0.84.2
 
1
  # Requirements for huggingface spaces
2
+ streamlit>=0.84.2