ImageAlfred / src /modal_app.py
mahan_ym
fixed up privacy preserve
9c43fab
raw
history blame
8.86 kB
import os
from io import BytesIO
import cv2
import modal
import numpy as np
from PIL import Image
app = modal.App("ImageAlfred")
PYTHON_VERSION = "3.12"
CUDA_VERSION = "12.4.0"
FLAVOR = "devel"
OPERATING_SYS = "ubuntu22.04"
tag = f"{CUDA_VERSION}-{FLAVOR}-{OPERATING_SYS}"
volume = modal.Volume.from_name("image-alfred-volume", create_if_missing=True)
volume_path = "/vol"
MODEL_CACHE_DIR = f"{volume_path}/models/cache"
TORCH_HOME = f"{volume_path}/torch/home"
HF_HOME = f"{volume_path}/huggingface"
image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python=PYTHON_VERSION)
.env(
{
"HF_HUB_ENABLE_HF_TRANSFER": "1", # faster downloads
"HF_HUB_CACHE": HF_HOME,
"TORCH_HOME": TORCH_HOME,
}
)
.apt_install("git")
.pip_install(
"huggingface-hub",
"hf_transfer",
"Pillow",
"numpy",
"opencv-contrib-python-headless",
gpu="A10G",
)
.pip_install(
"torch==2.4.1",
"torchvision==0.19.1",
index_url="https://download.pytorch.org/whl/cu124",
gpu="A10G",
)
.pip_install(
"git+https://github.com/luca-medeiros/lang-segment-anything.git",
gpu="A10G",
)
)
@app.function(
gpu="A10G",
image=image,
volumes={volume_path: volume},
)
def lang_sam_segment(
image_pil: Image.Image,
prompt: str,
box_threshold=0.3,
text_threshold=0.25,
) -> list:
"""Segments an image using LangSAM based on a text prompt.
This function uses LangSAM to segment objects in the image based on the provided prompt.
""" # noqa: E501
from lang_sam import LangSAM # type: ignore
os.environ["TORCH_HOME"] = TORCH_HOME
os.environ["HF_HOME"] = HF_HOME
os.makedirs(HF_HOME, exist_ok=True)
os.makedirs(TORCH_HOME, exist_ok=True)
model = LangSAM(sam_type="sam2.1_hiera_large")
langsam_results = model.predict(
images_pil=[image_pil],
texts_prompt=[prompt],
box_threshold=box_threshold,
text_threshold=text_threshold,
)
return langsam_results
@app.function(
gpu="T4",
image=image,
volumes={volume_path: volume},
)
def change_image_objects_hsv(
image_pil: Image.Image,
targets_config: list[list[str | int | float]],
) -> Image.Image:
"""Changes the hue and saturation of specified objects in an image.
This function uses LangSAM to segment objects in the image based on provided prompts,
and then modifies the hue and saturation of those objects in the HSV color space.
""" # noqa: E501
if not isinstance(targets_config, list) or not all(
(
isinstance(target, list)
and len(target) == 3
and isinstance(target[0], str)
and isinstance(target[1], (int, float))
and isinstance(target[2], (int, float))
and 0 <= target[1] <= 179
and target[2] >= 0
)
for target in targets_config
):
raise ValueError(
"targets_config must be a list of lists, each containing [target_name, hue, saturation_scale]." # noqa: E501
)
prompts = ". ".join(target[0] for target in targets_config)
os.environ["TORCH_HOME"] = TORCH_HOME
os.environ["HF_HOME"] = HF_HOME
os.makedirs(HF_HOME, exist_ok=True)
os.makedirs(TORCH_HOME, exist_ok=True)
langsam_results = lang_sam_segment.remote(image_pil=image_pil, prompt=prompts)
labels = langsam_results[0]["labels"]
scores = langsam_results[0]["scores"]
img_array = np.array(image_pil)
img_hsv = cv2.cvtColor(img_array, cv2.COLOR_RGB2HSV).astype(np.float32)
for target_spec in targets_config:
target_obj = target_spec[0]
hue = target_spec[1]
saturation_scale = target_spec[2]
try:
mask_idx = labels.index(target_obj)
except ValueError:
print(
f"Warning: Label '{target_obj}' not found in the image. Skipping this target." # noqa: E501
)
continue
mask = langsam_results[0]["masks"][mask_idx]
mask_bool = mask.astype(bool)
img_hsv[mask_bool, 0] = float(hue)
img_hsv[mask_bool, 1] = np.minimum(
img_hsv[mask_bool, 1] * saturation_scale,
255.0,
)
output_img = cv2.cvtColor(img_hsv.astype(np.uint8), cv2.COLOR_HSV2RGB)
output_img_pil = Image.fromarray(output_img)
return output_img_pil
@app.function(
gpu="T4",
image=image,
volumes={volume_path: volume},
)
def change_image_objects_lab(
image_pil: Image.Image,
targets_config: list[list[str | int | float]],
) -> Image.Image:
"""Changes the color of specified objects in an image.
This function uses LangSAM to segment objects in the image based on provided prompts,
and then modifies the color of those objects in the LAB color space.
""" # noqa: E501
if not isinstance(targets_config, list) or not all(
(
isinstance(target, list)
and len(target) == 3
and isinstance(target[0], str)
and isinstance(target[1], int)
and isinstance(target[2], int)
and 0 <= target[1] <= 255
and 0 <= target[2] <= 255
)
for target in targets_config
):
raise ValueError(
"targets_config must be a list of lists, each containing [target_name, new_a, new_b]." # noqa: E501
)
prompts = ". ".join(target[0] for target in targets_config)
os.environ["TORCH_HOME"] = TORCH_HOME
os.environ["HF_HOME"] = HF_HOME
os.makedirs(HF_HOME, exist_ok=True)
os.makedirs(TORCH_HOME, exist_ok=True)
langsam_results = lang_sam_segment.remote(
image_pil=image_pil,
prompt=prompts,
)
labels = langsam_results[0]["labels"]
scores = langsam_results[0]["scores"]
img_array = np.array(image_pil)
img_lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2Lab).astype(np.float32)
for target_spec in targets_config:
target_obj = target_spec[0]
new_a = target_spec[1]
new_b = target_spec[2]
try:
mask_idx = labels.index(target_obj)
except ValueError:
print(
f"Warning: Label '{target_obj}' not found in the image. Skipping this target." # noqa: E501
)
continue
mask = langsam_results[0]["masks"][mask_idx]
mask_bool = mask.astype(bool)
img_lab[mask_bool, 1] = new_a
img_lab[mask_bool, 2] = new_b
output_img = cv2.cvtColor(img_lab.astype(np.uint8), cv2.COLOR_Lab2RGB)
output_img_pil = Image.fromarray(output_img)
return output_img_pil
@app.function(
gpu="T4",
image=image,
volumes={volume_path: volume},
)
def apply_mosaic_with_bool_mask(image, mask, intensity: int = 50):
h, w = image.shape[:2]
block_size = max(1, min(intensity, min(h, w)))
small = cv2.resize(
image, (w // block_size, h // block_size), interpolation=cv2.INTER_LINEAR
)
mosaic = cv2.resize(small, (w, h), interpolation=cv2.INTER_NEAREST)
result = image.copy()
result[mask] = mosaic[mask]
return result
@app.function(
gpu="T4",
image=image,
volumes={volume_path: volume},
)
def preserve_privacy(
image_pil: Image.Image,
prompt: str,
) -> Image.Image:
os.environ["TORCH_HOME"] = TORCH_HOME
os.environ["HF_HOME"] = HF_HOME
os.makedirs(HF_HOME, exist_ok=True)
os.makedirs(TORCH_HOME, exist_ok=True)
langsam_results = lang_sam_segment.remote(
image_pil=image_pil,
prompt=prompt,
box_threshold=0.35,
text_threshold=0.40,
)
img_array = np.array(image_pil)
for result in langsam_results:
print(f"Found {len(result['masks'])} masks for label: {result['labels']}")
if len(result["masks"]) == 0:
print("No masks found for the given prompt.")
return image_pil
print(f"result: {result}")
for i, mask in enumerate(result["masks"]):
if "mask_scores" in result:
if (
hasattr(result["mask_scores"], "shape")
and result["mask_scores"].ndim > 0
):
mask_score = result["mask_scores"][i]
else:
mask_score = result["mask_scores"]
if mask_score < 0.6:
print(f"Skipping mask {i + 1}/{len(result['masks'])} -> low score.")
continue
print(f"Processing mask {i + 1}/{len(result['masks'])}")
print(f"Mask score: {mask_score}")
mask_bool = mask.astype(bool)
img_array = apply_mosaic_with_bool_mask.remote(img_array, mask_bool)
output_image_pil = Image.fromarray(img_array)
return output_image_pil