import os, sys, json, time os.system("pip install gradio==3.36.1") import gradio as gr from PIL import Image import numpy as np import torch import cv2 import io import multiprocessing import random from loguru import logger from utils import * from share_btn import * from lama_cleaner.model_manager import ModelManager from lama_cleaner.schema import Config try: torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(False) torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_nvfuser_enabled(False) except: pass from lama_cleaner.helper import ( load_img, numpy_to_bytes, resize_max_size, ) NUM_THREADS = str(multiprocessing.cpu_count()) # fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56 os.environ["KMP_DUPLICATE_LIB_OK"] = "True" os.environ["OMP_NUM_THREADS"] = NUM_THREADS os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS os.environ["MKL_NUM_THREADS"] = NUM_THREADS os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS if os.environ.get("CACHE_DIR"): os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] HF_TOKEN_SD = os.environ.get('HF_TOKEN_SD') device = "cuda" if torch.cuda.is_available() else "cpu" print(f'device = {device}') def read_content(file_path: str) -> str: """read the content of target file """ with open(file_path, 'r', encoding='utf-8') as f: content = f.read() return content def get_image_enhancer(scale = 2, device='cuda:0'): from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer from realesrgan.archs.srvgg_arch import SRVGGNetCompact from gfpgan import GFPGANer realesrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4 ) netscale = scale model_realesrgan = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth' upsampler = RealESRGANer( scale=netscale, model_path=model_realesrgan, model=realesrgan_model, tile=0, tile_pad=10, pre_pad=0, half=False if device=='cpu' else True, device=device ) model_GFPGAN = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth' img_enhancer = GFPGANer( model_path=model_GFPGAN, upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler, device=device ) return img_enhancer image_enhancer = None if sys.platform == 'linux' and 0==1: image_enhancer = get_image_enhancer(scale = 1, device=device) model = None def model_process(image, mask, img_enhancer): global model,image_enhancer ori_image = image if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]: # rotate image ori_image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...] image = ori_image original_shape = ori_image.shape interpolation = cv2.INTER_CUBIC size_limit = 1080 if size_limit == "Original": size_limit = max(image.shape) else: size_limit = int(size_limit) config = Config( ldm_steps=25, ldm_sampler='plms', zits_wireframe=True, hd_strategy='Original', hd_strategy_crop_margin=196, hd_strategy_crop_trigger_size=1280, hd_strategy_resize_limit=2048, prompt='', use_croper=False, croper_x=0, croper_y=0, croper_height=512, croper_width=512, sd_mask_blur=5, sd_strength=0.75, sd_steps=50, sd_guidance_scale=7.5, sd_sampler='ddim', sd_seed=42, cv2_flag='INPAINT_NS', cv2_radius=5, ) if config.sd_seed == -1: config.sd_seed = random.randint(1, 999999999) logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}") image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) logger.info(f"Resized image shape_1_: {image.shape}") logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}") mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}") if model is None: return None res_np_img = model(image, mask, config) torch.cuda.empty_cache() image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png'))) if image_enhancer is not None and img_enhancer: start = time.time() input_img_rgb = np.array(image) input_img_bgr = input_img_rgb[...,[2,1,0]] _, _, enhance_img = image_enhancer.enhance(input_img_bgr, has_aligned=False, only_center_face=False, paste_back=True) input_img_rgb = enhance_img[...,[2,1,0]] img_enhance = Image.fromarray(np.uint8(input_img_rgb)) image = img_enhance log_info = f"image_enhancer_: {(time.time() - start) * 1000}ms, {res_np_img.shape} " logger.info(log_info) return image, Image.fromarray(ori_image) def resize_image(pil_image, new_width=400): width, height = pil_image.size new_height = int(height*(new_width/width)) pil_image = pil_image.resize((new_width, new_height)) return pil_image model = ModelManager( name='lama', device=device, ) image_type = 'pil' # filepath' def predict(input, platform_radio, img_enhancer): if input is None: return None, [], gr.update(visible=False) if image_type == 'filepath': # input: {'image': '/tmp/tmp8mn9xw93.png', 'mask': '/tmp/tmpn5ars4te.png'} origin_image_bytes = open(input["image"], 'rb').read() print(f'origin_image_bytes = ', type(origin_image_bytes), len(origin_image_bytes)) image, _ = load_img(origin_image_bytes) mask, _ = load_img(open(input["mask"], 'rb').read(), gray=True) elif image_type == 'pil': # input: {'image': pil, 'mask': pil} image_pil = input['image'] mask_pil = input['mask'] image = np.array(image_pil) mask = np.array(mask_pil.convert("L")) output, ori_image = model_process(image, mask, img_enhancer) if platform_radio == 'pc': return output, [ori_image, output], gr.update(visible=True) else: return output, [resize_image(ori_image, new_width=400), resize_image(output, new_width=400)], gr.update(visible=True) image_blocks = gr.Blocks(css=css, title='Image Cleaner') with image_blocks as demo: with gr.Group(elem_id="page_1", visible=True) as page_1: with gr.Box(): with gr.Row(elem_id="gallery_row"): with gr.Column(elem_id="gallery_col"): gallery = gr.Gallery(value=['./sample_00.jpg','./sample_00_e.jpg'], show_label=False) gallery.style(grid=[2], height='500px') with gr.Row(): with gr.Column(): begin_button = gr.Button("Let's GO!", elem_id="begin-btn", visible=True) with gr.Row(): with gr.Column(): gr.HTML("""

Solemnly promise: this application will not collect any user information and image resources.

The model comes from [Lama]. Thanks! ❤️
[huggingface.co] provides code hosting. Thanks! ❤️
""" ) with gr.Group(elem_id="page_2", visible=False) as page_2: with gr.Box(elem_id="work-container"): with gr.Row(elem_id="input-container"): with gr.Column(): image_input = gr.Image(source='upload', elem_id="image_upload",tool='sketch', type=f'{image_type}', label="Upload", show_label=False).style(mobile_collapse=False) with gr.Row(elem_id="scroll_x_row"): with gr.Column(id="scroll_x_col"): gr.HTML("""
""" ) with gr.Row(elem_id="op-container").style(mobile_collapse=False, equal_height=True): with gr.Column(elem_id="erase-btn-container"): erase_btn = gr.Button(value = "Erase(⏬)",elem_id="erase-btn").style( margin=True, rounded=(True, True, True, True), full_width=True, ).style(width=100) with gr.Column(elem_id="enhancer-checkbox", visible=True if image_enhancer is not None else False): enhancer_label = 'Enhanced image(processing is very slow, please check only for blurred images)' img_enhancer = gr.Checkbox(label=enhancer_label).style(width=150) with gr.Row(elem_id="output-container"): with gr.Column(elem_id="image-output-container"): image_out = gr.Image(elem_id="image_output",label="Result", show_label=False, visible=False) with gr.Column(): gallery = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery" ).style(grid=[2], height="600px") platform_radio = gr.Radio(["pc", "mobile"], elem_id="platform_radio",value="pc", label="platform:", show_label=True, visible=False) with gr.Row(elem_id="download-container", visible=False) as download_container: with gr.Column(elem_id="download-btn-container") as download_btn_container: download_button = gr.Button(elem_id="download-btn", value="Save(⏩)") with gr.Column(elem_id="share-container") as share_container: with gr.Group(elem_id="share-btn-container"): community_icon = gr.HTML(community_icon_html, elem_id="community-icon", visible=True) loading_icon = gr.HTML(loading_icon_html, elem_id="loading-icon", visible=True) share_button = gr.Button("Share to community", elem_id="share-btn", visible=True) with gr.Row(elem_id="log_row"): with gr.Column(id="log_col"): gr.HTML("""
""" ) erase_btn.click(fn=predict, inputs=[image_input, platform_radio, img_enhancer], outputs=[image_out, gallery, download_container]) download_button.click(None, [], [], _js=download_img) share_button.click(None, [], [], _js=share_js) begin_button.click(fn=None, inputs=[], outputs=[page_1, page_2], _js=start_cleaner) os.system("pip list") image_blocks.launch(server_name='0.0.0.0')