import base64 import imghdr import os import cv2 import numpy as np import torch from ultralytics import YOLO from ultralytics.yolo.utils.ops import scale_image import asyncio from fastapi import FastAPI, File, UploadFile, Request, Response from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware import uvicorn # from mangum import Mangum from argparse import ArgumentParser import lama_cleaner.server2 as server from lama_cleaner.helper import ( load_img, ) # os.environ["TRANSFORMERS_CACHE"] = "/path/to/writable/directory" app = FastAPI() # handler = Mangum(app) origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: """ Args: image_numpy: numpy image ext: image extension Returns: image bytes """ data = cv2.imencode( f".{ext}", image_numpy, [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], )[1].tobytes() return data def get_image_ext(img_bytes): """ Args: img_bytes: image bytes Returns: image extension """ if not img_bytes: raise ValueError("Empty input") header = img_bytes[:32] w = imghdr.what("", header) if w is None: w = "jpeg" return w def predict_on_image(model, img, conf, retina_masks): """ Args: model: YOLOv8 model img: image (C, H, W) conf: confidence threshold retina_masks: use retina masks or not Returns: boxes: box with xyxy format, (N, 4) masks: masks, (N, H, W) cls: class of masks, (N, ) probs: confidence score, (N, 1) """ with torch.no_grad(): result = model(img, conf=conf, retina_masks=retina_masks, scale=1)[0] boxes, masks, cls, probs = None, None, None, None if result.boxes.cls.size(0) > 0: # detection cls = result.boxes.cls.cpu().numpy().astype(np.int32) probs = result.boxes.conf.cpu().numpy() # confidence score, (N, 1) boxes = result.boxes.xyxy.cpu().numpy() # box with xyxy format, (N, 4) # segmentation masks = result.masks.masks.cpu().numpy() # masks, (N, H, W) masks = np.transpose(masks, (1, 2, 0)) # masks, (H, W, N) # rescale masks to original image masks = scale_image(masks.shape[:2], masks, result.masks.orig_shape) masks = np.transpose(masks, (2, 0, 1)) # masks, (N, H, W) return boxes, masks, cls, probs def overlay(image, mask, color, alpha, id, resize=None): """Overlays a binary mask on an image. Args: image: Image to be overlayed on. mask: Binary mask to overlay. color: Color to use for the mask. alpha: Opacity of the mask. id: id of the mask resize: Resize the image to this size. If None, no resizing is performed. Returns: The overlayed image. """ color = color[::-1] colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0) colored_mask = np.moveaxis(colored_mask, 0, -1) masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color) image_overlay = masked.filled() imgray = cv2.cvtColor(image_overlay, cv2.COLOR_BGR2GRAY) contour_thickness = 8 _, thresh = cv2.threshold(imgray, 255, 255, 255) contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) imgray = cv2.cvtColor(imgray, cv2.COLOR_GRAY2BGR) imgray = cv2.drawContours(imgray, contours, -1, (255, 255, 255), contour_thickness) imgray = np.where(imgray.any(-1, keepdims=True), (46, 36, 225), 0) if resize is not None: image = cv2.resize(image.transpose(1, 2, 0), resize) image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize) return imgray async def process_mask(idx, mask_i, boxes, probs, yolo_model, blank_image, cls): """Process the mask of the image. Args: idx: index of the mask mask_i: mask of the image boxes: box with xyxy format, (N, 4) probs: confidence score, (N, 1) yolo_model: YOLOv8 model blank_image: blank image cls: class of masks, (N, ) Returns: dictionary_seg: dictionary of the mask of the image """ dictionary_seg = {} maskwith_back = overlay(blank_image, mask_i, color=(255, 155, 155), alpha=0.5, id=idx) alpha = np.sum(maskwith_back, axis=-1) > 0 alpha = np.uint8(alpha * 255) maskwith_back = np.dstack((maskwith_back, alpha)) imgencode = await asyncio.get_running_loop().run_in_executor(None, cv2.imencode, '.png', maskwith_back) mask = base64.b64encode(imgencode[1]).decode('utf-8') dictionary_seg["confi"] = f'{probs[idx] * 100:.2f}' dictionary_seg["boxe"] = [int(item) for item in list(boxes[idx])] dictionary_seg["mask"] = mask dictionary_seg["cls"] = str(yolo_model.names[cls[idx]]) return dictionary_seg @app.middleware("http") async def check_auth_header(request: Request, call_next): token = request.headers.get('Authorization') if token != os.environ.get("SECRET"): return JSONResponse(content={'error': 'Authorization header missing or incorrect.'}, status_code=403) else: response = await call_next(request) return response @app.post("/api/mask") async def detect_mask(file: UploadFile = File()): """ Detects masks in an image uploaded via a POST request and returns a JSON response containing the details of the detected masks. Args: None Parameters: - file: a file object containing the input image Returns: A JSON response containing the details of the detected masks: - code: 200 if objects were detected, 500 if no objects were detected - msg: a message indicating whether objects were detected or not - data: a list of dictionaries, where each dictionary contains the following keys: - confi: the confidence level of the detected object - boxe: a list containing the coordinates of the bounding box of the detected object - mask: the mask of the detected object encoded in base64 - cls: the class of the detected object Raises: 500: No objects detected """ file = await file.read() img, _ = load_img(file) # predict by YOLOv8 boxes, masks, cls, probs = predict_on_image(yolo_model, img, conf=0.55, retina_masks=True) if boxes is None: return {'code': 500, 'msg': 'No objects detected'} # overlay masks on original image blank_image = np.zeros(img.shape, dtype=np.uint8) data = [] coroutines = [process_mask(idx, mask_i, boxes, probs, yolo_model, blank_image, cls) for idx, mask_i in enumerate(masks)] results = await asyncio.gather(*coroutines) for result in results: data.append(result) return {'code': 200, 'msg': "object detected", 'data': data} @app.post("/api/lama/paint") async def paint(img: UploadFile = File(), mask: UploadFile = File()): """ Endpoint to process an image with a given mask using the server's process function. Route: '/api/lama/paint' Method: POST Parameters: img: The input image file (JPEG or PNG format). mask: The mask file (JPEG or PNG format). Returns: A JSON object containing the processed image in base64 format under the "image" key. """ img = await img.read() mask = await mask.read() return {"image": server.process(img, mask)} @app.post("/api/remove") async def remove(img: UploadFile = File()): x = await img.read() return {"image": server.remove(x)} @app.post("/api/lama/model") def switch_model(new_name: str): return server.switch_model(new_name) @app.get("/api/lama/model") def current_model(): return server.current_model() @app.get("/api/lama/switchmode") def get_is_disable_model_switch(): return server.get_is_disable_model_switch() @app.on_event("startup") def init_data(): model_device = "cpu" global yolo_model # TODO Update for local development yolo_model = YOLO('yolov8x-seg.pt') # yolo_model = YOLO('/app/yolov8x-seg.pt') yolo_model.to(model_device) print(f"YOLO model yolov8x-seg.pt loaded.") server.initModel() def create_app(args): """ Creates the FastAPI app and adds the endpoints. Args: args: The arguments. """ uvicorn.run("app:app", host=args.host, port=args.port, reload=args.reload) if __name__ == "__main__": parser = ArgumentParser() parser.add_argument('--model_name', type=str, default='lama', help='Model name') parser.add_argument('--host', type=str, default="0.0.0.0") parser.add_argument('--port', type=int, default=5000) parser.add_argument('--reload', type=bool, default=True) parser.add_argument('--model_device', type=str, default='cpu', help='Model device') parser.add_argument('--disable_model_switch', type=bool, default=False, help='Disable model switch') parser.add_argument('--gui', type=bool, default=False, help='Enable GUI') parser.add_argument('--cpu_offload', type=bool, default=False, help='Enable CPU offload') parser.add_argument('--disable_nsfw', type=bool, default=False, help='Disable NSFW') parser.add_argument('--enable_xformers', type=bool, default=False, help='Enable xformers') parser.add_argument('--hf_access_token', type=str, default='', help='Hugging Face access token') parser.add_argument('--local_files_only', type=bool, default=False, help='Enable local files only') parser.add_argument('--no_half', type=bool, default=False, help='Disable half') parser.add_argument('--sd_cpu_textencoder', type=bool, default=False, help='Enable CPU text encoder') parser.add_argument('--sd_disable_nsfw', type=bool, default=False, help='Disable NSFW') parser.add_argument('--sd_enable_xformers', type=bool, default=False, help='Enable xformers') parser.add_argument('--sd_run_local', type=bool, default=False, help='Enable local files only') args = parser.parse_args() create_app(args)