AdilEsset commited on
Commit
a621450
·
1 Parent(s): d00edc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -31
app.py CHANGED
@@ -6,14 +6,12 @@ from PIL import Image, ImageDraw
6
  import numpy as np
7
  from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
8
 
9
- import numpy as np
10
  from scipy.ndimage import center_of_mass
11
 
12
-
13
  def combine_ims(im1, im2, val=128):
14
- p = Image.new("L", im1.size, val)
15
- im = Image.composite(im1, im2, p)
16
- return im
17
 
18
  def get_class_centers(segmentation_mask, class_dict):
19
  segmentation_mask = segmentation_mask.numpy() + 1
@@ -24,44 +22,47 @@ def get_class_centers(segmentation_mask, class_dict):
24
 
25
  class_centers[class_index] = center_of_mass_list
26
 
27
- class_centers = {k:list(map(int, v)) for k,v in class_centers.items() if not np.isnan(sum(v))}
28
  return class_centers
29
 
30
  def visualize_mask(predicted_semantic_map, class_ids, class_colors):
31
- h, w = predicted_semantic_map.shape
32
- color_indexes = np.zeros((h, w), dtype=np.uint8)
33
- color_indexes[:] = predicted_semantic_map.numpy()
34
- color_indexes = color_indexes.flatten()
35
-
36
- colors = class_colors[class_ids[color_indexes]]
37
- output = colors.reshape(h, w, 3).astype(np.uint8)
38
- image_mask = Image.fromarray(output)
39
- return image_mask
40
 
 
 
 
 
41
 
42
  def get_out_image(image, predicted_semantic_map):
43
- class_centers = get_class_centers(predicted_semantic_map, class_dict)
44
- mask = visualize_mask(predicted_semantic_map, class_ids, class_colors)
45
- image_mask = combine_ims(image, mask, val=128)
46
- draw = ImageDraw.Draw(image_mask)
47
- for id, (y, x) in class_centers.items():
48
- draw.text((x, y), str(class_names[id-1]), fill='black')
 
 
 
 
49
 
50
- return image_mask
51
 
52
  def gradio_process(image):
53
- inputs = processor(images=image, return_tensors="pt")
54
 
55
- with torch.no_grad():
56
- outputs = model(**inputs)
57
 
58
- predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
59
 
60
- out_image = get_out_image(image, predicted_semantic_map)
61
- return out_image
62
 
63
  with open('ade20k_classes.pickle', 'rb') as f:
64
- class_names, class_ids, class_colors = pickle.load(f)
65
  class_names, class_ids, class_colors = np.array(class_names), np.array(class_ids), np.array(class_colors)
66
  class_dict = dict(zip(class_ids, class_names))
67
 
@@ -73,11 +74,10 @@ model.eval()
73
  demo = gr.Interface(
74
  gradio_process,
75
  inputs=gr.inputs.Image(type="pil"),
76
- outputs=gr.outputs.Image(type="pil"),
77
  title="Semantic Segmentation",
78
  examples=glob.glob('./examples/*.jpg'),
79
  allow_flagging="never",
80
-
81
  )
82
 
83
  demo.launch()
 
6
  import numpy as np
7
  from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
8
 
 
9
  from scipy.ndimage import center_of_mass
10
 
 
11
  def combine_ims(im1, im2, val=128):
12
+ p = Image.new("L", im1.size, val)
13
+ im = Image.composite(im1, im2, p)
14
+ return im
15
 
16
  def get_class_centers(segmentation_mask, class_dict):
17
  segmentation_mask = segmentation_mask.numpy() + 1
 
22
 
23
  class_centers[class_index] = center_of_mass_list
24
 
25
+ class_centers = {k: list(map(int, v)) for k, v in class_centers.items() if not np.isnan(sum(v))}
26
  return class_centers
27
 
28
  def visualize_mask(predicted_semantic_map, class_ids, class_colors):
29
+ h, w = predicted_semantic_map.shape
30
+ color_indexes = np.zeros((h, w), dtype=np.uint8)
31
+ color_indexes[:] = predicted_semantic_map.numpy()
32
+ color_indexes = color_indexes.flatten()
 
 
 
 
 
33
 
34
+ colors = class_colors[class_ids[color_indexes]]
35
+ output = colors.reshape(h, w, 3).astype(np.uint8)
36
+ image_mask = Image.fromarray(output)
37
+ return image_mask
38
 
39
  def get_out_image(image, predicted_semantic_map):
40
+ class_centers = get_class_centers(predicted_semantic_map, class_dict)
41
+ mask = visualize_mask(predicted_semantic_map, class_ids, class_colors)
42
+ image_mask = combine_ims(image, mask, val=128)
43
+ draw = ImageDraw.Draw(image_mask)
44
+
45
+ extracted_tags = []
46
+ for id, (y, x) in class_centers.items():
47
+ class_name = str(class_names[id - 1])
48
+ extracted_tags.append({"class_name": class_name, "coordinates": (x, y)})
49
+ draw.text((x, y), class_name, fill='black')
50
 
51
+ return image_mask, extracted_tags
52
 
53
  def gradio_process(image):
54
+ inputs = processor(images=image, return_tensors="pt")
55
 
56
+ with torch.no_grad():
57
+ outputs = model(**inputs)
58
 
59
+ predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
60
 
61
+ out_image, extracted_tags = get_out_image(image, predicted_semantic_map)
62
+ return out_image, extracted_tags
63
 
64
  with open('ade20k_classes.pickle', 'rb') as f:
65
+ class_names, class_ids, class_colors = pickle.load(f)
66
  class_names, class_ids, class_colors = np.array(class_names), np.array(class_ids), np.array(class_colors)
67
  class_dict = dict(zip(class_ids, class_names))
68
 
 
74
  demo = gr.Interface(
75
  gradio_process,
76
  inputs=gr.inputs.Image(type="pil"),
77
+ outputs=[gr.outputs.Image(type="pil"), gr.outputs.JSON()],
78
  title="Semantic Segmentation",
79
  examples=glob.glob('./examples/*.jpg'),
80
  allow_flagging="never",
 
81
  )
82
 
83
  demo.launch()