Pedro Cuenca commited on
Commit
0199604
·
unverified ·
2 Parent(s): bcd360f 58b9afd

Merge pull request #162 from borisdayma/demo-improvements

Browse files
Files changed (2) hide show
  1. app/streamlit/app.py +17 -23
  2. app/streamlit/backend.py +31 -0
app/streamlit/app.py CHANGED
@@ -1,28 +1,10 @@
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
 
4
- import base64
5
- from io import BytesIO
6
 
7
- import requests
8
  import streamlit as st
9
- from PIL import Image
10
-
11
-
12
- class ServiceError(Exception):
13
- def __init__(self, status_code):
14
- self.status_code = status_code
15
-
16
-
17
- def get_images_from_backend(prompt, backend_url):
18
- r = requests.post(backend_url, json={"prompt": prompt})
19
- if r.status_code == 200:
20
- images = r.json()["images"]
21
- images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
22
- return images
23
- else:
24
- raise ServiceError(r.status_code)
25
-
26
 
27
  st.sidebar.markdown(
28
  """
@@ -45,7 +27,7 @@ DALL·E mini is an AI model that generates images from any prompt you give!
45
  </p>
46
 
47
  <p style='text-align: center'>
48
- Created by Boris Dayma et al. 2021
49
  <br/>
50
  <a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA" target="_blank">Project Report</a>
51
  </p>
@@ -84,8 +66,8 @@ if prompt != "":
84
  )
85
 
86
  try:
87
- backend_url = st.secrets["BACKEND_SERVER"]
88
- print(f"Getting selections: {prompt}")
89
  selected = get_images_from_backend(prompt, backend_url)
90
 
91
  margin = 0.1 # for better position of zoom in arrow
@@ -95,6 +77,18 @@ if prompt != "":
95
  cols[(i % n_columns) * 2].image(img)
96
  container.markdown(f"**{prompt}**")
97
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  st.button("Again!", key="again_button")
99
 
100
  except ServiceError as error:
 
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
 
4
+ from datetime import datetime
 
5
 
 
6
  import streamlit as st
7
+ from backend import ServiceError, get_images_from_backend, get_model_version
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  st.sidebar.markdown(
10
  """
 
27
  </p>
28
 
29
  <p style='text-align: center'>
30
+ Created by Boris Dayma et al. 2021-2022
31
  <br/>
32
  <a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA" target="_blank">Project Report</a>
33
  </p>
 
66
  )
67
 
68
  try:
69
+ backend_url = st.secrets["BACKEND_SERVER"] + "/generate"
70
+ print(f"{datetime.now()} Getting selections: {prompt}")
71
  selected = get_images_from_backend(prompt, backend_url)
72
 
73
  margin = 0.1 # for better position of zoom in arrow
 
77
  cols[(i % n_columns) * 2].image(img)
78
  container.markdown(f"**{prompt}**")
79
 
80
+ version_url = st.secrets["BACKEND_SERVER"] + "/version"
81
+ version = get_model_version(version_url)
82
+ st.sidebar.markdown(
83
+ f"<small><center>{version}</center></small>", unsafe_allow_html=True
84
+ )
85
+
86
+ st.markdown(
87
+ f"""
88
+ These results have been obtained using model `{version}` from [an ongoing training run](https://wandb.ai/dalle-mini/dalle-mini/runs/mheh9e55).
89
+ """
90
+ )
91
+
92
  st.button("Again!", key="again_button")
93
 
94
  except ServiceError as error:
app/streamlit/backend.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Client requests to Dalle-Mini Backend server
2
+
3
+ import base64
4
+ from io import BytesIO
5
+
6
+ import requests
7
+ from PIL import Image
8
+
9
+
10
+ class ServiceError(Exception):
11
+ def __init__(self, status_code):
12
+ self.status_code = status_code
13
+
14
+
15
+ def get_images_from_backend(prompt, backend_url):
16
+ r = requests.post(backend_url, json={"prompt": prompt})
17
+ if r.status_code == 200:
18
+ images = r.json()["images"]
19
+ images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
20
+ return images
21
+ else:
22
+ raise ServiceError(r.status_code)
23
+
24
+
25
+ def get_model_version(url):
26
+ r = requests.get(url)
27
+ if r.status_code == 200:
28
+ version = r.json()["version"]
29
+ return version
30
+ else:
31
+ raise ServiceError(r.status_code)