File size: 6,411 Bytes
ea10013 021bfcb ea10013 f9e31ba 0fde99f f9e31ba ea10013 56f00bc 6d5b21a 5670e2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
from spaces import GPU # 最初にインポート
import gradio as gr
import os
import sys
import math
from typing import List
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from diffusers.utils.import_utils import is_xformers_available
from my_utils.testing_utils import parse_args_paired_testing
from de_net import DEResNet
from s3diff_tile import S3Diff
from torchvision import transforms
from utils.wavelet_color import wavelet_color_fix, adain_color_fix
tensor_transforms = transforms.Compose([
transforms.ToTensor(),
])
args = parse_args_paired_testing()
# Load scheduler, tokenizer and models.
if args.pretrained_path is None:
from huggingface_hub import hf_hub_download
pretrained_path = hf_hub_download(repo_id="zhangap/S3Diff", filename="s3diff.pkl")
else:
pretrained_path = args.pretrained_path
if args.sd_path is None:
from huggingface_hub import snapshot_download
sd_path = snapshot_download(repo_id="stabilityai/sd-turbo")
else:
sd_path = args.sd_path
de_net_path = 'assets/mm-realsr/de_net.pth'
# initialize net_sr
net_sr = S3Diff(lora_rank_unet=args.lora_rank_unet, lora_rank_vae=args.lora_rank_vae, sd_path=sd_path, pretrained_path=pretrained_path, args=args)
net_sr.set_eval()
# initalize degradation estimation network
net_de = DEResNet(num_in_ch=3, num_degradation=2)
net_de.load_model(de_net_path)
net_de = net_de.cuda()
net_de.eval()
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
net_sr.unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
if args.gradient_checkpointing:
net_sr.unet.enable_gradient_checkpointing()
weight_dtype = torch.float32
device = "cuda"
# Move text_encode and vae to gpu and cast to weight_dtype
net_sr.to(device, dtype=weight_dtype)
net_de.to(device, dtype=weight_dtype)
@GPU(duration=100) # GPUを利用する関数にデコレーターを追加 60 to 100
@torch.no_grad()
def process(
input_image: Image.Image,
scale_factor: float,
cfg_scale: float,
latent_tiled_size: int,
latent_tiled_overlap: int,
align_method: str,
) -> List[np.ndarray]:
# positive_prompt = ""
# negative_prompt = ""
net_sr._set_latent_tile(latent_tiled_size = latent_tiled_size, latent_tiled_overlap = latent_tiled_overlap)
im_lr = tensor_transforms(input_image).unsqueeze(0).to(device)
ori_h, ori_w = im_lr.shape[2:]
im_lr_resize = F.interpolate(
im_lr,
size=(int(ori_h * scale_factor),
int(ori_w * scale_factor)),
mode='bilinear',
align_corners=True
)
im_lr_resize = im_lr_resize.contiguous()
im_lr_resize_norm = im_lr_resize * 2 - 1.0
im_lr_resize_norm = torch.clamp(im_lr_resize_norm, -1.0, 1.0)
resize_h, resize_w = im_lr_resize_norm.shape[2:]
pad_h = (math.ceil(resize_h / 64)) * 64 - resize_h
pad_w = (math.ceil(resize_w / 64)) * 64 - resize_w
im_lr_resize_norm = F.pad(im_lr_resize_norm, pad=(0, pad_w, 0, pad_h), mode='reflect')
try:
with torch.autocast("cuda"):
deg_score = net_de(im_lr)
pos_tag_prompt = [args.pos_prompt]
neg_tag_prompt = [args.neg_prompt]
x_tgt_pred = net_sr(im_lr_resize_norm, deg_score, pos_prompt=pos_tag_prompt, neg_prompt=neg_tag_prompt)
x_tgt_pred = x_tgt_pred[:, :, :resize_h, :resize_w]
out_img = (x_tgt_pred * 0.5 + 0.5).cpu().detach()
output_pil = transforms.ToPILImage()(out_img[0])
if align_method == 'no fix':
image = output_pil
else:
im_lr_resize = transforms.ToPILImage()(im_lr_resize[0])
if align_method == 'wavelet':
image = wavelet_color_fix(output_pil, im_lr_resize)
elif align_method == 'adain':
image = adain_color_fix(output_pil, im_lr_resize)
except Exception as e:
print(e)
image = Image.new(mode="RGB", size=(512, 512))
return image
#
MARKDOWN = \
"""
"""
#block = gr.Blocks().queue()
#with block:
# 強制的にAPI情報生成をスキップ(TypeError対策)
gr.Blocks.get_api_info = lambda self: {}
with gr.Blocks() as demo: # 🔴変更: block → demo、明示的にwithで囲む
demo.queue() # 🔴変更: .queue()をwith内で実行
with gr.Row():
gr.Markdown(MARKDOWN)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil")
run_button = gr.Button("Run") # `label="Run"` を削除してボタンのテキストを直接渡す
with gr.Accordion("Options", open=True):
cfg_scale = gr.Slider(label="Classifier Free Guidance Scale (Set a value larger than 1 to enable it!)", minimum=1.0, maximum=1.1, value=1.07, step=0.01)
scale_factor = gr.Number(label="SR Scale", value=4)
latent_tiled_size = gr.Slider(label="Tile Size", minimum=64, maximum=160, value=96, step=1)
latent_tiled_overlap = gr.Slider(label="Tile Overlap", minimum=16, maximum=48, value=32, step=1)
align_method = gr.Dropdown(label="Color Correction", choices=["wavelet", "adain", "no fix"], value="wavelet")
with gr.Column():
#result_image = gr.Image(label="Output", show_label=False, elem_id="result_image", source="canvas", width="100%", height="auto")
result_image = gr.Image(label="Output", show_label=False, elem_id="result_image", width="100%", height="auto")
inputs = [
input_image,
scale_factor,
cfg_scale,
latent_tiled_size,
latent_tiled_overlap,
align_method
]
run_button.click(fn=process, inputs=inputs, outputs=[result_image])
import gradio_client.utils
# 例外の元になっている関数を丸ごと置き換え
def patched_get_type(schema):
if not isinstance(schema, dict):
return "Any" # boolやstrなどへの対策
if "const" in schema:
return repr(schema["const"])
if "enum" in schema:
return "Literal[" + ", ".join(repr(v) for v in schema["enum"]) + "]"
return schema.get("type", "Any")
gradio_client.utils.get_type = patched_get_type
#block.launch()
demo.launch(show_api=False)
|