Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,062 Bytes
369edb1 0c98b49 369edb1 f98bb34 369edb1 f68a8d0 7108a6a f68a8d0 7108a6a 0a4dd2d f68a8d0 7108a6a 0a4dd2d f68a8d0 0511af3 98a50a9 f68a8d0 7108a6a f68a8d0 a50b44f 402fe71 6dbbd62 06642da 543d5ec 4948a0e b2790c5 27419c1 402fe71 2b32e3d b2790c5 dd5cc2b 63bfd3a 906508c 543d5ec 906508c 543d5ec 63bfd3a ef92e8a 63bfd3a 543d5ec 63bfd3a 543d5ec 2b32e3d 63bfd3a 2b32e3d f091221 543d5ec 906508c 543d5ec f091221 2b32e3d f68a8d0 a9deb63 2b32e3d 402fe71 543d5ec 402fe71 edcfd06 402fe71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import spaces
import gradio as gr
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL, TCDScheduler
from PIL import Image
import os
import time
from utils.dl_utils import dl_cn_model, dl_cn_config, dl_tagger_model, dl_lora_model
from utils.image_utils import resize_image_aspect_ratio, base_generation, background_removal
from utils.prompt_utils import execute_prompt, remove_color, remove_duplicates
from utils.tagger import modelLoad, analysis
import pygit2
import shutil
import os
path = os.getcwd()
cn_dir = f"{path}/controlnet"
tagger_dir = f"{path}/tagger"
lora_dir = f"{path}/lora"
os.makedirs(cn_dir, exist_ok=True)
os.makedirs(tagger_dir, exist_ok=True)
os.makedirs(lora_dir, exist_ok=True)
dl_cn_model(cn_dir)
dl_cn_config(cn_dir)
dl_tagger_model(tagger_dir)
dl_lora_model(lora_dir)
# グローバル変数でpipeを管理
pipe = None
current_lora_model = None
def load_model(lora_model):
global pipe, current_lora_model
# 既にロードされたpipeがあり、同じLoRAモデルの場合は再利用
if pipe is not None and current_lora_model == lora_model:
return pipe # キャッシュされたpipeを返す
# 新しいpipeの生成
dtype = torch.float16
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype)
controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
"cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=dtype
)
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
# LoRAモデルの設定
if lora_model == "とりにく風":
pipe.load_lora_weights(lora_dir, weight_name="tcd-animaginexl-3_1.safetensors", adapter_name="tcd-animaginexl-3_1")
pipe.load_lora_weights(lora_dir, weight_name="tori29umai_line.safetensors", adapter_name="tori29umai_line")
pipe.set_adapters(["tcd-animaginexl-3_1", "tori29umai_line"], adapter_weights=[1.0, 1.0])
pipe.fuse_lora()
elif lora_model == "プレーン":
pipe.load_lora_weights(lora_dir, weight_name="tcd-animaginexl-3_1.safetensors", adapter_name="tcd-animaginexl-3_1")
pipe.set_adapters(["tcd-animaginexl-3_1"], adapter_weights=[1.0])
pipe.fuse_lora()
# 現在のLoRAモデルを保存
current_lora_model = lora_model
return pipe
@spaces.GPU(duration=120)
def predict(lora_model, input_image_path, prompt, negative_prompt, controlnet_scale):
# pipeをグローバル変数から取得
pipe = load_model(lora_model)
# 画像読み込みとリサイズ
input_image = Image.open(input_image_path)
base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB")
resize_image = resize_image_aspect_ratio(input_image)
resize_base_image = resize_image_aspect_ratio(base_image)
generator = torch.manual_seed(0)
last_time = time.time()
# プロンプト生成
prompt = "masterpiece, best quality, monochrome, greyscale, lineart, white background, " + prompt
execute_tags = ["realistic", "nose", "asian"]
prompt = execute_prompt(execute_tags, prompt)
prompt = remove_duplicates(prompt)
prompt = remove_color(prompt)
print(prompt)
# 画像生成
output_image = pipe(
image=resize_base_image,
control_image=resize_image,
strength=1.0,
prompt=prompt,
negative_prompt=negative_prompt,
controlnet_conditioning_scale=float(controlnet_scale),
generator=generator,
num_inference_steps=4,
guidance_scale=0,
eta=0.3,
).images[0]
print(f"Time taken: {time.time() - last_time}")
output_image = output_image.resize(input_image.size, Image.LANCZOS)
return output_image
class Img2Img:
def __init__(self):
self.demo = self.layout()
self.tagger_model = None
self.input_image_path = None
self.bg_removed_image = None
def process_prompt_analysis(self, input_image_path):
if self.tagger_model is None:
self.tagger_model = modelLoad(tagger_dir)
tags = analysis(input_image_path, tagger_dir, self.tagger_model)
prompt = remove_color(tags)
execute_tags = ["realistic", "nose", "asian"]
prompt = execute_prompt(execute_tags, prompt)
prompt = remove_duplicates(prompt)
return prompt
def layout(self):
css = """
#intro{
max-width: 32rem;
text-align: center;
margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column():
# LoRAモデル選択ドロップダウン
self.lora_model = gr.Dropdown(label="Image Style", choices=["とりにく風", "プレーン"], value="とりにく風")
self.input_image_path = gr.Image(label="Input image", type='filepath')
self.bg_removed_image_path = gr.Image(label="Background Removed Image", type='filepath')
# 自動背景除去トリガー
self.input_image_path.change(
fn=self.auto_background_removal,
inputs=[self.input_image_path],
outputs=[self.bg_removed_image_path]
)
self.prompt = gr.Textbox(label="Prompt", lines=3)
self.negative_prompt = gr.Textbox(label="Negative prompt", lines=3, value="nose, asian, realistic, lowres, error, extra digit, fewer digits, cropped, worst quality,low quality, normal quality, jpeg artifacts, blurry")
prompt_analysis_button = gr.Button("Prompt analysis")
self.controlnet_scale = gr.Slider(minimum=0.4, maximum=1.0, value=0.55, step=0.01, label="Photo fidelity")
generate_button = gr.Button(value="Generate", variant="primary")
with gr.Column():
self.output_image = gr.Image(type="pil", label="Output image")
prompt_analysis_button.click(
fn=self.process_prompt_analysis,
inputs=[self.bg_removed_image_path],
outputs=self.prompt
)
generate_button.click(
fn=predict,
inputs=[self.lora_model, self.bg_removed_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
outputs=self.output_image
)
return demo
def auto_background_removal(self, input_image_path):
if input_image_path is not None:
bg_removed_image = background_removal(input_image_path)
return bg_removed_image
return None
img2img = Img2Img()
img2img.demo.queue()
img2img.demo.launch(share=True)
|