|
import os |
|
|
|
import gradio as gr |
|
import torch |
|
from basicsr.archs.srvgg_arch import SRVGGNetCompact |
|
from gfpgan.utils import GFPGANer |
|
from huggingface_hub import hf_hub_download |
|
from realesrgan.utils import RealESRGANer |
|
|
|
REALESRGAN_REPO_ID = 'leonelhs/realesrgan' |
|
GFPGAN_REPO_ID = 'leonelhs/gfpgan' |
|
|
|
os.system("pip freeze") |
|
|
|
|
|
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') |
|
model_path = hf_hub_download(repo_id=REALESRGAN_REPO_ID, filename='realesr-general-x4v3.pth') |
|
half = True if torch.cuda.is_available() else False |
|
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half) |
|
|
|
|
|
def download_model(file): |
|
return hf_hub_download(repo_id=GFPGAN_REPO_ID, filename=file) |
|
|
|
|
|
def predict(image, version, scale): |
|
scale = int(scale) |
|
face_enhancer = None |
|
|
|
if version == 'v1.2': |
|
path = download_model('GFPGANv1.2.pth') |
|
face_enhancer = GFPGANer( |
|
model_path=path, upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) |
|
elif version == 'v1.3': |
|
path = download_model('GFPGANv1.3.pth') |
|
face_enhancer = GFPGANer( |
|
model_path=path, upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) |
|
elif version == 'v1.4': |
|
path = download_model('GFPGANv1.4.pth') |
|
face_enhancer = GFPGANer( |
|
model_path=path, upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) |
|
elif version == 'RestoreFormer': |
|
path = download_model('RestoreFormer.pth') |
|
face_enhancer = GFPGANer( |
|
model_path=path, upscale=scale, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler) |
|
|
|
_, _, output = face_enhancer.enhance(image, has_aligned=False, only_center_face=False, paste_back=True) |
|
|
|
return output |
|
|
|
|
|
title = "GFPGAN" |
|
description = r""" |
|
<b>Practical Face Restoration Algorithm</b> |
|
""" |
|
article = r""" |
|
<center><span>[email protected] or [email protected]</span></center> |
|
</br> |
|
<center><a href='https://github.com/TencentARC/GFPGAN' target='_blank'>Github Repo ⭐ </a> are welcome</center> |
|
""" |
|
|
|
demo = gr.Interface( |
|
predict, [ |
|
gr.Image(type="numpy", label="Input"), |
|
gr.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'], type="value", value='v1.4', label='version'), |
|
gr.Dropdown(["1", "2", "3", "4"], value="2", label="Rescaling factor") |
|
], [ |
|
gr.Image(type="numpy", label="Output", interactive=False) |
|
], |
|
title=title, |
|
description=description, |
|
article=article) |
|
|
|
demo.queue().launch() |
|
|