Spaces:
Runtime error
Runtime error
import base64 | |
import imghdr | |
import os | |
import cv2 | |
import numpy as np | |
import torch | |
from ultralytics import YOLO | |
from ultralytics.yolo.utils.ops import scale_image | |
import asyncio | |
from fastapi import FastAPI, File, UploadFile, Request, Response | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import uvicorn | |
# from mangum import Mangum | |
from argparse import ArgumentParser | |
import lama_cleaner.server2 as server | |
from lama_cleaner.helper import ( | |
load_img, | |
) | |
# os.environ["TRANSFORMERS_CACHE"] = "/path/to/writable/directory" | |
app = FastAPI() | |
# handler = Mangum(app) | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: | |
""" | |
Args: | |
image_numpy: numpy image | |
ext: image extension | |
Returns: | |
image bytes | |
""" | |
data = cv2.imencode( | |
f".{ext}", | |
image_numpy, | |
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], | |
)[1].tobytes() | |
return data | |
def get_image_ext(img_bytes): | |
""" | |
Args: | |
img_bytes: image bytes | |
Returns: | |
image extension | |
""" | |
if not img_bytes: | |
raise ValueError("Empty input") | |
header = img_bytes[:32] | |
w = imghdr.what("", header) | |
if w is None: | |
w = "jpeg" | |
return w | |
def predict_on_image(model, img, conf, retina_masks): | |
""" | |
Args: | |
model: YOLOv8 model | |
img: image (C, H, W) | |
conf: confidence threshold | |
retina_masks: use retina masks or not | |
Returns: | |
boxes: box with xyxy format, (N, 4) | |
masks: masks, (N, H, W) | |
cls: class of masks, (N, ) | |
probs: confidence score, (N, 1) | |
""" | |
with torch.no_grad(): | |
result = model(img, conf=conf, retina_masks=retina_masks, scale=1)[0] | |
boxes, masks, cls, probs = None, None, None, None | |
if result.boxes.cls.size(0) > 0: | |
# detection | |
cls = result.boxes.cls.cpu().numpy().astype(np.int32) | |
probs = result.boxes.conf.cpu().numpy() # confidence score, (N, 1) | |
boxes = result.boxes.xyxy.cpu().numpy() # box with xyxy format, (N, 4) | |
# segmentation | |
masks = result.masks.masks.cpu().numpy() # masks, (N, H, W) | |
masks = np.transpose(masks, (1, 2, 0)) # masks, (H, W, N) | |
# rescale masks to original image | |
masks = scale_image(masks.shape[:2], masks, result.masks.orig_shape) | |
masks = np.transpose(masks, (2, 0, 1)) # masks, (N, H, W) | |
return boxes, masks, cls, probs | |
def overlay(image, mask, color, alpha, id, resize=None): | |
"""Overlays a binary mask on an image. | |
Args: | |
image: Image to be overlayed on. | |
mask: Binary mask to overlay. | |
color: Color to use for the mask. | |
alpha: Opacity of the mask. | |
id: id of the mask | |
resize: Resize the image to this size. If None, no resizing is performed. | |
Returns: | |
The overlayed image. | |
""" | |
color = color[::-1] | |
colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0) | |
colored_mask = np.moveaxis(colored_mask, 0, -1) | |
masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color) | |
image_overlay = masked.filled() | |
imgray = cv2.cvtColor(image_overlay, cv2.COLOR_BGR2GRAY) | |
contour_thickness = 8 | |
_, thresh = cv2.threshold(imgray, 255, 255, 255) | |
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) | |
imgray = cv2.cvtColor(imgray, cv2.COLOR_GRAY2BGR) | |
imgray = cv2.drawContours(imgray, contours, -1, (255, 255, 255), contour_thickness) | |
imgray = np.where(imgray.any(-1, keepdims=True), (46, 36, 225), 0) | |
if resize is not None: | |
image = cv2.resize(image.transpose(1, 2, 0), resize) | |
image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize) | |
return imgray | |
async def process_mask(idx, mask_i, boxes, probs, yolo_model, blank_image, cls): | |
"""Process the mask of the image. | |
Args: | |
idx: index of the mask | |
mask_i: mask of the image | |
boxes: box with xyxy format, (N, 4) | |
probs: confidence score, (N, 1) | |
yolo_model: YOLOv8 model | |
blank_image: blank image | |
cls: class of masks, (N, ) | |
Returns: | |
dictionary_seg: dictionary of the mask of the image | |
""" | |
dictionary_seg = {} | |
maskwith_back = overlay(blank_image, mask_i, color=(255, 155, 155), alpha=0.5, id=idx) | |
alpha = np.sum(maskwith_back, axis=-1) > 0 | |
alpha = np.uint8(alpha * 255) | |
maskwith_back = np.dstack((maskwith_back, alpha)) | |
imgencode = await asyncio.get_running_loop().run_in_executor(None, cv2.imencode, '.png', maskwith_back) | |
mask = base64.b64encode(imgencode[1]).decode('utf-8') | |
dictionary_seg["confi"] = f'{probs[idx] * 100:.2f}' | |
dictionary_seg["boxe"] = [int(item) for item in list(boxes[idx])] | |
dictionary_seg["mask"] = mask | |
dictionary_seg["cls"] = str(yolo_model.names[cls[idx]]) | |
return dictionary_seg | |
async def check_auth_header(request: Request, call_next): | |
token = request.headers.get('Authorization') | |
if token != os.environ.get("SECRET"): | |
return JSONResponse(content={'error': 'Authorization header missing or incorrect.'}, status_code=403) | |
else: | |
response = await call_next(request) | |
return response | |
async def detect_mask(file: UploadFile = File()): | |
""" | |
Detects masks in an image uploaded via a POST request and returns a JSON response containing the details of the detected masks. | |
Args: | |
None | |
Parameters: | |
- file: a file object containing the input image | |
Returns: | |
A JSON response containing the details of the detected masks: | |
- code: 200 if objects were detected, 500 if no objects were detected | |
- msg: a message indicating whether objects were detected or not | |
- data: a list of dictionaries, where each dictionary contains the following keys: | |
- confi: the confidence level of the detected object | |
- boxe: a list containing the coordinates of the bounding box of the detected object | |
- mask: the mask of the detected object encoded in base64 | |
- cls: the class of the detected object | |
Raises: | |
500: No objects detected | |
""" | |
file = await file.read() | |
img, _ = load_img(file) | |
# predict by YOLOv8 | |
boxes, masks, cls, probs = predict_on_image(yolo_model, img, conf=0.55, retina_masks=True) | |
if boxes is None: | |
return {'code': 500, 'msg': 'No objects detected'} | |
# overlay masks on original image | |
blank_image = np.zeros(img.shape, dtype=np.uint8) | |
data = [] | |
coroutines = [process_mask(idx, mask_i, boxes, probs, yolo_model, blank_image, cls) for idx, mask_i in | |
enumerate(masks)] | |
results = await asyncio.gather(*coroutines) | |
for result in results: | |
data.append(result) | |
return {'code': 200, 'msg': "object detected", 'data': data} | |
async def paint(img: UploadFile = File(), mask: UploadFile = File()): | |
""" | |
Endpoint to process an image with a given mask using the server's process function. | |
Route: '/api/lama/paint' | |
Method: POST | |
Parameters: | |
img: The input image file (JPEG or PNG format). | |
mask: The mask file (JPEG or PNG format). | |
Returns: | |
A JSON object containing the processed image in base64 format under the "image" key. | |
""" | |
img = await img.read() | |
mask = await mask.read() | |
return {"image": server.process(img, mask)} | |
async def remove(img: UploadFile = File()): | |
x = await img.read() | |
return {"image": server.remove(x)} | |
def switch_model(new_name: str): | |
return server.switch_model(new_name) | |
def current_model(): | |
return server.current_model() | |
def get_is_disable_model_switch(): | |
return server.get_is_disable_model_switch() | |
def init_data(): | |
model_device = "cpu" | |
global yolo_model | |
# TODO Update for local development | |
yolo_model = YOLO('yolov8x-seg.pt') | |
# yolo_model = YOLO('/app/yolov8x-seg.pt') | |
yolo_model.to(model_device) | |
print(f"YOLO model yolov8x-seg.pt loaded.") | |
server.initModel() | |
def create_app(args): | |
""" | |
Creates the FastAPI app and adds the endpoints. | |
Args: | |
args: The arguments. | |
""" | |
uvicorn.run("app:app", host=args.host, port=args.port, reload=args.reload) | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument('--model_name', type=str, default='lama', help='Model name') | |
parser.add_argument('--host', type=str, default="0.0.0.0") | |
parser.add_argument('--port', type=int, default=5000) | |
parser.add_argument('--reload', type=bool, default=True) | |
parser.add_argument('--model_device', type=str, default='cpu', help='Model device') | |
parser.add_argument('--disable_model_switch', type=bool, default=False, help='Disable model switch') | |
parser.add_argument('--gui', type=bool, default=False, help='Enable GUI') | |
parser.add_argument('--cpu_offload', type=bool, default=False, help='Enable CPU offload') | |
parser.add_argument('--disable_nsfw', type=bool, default=False, help='Disable NSFW') | |
parser.add_argument('--enable_xformers', type=bool, default=False, help='Enable xformers') | |
parser.add_argument('--hf_access_token', type=str, default='', help='Hugging Face access token') | |
parser.add_argument('--local_files_only', type=bool, default=False, help='Enable local files only') | |
parser.add_argument('--no_half', type=bool, default=False, help='Disable half') | |
parser.add_argument('--sd_cpu_textencoder', type=bool, default=False, help='Enable CPU text encoder') | |
parser.add_argument('--sd_disable_nsfw', type=bool, default=False, help='Disable NSFW') | |
parser.add_argument('--sd_enable_xformers', type=bool, default=False, help='Enable xformers') | |
parser.add_argument('--sd_run_local', type=bool, default=False, help='Enable local files only') | |
args = parser.parse_args() | |
create_app(args) | |