Spaces:
Runtime error
Runtime error
upload files
Browse files- app.py +13 -0
- image_blurring.py +121 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from image_blurring import BlurImage
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
if __name__ == "__main__":
|
7 |
+
blur_image = BlurImage(device=None)
|
8 |
+
gr_interface = gr.Interface(
|
9 |
+
fn=lambda image, prompt, save=False: blur_image.blur(image, prompt.split("\n"), save=save),
|
10 |
+
inputs=[gr.Image(type="pil"), gr.Textbox(lines=3, placeholder="jacket\ndog head\netc...")],
|
11 |
+
outputs=gr.Image(type="pil")
|
12 |
+
)
|
13 |
+
gr_interface.launch()
|
image_blurring.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from transformers import pipeline
|
4 |
+
from ultralytics import SAM
|
5 |
+
from PIL import Image, ImageFilter
|
6 |
+
from torchvision import transforms
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class BlurImage(object):
|
11 |
+
def __init__(self,
|
12 |
+
device="mps",
|
13 |
+
owlvit_ckpt="google/owlv2-base-patch16-ensemble",
|
14 |
+
owlvit_task="zero-shot-object-detection",
|
15 |
+
sam_ckpt="mobile_sam.pt",
|
16 |
+
):
|
17 |
+
self.module_dir = os.path.dirname(__file__)
|
18 |
+
self.device = self.initialize_device(device)
|
19 |
+
self.owlvit = pipeline(model=owlvit_ckpt, task=owlvit_task, device=self.device)
|
20 |
+
self.sam = SAM(model=sam_ckpt)
|
21 |
+
self.create_dirs(self.module_dir)
|
22 |
+
|
23 |
+
def blur(self, image, text_prompts, labels=None, save=True, size=None):
|
24 |
+
"""Returns blurred image based on given text prompt"""
|
25 |
+
if type(image) == str:
|
26 |
+
image_name = image
|
27 |
+
image = self.read_image(self.module_dir, "images-to-blur", image_name, size)
|
28 |
+
else:
|
29 |
+
image_name = "noname"
|
30 |
+
if size: image = self.resize_image(image, size)
|
31 |
+
|
32 |
+
aggregated_mask = self.get_aggregated_mask(image, text_prompts, labels)
|
33 |
+
blurred_image = self.blur_entire_image(image)
|
34 |
+
blurred_image[:, ~aggregated_mask] = transforms.functional.pil_to_tensor(image)[:, ~aggregated_mask]
|
35 |
+
blurred_image = transforms.functional.to_pil_image(blurred_image)
|
36 |
+
if save:
|
37 |
+
blurred_image.save(os.path.join(self.module_dir, "blurred-images", os.path.splitext(image_name)[0] + "_blurred_image.jpg"))
|
38 |
+
return blurred_image
|
39 |
+
|
40 |
+
def get_aggregated_mask(self, image, text_prompts, labels=None):
|
41 |
+
"""Returns aggregated mask for given image, text prompts, and labels"""
|
42 |
+
aggregated_mask = 0
|
43 |
+
for annotation in self.get_annotations(image, text_prompts, labels):
|
44 |
+
aggregated_mask += annotation["segmentation"].float()
|
45 |
+
return aggregated_mask > 0
|
46 |
+
|
47 |
+
def get_annotations(self, image, text_prompts, labels=None):
|
48 |
+
"""Returns annotations predicted by SAM"""
|
49 |
+
return self.annotate_sam_inference(*self.predict_masks(image,
|
50 |
+
self.predict_boxes(image, text_prompts),
|
51 |
+
labels))
|
52 |
+
|
53 |
+
def predict_masks(self, image, box_prompts, labels=None):
|
54 |
+
"""Returns predicted masks by SAM"""
|
55 |
+
if labels is None:
|
56 |
+
labels = [1]*len(box_prompts)
|
57 |
+
return self.sam(image, bboxes=box_prompts, labels=labels)
|
58 |
+
|
59 |
+
def predict_boxes(self, image, prompts):
|
60 |
+
"""Returns bounding boxes for given image and prompts"""
|
61 |
+
return [list(pred["box"].values()) for pred in self.owlvit(image, prompts)]
|
62 |
+
|
63 |
+
def initialize_device(self, device):
|
64 |
+
"""Initializes device based on availability"""
|
65 |
+
if device is None:
|
66 |
+
if torch.cuda.is_available():
|
67 |
+
device = "cuda"
|
68 |
+
elif torch.backends.mps.is_available():
|
69 |
+
device = "mps"
|
70 |
+
else:
|
71 |
+
device = "cpu"
|
72 |
+
return torch.device(device)
|
73 |
+
|
74 |
+
def blur_entire_image(self, image, radius=50):
|
75 |
+
"""Returns Gaussian-blurred image"""
|
76 |
+
return transforms.functional.pil_to_tensor(image.filter(ImageFilter.GaussianBlur(radius=radius)))
|
77 |
+
|
78 |
+
def annotate_sam_inference(self, inference, area_threshold=0):
|
79 |
+
"""Returns list of annotation dicts
|
80 |
+
Args:
|
81 |
+
inference (ultralytics.engine.results.Results): Output of the model
|
82 |
+
area_threshold (int): Threshold for the segmentation area
|
83 |
+
"""
|
84 |
+
annotations = []
|
85 |
+
for i in range(len(inference.masks.data)):
|
86 |
+
annotation = {}
|
87 |
+
annotation["id"] = i
|
88 |
+
annotation["segmentation"] = inference.masks.data[i].cpu()==1
|
89 |
+
annotation["area"] = annotation["segmentation"].sum()
|
90 |
+
|
91 |
+
if annotation["area"] >= area_threshold:
|
92 |
+
annotations += [annotation]
|
93 |
+
return annotations
|
94 |
+
|
95 |
+
def read_image(self, root, base_folder, image_name, size=None):
|
96 |
+
"""Returns the openned for given image name base folder, and root"""
|
97 |
+
image = Image.open(os.path.join(root, base_folder, image_name))
|
98 |
+
if size:
|
99 |
+
image = self.resize_image(image, size)
|
100 |
+
return image
|
101 |
+
|
102 |
+
def resize_image(self, image, size):
|
103 |
+
"""Returns resized image"""
|
104 |
+
return image.resize(size)
|
105 |
+
|
106 |
+
def create_dirs(self, root):
|
107 |
+
"""Creates required directories under the given root"""
|
108 |
+
dir_names = ["images-to-blur", "blurred-images"]
|
109 |
+
for dir_name in dir_names:
|
110 |
+
os.makedirs(os.path.join(root, dir_name), exist_ok=True)
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
image = Image.open("images-to-blur/dogs.jpg")
|
115 |
+
image = image.resize((1024, 1024))
|
116 |
+
blur_image = BlurImage(device="cpu")
|
117 |
+
#annotations = blur_image.get_annotations(image, ["small nose"])
|
118 |
+
#print(annotations)
|
119 |
+
#print(blur_image.get_aggregated_mask(image, ["small nose"]).shape)
|
120 |
+
image = "dogs.jpg"
|
121 |
+
blur_image.blur(image, ["jacket"])
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
ultralytics
|
4 |
+
gradio
|
5 |
+
scipy
|
6 |
+
torchvision
|