File size: 4,452 Bytes
45b7dee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e04da3e
45b7dee
 
 
59390e7
45b7dee
 
 
307dfdb
45b7dee
 
 
 
 
307dfdb
e04da3e
 
45b7dee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import cv2
from PIL import Image
import numpy as np
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation


class SegmentationTool:

    def __init__(self,
                 segmentation_version='nvidia/segformer-b5-finetuned-ade-640-640'):

        self.segmentation_version = segmentation_version

        if segmentation_version == "openmmlab/upernet-convnext-tiny":
            self.feature_extractor = AutoImageProcessor.from_pretrained(self.segmentation_version)
            self.segmentation_model = UperNetForSemanticSegmentation.from_pretrained(self.segmentation_version)
        elif segmentation_version == "nvidia/segformer-b5-finetuned-ade-640-640":
            self.feature_extractor = SegformerFeatureExtractor.from_pretrained(self.segmentation_version)
            self.segmentation_model = SegformerForSemanticSegmentation.from_pretrained(self.segmentation_version)

    def _predict(self, image):
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        outputs = self.segmentation_model(**inputs)
        prediction = \
            self.feature_extractor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
        return prediction

    def _save_mask(self, prediction_array, mask_items=[]):
        mask = np.zeros_like(prediction_array, dtype=np.uint8)

        mask[np.isin(prediction_array, mask_items)] = 0
        mask[~np.isin(prediction_array, mask_items)] = 255

        buffer_size = 10

        # Dilate the binary image
        kernel = np.ones((buffer_size, buffer_size), np.uint8)
        dilated_image = cv2.dilate(mask, kernel, iterations=1)

        # Subtract the original binary image
        buffer_area = dilated_image - mask

        # Apply buffer area to the original image
        mask = cv2.bitwise_or(mask, buffer_area)

        #     # # Create a PIL Image object from the mask
        mask_image = Image.fromarray(mask, mode='L')
        # display(mask_image)

        # mask_image = mask_image.resize((512, 512))
        # mask_image.save(".tmp/mask_1.png", "PNG")
        # img = img.resize((512, 512))
        # img.save(".tmp/input_1.png", "PNG")
        return mask_image

    def _save_transparent_mask(self, img, prediction_array, mask_items=None):
        if mask_items is None:
            mask_items = []
        mask = np.array(img)
        mask[~np.isin(prediction_array, mask_items), :] = 255
        mask_image = Image.fromarray(mask).convert('RGBA')

        # Set the transparency of the pixels corresponding to object 1 to 0 (fully transparent)
        mask_data = mask_image.getdata()
        mask_data = [(r, g, b, 0) if r == 255 else (r, g, b, 255) for (r, g, b, a) in mask_data]
        mask_image.putdata(mask_data)

        return mask_image

    def get_mask(self, image_path=None, image=None, mask_items=None):
        if image_path:
            image = Image.open(image_path)
        else:
            if image is None:
                raise ValueError("no image provided")

        # display(image)
        # print(image)
        prediction = self._predict(image)

        label_ids = np.unique(prediction)

        # mask_items = [0, 3, 5, 8, 14]
        # mask_items = [8] # windowpane
        if mask_items is None:
            mask_items = []
        if 73 in label_ids or 50 in label_ids or 61 in label_ids:
            # mask_items = [0, 3, 5, 8, 14, 50, 61, 71, 73, 118, 124, 129]
            room = 'kitchen'
        elif 37 in label_ids or 65 in label_ids or (27 in label_ids and 47 in label_ids and 70 in label_ids):
            # mask_items = [0, 3, 5, 8, 14, 27, 65]
            room = 'bathroom'
        elif 7 in label_ids:
            room = 'bedroom'
        elif 23 in label_ids or 49 in label_ids:
            # mask_items = [0, 3, 5, 8, 14, 49]
            room = 'living room'
        elif 15 in label_ids and 19 in label_ids:
            room = 'dining room'
        else:
            room = 'room'
        label_ids_without_mask = [i for i in label_ids if i not in mask_items]

        items = [self.segmentation_model.config.id2label[i] for i in label_ids_without_mask]

        mask_image = self._save_mask(prediction, mask_items)
        transparent_mask_image = self._save_transparent_mask(image, prediction, mask_items)
        return mask_image, transparent_mask_image, image, items, room