Spaces:
Sleeping
Sleeping
Commit
·
3b253f9
1
Parent(s):
f083f3b
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import streamlit as st
|
5 |
+
import tensorflow as tf
|
6 |
+
|
7 |
+
from models.cstylegan import cStyleGAN
|
8 |
+
from models.gaugan import GauGAN
|
9 |
+
from utils import fix_pred_label, onehot_to_rgb, rgb_to_onehot, color_dict
|
10 |
+
|
11 |
+
from skimage import io
|
12 |
+
|
13 |
+
|
14 |
+
@st.cache_resource
|
15 |
+
def load_cstylegan():
|
16 |
+
conditional_style_gan = cStyleGAN(start_res=4, target_res=1024)
|
17 |
+
conditional_style_gan.grow_model(256)
|
18 |
+
conditional_style_gan.load_weights('checkpoints/cstylegan/cstylegan_256x256.ckpt').expect_partial()
|
19 |
+
print('Conditional StyleGAN Model Loaded!')
|
20 |
+
return conditional_style_gan
|
21 |
+
|
22 |
+
|
23 |
+
@st.cache_resource
|
24 |
+
def load_gaugan(batch_size):
|
25 |
+
gaugan = GauGAN(image_size=1024, num_classes=7, batch_size=batch_size, latent_dim=512)
|
26 |
+
gaugan.load_weights('checkpoints/gaugan/gaugan_1024x1024.ckpt').expect_partial()
|
27 |
+
print('GauGAN Model Loaded!')
|
28 |
+
return gaugan
|
29 |
+
|
30 |
+
|
31 |
+
def set_seed():
|
32 |
+
tf.random.set_seed(seed=st.session_state.seed)
|
33 |
+
|
34 |
+
|
35 |
+
def main():
|
36 |
+
|
37 |
+
st.title('RetinaGAN')
|
38 |
+
|
39 |
+
st.sidebar.columns([1, 5, 1])[1].image(cv2.cvtColor(cv2.imread('assets/sample.jpeg'), cv2.COLOR_BGR2RGB))
|
40 |
+
|
41 |
+
st.sidebar.title('Menu')
|
42 |
+
options = st.sidebar.selectbox('Select Option:', ('About', 'Random', 'Upload your own', 'Retina Template'))
|
43 |
+
|
44 |
+
if options == 'About':
|
45 |
+
st.write('Online Demo for **High-Fidelity Diabetic Retina Fundus Image Synthesis from Freestyle Lesion Maps**')
|
46 |
+
|
47 |
+
st.write('''
|
48 |
+
Paper: https://opg.optica.org/abstract.cfm?uri=boe-14-2-533
|
49 |
+
|
50 |
+
Github: http://github.com/farrell236/RetinaGAN
|
51 |
+
|
52 |
+
👈 Select an Option From the drop down menu
|
53 |
+
|
54 |
+
---
|
55 |
+
''')
|
56 |
+
|
57 |
+
st.write('''
|
58 |
+
RetinaGAN a two-step process for generating photo-realistic retinal
|
59 |
+
Fundus images based on artificially generated or free-hand drawn semantic lesion maps.
|
60 |
+
''')
|
61 |
+
|
62 |
+
st.columns([1, 5, 1])[1].image(cv2.cvtColor(cv2.imread('assets/RetinaGAN_pipeline.png'), cv2.COLOR_BGR2RGB),
|
63 |
+
caption='RetinaGAN Pipeline')
|
64 |
+
|
65 |
+
st.write('''
|
66 |
+
StyleGAN is modified to be conditional in to synthesize pathological lesion maps
|
67 |
+
based on a specified DR grade (i.e., grades 0 to 4). The DR Grades are defined by the
|
68 |
+
International Clinical Diabetic Retinopathy (ICDR) disease severity scale;
|
69 |
+
no apparent retinopathy, {mild, moderate, severe} Non-Proliferative Diabetic Retinopathy (NPDR),
|
70 |
+
and Proliferative Diabetic Retinopathy (PDR). The output of the network is a binary image with
|
71 |
+
seven channels instead of class colors to avoid ambiguity.
|
72 |
+
''')
|
73 |
+
|
74 |
+
st.columns([1, 5, 1])[1].image(cv2.cvtColor(cv2.imread('assets/cStyleGAN.png'), cv2.COLOR_BGR2RGB),
|
75 |
+
caption='Conditional StyleGAN Model')
|
76 |
+
|
77 |
+
st.write('''
|
78 |
+
The generated label maps are then passed through GauGAN, an image-to-image translation network,
|
79 |
+
to turn them into photo-realistic retina fundus images. The input to the network are one-hot
|
80 |
+
encoded labels.
|
81 |
+
''')
|
82 |
+
|
83 |
+
st.columns([1, 5, 1])[1].image(cv2.cvtColor(cv2.imread('assets/GauGAN.png'), cv2.COLOR_BGR2RGB),
|
84 |
+
caption='GauGAN Model')
|
85 |
+
|
86 |
+
|
87 |
+
elif options == 'Random':
|
88 |
+
|
89 |
+
st.session_state.seed = st.sidebar.number_input('Sampling Seed:', value=42, on_change=set_seed)
|
90 |
+
|
91 |
+
## Load Models
|
92 |
+
conditional_style_gan = load_cstylegan()
|
93 |
+
gaugan = load_gaugan(4)
|
94 |
+
|
95 |
+
for idx, col in enumerate(st.columns(5)):
|
96 |
+
|
97 |
+
z = tf.random.normal((1, conditional_style_gan.z_dim))
|
98 |
+
w = conditional_style_gan.mapping([z, conditional_style_gan.embedding(idx)])
|
99 |
+
noise = conditional_style_gan.generate_noise(batch_size=1)
|
100 |
+
labels = conditional_style_gan.call({"style_code": w, "noise": noise, "alpha": 1.0, "class_label": idx})
|
101 |
+
|
102 |
+
labels = tf.keras.backend.softmax(labels)
|
103 |
+
labels = tf.cast(labels > 0.5, dtype=tf.float32)
|
104 |
+
labels = tf.image.resize(labels, (1024, 1024), method='nearest')
|
105 |
+
|
106 |
+
fixed_labels = fix_pred_label(labels)
|
107 |
+
fixed_labels = tf.tile(fixed_labels, (4, 1, 1, 1))
|
108 |
+
|
109 |
+
latent_vector = tf.random.normal(shape=(4, 512), mean=0.0, stddev=2.0)
|
110 |
+
fake_image = gaugan.predict([latent_vector, fixed_labels])
|
111 |
+
|
112 |
+
with col:
|
113 |
+
st.text(f'DR Grade {idx}')
|
114 |
+
st.image(onehot_to_rgb(fixed_labels[0], color_dict), output_format='PNG')
|
115 |
+
for im in fake_image:
|
116 |
+
st.image(im)
|
117 |
+
|
118 |
+
# Run again?
|
119 |
+
st.button('Regenerate Images')
|
120 |
+
|
121 |
+
elif options == 'Upload your own':
|
122 |
+
|
123 |
+
st.session_state.seed = st.sidebar.number_input('Sampling Seed:', value=42, on_change=set_seed)
|
124 |
+
|
125 |
+
st.sidebar.info('PRIVACY POLICY: Uploaded images are never stored on disk.')
|
126 |
+
|
127 |
+
## Load Models
|
128 |
+
gaugan = load_gaugan(1)
|
129 |
+
|
130 |
+
uploaded_file = st.file_uploader('Choose an image...', type=('png'))
|
131 |
+
|
132 |
+
if uploaded_file:
|
133 |
+
col1, col2 = st.columns(2)
|
134 |
+
|
135 |
+
# Read input image with size [H, W, 3] and range (0, 255)
|
136 |
+
img_array = io.imread(uploaded_file)[..., 0:3]
|
137 |
+
|
138 |
+
# Test for valid mask
|
139 |
+
test_colours = np.unique(img_array.reshape(-1, img_array.shape[2]), axis=0)
|
140 |
+
if not all([tuple(x) in color_dict.values() for x in test_colours]):
|
141 |
+
st.info('Mask Contains invalid Class Colours')
|
142 |
+
return
|
143 |
+
|
144 |
+
# Resize image with padding to [1024, 1024, 3]
|
145 |
+
img_array = tf.image.resize_with_pad(img_array, 1024, 1024, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
146 |
+
|
147 |
+
# Display input image
|
148 |
+
with col1:
|
149 |
+
st.image(img_array.numpy(), caption='Uploaded Image')
|
150 |
+
|
151 |
+
img_label = rgb_to_onehot(img_array.numpy(), color_dict)[None, ...]
|
152 |
+
latent_vector = tf.random.normal(shape=(1, 512), mean=0.0, stddev=2.0)
|
153 |
+
fake_image = gaugan.predict([latent_vector, img_label])[0]
|
154 |
+
|
155 |
+
with col2:
|
156 |
+
st.image(fake_image, caption='Generated Image')
|
157 |
+
|
158 |
+
# Run again?
|
159 |
+
st.button('Regenerate Image')
|
160 |
+
|
161 |
+
elif options == 'Retina Template':
|
162 |
+
|
163 |
+
st.header('Template')
|
164 |
+
|
165 |
+
st.write('Download the Retina Template image below. '
|
166 |
+
'Using an image editor of your choice, paint lesions '
|
167 |
+
'into the Vitreous Body and upload it to the model. '
|
168 |
+
'NB: Images must be stored as lossless PNGs')
|
169 |
+
|
170 |
+
template = np.uint8(cv2.circle(np.zeros((1024, 1024, 3)), [512, 512], 512, (255, 255, 255), -1))
|
171 |
+
st.columns([1, 5, 1])[1].image(template, use_column_width=True, output_format='PNG')
|
172 |
+
|
173 |
+
st.header('Class Colours')
|
174 |
+
cols = st.columns(7)
|
175 |
+
for idx, cls in enumerate(color_dict):
|
176 |
+
with cols[idx]:
|
177 |
+
st.image(image=np.tile(color_dict[cls], (32, 32, 1)),
|
178 |
+
caption=cls,
|
179 |
+
output_format='PNG')
|
180 |
+
# st.caption(color_dict[cls])
|
181 |
+
|
182 |
+
|
183 |
+
data = {'Class Name': [
|
184 |
+
'Background',
|
185 |
+
'Hard Exudate',
|
186 |
+
'Hemohedge',
|
187 |
+
'Soft Exudate',
|
188 |
+
'Micro Aneurysms',
|
189 |
+
'Optical Disc',
|
190 |
+
'Vitreous Body'],
|
191 |
+
'RGB Colour': [
|
192 |
+
str(color_dict[0]), # BG
|
193 |
+
str(color_dict[1]), # EX
|
194 |
+
str(color_dict[2]), # HE
|
195 |
+
str(color_dict[3]), # SE
|
196 |
+
str(color_dict[4]), # MA
|
197 |
+
str(color_dict[5]), # OD
|
198 |
+
str(color_dict[6])] # VB
|
199 |
+
}
|
200 |
+
|
201 |
+
st.table(data)
|
202 |
+
|
203 |
+
|
204 |
+
if __name__ == '__main__':
|
205 |
+
|
206 |
+
# tf.config.set_visible_devices([], 'GPU')
|
207 |
+
|
208 |
+
main()
|
209 |
+
|
210 |
+
|