Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
import os | |
import sys | |
# import traceback | |
__dir__ = os.path.dirname(os.path.abspath(__file__)) | |
sys.path.append(__dir__) | |
sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) | |
import base64 | |
import logging | |
import multiprocessing | |
import os | |
import random | |
import time | |
import imghdr | |
from pathlib import Path | |
import cv2 | |
import torch | |
import numpy as np | |
from loguru import logger | |
from lama_cleaner.interactive_seg import InteractiveSeg | |
from lama_cleaner.model_manager import ModelManager | |
from lama_cleaner.schema import Config | |
from lama_cleaner.file_manager import FileManager | |
from lama_cleaner.plugins.remove_bg import RemoveBG | |
try: | |
torch._C._jit_override_can_fuse_on_cpu(False) | |
torch._C._jit_override_can_fuse_on_gpu(False) | |
torch._C._jit_set_texpr_fuser_enabled(False) | |
torch._C._jit_set_nvfuser_enabled(False) | |
except: | |
pass | |
# Disable ability for Flask to display warning about using a development server in a production environment. | |
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356 | |
# cli.show_server_banner = lambda *_: None | |
# from flask_cors import CORS | |
from lama_cleaner.helper import ( | |
load_img, | |
resize_max_size, | |
) | |
NUM_THREADS = str(multiprocessing.cpu_count()) | |
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56 | |
os.environ["KMP_DUPLICATE_LIB_OK"] = "True" | |
os.environ["OMP_NUM_THREADS"] = NUM_THREADS | |
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS | |
os.environ["MKL_NUM_THREADS"] = NUM_THREADS | |
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS | |
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS | |
if os.environ.get("CACHE_DIR"): | |
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] | |
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build") | |
FILE = Path(__file__).resolve() | |
ROOT = FILE.parents[1] # YOLOv5 root directory | |
if str(ROOT) not in sys.path: | |
sys.path.append(str(ROOT)) # add ROOT to PATH | |
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative | |
class NoFlaskwebgui(logging.Filter): | |
def filter(self, record): | |
return "flaskwebgui-keep-server-alive" not in record.getMessage() | |
logging.getLogger("werkzeug").addFilter(NoFlaskwebgui()) | |
# app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) | |
# app.config["JSON_AS_ASCII"] = False | |
# CORS(app, expose_headers=["Content-Disposition"]) | |
model: ModelManager = None | |
thumb: FileManager = None | |
device = None | |
input_image_path: str = None | |
is_disable_model_switch: bool = False | |
is_enable_file_manager: bool = False | |
is_desktop: bool = False | |
plugins = {} | |
def get_image_ext(img_bytes): | |
w = imghdr.what("", img_bytes) | |
if w is None: | |
w = "jpeg" | |
return w | |
def diffuser_callback(i, t, latents): | |
pass | |
# socketio.emit('diffusion_step', {'diffusion_step': step}) | |
config = Config( | |
ldm_steps=25, | |
ldm_sampler='plms', | |
hd_strategy='Resize', # Original, Resize, Crop | |
zits_wireframe=True, | |
hd_strategy_crop_margin=196, | |
hd_strategy_crop_trigger_size=1280, | |
hd_strategy_resize_limit=2048, | |
prompt="", | |
negative_prompt="", | |
use_croper=False, | |
croper_x=None, | |
croper_y=None, | |
croper_height=None, | |
croper_width=None, | |
sd_scale=1, | |
sd_mask_blur=5, | |
sd_strength=0.75, | |
sd_steps=50, | |
sd_guidance_scale=7.5, | |
sd_sampler="pndm", | |
sd_seed=42, | |
sd_match_histograms=False, | |
cv2_flag="INPAINT_NS", | |
cv2_radius=40, | |
paint_by_example_steps=50, | |
paint_by_example_guidance_scale=7.5, | |
paint_by_example_mask_blur=5, | |
paint_by_example_seed=42, | |
paint_by_example_match_histograms=False, | |
paint_by_example_example_image=None, | |
) | |
def process(origin_image_bytes, mask): | |
image, alpha_channel = load_img(origin_image_bytes) | |
mask, _ = load_img(mask, gray=True) | |
mask = np.where(mask > 0, 255, 0).astype(np.uint8) | |
if image.shape[:2] != mask.shape[:2]: | |
return f"Mask shape {mask.shape[:2]} not queal to Image shape {image.shape[:2]}", 400 | |
original_shape = image.shape | |
interpolation = cv2.INTER_CUBIC | |
size_limit = 2048 | |
if size_limit == "Original": | |
size_limit = max(image.shape) | |
else: | |
size_limit = int(size_limit) | |
if config.sd_seed == -1: | |
config.sd_seed = random.randint(1, 999999999) | |
if config.paint_by_example_seed == -1: | |
config.paint_by_example_seed = random.randint(1, 999999999) | |
logger.info(f"Origin image shape: {original_shape}") | |
image = resize_max_size(image, size_limit=size_limit, | |
interpolation=interpolation) | |
logger.info(f"Resized image shape: {image.shape}") | |
mask = resize_max_size(mask, size_limit=size_limit, | |
interpolation=interpolation) | |
start = time.time() | |
try: | |
with torch.no_grad(): | |
res_np_img = model(image, mask, config) | |
except RuntimeError as e: | |
torch.cuda.empty_cache() | |
if "CUDA out of memory. " in str(e): | |
# NOTE: the string may change? | |
return "CUDA out of memory", 500 | |
else: | |
logger.exception(e) | |
return "Internal Server Error", 500 | |
finally: | |
torch.cuda.empty_cache() | |
logger.info(f"process time: {(time.time() - start)}s") | |
if alpha_channel is not None: | |
if alpha_channel.shape[:2] != res_np_img.shape[:2]: | |
alpha_channel = np.resize( | |
alpha_channel, (res_np_img.shape[1], res_np_img.shape[0]) | |
) | |
res_np_img = np.concatenate( | |
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 | |
) | |
img = cv2.imencode('.jpg', res_np_img)[1] | |
return base64.b64encode(img).decode('utf-8') | |
def current_model(): | |
return model.name, 200 | |
def get_is_disable_model_switch(): | |
res = 'true' if is_disable_model_switch else 'false' | |
return res, 200 | |
def switch_model(new_name): | |
if is_disable_model_switch: | |
return "Switch model is disabled", 400 | |
if new_name == model.name: | |
return "Same model", 200 | |
try: | |
model.switch(new_name) | |
except NotImplementedError: | |
return f"{new_name} not implemented", 403 | |
return f"ok, switch to {new_name}", 200 | |
def remove(origin_image_bytes): | |
name = RemoveBG.name | |
rgb_np_img, alpha_channel = load_img(origin_image_bytes) | |
start = time.time() | |
try: | |
bgr_res = plugins[name](rgb_np_img) | |
except RuntimeError as e: | |
torch.cuda.empty_cache() | |
if "CUDA out of memory. " in str(e): | |
return "CUDA out of memory", 500 | |
else: | |
logger.exception(e) | |
return "Internal Server Error", 500 | |
logger.info(f"{name} process time: {(time.time() - start) * 1000}ms") | |
img = cv2.imencode('.png', bgr_res)[1] | |
return base64.b64encode(img).decode('utf-8') | |
def initModel(): | |
global model | |
global device | |
global input_image_path | |
global is_disable_model_switch | |
global is_enable_file_manager | |
global is_desktop | |
global thumb | |
global plugins | |
model_device = "cpu" | |
device = torch.device(model_device) | |
is_disable_model_switch = False | |
is_desktop = False | |
if is_disable_model_switch: | |
logger.info( | |
f"Start with --disable-model-switch, model switch on frontend is disable") | |
model = ModelManager(model_device, callback=diffuser_callback) | |
plugins[RemoveBG.name] = RemoveBG() | |