boris commited on
Commit
b49f529
·
2 Parent(s): 5bf185b 50b9a44

Merge pull request #40 from borisdayma/app-ui

Browse files

Update demo to use Suraj's backend server

Former-commit-id: 176cdf8e7b05f983e22cc16ac884ff55b00e7ab7

Files changed (2) hide show
  1. README.md +1 -1
  2. app/app_gradio_ngrok.py +99 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🎨
4
  colorFrom: red
5
  colorTo: blue
6
  sdk: gradio
7
- app_file: app/app_gradio.py
8
  pinned: false
9
  ---
10
 
 
4
  colorFrom: red
5
  colorTo: blue
6
  sdk: gradio
7
+ app_file: app/app_gradio_ngrok.py
8
  pinned: false
9
  ---
10
 
app/app_gradio_ngrok.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import requests
5
+ from PIL import Image
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ from io import BytesIO
9
+ import base64
10
+
11
+ import gradio as gr
12
+
13
+ # If we use streamlit, this would be exported as a streamlit secret
14
+ import os
15
+ backend_url = os.environ["BACKEND_SERVER"]
16
+
17
+ def compose_predictions(images, caption=None):
18
+ increased_h = 0 if caption is None else 48
19
+ w, h = images[0].size[0], images[0].size[1]
20
+ img = Image.new("RGB", (len(images)*w, h + increased_h))
21
+ for i, img_ in enumerate(images):
22
+ img.paste(img_, (i*w, increased_h))
23
+
24
+ if caption is not None:
25
+ draw = ImageDraw.Draw(img)
26
+ font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
27
+ draw.text((20, 3), caption, (255,255,255), font=font)
28
+ return img
29
+
30
+ class ServiceError(Exception):
31
+ def __init__(self, status_code):
32
+ self.status_code = status_code
33
+
34
+ def get_images_from_ngrok(prompt):
35
+ r = requests.post(
36
+ backend_url,
37
+ json={"prompt": prompt}
38
+ )
39
+ if r.status_code == 200:
40
+ images = r.json()["images"]
41
+ images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
42
+ return images
43
+ else:
44
+ raise ServiceError(r.status_code)
45
+
46
+ def run_inference(prompt):
47
+ try:
48
+ images = get_images_from_ngrok(prompt)
49
+ predictions = compose_predictions(images)
50
+ output_title = f"""
51
+ <p style="font-size:22px; font-style:bold">Best predictions</p>
52
+ <p>We asked our model to generate 128 candidates for your prompt:</p>
53
+
54
+ <pre>
55
+
56
+ <b>{prompt}</b>
57
+ </pre>
58
+ <p>We then used a pre-trained <a href="https://huggingface.co/openai/clip-vit-base-patch32">CLIP model</a> to score them according to the
59
+ similarity of the text and the image representations.</p>
60
+
61
+ <p>This is the result:</p>
62
+ """
63
+
64
+ output_description = """
65
+ <p>Read our <a style="color:blue;" href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA">full report</a> for more details on how this works.<p>
66
+ <p style='text-align: center'>Created with <a style="color:blue;" href="https://github.com/borisdayma/dalle-mini">DALL·E mini</a></p>
67
+ """
68
+
69
+ except ServiceError:
70
+ output_title = f"""
71
+ Sorry, there was an error retrieving the images. Please, try again later or <a href="mailto:[email protected]">contact us here</a>.
72
+ """
73
+ predictions = None
74
+ output_description = ""
75
+
76
+ return (output_title, predictions, output_description)
77
+
78
+ outputs = [
79
+ gr.outputs.HTML(label=""), # To be used as title
80
+ gr.outputs.Image(label=''),
81
+ gr.outputs.HTML(label=""), # Additional text that appears in the screenshot
82
+ ]
83
+
84
+ description = """
85
+ Welcome to DALL·E-mini, a text-to-image generation model.
86
+ """
87
+ gr.Interface(run_inference,
88
+ inputs=[gr.inputs.Textbox(label='Prompt')],
89
+ outputs=outputs,
90
+ title='DALL·E mini',
91
+ description=description,
92
+ article="<p style='text-align: center'> DALLE·mini by Boris Dayma et al. | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a></p>",
93
+ layout='vertical',
94
+ theme='huggingface',
95
+ examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
96
+ allow_flagging=False,
97
+ live=False,
98
+ # server_name="0.0.0.0", # Bind to all interfaces
99
+ ).launch()