Spaces:
Running
Running
import base64 | |
from rembg import remove, new_session | |
from typing import Annotated | |
from fastapi import FastAPI, UploadFile, Request, File, Form | |
from fastapi.templating import Jinja2Templates | |
from fastapi.responses import RedirectResponse, HTMLResponse, Response | |
import uvicorn | |
import os | |
os.environ['U2NET_HOME'] = './models/' | |
model_names = [ | |
'u2net', | |
'u2net_human_seg', | |
'u2net_cloth_seg', | |
'isnet-general-use', | |
] | |
sessions = { | |
'u2net': new_session(model_name=model_names[0]), | |
} | |
app = FastAPI() | |
templates = Jinja2Templates(directory="templates") | |
def health(request: Request): | |
return templates.TemplateResponse('dynamic.html', { "request": request }) | |
async def remove_bg( | |
request: Request, | |
file: Annotated[UploadFile, File()], | |
mask_only: Annotated[str, Form()] = 'off', | |
name_of_model: Annotated[str, Form()] = 'u2net' | |
) -> HTMLResponse: | |
try: | |
if name_of_model not in sessions.keys(): | |
sessions[name_of_model] = new_session(model_name=name_of_model) | |
current_session = sessions[name_of_model] | |
only_mask = mask_only == 'on' | |
data = file.file.read() | |
output_array = remove(data, only_mask=only_mask, session=current_session) | |
output_img = base64.b64encode(output_array).decode('utf-8') | |
file.file.close() | |
encoded_image = base64.b64encode(data).decode('utf-8') | |
return templates.TemplateResponse( | |
'dynamic.html', | |
{ | |
"request" : request, | |
"image" : encoded_image, | |
"output_img" : output_img | |
} | |
) | |
except Exception as error_msg: | |
return templates.TemplateResponse( | |
"error.html", | |
{ | |
"request" : request, | |
"error_msg" : str(error_msg), | |
} | |
) | |
async def remove_bg( | |
request: Request, | |
file: UploadFile, | |
mask_only: str = 'off', | |
name_of_model: str = 'u2net' | |
): | |
try: | |
if name_of_model not in sessions.keys(): | |
sessions[name_of_model] = new_session(model_name=name_of_model) | |
current_session = sessions[name_of_model] | |
only_mask = mask_only == 'on' | |
data = file.file.read() | |
output_array = remove(data, only_mask=only_mask, session=current_session) | |
file.file.close() | |
return Response(content=output_array, media_type="image/png") | |
except Exception as error: | |
return f"Oopss!!!! {error}" | |
def remove_bg_redirect(): | |
return RedirectResponse('/') | |
if __name__ == '__main__': | |
uvicorn.run(app, host='0.0.0.0', port=7860) | |