방재호
init
b5ba7a5
import os
from fastapi import FastAPI, Body
from pydantic import BaseModel
from typing import Any, Optional, List
import gradio as gr
from PIL import Image
import numpy as np
from modules.api.api import encode_pil_to_base64, decode_base64_to_image
from scripts.sam import sam_predict, dino_predict, update_mask, cnet_seg, categorical_mask
from scripts.sam import sam_model_list
def decode_to_pil(image):
if os.path.exists(image):
return Image.open(image)
elif type(image) is str:
return decode_base64_to_image(image)
elif type(image) is Image.Image:
return image
elif type(image) is np.ndarray:
return Image.fromarray(image)
else:
Exception("Not an image")
def encode_to_base64(image):
if type(image) is str:
return image
elif type(image) is Image.Image:
return encode_pil_to_base64(image).decode()
elif type(image) is np.ndarray:
pil = Image.fromarray(image)
return encode_pil_to_base64(pil).decode()
else:
Exception("Invalid type")
def sam_api(_: gr.Blocks, app: FastAPI):
@app.get("/sam/heartbeat")
async def heartbeat():
return {
"msg": "Success!"
}
@app.get("/sam/sam-model", description='Query available SAM model')
async def api_sam_model() -> List[str]:
return sam_model_list
class SamPredictRequest(BaseModel):
sam_model_name: str = "sam_vit_h_4b8939.pth"
input_image: str
sam_positive_points: List[List[float]] = []
sam_negative_points: List[List[float]] = []
dino_enabled: bool = False
dino_model_name: Optional[str] = "GroundingDINO_SwinT_OGC (694MB)"
dino_text_prompt: Optional[str] = None
dino_box_threshold: Optional[float] = 0.3
dino_preview_checkbox: bool = False
dino_preview_boxes_selection: Optional[List[int]] = None
@app.post("/sam/sam-predict")
async def api_sam_predict(payload: SamPredictRequest = Body(...)) -> Any:
print(f"SAM API /sam/sam-predict received request")
payload.input_image = decode_to_pil(payload.input_image).convert('RGBA')
sam_output_mask_gallery, sam_message = sam_predict(
payload.sam_model_name,
payload.input_image,
payload.sam_positive_points,
payload.sam_negative_points,
payload.dino_enabled,
payload.dino_model_name,
payload.dino_text_prompt,
payload.dino_box_threshold,
payload.dino_preview_checkbox,
payload.dino_preview_boxes_selection)
print(f"SAM API /sam/sam-predict finished with message: {sam_message}")
result = {
"msg": sam_message,
}
if len(sam_output_mask_gallery) == 9:
result["blended_images"] = list(map(encode_to_base64, sam_output_mask_gallery[:3]))
result["masks"] = list(map(encode_to_base64, sam_output_mask_gallery[3:6]))
result["masked_images"] = list(map(encode_to_base64, sam_output_mask_gallery[6:]))
return result
class DINOPredictRequest(BaseModel):
input_image: str
dino_model_name: str = "GroundingDINO_SwinT_OGC (694MB)"
text_prompt: str
box_threshold: float = 0.3
@app.post("/sam/dino-predict")
async def api_dino_predict(payload: DINOPredictRequest = Body(...)) -> Any:
print(f"SAM API /sam/dino-predict received request")
payload.input_image = decode_to_pil(payload.input_image)
dino_output_img, _, dino_msg = dino_predict(
payload.input_image,
payload.dino_model_name,
payload.text_prompt,
payload.box_threshold)
if "value" in dino_msg:
dino_msg = dino_msg["value"]
else:
dino_msg = "Done"
print(f"SAM API /sam/dino-predict finished with message: {dino_msg}")
return {
"msg": dino_msg,
"image_with_box": encode_to_base64(dino_output_img) if dino_output_img is not None else None,
}
class DilateMaskRequest(BaseModel):
input_image: str
mask: str
dilate_amount: int = 10
@app.post("/sam/dilate-mask")
async def api_dilate_mask(payload: DilateMaskRequest = Body(...)) -> Any:
print(f"SAM API /sam/dilate-mask received request")
payload.input_image = decode_to_pil(payload.input_image).convert("RGBA")
payload.mask = decode_to_pil(payload.mask)
dilate_result = list(map(encode_to_base64, update_mask(payload.mask, 0, payload.dilate_amount, payload.input_image)))
print(f"SAM API /sam/dilate-mask finished")
return {"blended_image": dilate_result[0], "mask": dilate_result[1], "masked_image": dilate_result[2]}
class AutoSAMConfig(BaseModel):
points_per_side: Optional[int] = 32
points_per_batch: int = 64
pred_iou_thresh: float = 0.88
stability_score_thresh: float = 0.95
stability_score_offset: float = 1.0
box_nms_thresh: float = 0.7
crop_n_layers: int = 0
crop_nms_thresh: float = 0.7
crop_overlap_ratio: float = 512 / 1500
crop_n_points_downscale_factor: int = 1
min_mask_region_area: int = 0
class ControlNetSegRequest(BaseModel):
sam_model_name: str = "sam_vit_h_4b8939.pth"
input_image: str
processor: str = "seg_ofade20k"
processor_res: int = 512
pixel_perfect: bool = False
resize_mode: Optional[int] = 1 # 0: just resize, 1: crop and resize, 2: resize and fill
target_W: Optional[int] = None
target_H: Optional[int] = None
@app.post("/sam/controlnet-seg")
async def api_controlnet_seg(payload: ControlNetSegRequest = Body(...),
autosam_conf: AutoSAMConfig = Body(...)) -> Any:
print(f"SAM API /sam/controlnet-seg received request")
payload.input_image = decode_to_pil(payload.input_image)
cnet_seg_img, cnet_seg_msg = cnet_seg(
payload.sam_model_name,
payload.input_image,
payload.processor,
payload.processor_res,
payload.pixel_perfect,
payload.resize_mode,
payload.target_W,
payload.target_H,
autosam_conf.points_per_side,
autosam_conf.points_per_batch,
autosam_conf.pred_iou_thresh,
autosam_conf.stability_score_thresh,
autosam_conf.stability_score_offset,
autosam_conf.box_nms_thresh,
autosam_conf.crop_n_layers,
autosam_conf.crop_nms_thresh,
autosam_conf.crop_overlap_ratio,
autosam_conf.crop_n_points_downscale_factor,
autosam_conf.min_mask_region_area)
cnet_seg_img = list(map(encode_to_base64, cnet_seg_img))
print(f"SAM API /sam/controlnet-seg finished with message {cnet_seg_msg}")
result = {
"msg": cnet_seg_msg,
}
if len(cnet_seg_img) == 3:
result["blended_images"] = cnet_seg_img[0]
result["random_seg"] = cnet_seg_img[1]
result["edit_anything_control"] = cnet_seg_img[2]
elif len(cnet_seg_img) == 4:
result["sem_presam"] = cnet_seg_img[0]
result["sem_postsam"] = cnet_seg_img[1]
result["blended_presam"] = cnet_seg_img[2]
result["blended_postsam"] = cnet_seg_img[3]
return result
class CategoryMaskRequest(BaseModel):
sam_model_name: str = "sam_vit_h_4b8939.pth"
processor: str = "seg_ofade20k"
processor_res: int = 512
pixel_perfect: bool = False
resize_mode: Optional[int] = 1
target_W: Optional[int] = None
target_H: Optional[int] = None
category: str
input_image: str
@app.post("/sam/category-mask")
async def api_category_mask(payload: CategoryMaskRequest = Body(...),
autosam_conf: AutoSAMConfig = Body(...)) -> Any:
print(f"SAM API /sam/category-mask received request")
payload.input_image = decode_to_pil(payload.input_image)
category_mask_img, category_mask_msg, resized_input_img = categorical_mask(
payload.sam_model_name,
payload.processor,
payload.processor_res,
payload.pixel_perfect,
payload.resize_mode,
payload.target_W,
payload.target_H,
payload.category,
payload.input_image,
autosam_conf.points_per_side,
autosam_conf.points_per_batch,
autosam_conf.pred_iou_thresh,
autosam_conf.stability_score_thresh,
autosam_conf.stability_score_offset,
autosam_conf.box_nms_thresh,
autosam_conf.crop_n_layers,
autosam_conf.crop_nms_thresh,
autosam_conf.crop_overlap_ratio,
autosam_conf.crop_n_points_downscale_factor,
autosam_conf.min_mask_region_area)
category_mask_img = list(map(encode_to_base64, category_mask_img))
print(f"SAM API /sam/category-mask finished with message {category_mask_msg}")
result = {
"msg": category_mask_msg,
}
if len(category_mask_img) == 3:
result["blended_image"] = category_mask_img[0]
result["mask"] = category_mask_img[1]
result["masked_image"] = category_mask_img[2]
if resized_input_img is not None:
result["resized_input"] = encode_to_base64(resized_input_img)
return result
try:
import modules.script_callbacks as script_callbacks
script_callbacks.on_app_started(sam_api)
except:
print("SAM Web UI API failed to initialize")