Blur-image-via-prompts / image_blurring.py
byrkbrk's picture
update files
497580e verified
import os
import torch
from transformers import pipeline
from ultralytics import SAM
from PIL import Image, ImageFilter
from torchvision import transforms
class BlurImage(object):
def __init__(self,
device="mps",
owlvit_ckpt="google/owlv2-base-patch16-ensemble",
owlvit_task="zero-shot-object-detection",
sam_ckpt="mobile_sam.pt",
):
self.module_dir = os.path.dirname(__file__)
self.device = self.initialize_device(device)
self.owlvit = pipeline(model=owlvit_ckpt, task=owlvit_task, device=self.device)
self.sam = SAM(model=sam_ckpt)
self.create_dirs(self.module_dir)
def blur(self, image, text_prompts, blur_intensity=50, labels=None, save=True, size=None):
"""Returns blurred image based on given text prompt"""
if type(image) == str:
image_name = image
image = self.read_image(self.module_dir, "images-to-blur", image_name, size)
else:
image_name = "noname"
if size: image = self.resize_image(image, size)
aggregated_mask = self.get_aggregated_mask(image, text_prompts, labels)
blurred_image = self.blur_entire_image(image, radius=blur_intensity)
blurred_image[:, ~aggregated_mask] = transforms.functional.pil_to_tensor(image)[:, ~aggregated_mask]
blurred_image = transforms.functional.to_pil_image(blurred_image)
if save:
blurred_image.save(os.path.join(self.module_dir, "blurred-images", os.path.splitext(image_name)[0] + "_blurred_image.jpg"))
return blurred_image
def get_aggregated_mask(self, image, text_prompts, labels=None):
"""Returns aggregated mask for given image, text prompts, and labels"""
aggregated_mask = 0
for annotation in self.get_annotations(image, text_prompts, labels):
aggregated_mask += annotation["segmentation"].float()
return aggregated_mask > 0
def get_annotations(self, image, text_prompts, labels=None):
"""Returns annotations predicted by SAM"""
return self.annotate_sam_inference(*self.predict_masks(image,
self.predict_boxes(image, text_prompts),
labels))
def predict_masks(self, image, box_prompts, labels=None):
"""Returns predicted masks by SAM"""
if labels is None:
labels = [1]*len(box_prompts)
return self.sam(image, bboxes=box_prompts, labels=labels)
def predict_boxes(self, image, prompts):
"""Returns bounding boxes for given image and prompts"""
return [list(pred["box"].values()) for pred in self.owlvit(image, prompts)]
def initialize_device(self, device):
"""Initializes device based on availability"""
if device is None:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
return torch.device(device)
def blur_entire_image(self, image, radius=50):
"""Returns Gaussian-blurred image"""
return transforms.functional.pil_to_tensor(image.filter(ImageFilter.GaussianBlur(radius=radius)))
def annotate_sam_inference(self, inference, area_threshold=0):
"""Returns list of annotation dicts
Args:
inference (ultralytics.engine.results.Results): Output of the model
area_threshold (int): Threshold for the segmentation area
"""
annotations = []
for i in range(len(inference.masks.data)):
annotation = {}
annotation["id"] = i
annotation["segmentation"] = inference.masks.data[i].cpu()==1
annotation["area"] = annotation["segmentation"].sum()
if annotation["area"] >= area_threshold:
annotations += [annotation]
return annotations
def read_image(self, root, base_folder, image_name, size=None):
"""Returns the openned for given image name base folder, and root"""
image = Image.open(os.path.join(root, base_folder, image_name))
if size:
image = self.resize_image(image, size)
return image
def resize_image(self, image, size):
"""Returns resized image"""
return image.resize(size)
def create_dirs(self, root):
"""Creates required directories under the given root"""
dir_names = ["images-to-blur", "blurred-images"]
for dir_name in dir_names:
os.makedirs(os.path.join(root, dir_name), exist_ok=True)
if __name__ == "__main__":
image = Image.open("images-to-blur/dogs.jpg")
image = image.resize((1024, 1024))
blur_image = BlurImage(device="cpu")
#annotations = blur_image.get_annotations(image, ["small nose"])
#print(annotations)
#print(blur_image.get_aggregated_mask(image, ["small nose"]).shape)
image = "dogs.jpg"
blur_image.blur(image, ["jacket"])