Spaces:
Runtime error
Runtime error
File size: 5,156 Bytes
beaa9e8 497580e beaa9e8 497580e beaa9e8 497580e beaa9e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
import torch
from transformers import pipeline
from ultralytics import SAM
from PIL import Image, ImageFilter
from torchvision import transforms
class BlurImage(object):
def __init__(self,
device="mps",
owlvit_ckpt="google/owlv2-base-patch16-ensemble",
owlvit_task="zero-shot-object-detection",
sam_ckpt="mobile_sam.pt",
):
self.module_dir = os.path.dirname(__file__)
self.device = self.initialize_device(device)
self.owlvit = pipeline(model=owlvit_ckpt, task=owlvit_task, device=self.device)
self.sam = SAM(model=sam_ckpt)
self.create_dirs(self.module_dir)
def blur(self, image, text_prompts, blur_intensity=50, labels=None, save=True, size=None):
"""Returns blurred image based on given text prompt"""
if type(image) == str:
image_name = image
image = self.read_image(self.module_dir, "images-to-blur", image_name, size)
else:
image_name = "noname"
if size: image = self.resize_image(image, size)
aggregated_mask = self.get_aggregated_mask(image, text_prompts, labels)
blurred_image = self.blur_entire_image(image, radius=blur_intensity)
blurred_image[:, ~aggregated_mask] = transforms.functional.pil_to_tensor(image)[:, ~aggregated_mask]
blurred_image = transforms.functional.to_pil_image(blurred_image)
if save:
blurred_image.save(os.path.join(self.module_dir, "blurred-images", os.path.splitext(image_name)[0] + "_blurred_image.jpg"))
return blurred_image
def get_aggregated_mask(self, image, text_prompts, labels=None):
"""Returns aggregated mask for given image, text prompts, and labels"""
aggregated_mask = 0
for annotation in self.get_annotations(image, text_prompts, labels):
aggregated_mask += annotation["segmentation"].float()
return aggregated_mask > 0
def get_annotations(self, image, text_prompts, labels=None):
"""Returns annotations predicted by SAM"""
return self.annotate_sam_inference(*self.predict_masks(image,
self.predict_boxes(image, text_prompts),
labels))
def predict_masks(self, image, box_prompts, labels=None):
"""Returns predicted masks by SAM"""
if labels is None:
labels = [1]*len(box_prompts)
return self.sam(image, bboxes=box_prompts, labels=labels)
def predict_boxes(self, image, prompts):
"""Returns bounding boxes for given image and prompts"""
return [list(pred["box"].values()) for pred in self.owlvit(image, prompts)]
def initialize_device(self, device):
"""Initializes device based on availability"""
if device is None:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
return torch.device(device)
def blur_entire_image(self, image, radius=50):
"""Returns Gaussian-blurred image"""
return transforms.functional.pil_to_tensor(image.filter(ImageFilter.GaussianBlur(radius=radius)))
def annotate_sam_inference(self, inference, area_threshold=0):
"""Returns list of annotation dicts
Args:
inference (ultralytics.engine.results.Results): Output of the model
area_threshold (int): Threshold for the segmentation area
"""
annotations = []
for i in range(len(inference.masks.data)):
annotation = {}
annotation["id"] = i
annotation["segmentation"] = inference.masks.data[i].cpu()==1
annotation["area"] = annotation["segmentation"].sum()
if annotation["area"] >= area_threshold:
annotations += [annotation]
return annotations
def read_image(self, root, base_folder, image_name, size=None):
"""Returns the openned for given image name base folder, and root"""
image = Image.open(os.path.join(root, base_folder, image_name))
if size:
image = self.resize_image(image, size)
return image
def resize_image(self, image, size):
"""Returns resized image"""
return image.resize(size)
def create_dirs(self, root):
"""Creates required directories under the given root"""
dir_names = ["images-to-blur", "blurred-images"]
for dir_name in dir_names:
os.makedirs(os.path.join(root, dir_name), exist_ok=True)
if __name__ == "__main__":
image = Image.open("images-to-blur/dogs.jpg")
image = image.resize((1024, 1024))
blur_image = BlurImage(device="cpu")
#annotations = blur_image.get_annotations(image, ["small nose"])
#print(annotations)
#print(blur_image.get_aggregated_mask(image, ["small nose"]).shape)
image = "dogs.jpg"
blur_image.blur(image, ["jacket"])
|