bienom commited on
Commit
45b7dee
1 Parent(s): 122ca42

Init project

Browse files
Files changed (3) hide show
  1. app.py +15 -0
  2. model.py +107 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ from model import SegmentationTool
4
+
5
+ seg_tool = SegmentationTool()
6
+
7
+
8
+ def segment(input_img):
9
+ mask_image, transparent_mask_image, image, items, room = seg_tool.get_mask(input_img)
10
+ return mask_image
11
+
12
+
13
+ demo = gr.Interface(segment, gr.Image(), "image")
14
+ if __name__ == "__main__":
15
+ demo.launch()
model.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from PIL import Image
3
+ import numpy as np
4
+ from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
5
+ from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
6
+
7
+
8
+ class SegmentationTool:
9
+
10
+ def __init__(self,
11
+ segmentation_version='nvidia/segformer-b5-finetuned-ade-640-640'):
12
+
13
+ self.segmentation_version = segmentation_version
14
+
15
+ if segmentation_version == "openmmlab/upernet-convnext-tiny":
16
+ self.feature_extractor = AutoImageProcessor.from_pretrained(self.segmentation_version)
17
+ self.segmentation_model = UperNetForSemanticSegmentation.from_pretrained(self.segmentation_version)
18
+ elif segmentation_version == "nvidia/segformer-b5-finetuned-ade-640-640":
19
+ self.feature_extractor = SegformerFeatureExtractor.from_pretrained(self.segmentation_version)
20
+ self.segmentation_model = SegformerForSemanticSegmentation.from_pretrained(self.segmentation_version)
21
+
22
+ def _predict(self, image):
23
+ inputs = self.feature_extractor(images=image, return_tensors="pt")
24
+ outputs = self.segmentation_model(**inputs)
25
+ prediction = \
26
+ self.feature_extractor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
27
+ return prediction
28
+
29
+ def _save_mask(self, prediction_array, mask_items=[]):
30
+ mask = np.zeros_like(prediction_array, dtype=np.uint8)
31
+
32
+ mask[np.isin(prediction_array, mask_items)] = 0
33
+ mask[~np.isin(prediction_array, mask_items)] = 255
34
+
35
+ buffer_size = 10
36
+
37
+ # Dilate the binary image
38
+ kernel = np.ones((buffer_size, buffer_size), np.uint8)
39
+ dilated_image = cv2.dilate(mask, kernel, iterations=1)
40
+
41
+ # Subtract the original binary image
42
+ buffer_area = dilated_image - mask
43
+
44
+ # Apply buffer area to the original image
45
+ mask = cv2.bitwise_or(mask, buffer_area)
46
+
47
+ # # # Create a PIL Image object from the mask
48
+ mask_image = Image.fromarray(mask, mode='L')
49
+ # display(mask_image)
50
+
51
+ # mask_image = mask_image.resize((512, 512))
52
+ # mask_image.save(".tmp/mask_1.png", "PNG")
53
+ # img = img.resize((512, 512))
54
+ # img.save(".tmp/input_1.png", "PNG")
55
+ return mask_image
56
+
57
+ def _save_transparent_mask(self, img, prediction_array, mask_items=None):
58
+ if mask_items is None:
59
+ mask_items = []
60
+ mask = np.array(img)
61
+ mask[~np.isin(prediction_array, mask_items), :] = 255
62
+ mask_image = Image.fromarray(mask).convert('RGBA')
63
+
64
+ # Set the transparency of the pixels corresponding to object 1 to 0 (fully transparent)
65
+ mask_data = mask_image.getdata()
66
+ mask_data = [(r, g, b, 0) if r == 255 else (r, g, b, 255) for (r, g, b, a) in mask_data]
67
+ mask_image.putdata(mask_data)
68
+
69
+ return mask_image
70
+
71
+ def get_mask(self, image_path=None, image=None):
72
+ if image_path:
73
+ image = Image.open(image_path)
74
+ else:
75
+ if not image:
76
+ raise ValueError("no image provided")
77
+
78
+ # display(image)
79
+ prediction = self._predict(image)
80
+
81
+ label_ids = np.unique(prediction)
82
+
83
+ # mask_items = [0, 3, 5, 8, 14]
84
+ mask_items = [8] # windowpane
85
+
86
+ if 73 in label_ids or 50 in label_ids or 61 in label_ids:
87
+ # mask_items = [0, 3, 5, 8, 14, 50, 61, 71, 73, 118, 124, 129]
88
+ room = 'kitchen'
89
+ elif 37 in label_ids or 65 in label_ids or (27 in label_ids and 47 in label_ids and 70 in label_ids):
90
+ # mask_items = [0, 3, 5, 8, 14, 27, 65]
91
+ room = 'bathroom'
92
+ elif 7 in label_ids:
93
+ room = 'bedroom'
94
+ elif 23 in label_ids or 49 in label_ids:
95
+ # mask_items = [0, 3, 5, 8, 14, 49]
96
+ room = 'living room'
97
+ elif 15 in label_ids and 19 in label_ids:
98
+ room = 'dining room'
99
+ else:
100
+ room = 'room'
101
+ label_ids_without_mask = [i for i in label_ids if i not in mask_items]
102
+
103
+ items = [self.segmentation_model.config.id2label[i] for i in label_ids_without_mask]
104
+
105
+ mask_image = self._save_mask(prediction, mask_items)
106
+ transparent_mask_image = self._save_transparent_mask(image, prediction, mask_items)
107
+ return mask_image, transparent_mask_image, image, items, room
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy
2
+ Pillow
3
+ opencv-python
4
+ transformers