byrkbrk commited on
Commit
beaa9e8
·
verified ·
1 Parent(s): 43d05f7

upload files

Browse files
Files changed (3) hide show
  1. app.py +13 -0
  2. image_blurring.py +121 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from image_blurring import BlurImage
3
+
4
+
5
+
6
+ if __name__ == "__main__":
7
+ blur_image = BlurImage(device=None)
8
+ gr_interface = gr.Interface(
9
+ fn=lambda image, prompt, save=False: blur_image.blur(image, prompt.split("\n"), save=save),
10
+ inputs=[gr.Image(type="pil"), gr.Textbox(lines=3, placeholder="jacket\ndog head\netc...")],
11
+ outputs=gr.Image(type="pil")
12
+ )
13
+ gr_interface.launch()
image_blurring.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import pipeline
4
+ from ultralytics import SAM
5
+ from PIL import Image, ImageFilter
6
+ from torchvision import transforms
7
+
8
+
9
+
10
+ class BlurImage(object):
11
+ def __init__(self,
12
+ device="mps",
13
+ owlvit_ckpt="google/owlv2-base-patch16-ensemble",
14
+ owlvit_task="zero-shot-object-detection",
15
+ sam_ckpt="mobile_sam.pt",
16
+ ):
17
+ self.module_dir = os.path.dirname(__file__)
18
+ self.device = self.initialize_device(device)
19
+ self.owlvit = pipeline(model=owlvit_ckpt, task=owlvit_task, device=self.device)
20
+ self.sam = SAM(model=sam_ckpt)
21
+ self.create_dirs(self.module_dir)
22
+
23
+ def blur(self, image, text_prompts, labels=None, save=True, size=None):
24
+ """Returns blurred image based on given text prompt"""
25
+ if type(image) == str:
26
+ image_name = image
27
+ image = self.read_image(self.module_dir, "images-to-blur", image_name, size)
28
+ else:
29
+ image_name = "noname"
30
+ if size: image = self.resize_image(image, size)
31
+
32
+ aggregated_mask = self.get_aggregated_mask(image, text_prompts, labels)
33
+ blurred_image = self.blur_entire_image(image)
34
+ blurred_image[:, ~aggregated_mask] = transforms.functional.pil_to_tensor(image)[:, ~aggregated_mask]
35
+ blurred_image = transforms.functional.to_pil_image(blurred_image)
36
+ if save:
37
+ blurred_image.save(os.path.join(self.module_dir, "blurred-images", os.path.splitext(image_name)[0] + "_blurred_image.jpg"))
38
+ return blurred_image
39
+
40
+ def get_aggregated_mask(self, image, text_prompts, labels=None):
41
+ """Returns aggregated mask for given image, text prompts, and labels"""
42
+ aggregated_mask = 0
43
+ for annotation in self.get_annotations(image, text_prompts, labels):
44
+ aggregated_mask += annotation["segmentation"].float()
45
+ return aggregated_mask > 0
46
+
47
+ def get_annotations(self, image, text_prompts, labels=None):
48
+ """Returns annotations predicted by SAM"""
49
+ return self.annotate_sam_inference(*self.predict_masks(image,
50
+ self.predict_boxes(image, text_prompts),
51
+ labels))
52
+
53
+ def predict_masks(self, image, box_prompts, labels=None):
54
+ """Returns predicted masks by SAM"""
55
+ if labels is None:
56
+ labels = [1]*len(box_prompts)
57
+ return self.sam(image, bboxes=box_prompts, labels=labels)
58
+
59
+ def predict_boxes(self, image, prompts):
60
+ """Returns bounding boxes for given image and prompts"""
61
+ return [list(pred["box"].values()) for pred in self.owlvit(image, prompts)]
62
+
63
+ def initialize_device(self, device):
64
+ """Initializes device based on availability"""
65
+ if device is None:
66
+ if torch.cuda.is_available():
67
+ device = "cuda"
68
+ elif torch.backends.mps.is_available():
69
+ device = "mps"
70
+ else:
71
+ device = "cpu"
72
+ return torch.device(device)
73
+
74
+ def blur_entire_image(self, image, radius=50):
75
+ """Returns Gaussian-blurred image"""
76
+ return transforms.functional.pil_to_tensor(image.filter(ImageFilter.GaussianBlur(radius=radius)))
77
+
78
+ def annotate_sam_inference(self, inference, area_threshold=0):
79
+ """Returns list of annotation dicts
80
+ Args:
81
+ inference (ultralytics.engine.results.Results): Output of the model
82
+ area_threshold (int): Threshold for the segmentation area
83
+ """
84
+ annotations = []
85
+ for i in range(len(inference.masks.data)):
86
+ annotation = {}
87
+ annotation["id"] = i
88
+ annotation["segmentation"] = inference.masks.data[i].cpu()==1
89
+ annotation["area"] = annotation["segmentation"].sum()
90
+
91
+ if annotation["area"] >= area_threshold:
92
+ annotations += [annotation]
93
+ return annotations
94
+
95
+ def read_image(self, root, base_folder, image_name, size=None):
96
+ """Returns the openned for given image name base folder, and root"""
97
+ image = Image.open(os.path.join(root, base_folder, image_name))
98
+ if size:
99
+ image = self.resize_image(image, size)
100
+ return image
101
+
102
+ def resize_image(self, image, size):
103
+ """Returns resized image"""
104
+ return image.resize(size)
105
+
106
+ def create_dirs(self, root):
107
+ """Creates required directories under the given root"""
108
+ dir_names = ["images-to-blur", "blurred-images"]
109
+ for dir_name in dir_names:
110
+ os.makedirs(os.path.join(root, dir_name), exist_ok=True)
111
+
112
+
113
+ if __name__ == "__main__":
114
+ image = Image.open("images-to-blur/dogs.jpg")
115
+ image = image.resize((1024, 1024))
116
+ blur_image = BlurImage(device="cpu")
117
+ #annotations = blur_image.get_annotations(image, ["small nose"])
118
+ #print(annotations)
119
+ #print(blur_image.get_aggregated_mask(image, ["small nose"]).shape)
120
+ image = "dogs.jpg"
121
+ blur_image.blur(image, ["jacket"])
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ ultralytics
4
+ gradio
5
+ scipy
6
+ torchvision