File size: 5,156 Bytes
beaa9e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497580e
beaa9e8
 
 
 
 
 
 
 
 
497580e
beaa9e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497580e
beaa9e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import torch
from transformers import pipeline
from ultralytics import SAM
from PIL import Image, ImageFilter
from torchvision import transforms



class BlurImage(object):
    def __init__(self, 
                 device="mps",
                 owlvit_ckpt="google/owlv2-base-patch16-ensemble",
                 owlvit_task="zero-shot-object-detection",
                 sam_ckpt="mobile_sam.pt",
                 ):
        self.module_dir = os.path.dirname(__file__)
        self.device = self.initialize_device(device)
        self.owlvit = pipeline(model=owlvit_ckpt, task=owlvit_task, device=self.device)
        self.sam = SAM(model=sam_ckpt)
        self.create_dirs(self.module_dir)

    def blur(self, image, text_prompts, blur_intensity=50, labels=None, save=True, size=None):
        """Returns blurred image based on given text prompt"""
        if type(image) == str: 
            image_name = image
            image = self.read_image(self.module_dir, "images-to-blur", image_name, size)
        else:
            image_name = "noname"
            if size: image = self.resize_image(image, size)

        aggregated_mask = self.get_aggregated_mask(image, text_prompts, labels)
        blurred_image = self.blur_entire_image(image, radius=blur_intensity)
        blurred_image[:, ~aggregated_mask] = transforms.functional.pil_to_tensor(image)[:, ~aggregated_mask]
        blurred_image = transforms.functional.to_pil_image(blurred_image)
        if save:
            blurred_image.save(os.path.join(self.module_dir, "blurred-images", os.path.splitext(image_name)[0] + "_blurred_image.jpg"))
        return blurred_image

    def get_aggregated_mask(self, image, text_prompts, labels=None):
        """Returns aggregated mask for given image, text prompts, and labels"""
        aggregated_mask = 0
        for annotation in self.get_annotations(image, text_prompts, labels):
            aggregated_mask += annotation["segmentation"].float()
        return aggregated_mask > 0
    
    def get_annotations(self, image, text_prompts, labels=None):
        """Returns annotations predicted by SAM"""
        return self.annotate_sam_inference(*self.predict_masks(image, 
                                                               self.predict_boxes(image, text_prompts),
                                                               labels))

    def predict_masks(self, image, box_prompts, labels=None):
        """Returns predicted masks by SAM"""
        if labels is None:
            labels = [1]*len(box_prompts)
        return self.sam(image, bboxes=box_prompts, labels=labels)
    
    def predict_boxes(self, image, prompts):
        """Returns bounding boxes for given image and prompts"""
        return [list(pred["box"].values()) for pred in self.owlvit(image, prompts)]

    def initialize_device(self, device):
        """Initializes device based on availability"""
        if device is None:
            if torch.cuda.is_available():
                device = "cuda"
            elif torch.backends.mps.is_available():
                device = "mps"
            else:
                device = "cpu"
        return torch.device(device)

    def blur_entire_image(self, image, radius=50):
        """Returns Gaussian-blurred image"""
        return transforms.functional.pil_to_tensor(image.filter(ImageFilter.GaussianBlur(radius=radius)))
    
    def annotate_sam_inference(self, inference, area_threshold=0):
        """Returns list of annotation dicts
        Args:
            inference (ultralytics.engine.results.Results): Output of the model
            area_threshold (int): Threshold for the segmentation area
        """
        annotations = []
        for i in range(len(inference.masks.data)):
            annotation = {}
            annotation["id"] = i
            annotation["segmentation"] = inference.masks.data[i].cpu()==1
            annotation["area"] = annotation["segmentation"].sum()

            if annotation["area"] >= area_threshold:
                annotations += [annotation]
        return annotations
    
    def read_image(self, root, base_folder, image_name, size=None):
        """Returns the openned for given image name base folder, and root"""
        image = Image.open(os.path.join(root, base_folder, image_name))
        if size:
            image = self.resize_image(image, size)
        return image
    
    def resize_image(self, image, size):
        """Returns resized image"""
        return image.resize(size)
    
    def create_dirs(self, root):
        """Creates required directories under the given root"""
        dir_names = ["images-to-blur", "blurred-images"]
        for dir_name in dir_names:
            os.makedirs(os.path.join(root, dir_name), exist_ok=True)
        

if __name__ == "__main__":
    image = Image.open("images-to-blur/dogs.jpg")
    image = image.resize((1024, 1024))
    blur_image = BlurImage(device="cpu")
    #annotations = blur_image.get_annotations(image, ["small nose"])
    #print(annotations)
    #print(blur_image.get_aggregated_mask(image, ["small nose"]).shape)
    image = "dogs.jpg"
    blur_image.blur(image, ["jacket"])