|
import os |
|
from fastapi import FastAPI, Body |
|
from pydantic import BaseModel |
|
from typing import Any, Optional, List |
|
import gradio as gr |
|
from PIL import Image |
|
import numpy as np |
|
|
|
from modules.api.api import encode_pil_to_base64, decode_base64_to_image |
|
from scripts.sam import sam_predict, dino_predict, update_mask, cnet_seg, categorical_mask |
|
from scripts.sam import sam_model_list |
|
|
|
|
|
def decode_to_pil(image): |
|
if os.path.exists(image): |
|
return Image.open(image) |
|
elif type(image) is str: |
|
return decode_base64_to_image(image) |
|
elif type(image) is Image.Image: |
|
return image |
|
elif type(image) is np.ndarray: |
|
return Image.fromarray(image) |
|
else: |
|
Exception("Not an image") |
|
|
|
|
|
def encode_to_base64(image): |
|
if type(image) is str: |
|
return image |
|
elif type(image) is Image.Image: |
|
return encode_pil_to_base64(image).decode() |
|
elif type(image) is np.ndarray: |
|
pil = Image.fromarray(image) |
|
return encode_pil_to_base64(pil).decode() |
|
else: |
|
Exception("Invalid type") |
|
|
|
|
|
def sam_api(_: gr.Blocks, app: FastAPI): |
|
@app.get("/sam/heartbeat") |
|
async def heartbeat(): |
|
return { |
|
"msg": "Success!" |
|
} |
|
|
|
@app.get("/sam/sam-model", description='Query available SAM model') |
|
async def api_sam_model() -> List[str]: |
|
return sam_model_list |
|
|
|
class SamPredictRequest(BaseModel): |
|
sam_model_name: str = "sam_vit_h_4b8939.pth" |
|
input_image: str |
|
sam_positive_points: List[List[float]] = [] |
|
sam_negative_points: List[List[float]] = [] |
|
dino_enabled: bool = False |
|
dino_model_name: Optional[str] = "GroundingDINO_SwinT_OGC (694MB)" |
|
dino_text_prompt: Optional[str] = None |
|
dino_box_threshold: Optional[float] = 0.3 |
|
dino_preview_checkbox: bool = False |
|
dino_preview_boxes_selection: Optional[List[int]] = None |
|
|
|
@app.post("/sam/sam-predict") |
|
async def api_sam_predict(payload: SamPredictRequest = Body(...)) -> Any: |
|
print(f"SAM API /sam/sam-predict received request") |
|
payload.input_image = decode_to_pil(payload.input_image).convert('RGBA') |
|
sam_output_mask_gallery, sam_message = sam_predict( |
|
payload.sam_model_name, |
|
payload.input_image, |
|
payload.sam_positive_points, |
|
payload.sam_negative_points, |
|
payload.dino_enabled, |
|
payload.dino_model_name, |
|
payload.dino_text_prompt, |
|
payload.dino_box_threshold, |
|
payload.dino_preview_checkbox, |
|
payload.dino_preview_boxes_selection) |
|
print(f"SAM API /sam/sam-predict finished with message: {sam_message}") |
|
result = { |
|
"msg": sam_message, |
|
} |
|
if len(sam_output_mask_gallery) == 9: |
|
result["blended_images"] = list(map(encode_to_base64, sam_output_mask_gallery[:3])) |
|
result["masks"] = list(map(encode_to_base64, sam_output_mask_gallery[3:6])) |
|
result["masked_images"] = list(map(encode_to_base64, sam_output_mask_gallery[6:])) |
|
return result |
|
|
|
class DINOPredictRequest(BaseModel): |
|
input_image: str |
|
dino_model_name: str = "GroundingDINO_SwinT_OGC (694MB)" |
|
text_prompt: str |
|
box_threshold: float = 0.3 |
|
|
|
@app.post("/sam/dino-predict") |
|
async def api_dino_predict(payload: DINOPredictRequest = Body(...)) -> Any: |
|
print(f"SAM API /sam/dino-predict received request") |
|
payload.input_image = decode_to_pil(payload.input_image) |
|
dino_output_img, _, dino_msg = dino_predict( |
|
payload.input_image, |
|
payload.dino_model_name, |
|
payload.text_prompt, |
|
payload.box_threshold) |
|
if "value" in dino_msg: |
|
dino_msg = dino_msg["value"] |
|
else: |
|
dino_msg = "Done" |
|
print(f"SAM API /sam/dino-predict finished with message: {dino_msg}") |
|
return { |
|
"msg": dino_msg, |
|
"image_with_box": encode_to_base64(dino_output_img) if dino_output_img is not None else None, |
|
} |
|
|
|
class DilateMaskRequest(BaseModel): |
|
input_image: str |
|
mask: str |
|
dilate_amount: int = 10 |
|
|
|
@app.post("/sam/dilate-mask") |
|
async def api_dilate_mask(payload: DilateMaskRequest = Body(...)) -> Any: |
|
print(f"SAM API /sam/dilate-mask received request") |
|
payload.input_image = decode_to_pil(payload.input_image).convert("RGBA") |
|
payload.mask = decode_to_pil(payload.mask) |
|
dilate_result = list(map(encode_to_base64, update_mask(payload.mask, 0, payload.dilate_amount, payload.input_image))) |
|
print(f"SAM API /sam/dilate-mask finished") |
|
return {"blended_image": dilate_result[0], "mask": dilate_result[1], "masked_image": dilate_result[2]} |
|
|
|
|
|
class AutoSAMConfig(BaseModel): |
|
points_per_side: Optional[int] = 32 |
|
points_per_batch: int = 64 |
|
pred_iou_thresh: float = 0.88 |
|
stability_score_thresh: float = 0.95 |
|
stability_score_offset: float = 1.0 |
|
box_nms_thresh: float = 0.7 |
|
crop_n_layers: int = 0 |
|
crop_nms_thresh: float = 0.7 |
|
crop_overlap_ratio: float = 512 / 1500 |
|
crop_n_points_downscale_factor: int = 1 |
|
min_mask_region_area: int = 0 |
|
|
|
class ControlNetSegRequest(BaseModel): |
|
sam_model_name: str = "sam_vit_h_4b8939.pth" |
|
input_image: str |
|
processor: str = "seg_ofade20k" |
|
processor_res: int = 512 |
|
pixel_perfect: bool = False |
|
resize_mode: Optional[int] = 1 |
|
target_W: Optional[int] = None |
|
target_H: Optional[int] = None |
|
|
|
@app.post("/sam/controlnet-seg") |
|
async def api_controlnet_seg(payload: ControlNetSegRequest = Body(...), |
|
autosam_conf: AutoSAMConfig = Body(...)) -> Any: |
|
print(f"SAM API /sam/controlnet-seg received request") |
|
payload.input_image = decode_to_pil(payload.input_image) |
|
cnet_seg_img, cnet_seg_msg = cnet_seg( |
|
payload.sam_model_name, |
|
payload.input_image, |
|
payload.processor, |
|
payload.processor_res, |
|
payload.pixel_perfect, |
|
payload.resize_mode, |
|
payload.target_W, |
|
payload.target_H, |
|
autosam_conf.points_per_side, |
|
autosam_conf.points_per_batch, |
|
autosam_conf.pred_iou_thresh, |
|
autosam_conf.stability_score_thresh, |
|
autosam_conf.stability_score_offset, |
|
autosam_conf.box_nms_thresh, |
|
autosam_conf.crop_n_layers, |
|
autosam_conf.crop_nms_thresh, |
|
autosam_conf.crop_overlap_ratio, |
|
autosam_conf.crop_n_points_downscale_factor, |
|
autosam_conf.min_mask_region_area) |
|
cnet_seg_img = list(map(encode_to_base64, cnet_seg_img)) |
|
print(f"SAM API /sam/controlnet-seg finished with message {cnet_seg_msg}") |
|
result = { |
|
"msg": cnet_seg_msg, |
|
} |
|
if len(cnet_seg_img) == 3: |
|
result["blended_images"] = cnet_seg_img[0] |
|
result["random_seg"] = cnet_seg_img[1] |
|
result["edit_anything_control"] = cnet_seg_img[2] |
|
elif len(cnet_seg_img) == 4: |
|
result["sem_presam"] = cnet_seg_img[0] |
|
result["sem_postsam"] = cnet_seg_img[1] |
|
result["blended_presam"] = cnet_seg_img[2] |
|
result["blended_postsam"] = cnet_seg_img[3] |
|
return result |
|
|
|
class CategoryMaskRequest(BaseModel): |
|
sam_model_name: str = "sam_vit_h_4b8939.pth" |
|
processor: str = "seg_ofade20k" |
|
processor_res: int = 512 |
|
pixel_perfect: bool = False |
|
resize_mode: Optional[int] = 1 |
|
target_W: Optional[int] = None |
|
target_H: Optional[int] = None |
|
category: str |
|
input_image: str |
|
|
|
@app.post("/sam/category-mask") |
|
async def api_category_mask(payload: CategoryMaskRequest = Body(...), |
|
autosam_conf: AutoSAMConfig = Body(...)) -> Any: |
|
print(f"SAM API /sam/category-mask received request") |
|
payload.input_image = decode_to_pil(payload.input_image) |
|
category_mask_img, category_mask_msg, resized_input_img = categorical_mask( |
|
payload.sam_model_name, |
|
payload.processor, |
|
payload.processor_res, |
|
payload.pixel_perfect, |
|
payload.resize_mode, |
|
payload.target_W, |
|
payload.target_H, |
|
payload.category, |
|
payload.input_image, |
|
autosam_conf.points_per_side, |
|
autosam_conf.points_per_batch, |
|
autosam_conf.pred_iou_thresh, |
|
autosam_conf.stability_score_thresh, |
|
autosam_conf.stability_score_offset, |
|
autosam_conf.box_nms_thresh, |
|
autosam_conf.crop_n_layers, |
|
autosam_conf.crop_nms_thresh, |
|
autosam_conf.crop_overlap_ratio, |
|
autosam_conf.crop_n_points_downscale_factor, |
|
autosam_conf.min_mask_region_area) |
|
category_mask_img = list(map(encode_to_base64, category_mask_img)) |
|
print(f"SAM API /sam/category-mask finished with message {category_mask_msg}") |
|
result = { |
|
"msg": category_mask_msg, |
|
} |
|
if len(category_mask_img) == 3: |
|
result["blended_image"] = category_mask_img[0] |
|
result["mask"] = category_mask_img[1] |
|
result["masked_image"] = category_mask_img[2] |
|
if resized_input_img is not None: |
|
result["resized_input"] = encode_to_base64(resized_input_img) |
|
return result |
|
|
|
|
|
try: |
|
import modules.script_callbacks as script_callbacks |
|
script_callbacks.on_app_started(sam_api) |
|
except: |
|
print("SAM Web UI API failed to initialize") |