|
import json |
|
import os |
|
import webbrowser |
|
from typing import Optional, Tuple, cast |
|
|
|
import aiohttp |
|
import click |
|
import gradio as gr |
|
import uvicorn |
|
from asyncer import asyncify |
|
from fastapi import Depends, FastAPI, File, Form, Query |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from starlette.responses import Response |
|
|
|
from .._version import get_versions |
|
from ..bg import remove |
|
from ..session_factory import new_session |
|
from ..sessions import sessions_names |
|
from ..sessions.base import BaseSession |
|
|
|
|
|
@click.command( |
|
name="s", |
|
help="for a http server", |
|
) |
|
@click.option( |
|
"-p", |
|
"--port", |
|
default=5000, |
|
type=int, |
|
show_default=True, |
|
help="port", |
|
) |
|
@click.option( |
|
"-l", |
|
"--log_level", |
|
default="info", |
|
type=str, |
|
show_default=True, |
|
help="log level", |
|
) |
|
@click.option( |
|
"-t", |
|
"--threads", |
|
default=None, |
|
type=int, |
|
show_default=True, |
|
help="number of worker threads", |
|
) |
|
def s_command(port: int, log_level: str, threads: int) -> None: |
|
sessions: dict[str, BaseSession] = {} |
|
tags_metadata = [ |
|
{ |
|
"name": "Background Removal", |
|
"description": "Endpoints that perform background removal with different image sources.", |
|
"externalDocs": { |
|
"description": "GitHub Source", |
|
"url": "https://github.com/danielgatis/rembg", |
|
}, |
|
}, |
|
] |
|
app = FastAPI( |
|
title="Rembg", |
|
description="Rembg is a tool to remove images background. That is it.", |
|
version=get_versions()["version"], |
|
contact={ |
|
"name": "Daniel Gatis", |
|
"url": "https://github.com/danielgatis", |
|
"email": "[email protected]", |
|
}, |
|
license_info={ |
|
"name": "MIT License", |
|
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt", |
|
}, |
|
openapi_tags=tags_metadata, |
|
docs_url="/api", |
|
) |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_credentials=True, |
|
allow_origins=["*"], |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
class CommonQueryParams: |
|
def __init__( |
|
self, |
|
model: str = Query( |
|
description="Model to use when processing image", |
|
regex=r"(" + "|".join(sessions_names) + ")", |
|
default="u2net", |
|
), |
|
a: bool = Query(default=False, description="Enable Alpha Matting"), |
|
af: int = Query( |
|
default=240, |
|
ge=0, |
|
le=255, |
|
description="Alpha Matting (Foreground Threshold)", |
|
), |
|
ab: int = Query( |
|
default=10, |
|
ge=0, |
|
le=255, |
|
description="Alpha Matting (Background Threshold)", |
|
), |
|
ae: int = Query( |
|
default=10, ge=0, description="Alpha Matting (Erode Structure Size)" |
|
), |
|
om: bool = Query(default=False, description="Only Mask"), |
|
ppm: bool = Query(default=False, description="Post Process Mask"), |
|
bgc: Optional[str] = Query(default=None, description="Background Color"), |
|
extras: Optional[str] = Query( |
|
default=None, description="Extra parameters as JSON" |
|
), |
|
): |
|
self.model = model |
|
self.a = a |
|
self.af = af |
|
self.ab = ab |
|
self.ae = ae |
|
self.om = om |
|
self.ppm = ppm |
|
self.extras = extras |
|
self.bgc = ( |
|
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(",")))) |
|
if bgc |
|
else None |
|
) |
|
|
|
class CommonQueryPostParams: |
|
def __init__( |
|
self, |
|
model: str = Form( |
|
description="Model to use when processing image", |
|
regex=r"(" + "|".join(sessions_names) + ")", |
|
default="u2net", |
|
), |
|
a: bool = Form(default=False, description="Enable Alpha Matting"), |
|
af: int = Form( |
|
default=240, |
|
ge=0, |
|
le=255, |
|
description="Alpha Matting (Foreground Threshold)", |
|
), |
|
ab: int = Form( |
|
default=10, |
|
ge=0, |
|
le=255, |
|
description="Alpha Matting (Background Threshold)", |
|
), |
|
ae: int = Form( |
|
default=10, ge=0, description="Alpha Matting (Erode Structure Size)" |
|
), |
|
om: bool = Form(default=False, description="Only Mask"), |
|
ppm: bool = Form(default=False, description="Post Process Mask"), |
|
bgc: Optional[str] = Query(default=None, description="Background Color"), |
|
extras: Optional[str] = Query( |
|
default=None, description="Extra parameters as JSON" |
|
), |
|
): |
|
self.model = model |
|
self.a = a |
|
self.af = af |
|
self.ab = ab |
|
self.ae = ae |
|
self.om = om |
|
self.ppm = ppm |
|
self.extras = extras |
|
self.bgc = ( |
|
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(",")))) |
|
if bgc |
|
else None |
|
) |
|
|
|
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response: |
|
kwargs = {} |
|
|
|
if commons.extras: |
|
try: |
|
kwargs.update(json.loads(commons.extras)) |
|
except Exception: |
|
pass |
|
|
|
return Response( |
|
remove( |
|
content, |
|
session=sessions.setdefault(commons.model, new_session(commons.model)), |
|
alpha_matting=commons.a, |
|
alpha_matting_foreground_threshold=commons.af, |
|
alpha_matting_background_threshold=commons.ab, |
|
alpha_matting_erode_size=commons.ae, |
|
only_mask=commons.om, |
|
post_process_mask=commons.ppm, |
|
bgcolor=commons.bgc, |
|
**kwargs, |
|
), |
|
media_type="image/png", |
|
) |
|
|
|
@app.on_event("startup") |
|
def startup(): |
|
try: |
|
webbrowser.open(f"http://localhost:{port}") |
|
except Exception: |
|
pass |
|
|
|
if threads is not None: |
|
from anyio import CapacityLimiter |
|
from anyio.lowlevel import RunVar |
|
|
|
RunVar("_default_thread_limiter").set(CapacityLimiter(threads)) |
|
|
|
@app.get( |
|
path="/api/remove", |
|
tags=["Background Removal"], |
|
summary="Remove from URL", |
|
description="Removes the background from an image obtained by retrieving an URL.", |
|
) |
|
async def get_index( |
|
url: str = Query( |
|
default=..., description="URL of the image that has to be processed." |
|
), |
|
commons: CommonQueryParams = Depends(), |
|
): |
|
async with aiohttp.ClientSession() as session: |
|
async with session.get(url) as response: |
|
file = await response.read() |
|
return await asyncify(im_without_bg)(file, commons) |
|
|
|
@app.post( |
|
path="/api/remove", |
|
tags=["Background Removal"], |
|
summary="Remove from Stream", |
|
description="Removes the background from an image sent within the request itself.", |
|
) |
|
async def post_index( |
|
file: bytes = File( |
|
default=..., |
|
description="Image file (byte stream) that has to be processed.", |
|
), |
|
commons: CommonQueryPostParams = Depends(), |
|
): |
|
return await asyncify(im_without_bg)(file, commons) |
|
|
|
def gr_app(app): |
|
def inference(input_path, model): |
|
output_path = "output.png" |
|
with open(input_path, "rb") as i: |
|
with open(output_path, "wb") as o: |
|
input = i.read() |
|
output = remove(input, session=new_session(model)) |
|
o.write(output) |
|
return os.path.join(output_path) |
|
|
|
interface = gr.Interface( |
|
inference, |
|
[ |
|
gr.components.Image(type="filepath", label="Input"), |
|
gr.components.Dropdown( |
|
[ |
|
"u2net", |
|
"u2netp", |
|
"u2net_human_seg", |
|
"u2net_cloth_seg", |
|
"silueta", |
|
"isnet-general-use", |
|
"isnet-anime", |
|
], |
|
value="u2net", |
|
label="Models", |
|
), |
|
], |
|
gr.components.Image(type="filepath", label="Output"), |
|
) |
|
|
|
interface.queue(concurrency_count=3) |
|
app = gr.mount_gradio_app(app, interface, path="/") |
|
return app |
|
|
|
print(f"To access the API documentation, go to http://localhost:{port}/api") |
|
print(f"To access the UI, go to http://localhost:{port}") |
|
|
|
uvicorn.run(gr_app(app), host="0.0.0.0", port=port, log_level=log_level) |
|
|