farrell236 commited on
Commit
3b253f9
·
1 Parent(s): f083f3b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -0
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+ import numpy as np
4
+ import streamlit as st
5
+ import tensorflow as tf
6
+
7
+ from models.cstylegan import cStyleGAN
8
+ from models.gaugan import GauGAN
9
+ from utils import fix_pred_label, onehot_to_rgb, rgb_to_onehot, color_dict
10
+
11
+ from skimage import io
12
+
13
+
14
+ @st.cache_resource
15
+ def load_cstylegan():
16
+ conditional_style_gan = cStyleGAN(start_res=4, target_res=1024)
17
+ conditional_style_gan.grow_model(256)
18
+ conditional_style_gan.load_weights('checkpoints/cstylegan/cstylegan_256x256.ckpt').expect_partial()
19
+ print('Conditional StyleGAN Model Loaded!')
20
+ return conditional_style_gan
21
+
22
+
23
+ @st.cache_resource
24
+ def load_gaugan(batch_size):
25
+ gaugan = GauGAN(image_size=1024, num_classes=7, batch_size=batch_size, latent_dim=512)
26
+ gaugan.load_weights('checkpoints/gaugan/gaugan_1024x1024.ckpt').expect_partial()
27
+ print('GauGAN Model Loaded!')
28
+ return gaugan
29
+
30
+
31
+ def set_seed():
32
+ tf.random.set_seed(seed=st.session_state.seed)
33
+
34
+
35
+ def main():
36
+
37
+ st.title('RetinaGAN')
38
+
39
+ st.sidebar.columns([1, 5, 1])[1].image(cv2.cvtColor(cv2.imread('assets/sample.jpeg'), cv2.COLOR_BGR2RGB))
40
+
41
+ st.sidebar.title('Menu')
42
+ options = st.sidebar.selectbox('Select Option:', ('About', 'Random', 'Upload your own', 'Retina Template'))
43
+
44
+ if options == 'About':
45
+ st.write('Online Demo for **High-Fidelity Diabetic Retina Fundus Image Synthesis from Freestyle Lesion Maps**')
46
+
47
+ st.write('''
48
+ Paper: https://opg.optica.org/abstract.cfm?uri=boe-14-2-533
49
+
50
+ Github: http://github.com/farrell236/RetinaGAN
51
+
52
+ 👈 Select an Option From the drop down menu
53
+
54
+ ---
55
+ ''')
56
+
57
+ st.write('''
58
+ RetinaGAN a two-step process for generating photo-realistic retinal
59
+ Fundus images based on artificially generated or free-hand drawn semantic lesion maps.
60
+ ''')
61
+
62
+ st.columns([1, 5, 1])[1].image(cv2.cvtColor(cv2.imread('assets/RetinaGAN_pipeline.png'), cv2.COLOR_BGR2RGB),
63
+ caption='RetinaGAN Pipeline')
64
+
65
+ st.write('''
66
+ StyleGAN is modified to be conditional in to synthesize pathological lesion maps
67
+ based on a specified DR grade (i.e., grades 0 to 4). The DR Grades are defined by the
68
+ International Clinical Diabetic Retinopathy (ICDR) disease severity scale;
69
+ no apparent retinopathy, {mild, moderate, severe} Non-Proliferative Diabetic Retinopathy (NPDR),
70
+ and Proliferative Diabetic Retinopathy (PDR). The output of the network is a binary image with
71
+ seven channels instead of class colors to avoid ambiguity.
72
+ ''')
73
+
74
+ st.columns([1, 5, 1])[1].image(cv2.cvtColor(cv2.imread('assets/cStyleGAN.png'), cv2.COLOR_BGR2RGB),
75
+ caption='Conditional StyleGAN Model')
76
+
77
+ st.write('''
78
+ The generated label maps are then passed through GauGAN, an image-to-image translation network,
79
+ to turn them into photo-realistic retina fundus images. The input to the network are one-hot
80
+ encoded labels.
81
+ ''')
82
+
83
+ st.columns([1, 5, 1])[1].image(cv2.cvtColor(cv2.imread('assets/GauGAN.png'), cv2.COLOR_BGR2RGB),
84
+ caption='GauGAN Model')
85
+
86
+
87
+ elif options == 'Random':
88
+
89
+ st.session_state.seed = st.sidebar.number_input('Sampling Seed:', value=42, on_change=set_seed)
90
+
91
+ ## Load Models
92
+ conditional_style_gan = load_cstylegan()
93
+ gaugan = load_gaugan(4)
94
+
95
+ for idx, col in enumerate(st.columns(5)):
96
+
97
+ z = tf.random.normal((1, conditional_style_gan.z_dim))
98
+ w = conditional_style_gan.mapping([z, conditional_style_gan.embedding(idx)])
99
+ noise = conditional_style_gan.generate_noise(batch_size=1)
100
+ labels = conditional_style_gan.call({"style_code": w, "noise": noise, "alpha": 1.0, "class_label": idx})
101
+
102
+ labels = tf.keras.backend.softmax(labels)
103
+ labels = tf.cast(labels > 0.5, dtype=tf.float32)
104
+ labels = tf.image.resize(labels, (1024, 1024), method='nearest')
105
+
106
+ fixed_labels = fix_pred_label(labels)
107
+ fixed_labels = tf.tile(fixed_labels, (4, 1, 1, 1))
108
+
109
+ latent_vector = tf.random.normal(shape=(4, 512), mean=0.0, stddev=2.0)
110
+ fake_image = gaugan.predict([latent_vector, fixed_labels])
111
+
112
+ with col:
113
+ st.text(f'DR Grade {idx}')
114
+ st.image(onehot_to_rgb(fixed_labels[0], color_dict), output_format='PNG')
115
+ for im in fake_image:
116
+ st.image(im)
117
+
118
+ # Run again?
119
+ st.button('Regenerate Images')
120
+
121
+ elif options == 'Upload your own':
122
+
123
+ st.session_state.seed = st.sidebar.number_input('Sampling Seed:', value=42, on_change=set_seed)
124
+
125
+ st.sidebar.info('PRIVACY POLICY: Uploaded images are never stored on disk.')
126
+
127
+ ## Load Models
128
+ gaugan = load_gaugan(1)
129
+
130
+ uploaded_file = st.file_uploader('Choose an image...', type=('png'))
131
+
132
+ if uploaded_file:
133
+ col1, col2 = st.columns(2)
134
+
135
+ # Read input image with size [H, W, 3] and range (0, 255)
136
+ img_array = io.imread(uploaded_file)[..., 0:3]
137
+
138
+ # Test for valid mask
139
+ test_colours = np.unique(img_array.reshape(-1, img_array.shape[2]), axis=0)
140
+ if not all([tuple(x) in color_dict.values() for x in test_colours]):
141
+ st.info('Mask Contains invalid Class Colours')
142
+ return
143
+
144
+ # Resize image with padding to [1024, 1024, 3]
145
+ img_array = tf.image.resize_with_pad(img_array, 1024, 1024, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
146
+
147
+ # Display input image
148
+ with col1:
149
+ st.image(img_array.numpy(), caption='Uploaded Image')
150
+
151
+ img_label = rgb_to_onehot(img_array.numpy(), color_dict)[None, ...]
152
+ latent_vector = tf.random.normal(shape=(1, 512), mean=0.0, stddev=2.0)
153
+ fake_image = gaugan.predict([latent_vector, img_label])[0]
154
+
155
+ with col2:
156
+ st.image(fake_image, caption='Generated Image')
157
+
158
+ # Run again?
159
+ st.button('Regenerate Image')
160
+
161
+ elif options == 'Retina Template':
162
+
163
+ st.header('Template')
164
+
165
+ st.write('Download the Retina Template image below. '
166
+ 'Using an image editor of your choice, paint lesions '
167
+ 'into the Vitreous Body and upload it to the model. '
168
+ 'NB: Images must be stored as lossless PNGs')
169
+
170
+ template = np.uint8(cv2.circle(np.zeros((1024, 1024, 3)), [512, 512], 512, (255, 255, 255), -1))
171
+ st.columns([1, 5, 1])[1].image(template, use_column_width=True, output_format='PNG')
172
+
173
+ st.header('Class Colours')
174
+ cols = st.columns(7)
175
+ for idx, cls in enumerate(color_dict):
176
+ with cols[idx]:
177
+ st.image(image=np.tile(color_dict[cls], (32, 32, 1)),
178
+ caption=cls,
179
+ output_format='PNG')
180
+ # st.caption(color_dict[cls])
181
+
182
+
183
+ data = {'Class Name': [
184
+ 'Background',
185
+ 'Hard Exudate',
186
+ 'Hemohedge',
187
+ 'Soft Exudate',
188
+ 'Micro Aneurysms',
189
+ 'Optical Disc',
190
+ 'Vitreous Body'],
191
+ 'RGB Colour': [
192
+ str(color_dict[0]), # BG
193
+ str(color_dict[1]), # EX
194
+ str(color_dict[2]), # HE
195
+ str(color_dict[3]), # SE
196
+ str(color_dict[4]), # MA
197
+ str(color_dict[5]), # OD
198
+ str(color_dict[6])] # VB
199
+ }
200
+
201
+ st.table(data)
202
+
203
+
204
+ if __name__ == '__main__':
205
+
206
+ # tf.config.set_visible_devices([], 'GPU')
207
+
208
+ main()
209
+
210
+