sinashish commited on
Commit
5bec537
·
1 Parent(s): a0a6e09

adds depth, anatomy segmentation viewer

Browse files
Files changed (1) hide show
  1. gradio_app.py +112 -93
gradio_app.py CHANGED
@@ -1,6 +1,6 @@
1
  from functools import partial
2
  import gradio as gr
3
-
4
  from PIL import Image
5
  import numpy as np
6
  import gradio as gr
@@ -129,6 +129,59 @@ background_ds = Background2d(
129
  image_filenames=None,
130
  )
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  def prepare_ds_renderer(
134
  randomize,
@@ -168,30 +221,30 @@ def prepare_ds_renderer(
168
  config=None,
169
  is_blended=is_blend,
170
  )
171
- blended3d = Blended3d(
172
- mesh_filename=os.path.join(
173
- "./DermSynth3D/data/3dbodytex-1.1-highres/",
174
- mesh_name,
175
- "model_highres_0_normalized.obj",
176
- ),
177
- device=DEVICE,
178
- dir_blended_textures=dir_blended_textures,
179
- dir_anatomy=dir_anatomy,
180
- extension=extension ,
181
- )
182
- normal_texture = load_texture_map(
183
- mesh, mesh_name, "No Lesion", 0, device
184
- ).maps_padded()
185
- if num_lesion > 0:
186
- blended_texture_image = load_texture_map(
187
- mesh, mesh_name, "Blended Lesion", num_lesion, device
188
- ).maps_padded()
189
- pasted_texture_image = load_texture_map(
190
- mesh, mesh_name, "Pasted Lesion", num_lesion, device
191
- ).maps_padded()
192
- dilated_texture_image = load_texture_map(
193
- mesh, mesh_name, "Dilated Lesion", num_lesion, device
194
- ).maps_padded()
195
 
196
  # texture_lesion_mask = blended3d.lesion_texture_mask(astensor=True).to(device)
197
  # non_skin_texture_mask = blended3d.nonskin_texture_mask(astensor=True).to(device)
@@ -210,16 +263,8 @@ def prepare_ds_renderer(
210
  mat_sh,
211
  mat_sc,
212
  )
213
- # mesh_renderer.mesh = mesh
214
- # mesh_renderer.cameras = cameras
215
- # mesh_renderer.lights = lights
216
- # mesh_renderer.materials = materials
217
- # mesh_renderer.renderer = renderer
218
  gr.Info("Successfully prepared renderer.")
219
- # render normal images
220
  gr.Info("Rendering Images...")
221
- # if num_views > 1:
222
- # mesh_renderer.mesh = mesh.extend(num_views)
223
  gr.Info(f"Rendering {num_views} views on {DEVICE}. Please wait...")
224
  img_count = 0
225
  view2d = []
@@ -227,10 +272,11 @@ def prepare_ds_renderer(
227
  anatomy2d = []
228
  seg2d = []
229
  view_size = (224, 224)
 
230
  while img_count < num_views:
231
  if randomize:
232
  gr.Info("Finding suitable parameters...")
233
- success = gen2d.randomize_parameters(config=None)
234
  if not success:
235
  gr.Info("Could not find suitable parameters. Trying again.")
236
  continue
@@ -238,9 +284,7 @@ def prepare_ds_renderer(
238
  raster_settings = RasterizationSettings(
239
  image_size=view_size[0],
240
  blur_radius=0.0,
241
- faces_per_pixel=1,
242
- # max_faces_per_bin=100,
243
- # bin_size=0,
244
  perspective_correct=True,
245
  )
246
  gen2d.mesh_renderer.cameras = cameras
@@ -259,60 +303,19 @@ def prepare_ds_renderer(
259
  print("***Not enough skin or unable to paste lesion. Skipping.")
260
  continue
261
  paste_img = (paste_img * 255).astype(np.uint8)
 
262
  depth_view = target[:, :, 4]
263
- depth_img = (depth_view - depth_view.min()) / (
264
- depth_view.max() - depth_view.min()
265
- )
266
- depth_img = (depth_img * 255).astype(np.uint8)
267
  view2d.append(paste_img)
268
  depth2d.append(depth_img)
269
- anatomy2d.append(target[:, :, 5])
270
- seg2d.append(target[:, :, 3])
 
 
271
  gr.Info(f"Successfully rendered {img_count+1}/{num_views} image+annotations.")
272
  img_count += 1
273
  return view2d, depth2d, anatomy2d, seg2d
274
 
275
- # mesh_renderer.compute_fragments()
276
- # view2d = mesh_renderer.render_view(asnumpy=True, asRGB=True)
277
- # gr.Info("Successfully rendered images.")
278
- # gr.Info("Preparing annotations...")
279
- # # breakpoint()
280
- # pix2face = torch.from_numpy(mesh_renderer.pixels_to_face()).to(
281
- # mesh_renderer.mesh.device
282
- # )
283
- # pix2vert = torch.stack(
284
- # [a[i] for a, i in zip(mesh_renderer.mesh.faces_padded().squeeze(), pix2face)]
285
- # )
286
- # pix2vert = pix2vert.detach().cpu().numpy()
287
- # anatomy_image = [
288
- # vertices_to_anatomy[pix2vert[i]] * mesh_renderer.body_mask()
289
- # for i in range(num_views)
290
- # ]
291
- # anatomy_image = np.stack(anatomy_image)
292
-
293
- # anatomy_image = mesh_renderer.anatomy_image(vertices_to_anatomy)
294
- # depth_img = mesh_renderer.depth_view(asnumpy=True)
295
- # mesh_renderer.set_texture_image(texture_lesion_mask[:, :, np.newaxis])
296
- # mask2d = mesh_renderer.render_view(asnumpy=True, asRGB=True)
297
- # lesion_mask = mesh_renderer.lesion_mask(mask2d[:, :, 0], lesion_mask_id=None)
298
- # # skin mask
299
- # mesh_renderer.set_texture_image(non_skin_texture_mask)
300
- # nonskin_mask = mesh_renderer.render_view(asnumpy=True, asRGB=True)
301
- # skin_mask = mesh_renderer.skin_mask(nonskin_mask[:, :, 0] > 0.5)
302
- # segmentation_mask = make_masks(lesion_mask, skin_mask)
303
- # gr.Info("Successfully prepared annotations.")
304
- # print(view2d.shape, anatomy_image.shape, depth_img.shape, segmentation_mask.shape)
305
- # convert anatomy image with labels for each pixel to an image with RGB values
306
- # map labels to pixels
307
-
308
- # return (
309
- # view2d,
310
- # anatomy_image,
311
- # depth_img,
312
- # skin_mask,
313
- # ) # segmentation_mask
314
-
315
-
316
  # define the list of all the examples
317
  def get_examples():
318
  # setup_paths()
@@ -392,7 +395,7 @@ def plotly_mesh(verts, faces, vc, mesh_name):
392
  )
393
  ]
394
  )
395
- fig.update_layout(scene_aspectmode="manual", scene_aspectratio=dict(x=1, y=1, z=1))
396
  fig.update_layout(scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False)))
397
  fig.update_layout(scene=dict(zaxis=dict(visible=False)))
398
  fig.update_layout(scene=dict(camera=dict(up=dict(x=1, y=0, z=1))))
@@ -624,7 +627,6 @@ def process_examples(mesh_name, tex_name, n_lesion):
624
  mesh_path = get_mesh_path(mesh_name)
625
  texture_path = get_texture_module(tex_name)(mesh_name, n_lesion)
626
  mesh_to_view = plotly_mesh(*get_trimesh_attrs(mesh_name, tex_name, n_lesion))
627
- # mesh = load_mesh_and_texture(mesh_name, tex_name, n_lesion)
628
  return mesh_to_view, texture_path, n_lesion
629
 
630
 
@@ -640,7 +642,7 @@ def update_plots(mesh_name, texture_name, num_lesion):
640
  )
641
  return default_mesh_plot, default_texture, num_lesion
642
  mesh_path = get_mesh_path(mesh_name)
643
- texture_path = get_texture_module(texture_name)(mesh_name, num_lesion)
644
  mesh_to_view = plotly_mesh(*get_trimesh_attrs(mesh_name, texture_name, num_lesion))
645
  gr.Info("Successfully updated mesh and texture.")
646
  return mesh_to_view, texture_path, num_lesion
@@ -852,9 +854,26 @@ def run_demo():
852
  )
853
  # rendered views panel
854
  with gr.Row(variant="panel"):
855
- render_block = gr.Gallery(
856
- label="Rendered Views", columns=4, height="auto", object_fit="contain"
857
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858
 
859
  @gr.on(
860
  triggers=[render_button.click],
@@ -874,7 +893,7 @@ def run_demo():
874
  mat_sh,
875
  mat_sc,
876
  ],
877
- outputs=[render_block],
878
  )
879
  def render_views(
880
  randomize,
@@ -912,7 +931,7 @@ def run_demo():
912
  # renderer, mesh, lights, cameras, materials, num_views
913
  # )
914
  # return [_ for _ in images.detach().cpu().numpy()]
915
- view2d, anatomy, depth, segmentation = prepare_ds_renderer(
916
  randomize,
917
  select_mesh,
918
  select_texture,
@@ -928,7 +947,7 @@ def run_demo():
928
  mat_sh,
929
  mat_sc,
930
  )
931
- return view2d
932
 
933
  # examples panel when the iuser does not want to input
934
  with gr.Row(variant="panel"):
 
1
  from functools import partial
2
  import gradio as gr
3
+ import pdb
4
  from PIL import Image
5
  import numpy as np
6
  import gradio as gr
 
129
  image_filenames=None,
130
  )
131
 
132
+ from dermsynth3d.utils.anatomy import SimpleAnatomy
133
+ color_labels = {
134
+ 0: (0., 0., 0.), # background
135
+ 1: (174., 199., 232.), # head
136
+ 2: (152., 223., 138.), # torso
137
+ 3: (31., 119., 180.), # hips
138
+ 4: (255., 187., 120.), # legs
139
+ 5: (188., 189., 34.), # feet
140
+ 6: (140., 86., 75.), # arms
141
+ 7: (255., 152., 150.), # hands
142
+ }
143
+
144
+
145
+ def to_simple_anatomy(anatomy):
146
+ for i in range(16+1):
147
+ if i in [0,1]:
148
+ continue
149
+ if i in [2,3]:
150
+ anatomy[anatomy==i] = 2
151
+ if i == 4:
152
+ anatomy[anatomy==i] = 3
153
+ if i in [5,6,7,8]:
154
+ anatomy[anatomy==i] = 4
155
+ if i in [9,10]:
156
+ anatomy[anatomy==i] = 5
157
+ if i in [11,12,13,14]:
158
+ anatomy[anatomy==i] = 6
159
+ if i in [15,16]:
160
+ anatomy[anatomy==i] = 7
161
+ return anatomy
162
+
163
+ def convert_anatomy_to_rgb(anatomy):
164
+ anatomy = to_simple_anatomy(anatomy)
165
+ anatomy_rgb = np.zeros((anatomy.shape[0], anatomy.shape[1], 3))
166
+ for k, v in color_labels.items():
167
+ anatomy_rgb[anatomy == k] = v
168
+ return anatomy_rgb.astype(np.uint8)
169
+
170
+ import PIL.Image as pil
171
+ import numpy as np
172
+ import matplotlib as mpl
173
+ import matplotlib.cm as cm
174
+ def convert_depth_to_rgb(depth):
175
+ mask = depth != 0
176
+ disp_map = 1 / depth
177
+ vmax = np.percentile(disp_map[mask], 95)
178
+ vmin = np.percentile(disp_map[mask], 5)
179
+ normalizer = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
180
+ mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
181
+ mask = np.repeat(np.expand_dims(mask,-1), 3, -1)
182
+ colormapped_im = (mapper.to_rgba(disp_map)[:, :, :3] * 255).astype(np.uint8)
183
+ colormapped_im[~mask] = 255
184
+ return colormapped_im
185
 
186
  def prepare_ds_renderer(
187
  randomize,
 
221
  config=None,
222
  is_blended=is_blend,
223
  )
224
+ # blended3d = Blended3d(
225
+ # mesh_filename=os.path.join(
226
+ # "./DermSynth3D/data/3dbodytex-1.1-highres/",
227
+ # mesh_name,
228
+ # "model_highres_0_normalized.obj",
229
+ # ),
230
+ # device=DEVICE,
231
+ # dir_blended_textures=dir_blended_textures,
232
+ # dir_anatomy=dir_anatomy,
233
+ # extension=extension ,
234
+ # )
235
+ # normal_texture = load_texture_map(
236
+ # mesh, mesh_name, "No Lesion", 0, device
237
+ # ).maps_padded()
238
+ # if num_lesion > 0:
239
+ # blended_texture_image = load_texture_map(
240
+ # mesh, mesh_name, "Blended Lesion", num_lesion, device
241
+ # ).maps_padded()
242
+ # pasted_texture_image = load_texture_map(
243
+ # mesh, mesh_name, "Pasted Lesion", num_lesion, device
244
+ # ).maps_padded()
245
+ # dilated_texture_image = load_texture_map(
246
+ # mesh, mesh_name, "Dilated Lesion", num_lesion, device
247
+ # ).maps_padded()
248
 
249
  # texture_lesion_mask = blended3d.lesion_texture_mask(astensor=True).to(device)
250
  # non_skin_texture_mask = blended3d.nonskin_texture_mask(astensor=True).to(device)
 
263
  mat_sh,
264
  mat_sc,
265
  )
 
 
 
 
 
266
  gr.Info("Successfully prepared renderer.")
 
267
  gr.Info("Rendering Images...")
 
 
268
  gr.Info(f"Rendering {num_views} views on {DEVICE}. Please wait...")
269
  img_count = 0
270
  view2d = []
 
272
  anatomy2d = []
273
  seg2d = []
274
  view_size = (224, 224)
275
+ gen2d.view_size = view_size
276
  while img_count < num_views:
277
  if randomize:
278
  gr.Info("Finding suitable parameters...")
279
+ success = gen2d.randomize_parameters(config=None, view_size=view_size)
280
  if not success:
281
  gr.Info("Could not find suitable parameters. Trying again.")
282
  continue
 
284
  raster_settings = RasterizationSettings(
285
  image_size=view_size[0],
286
  blur_radius=0.0,
287
+ faces_per_pixel=10,
 
 
288
  perspective_correct=True,
289
  )
290
  gen2d.mesh_renderer.cameras = cameras
 
303
  print("***Not enough skin or unable to paste lesion. Skipping.")
304
  continue
305
  paste_img = (paste_img * 255).astype(np.uint8)
306
+ anatomy_view = target[:, :, 3]
307
  depth_view = target[:, :, 4]
308
+ depth_img = convert_depth_to_rgb(depth_view)
 
 
 
309
  view2d.append(paste_img)
310
  depth2d.append(depth_img)
311
+ anatomy_img = convert_anatomy_to_rgb(anatomy_view)
312
+ anatomy2d.append(anatomy_img)
313
+ mask = target[:, :, 0]
314
+ seg2d.append(mask)
315
  gr.Info(f"Successfully rendered {img_count+1}/{num_views} image+annotations.")
316
  img_count += 1
317
  return view2d, depth2d, anatomy2d, seg2d
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  # define the list of all the examples
320
  def get_examples():
321
  # setup_paths()
 
395
  )
396
  ]
397
  )
398
+ # fig.update_layout(scene_aspectmode="manual", scene_aspectratio=dict(x=1, y=1, z=1))
399
  fig.update_layout(scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False)))
400
  fig.update_layout(scene=dict(zaxis=dict(visible=False)))
401
  fig.update_layout(scene=dict(camera=dict(up=dict(x=1, y=0, z=1))))
 
627
  mesh_path = get_mesh_path(mesh_name)
628
  texture_path = get_texture_module(tex_name)(mesh_name, n_lesion)
629
  mesh_to_view = plotly_mesh(*get_trimesh_attrs(mesh_name, tex_name, n_lesion))
 
630
  return mesh_to_view, texture_path, n_lesion
631
 
632
 
 
642
  )
643
  return default_mesh_plot, default_texture, num_lesion
644
  mesh_path = get_mesh_path(mesh_name)
645
+ texture_path = Image.open(get_texture_module(texture_name)(mesh_name, num_lesion)).convert("RGB").resize((512, 512))
646
  mesh_to_view = plotly_mesh(*get_trimesh_attrs(mesh_name, texture_name, num_lesion))
647
  gr.Info("Successfully updated mesh and texture.")
648
  return mesh_to_view, texture_path, num_lesion
 
854
  )
855
  # rendered views panel
856
  with gr.Row(variant="panel"):
857
+ with gr.Tab("Rendered RGB Views"):
858
+ render_block = gr.Gallery(
859
+ label="Rendered Views", columns=4, height="auto", object_fit="contain"
860
+ )
861
+ with gr.Tab("Rendered Depth Views"):
862
+ depth_block = gr.Gallery(
863
+ label="Depth Maps", columns=4, height="auto", object_fit="contain"
864
+ )
865
+ with gr.Tab("Rendered Anatomy Views"):
866
+ anatomy_block = gr.Gallery(
867
+ label="Anatomy Labels", columns=4, height="auto", object_fit="contain"
868
+ )
869
+ with gr.Tab("Rendered Segmentation Views"):
870
+ seg_block = gr.Gallery(
871
+ label="Segmentation Masks", columns=4, height="auto", object_fit="contain"
872
+ )
873
+ #
874
+ # render_block = gr.Gallery(
875
+ # label="Rendered Views", columns=4, height="auto", object_fit="contain"
876
+ # )
877
 
878
  @gr.on(
879
  triggers=[render_button.click],
 
893
  mat_sh,
894
  mat_sc,
895
  ],
896
+ outputs=[render_block, depth_block, anatomy_block, seg_block],
897
  )
898
  def render_views(
899
  randomize,
 
931
  # renderer, mesh, lights, cameras, materials, num_views
932
  # )
933
  # return [_ for _ in images.detach().cpu().numpy()]
934
+ view2d, depth, anatomy, segmentation = prepare_ds_renderer(
935
  randomize,
936
  select_mesh,
937
  select_texture,
 
947
  mat_sh,
948
  mat_sc,
949
  )
950
+ return view2d, depth, anatomy, segmentation
951
 
952
  # examples panel when the iuser does not want to input
953
  with gr.Row(variant="panel"):