owiedotch commited on
Commit
182f0d5
·
verified ·
1 Parent(s): f77813c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -24
app.py CHANGED
@@ -2,13 +2,15 @@ import os
2
  import sys
3
  import torch
4
  import gradio as gr
5
- import spaces
6
  from PIL import Image
7
  import numpy as np
8
  from omegaconf import OmegaConf
9
- import requests
10
- from tqdm import tqdm
11
  import subprocess
 
 
 
 
 
12
 
13
  def download_file(url, filename):
14
  response = requests.get(url, stream=True)
@@ -45,38 +47,52 @@ def setup_environment():
45
 
46
  setup_environment()
47
 
48
- from ccsr.models.ccsr import CCSR
49
- from ccsr.utils.util import instantiate_from_config
 
 
 
50
 
51
  config = OmegaConf.load("configs/model/ccsr_stage2.yaml")
52
- model = instantiate_from_config(config.model)
53
  ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu")
54
- model.load_state_dict(ckpt["state_dict"], strict=False)
55
- model.cuda().eval()
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- @spaces.GPU
58
- @torch.inference_mode()
59
- def infer(image, sr_scale, t_max, t_min, color_fix_type):
60
- image = Image.open(image).convert("RGB").resize((256, 256), Image.LANCZOS)
61
- image = torch.from_numpy(np.array(image)).float().cuda() / 127.5 - 1
62
- image = image.permute(2, 0, 1).unsqueeze(0)
63
 
64
- output = model.super_resolution(
65
- image,
66
- sr_scale=sr_scale,
67
- t_max=t_max,
68
- t_min=t_min,
69
- color_fix_type=color_fix_type
 
 
70
  )
71
 
72
- output = ((output.squeeze().permute(1, 2, 0).cpu().numpy() + 1) * 127.5).clip(0, 255).astype(np.uint8)
73
- return Image.fromarray(output)
 
 
74
 
75
  interface = gr.Interface(
76
- fn=infer,
77
  inputs=[
78
  gr.Image(type="filepath"),
79
- gr.Slider(minimum=1, maximum=8, step=1, value=4),
80
  gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.6667),
81
  gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.3333),
82
  gr.Dropdown(choices=["adain", "wavelet", "none"], value="adain"),
 
2
  import sys
3
  import torch
4
  import gradio as gr
 
5
  from PIL import Image
6
  import numpy as np
7
  from omegaconf import OmegaConf
 
 
8
  import subprocess
9
+ from tqdm import tqdm
10
+ import requests
11
+
12
+ # Assuming spaces is a valid module
13
+ import spaces
14
 
15
  def download_file(url, filename):
16
  response = requests.get(url, stream=True)
 
47
 
48
  setup_environment()
49
 
50
+ # Importing from the CCSR folder
51
+ from CCSR.ldm.xformers_state import disable_xformers
52
+ from CCSR.model.q_sampler import SpacedSampler
53
+ from CCSR.model.ccsr_stage1 import ControlLDM
54
+ from CCSR.utils.common import instantiate_from_config, load_state_dict
55
 
56
  config = OmegaConf.load("configs/model/ccsr_stage2.yaml")
57
+ model = instantiate_from_config(config)
58
  ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu")
59
+ load_state_dict(model, ckpt, strict=True)
60
+ model.freeze()
61
+ model.to("cuda")
62
+
63
+ @spaces.GPU # Decorate the inference function with @spaces.GPU
64
+ @torch.no_grad()
65
+ def process(image, steps, t_max, t_min, color_fix_type):
66
+ image = Image.open(image).convert("RGB")
67
+ image = image.resize((256, 256), Image.LANCZOS)
68
+ image = np.array(image)
69
+
70
+ sampler = SpacedSampler(model, var_type="fixed_small")
71
+ control = torch.tensor(np.stack([image]) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
72
+ control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
73
 
74
+ model.control_scales = [1.0] * 13
 
 
 
 
 
75
 
76
+ height, width = control.size(-2), control.size(-1)
77
+ shape = (1, 4, height // 8, width // 8)
78
+ x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
79
+
80
+ samples = sampler.sample_ccsr(
81
+ steps=steps, t_max=t_max, t_min=t_min, shape=shape, cond_img=control,
82
+ positive_prompt="", negative_prompt="", x_T=x_T,
83
+ cfg_scale=1.0, color_fix_type=color_fix_type
84
  )
85
 
86
+ x_samples = samples.clamp(0, 1)
87
+ x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
88
+
89
+ return Image.fromarray(x_samples[0])
90
 
91
  interface = gr.Interface(
92
+ fn=process,
93
  inputs=[
94
  gr.Image(type="filepath"),
95
+ gr.Slider(minimum=1, maximum=100, step=1, value=45),
96
  gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.6667),
97
  gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.3333),
98
  gr.Dropdown(choices=["adain", "wavelet", "none"], value="adain"),