Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL, TCDScheduler | |
from PIL import Image | |
import os | |
import time | |
from utils.dl_utils import dl_cn_model, dl_cn_config, dl_tagger_model, dl_lora_model | |
from utils.image_utils import resize_image_aspect_ratio, base_generation, background_removal | |
from utils.prompt_utils import execute_prompt, remove_color, remove_duplicates | |
from utils.tagger import modelLoad, analysis | |
import pygit2 | |
import shutil | |
import os | |
path = os.getcwd() | |
cn_dir = f"{path}/controlnet" | |
tagger_dir = f"{path}/tagger" | |
lora_dir = f"{path}/lora" | |
os.makedirs(cn_dir, exist_ok=True) | |
os.makedirs(tagger_dir, exist_ok=True) | |
os.makedirs(lora_dir, exist_ok=True) | |
dl_cn_model(cn_dir) | |
dl_cn_config(cn_dir) | |
dl_tagger_model(tagger_dir) | |
dl_lora_model(lora_dir) | |
# グローバル変数でpipeを管理 | |
pipe = None | |
current_lora_model = None | |
def load_model(lora_model): | |
global pipe, current_lora_model | |
# 既にロードされたpipeがあり、同じLoRAモデルの場合は再利用 | |
if pipe is not None and current_lora_model == lora_model: | |
return pipe # キャッシュされたpipeを返す | |
# 新しいpipeの生成 | |
dtype = torch.float16 | |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype) | |
controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True) | |
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( | |
"cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=dtype | |
) | |
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config) | |
pipe.enable_model_cpu_offload() | |
# LoRAモデルの設定 | |
if lora_model == "とりにく風": | |
pipe.load_lora_weights(lora_dir, weight_name="tcd-animaginexl-3_1.safetensors", adapter_name="tcd-animaginexl-3_1") | |
pipe.load_lora_weights(lora_dir, weight_name="tori29umai_line.safetensors", adapter_name="tori29umai_line") | |
pipe.set_adapters(["tcd-animaginexl-3_1", "tori29umai_line"], adapter_weights=[1.0, 1.0]) | |
pipe.fuse_lora() | |
elif lora_model == "プレーン": | |
pipe.load_lora_weights(lora_dir, weight_name="tcd-animaginexl-3_1.safetensors", adapter_name="tcd-animaginexl-3_1") | |
pipe.set_adapters(["tcd-animaginexl-3_1"], adapter_weights=[1.0]) | |
pipe.fuse_lora() | |
# 現在のLoRAモデルを保存 | |
current_lora_model = lora_model | |
return pipe | |
def predict(lora_model, input_image_path, prompt, negative_prompt, controlnet_scale): | |
# pipeをグローバル変数から取得 | |
pipe = load_model(lora_model) | |
# 画像読み込みとリサイズ | |
input_image = Image.open(input_image_path) | |
base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB") | |
resize_image = resize_image_aspect_ratio(input_image) | |
resize_base_image = resize_image_aspect_ratio(base_image) | |
generator = torch.manual_seed(0) | |
last_time = time.time() | |
# プロンプト生成 | |
prompt = "masterpiece, best quality, monochrome, greyscale, lineart, white background, " + prompt | |
execute_tags = ["realistic", "nose", "asian"] | |
prompt = execute_prompt(execute_tags, prompt) | |
prompt = remove_duplicates(prompt) | |
prompt = remove_color(prompt) | |
print(prompt) | |
# 画像生成 | |
output_image = pipe( | |
image=resize_base_image, | |
control_image=resize_image, | |
strength=1.0, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
controlnet_conditioning_scale=float(controlnet_scale), | |
generator=generator, | |
num_inference_steps=4, | |
guidance_scale=0, | |
eta=0.3, | |
).images[0] | |
print(f"Time taken: {time.time() - last_time}") | |
output_image = output_image.resize(input_image.size, Image.LANCZOS) | |
return output_image | |
class Img2Img: | |
def __init__(self): | |
self.demo = self.layout() | |
self.tagger_model = None | |
self.input_image_path = None | |
self.bg_removed_image = None | |
def process_prompt_analysis(self, input_image_path): | |
if self.tagger_model is None: | |
self.tagger_model = modelLoad(tagger_dir) | |
tags = analysis(input_image_path, tagger_dir, self.tagger_model) | |
prompt = remove_color(tags) | |
execute_tags = ["realistic", "nose", "asian"] | |
prompt = execute_prompt(execute_tags, prompt) | |
prompt = remove_duplicates(prompt) | |
return prompt | |
def layout(self): | |
css = """ | |
#intro{ | |
max-width: 32rem; | |
text-align: center; | |
margin: 0 auto; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Row(): | |
with gr.Column(): | |
# LoRAモデル選択ドロップダウン | |
self.lora_model = gr.Dropdown(label="Image Style", choices=["とりにく風", "プレーン"], value="とりにく風") | |
self.input_image_path = gr.Image(label="Input image", type='filepath') | |
self.bg_removed_image_path = gr.Image(label="Background Removed Image", type='filepath') | |
# 自動背景除去トリガー | |
self.input_image_path.change( | |
fn=self.auto_background_removal, | |
inputs=[self.input_image_path], | |
outputs=[self.bg_removed_image_path] | |
) | |
self.prompt = gr.Textbox(label="Prompt", lines=3) | |
self.negative_prompt = gr.Textbox(label="Negative prompt", lines=3, value="nose, asian, realistic, lowres, error, extra digit, fewer digits, cropped, worst quality,low quality, normal quality, jpeg artifacts, blurry") | |
prompt_analysis_button = gr.Button("Prompt analysis") | |
self.controlnet_scale = gr.Slider(minimum=0.4, maximum=1.0, value=0.55, step=0.01, label="Photo fidelity") | |
generate_button = gr.Button(value="Generate", variant="primary") | |
with gr.Column(): | |
self.output_image = gr.Image(type="pil", label="Output image") | |
prompt_analysis_button.click( | |
fn=self.process_prompt_analysis, | |
inputs=[self.bg_removed_image_path], | |
outputs=self.prompt | |
) | |
generate_button.click( | |
fn=predict, | |
inputs=[self.lora_model, self.bg_removed_image_path, self.prompt, self.negative_prompt, self.controlnet_scale], | |
outputs=self.output_image | |
) | |
return demo | |
def auto_background_removal(self, input_image_path): | |
if input_image_path is not None: | |
bg_removed_image = background_removal(input_image_path) | |
return bg_removed_image | |
return None | |
img2img = Img2Img() | |
img2img.demo.queue() | |
img2img.demo.launch(share=True) | |