Spaces:
Running
Running
File size: 3,382 Bytes
6381c79 f77813c 6381c79 f77813c 182f0d5 6381c79 d527cc3 f77813c d527cc3 f77813c d527cc3 f77813c 6381c79 d527cc3 33da899 6381c79 182f0d5 6381c79 182f0d5 6381c79 182f0d5 6381c79 182f0d5 6381c79 182f0d5 6381c79 182f0d5 6381c79 182f0d5 6381c79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import os
import sys
import torch
import gradio as gr
from PIL import Image
import numpy as np
from omegaconf import OmegaConf
import subprocess
from tqdm import tqdm
import requests
def download_file(url, filename):
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
block_size = 1024
with open(filename, 'wb') as file, tqdm(
desc=filename,
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for data in response.iter_content(block_size):
size = file.write(data)
progress_bar.update(size)
def setup_environment():
if not os.path.exists("CCSR"):
print("Cloning CCSR repository...")
subprocess.run(["git", "clone", "-b", "dev", "https://github.com/camenduru/CCSR.git"])
os.chdir("CCSR")
sys.path.append(os.getcwd())
os.makedirs("weights", exist_ok=True)
if not os.path.exists("weights/real-world_ccsr.ckpt"):
print("Downloading model checkpoint...")
download_file(
"https://huggingface.co/camenduru/CCSR/resolve/main/real-world_ccsr.ckpt",
"weights/real-world_ccsr.ckpt"
)
else:
print("Model checkpoint already exists. Skipping download.")
setup_environment()
# Importing from the CCSR folder
from ldm.xformers_state import disable_xformers
from model.q_sampler import SpacedSampler
from model.ccsr_stage1 import ControlLDM
from utils.common import instantiate_from_config, load_state_dict
config = OmegaConf.load("configs/model/ccsr_stage2.yaml")
model = instantiate_from_config(config)
ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu")
load_state_dict(model, ckpt, strict=True)
model.freeze()
model.to("cuda")
@torch.no_grad()
def process(image, steps, t_max, t_min, color_fix_type):
image = Image.open(image).convert("RGB")
image = image.resize((256, 256), Image.LANCZOS)
image = np.array(image)
sampler = SpacedSampler(model, var_type="fixed_small")
control = torch.tensor(np.stack([image]) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
model.control_scales = [1.0] * 13
height, width = control.size(-2), control.size(-1)
shape = (1, 4, height // 8, width // 8)
x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
samples = sampler.sample_ccsr(
steps=steps, t_max=t_max, t_min=t_min, shape=shape, cond_img=control,
positive_prompt="", negative_prompt="", x_T=x_T,
cfg_scale=1.0, color_fix_type=color_fix_type
)
x_samples = samples.clamp(0, 1)
x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
return Image.fromarray(x_samples[0])
interface = gr.Interface(
fn=process,
inputs=[
gr.Image(type="filepath"),
gr.Slider(minimum=1, maximum=100, step=1, value=45),
gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.6667),
gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.3333),
gr.Dropdown(choices=["adain", "wavelet", "none"], value="adain"),
],
outputs=gr.Image(type="pil"),
title="CCSR: Continuous Contrastive Super-Resolution",
)
interface.launch() |