fffiloni commited on
Commit
7acb2a5
1 Parent(s): b46a257

Create app_with_diffusers.py

Browse files
Files changed (1) hide show
  1. app_with_diffusers.py +68 -0
app_with_diffusers.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+
3
+ hf_hub_download(repo_id="InstantX/InstantIR", filename="models/adapter.pt", local_dir=".")
4
+ hf_hub_download(repo_id="InstantX/InstantIR", filename="models/aggregator.pt", local_dir=".")
5
+ hf_hub_download(repo_id="InstantX/InstantIR", filename="models/previewer_lora_weights.bin", local_dir=".")
6
+
7
+ import torch
8
+ from PIL import Image
9
+
10
+ from diffusers import DDPMScheduler
11
+ from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
12
+
13
+ from module.ip_adapter.utils import load_adapter_to_pipe
14
+ from pipelines.sdxl_instantir import InstantIRPipeline
15
+
16
+ # prepare models under ./models
17
+ instantir_path = f'./models'
18
+
19
+ # load pretrained models
20
+ pipe = InstantIRPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16)
21
+
22
+ # load adapter
23
+ load_adapter_to_pipe(
24
+ pipe,
25
+ f"{instantir_path}/adapter.pt",
26
+ image_encoder_or_path = 'facebook/dinov2-large',
27
+ )
28
+
29
+ # load previewer lora
30
+ pipe.prepare_previewers(instantir_path)
31
+ pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler")
32
+ lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
33
+
34
+ # load aggregator weights
35
+ pretrained_state_dict = torch.load(f"{instantir_path}/aggregator.pt")
36
+ pipe.aggregator.load_state_dict(pretrained_state_dict)
37
+
38
+ # send to GPU and fp16
39
+ pipe.to(device='cuda', dtype=torch.float16)
40
+ pipe.aggregator.to(device='cuda', dtype=torch.float16)
41
+
42
+ def infer(input_image):
43
+ # load a broken image
44
+ low_quality_image = Image.open(input_image).convert("RGB")
45
+
46
+ # InstantIR restoration
47
+ image = pipe(
48
+ image=low_quality_image,
49
+ previewer_scheduler=lcm_scheduler,
50
+ ).images[0]
51
+
52
+ return image
53
+
54
+ import gradio as gr
55
+
56
+ with gr.Blocks() as demo:
57
+ with gr.Column():
58
+ with gr.Row():
59
+ with gr.Column():
60
+ lq_img = gr.Image(label="Low-quality image", type="filepath")
61
+ submit_btn = gr.Button("InstantIR magic!")
62
+ output_img = gr.Image(label="InstantIR restored")
63
+ submit_btn.click(
64
+ fn=infer,
65
+ inputs=[lq_img],
66
+ outputs=[output_img]
67
+ )
68
+ demo.launch(show_error=True)