|
|
|
import os |
|
import sys |
|
|
|
|
|
__dir__ = os.path.dirname(os.path.abspath(__file__)) |
|
sys.path.append(__dir__) |
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) |
|
|
|
import base64 |
|
|
|
import logging |
|
import multiprocessing |
|
import os |
|
import random |
|
import time |
|
import imghdr |
|
from pathlib import Path |
|
|
|
import cv2 |
|
import torch |
|
import numpy as np |
|
from loguru import logger |
|
|
|
from lama_cleaner.interactive_seg import InteractiveSeg |
|
from lama_cleaner.model_manager import ModelManager |
|
from lama_cleaner.schema import Config |
|
from lama_cleaner.file_manager import FileManager |
|
|
|
try: |
|
torch._C._jit_override_can_fuse_on_cpu(False) |
|
torch._C._jit_override_can_fuse_on_gpu(False) |
|
torch._C._jit_set_texpr_fuser_enabled(False) |
|
torch._C._jit_set_nvfuser_enabled(False) |
|
except: |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
from lama_cleaner.helper import ( |
|
load_img, |
|
resize_max_size, |
|
) |
|
|
|
NUM_THREADS = str(multiprocessing.cpu_count()) |
|
|
|
|
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "True" |
|
|
|
os.environ["OMP_NUM_THREADS"] = NUM_THREADS |
|
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS |
|
os.environ["MKL_NUM_THREADS"] = NUM_THREADS |
|
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS |
|
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS |
|
if os.environ.get("CACHE_DIR"): |
|
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] |
|
|
|
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build") |
|
|
|
|
|
FILE = Path(__file__).resolve() |
|
ROOT = FILE.parents[1] |
|
if str(ROOT) not in sys.path: |
|
sys.path.append(str(ROOT)) |
|
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) |
|
|
|
class NoFlaskwebgui(logging.Filter): |
|
def filter(self, record): |
|
return "flaskwebgui-keep-server-alive" not in record.getMessage() |
|
|
|
|
|
logging.getLogger("werkzeug").addFilter(NoFlaskwebgui()) |
|
|
|
|
|
|
|
|
|
|
|
model: ModelManager = None |
|
thumb: FileManager = None |
|
interactive_seg_model: InteractiveSeg = None |
|
device = None |
|
input_image_path: str = None |
|
is_disable_model_switch: bool = False |
|
is_enable_file_manager: bool = False |
|
is_desktop: bool = False |
|
|
|
|
|
def get_image_ext(img_bytes): |
|
w = imghdr.what("", img_bytes) |
|
if w is None: |
|
w = "jpeg" |
|
return w |
|
|
|
|
|
def diffuser_callback(i, t, latents): |
|
pass |
|
|
|
|
|
|
|
|
|
config = Config( |
|
ldm_steps=25, |
|
ldm_sampler='plms', |
|
hd_strategy='Original', |
|
zits_wireframe=True, |
|
hd_strategy_crop_margin=196, |
|
hd_strategy_crop_trigger_size=1280, |
|
hd_strategy_resize_limit=2048, |
|
prompt="", |
|
negative_prompt="", |
|
use_croper=False, |
|
croper_x=None, |
|
croper_y=None, |
|
croper_height=None, |
|
croper_width=None, |
|
sd_scale=1, |
|
sd_mask_blur=5, |
|
sd_strength=0.75, |
|
sd_steps=50, |
|
sd_guidance_scale=7.5, |
|
sd_sampler="pndm", |
|
sd_seed=42, |
|
sd_match_histograms=False, |
|
cv2_flag="INPAINT_NS", |
|
cv2_radius=40, |
|
paint_by_example_steps=50, |
|
paint_by_example_guidance_scale=7.5, |
|
paint_by_example_mask_blur=5, |
|
paint_by_example_seed=42, |
|
paint_by_example_match_histograms=False, |
|
paint_by_example_example_image=None, |
|
) |
|
|
|
|
|
|
|
|
|
def process(origin_image_bytes, mask): |
|
|
|
image, alpha_channel = load_img(origin_image_bytes) |
|
|
|
mask, _ = load_img(mask, gray=True) |
|
mask = np.where(mask > 0, 255, 0).astype(np.uint8) |
|
|
|
if image.shape[:2] != mask.shape[:2]: |
|
return f"Mask shape {mask.shape[:2]} not queal to Image shape {image.shape[:2]}", 400 |
|
|
|
original_shape = image.shape |
|
interpolation = cv2.INTER_CUBIC |
|
|
|
size_limit = 2048 |
|
if size_limit == "Original": |
|
size_limit = max(image.shape) |
|
else: |
|
size_limit = int(size_limit) |
|
|
|
if config.sd_seed == -1: |
|
config.sd_seed = random.randint(1, 999999999) |
|
if config.paint_by_example_seed == -1: |
|
config.paint_by_example_seed = random.randint(1, 999999999) |
|
|
|
logger.info(f"Origin image shape: {original_shape}") |
|
image = resize_max_size(image, size_limit=size_limit, |
|
interpolation=interpolation) |
|
logger.info(f"Resized image shape: {image.shape}") |
|
|
|
mask = resize_max_size(mask, size_limit=size_limit, |
|
interpolation=interpolation) |
|
|
|
start = time.time() |
|
try: |
|
with torch.no_grad(): |
|
res_np_img = model(image, mask, config) |
|
except RuntimeError as e: |
|
torch.cuda.empty_cache() |
|
if "CUDA out of memory. " in str(e): |
|
|
|
return "CUDA out of memory", 500 |
|
else: |
|
logger.exception(e) |
|
return "Internal Server Error", 500 |
|
finally: |
|
torch.cuda.empty_cache() |
|
logger.info(f"process time: {(time.time() - start)}s") |
|
|
|
if alpha_channel is not None: |
|
if alpha_channel.shape[:2] != res_np_img.shape[:2]: |
|
alpha_channel = np.resize( |
|
alpha_channel, (res_np_img.shape[1], res_np_img.shape[0]) |
|
) |
|
res_np_img = np.concatenate( |
|
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 |
|
) |
|
|
|
|
|
img = cv2.imencode('.jpg', res_np_img)[1] |
|
return base64.b64encode(img).decode('utf-8') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def current_model(): |
|
return model.name, 200 |
|
|
|
|
|
def get_is_disable_model_switch(): |
|
res = 'true' if is_disable_model_switch else 'false' |
|
return res, 200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def switch_model(new_name): |
|
if is_disable_model_switch: |
|
return "Switch model is disabled", 400 |
|
|
|
if new_name == model.name: |
|
return "Same model", 200 |
|
|
|
try: |
|
model.switch(new_name) |
|
except NotImplementedError: |
|
return f"{new_name} not implemented", 403 |
|
return f"ok, switch to {new_name}", 200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def initModel(): |
|
|
|
global model |
|
global interactive_seg_model |
|
global device |
|
global input_image_path |
|
global is_disable_model_switch |
|
global is_enable_file_manager |
|
global is_desktop |
|
global thumb |
|
|
|
model_device = "cuda" |
|
|
|
if not torch.cuda.is_available(): |
|
model_device = "cpu" |
|
|
|
device = torch.device(model_device) |
|
is_disable_model_switch = False |
|
is_desktop = False |
|
if is_disable_model_switch: |
|
logger.info( |
|
f"Start with --disable-model-switch, model switch on frontend is disable") |
|
|
|
model = ModelManager(model_device, callback=diffuser_callback) |
|
|
|
interactive_seg_model = InteractiveSeg() |
|
|