Spaces:
Running
on
Zero
Running
on
Zero
from huggingface_hub import hf_hub_download | |
hf_hub_download(repo_id="InstantX/InstantIR", filename="models/adapter.pt", local_dir=".") | |
hf_hub_download(repo_id="InstantX/InstantIR", filename="models/aggregator.pt", local_dir=".") | |
hf_hub_download(repo_id="InstantX/InstantIR", filename="models/previewer_lora_weights.bin", local_dir=".") | |
import torch | |
from PIL import Image | |
from diffusers import DDPMScheduler | |
from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler | |
from module.ip_adapter.utils import load_adapter_to_pipe | |
from pipelines.sdxl_instantir import InstantIRPipeline | |
# prepare models under ./models | |
instantir_path = f'./models' | |
# load pretrained models | |
pipe = InstantIRPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16) | |
# load adapter | |
load_adapter_to_pipe( | |
pipe, | |
f"{instantir_path}/adapter.pt", | |
image_encoder_or_path = 'facebook/dinov2-large', | |
) | |
# load previewer lora | |
pipe.prepare_previewers(instantir_path) | |
pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler") | |
lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config) | |
# load aggregator weights | |
pretrained_state_dict = torch.load(f"{instantir_path}/aggregator.pt") | |
pipe.aggregator.load_state_dict(pretrained_state_dict) | |
# send to GPU and fp16 | |
pipe.to(device='cuda', dtype=torch.float16) | |
pipe.aggregator.to(device='cuda', dtype=torch.float16) | |
def infer(input_image): | |
# load a broken image | |
low_quality_image = Image.open(input_image).convert("RGB") | |
# InstantIR restoration | |
image = pipe( | |
image=low_quality_image, | |
previewer_scheduler=lcm_scheduler, | |
).images[0] | |
return image | |
import gradio as gr | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
lq_img = gr.Image(label="Low-quality image", type="filepath") | |
submit_btn = gr.Button("InstantIR magic!") | |
output_img = gr.Image(label="InstantIR restored") | |
submit_btn.click( | |
fn=infer, | |
inputs=[lq_img], | |
outputs=[output_img] | |
) | |
demo.launch(show_error=True) |