#!/usr/bin/env python3 import os import sys # import traceback __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) import base64 import logging import multiprocessing import os import random import time import imghdr from pathlib import Path import cv2 import torch import numpy as np from loguru import logger from lama_cleaner.interactive_seg import InteractiveSeg from lama_cleaner.model_manager import ModelManager from lama_cleaner.schema import Config from lama_cleaner.file_manager import FileManager from lama_cleaner.plugins.remove_bg import RemoveBG try: torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(False) torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_nvfuser_enabled(False) except: pass # Disable ability for Flask to display warning about using a development server in a production environment. # https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356 # cli.show_server_banner = lambda *_: None # from flask_cors import CORS from lama_cleaner.helper import ( load_img, resize_max_size, ) NUM_THREADS = str(multiprocessing.cpu_count()) # fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56 os.environ["KMP_DUPLICATE_LIB_OK"] = "True" os.environ["OMP_NUM_THREADS"] = NUM_THREADS os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS os.environ["MKL_NUM_THREADS"] = NUM_THREADS os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS if os.environ.get("CACHE_DIR"): os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build") FILE = Path(__file__).resolve() ROOT = FILE.parents[1] # YOLOv5 root directory if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) # add ROOT to PATH ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative class NoFlaskwebgui(logging.Filter): def filter(self, record): return "flaskwebgui-keep-server-alive" not in record.getMessage() logging.getLogger("werkzeug").addFilter(NoFlaskwebgui()) # app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) # app.config["JSON_AS_ASCII"] = False # CORS(app, expose_headers=["Content-Disposition"]) model: ModelManager = None thumb: FileManager = None device = None input_image_path: str = None is_disable_model_switch: bool = False is_enable_file_manager: bool = False is_desktop: bool = False plugins = {} def get_image_ext(img_bytes): w = imghdr.what("", img_bytes) if w is None: w = "jpeg" return w def diffuser_callback(i, t, latents): pass # socketio.emit('diffusion_step', {'diffusion_step': step}) config = Config( ldm_steps=25, ldm_sampler='plms', hd_strategy='Resize', # Original, Resize, Crop zits_wireframe=True, hd_strategy_crop_margin=196, hd_strategy_crop_trigger_size=1280, hd_strategy_resize_limit=2048, prompt="", negative_prompt="", use_croper=False, croper_x=None, croper_y=None, croper_height=None, croper_width=None, sd_scale=1, sd_mask_blur=5, sd_strength=0.75, sd_steps=50, sd_guidance_scale=7.5, sd_sampler="pndm", sd_seed=42, sd_match_histograms=False, cv2_flag="INPAINT_NS", cv2_radius=40, paint_by_example_steps=50, paint_by_example_guidance_scale=7.5, paint_by_example_mask_blur=5, paint_by_example_seed=42, paint_by_example_match_histograms=False, paint_by_example_example_image=None, ) def process(origin_image_bytes, mask): image, alpha_channel = load_img(origin_image_bytes) mask, _ = load_img(mask, gray=True) mask = np.where(mask > 0, 255, 0).astype(np.uint8) if image.shape[:2] != mask.shape[:2]: return f"Mask shape {mask.shape[:2]} not queal to Image shape {image.shape[:2]}", 400 original_shape = image.shape interpolation = cv2.INTER_CUBIC size_limit = 2048 if size_limit == "Original": size_limit = max(image.shape) else: size_limit = int(size_limit) if config.sd_seed == -1: config.sd_seed = random.randint(1, 999999999) if config.paint_by_example_seed == -1: config.paint_by_example_seed = random.randint(1, 999999999) logger.info(f"Origin image shape: {original_shape}") image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) logger.info(f"Resized image shape: {image.shape}") mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) start = time.time() try: with torch.no_grad(): res_np_img = model(image, mask, config) except RuntimeError as e: torch.cuda.empty_cache() if "CUDA out of memory. " in str(e): # NOTE: the string may change? return "CUDA out of memory", 500 else: logger.exception(e) return "Internal Server Error", 500 finally: torch.cuda.empty_cache() logger.info(f"process time: {(time.time() - start)}s") if alpha_channel is not None: if alpha_channel.shape[:2] != res_np_img.shape[:2]: alpha_channel = np.resize( alpha_channel, (res_np_img.shape[1], res_np_img.shape[0]) ) res_np_img = np.concatenate( (res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 ) img = cv2.imencode('.jpg', res_np_img)[1] return base64.b64encode(img).decode('utf-8') def current_model(): return model.name, 200 def get_is_disable_model_switch(): res = 'true' if is_disable_model_switch else 'false' return res, 200 def switch_model(new_name): if is_disable_model_switch: return "Switch model is disabled", 400 if new_name == model.name: return "Same model", 200 try: model.switch(new_name) except NotImplementedError: return f"{new_name} not implemented", 403 return f"ok, switch to {new_name}", 200 def remove(origin_image_bytes): name = RemoveBG.name rgb_np_img, alpha_channel = load_img(origin_image_bytes) start = time.time() try: bgr_res = plugins[name](rgb_np_img) except RuntimeError as e: torch.cuda.empty_cache() if "CUDA out of memory. " in str(e): return "CUDA out of memory", 500 else: logger.exception(e) return "Internal Server Error", 500 logger.info(f"{name} process time: {(time.time() - start) * 1000}ms") img = cv2.imencode('.png', bgr_res)[1] return base64.b64encode(img).decode('utf-8') def initModel(): global model global device global input_image_path global is_disable_model_switch global is_enable_file_manager global is_desktop global thumb global plugins model_device = "cpu" device = torch.device(model_device) is_disable_model_switch = False is_desktop = False if is_disable_model_switch: logger.info( f"Start with --disable-model-switch, model switch on frontend is disable") model = ModelManager(model_device, callback=diffuser_callback) plugins[RemoveBG.name] = RemoveBG()