Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,491 Bytes
369edb1 a50b44f 402fe71 6dbbd62 06642da 543d5ec 4f189ae a9deb63 b2790c5 a9deb63 b2790c5 a9deb63 b2790c5 a9deb63 b2790c5 a9deb63 b2790c5 a9deb63 b2790c5 5110943 b2790c5 4948a0e b2790c5 27419c1 402fe71 2b32e3d b2790c5 63bfd3a 906508c 543d5ec 906508c 543d5ec 63bfd3a ef92e8a 63bfd3a 543d5ec 63bfd3a 543d5ec 2b32e3d 63bfd3a 2b32e3d f091221 543d5ec 906508c 543d5ec f091221 2b32e3d b2790c5 a9deb63 2b32e3d 402fe71 543d5ec 402fe71 edcfd06 402fe71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import spaces
import gradio as gr
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
from PIL import Image
import os
import time
from utils.dl_utils import dl_cn_model, dl_cn_config, dl_tagger_model, dl_lora_model
from utils.image_utils import resize_image_aspect_ratio, base_generation, background_removal
from utils.prompt_utils import execute_prompt, remove_color, remove_duplicates
from utils.tagger import modelLoad, analysis
path = os.getcwd()
cn_dir = f"{path}/controlnet"
tagger_dir = f"{path}/tagger"
lora_dir = f"{path}/lora"
os.makedirs(cn_dir, exist_ok=True)
os.makedirs(tagger_dir, exist_ok=True)
os.makedirs(lora_dir, exist_ok=True)
dl_cn_model(cn_dir)
dl_cn_config(cn_dir)
dl_tagger_model(tagger_dir)
dl_lora_model(lora_dir)
class Img2Img:
def __init__(self):
self.demo = self.layout()
self.tagger_model = None
self.input_image_path = None
self.bg_removed_image = None
def load_model(self, lora_model):
dtype = torch.float16
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
"cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
# LoRAモデルの設定
if lora_model == "とりにく風":
pipe.load_lora_weights(lora_dir, weight_name="tori29umai_line.safetensors")
elif lora_model == "少女漫画風":
pipe.load_lora_weights(lora_dir, weight_name="syoujomannga_line.safetensors")
elif lora_model == "劇画調風":
pipe.load_lora_weights(lora_dir, weight_name="gekiga_line.safetensors")
elif lora_model == "プレーン":
pass # プレーンの場合はLoRAを読み込まない
return pipe
@spaces.GPU(duration=120)
def predict(self, lora_model, input_image_path, prompt, negative_prompt, controlnet_scale):
pipe = self.load_model(lora_model)
input_image = Image.open(input_image_path)
base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB")
resize_image = resize_image_aspect_ratio(input_image)
resize_base_image = resize_image_aspect_ratio(base_image)
generator = torch.manual_seed(0)
last_time = time.time()
# プロンプト生成
prompt = "masterpiece, best quality, monochrome, greyscale, lineart, white background, star-shaped pupils, " + prompt
execute_tags = ["realistic", "nose", "asian"]
prompt = execute_prompt(execute_tags, prompt)
prompt = remove_duplicates(prompt)
prompt = remove_color(prompt)
print(prompt)
# 画像生成
output_image = pipe(
image=resize_base_image,
control_image=resize_image,
strength=1.0,
prompt=prompt,
negative_prompt=negative_prompt,
controlnet_conditioning_scale=float(controlnet_scale),
generator=generator,
num_inference_steps=30,
eta=1.0,
).images[0]
print(f"Time taken: {time.time() - last_time}")
output_image = output_image.resize(input_image.size, Image.LANCZOS)
return output_image
def process_prompt_analysis(self, input_image_path):
if self.tagger_model is None:
self.tagger_model = modelLoad(tagger_dir)
tags = analysis(input_image_path, tagger_dir, self.tagger_model)
prompt = remove_color(tags)
execute_tags = ["realistic", "nose", "asian"]
prompt = execute_prompt(execute_tags, prompt)
prompt = remove_duplicates(prompt)
return prompt
def layout(self):
css = """
#intro{
max-width: 32rem;
text-align: center;
margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column():
# LoRAモデル選択ドロップダウン
self.lora_model = gr.Dropdown(label="Image Style", choices=["プレーン", "とりにく風", "少女漫画風", "劇画調風"], value="プレーン")
self.input_image_path = gr.Image(label="Input image", type='filepath')
self.bg_removed_image_path = gr.Image(label="Background Removed Image", type='filepath')
# 自動背景除去トリガー
self.input_image_path.change(
fn=self.auto_background_removal,
inputs=[self.input_image_path],
outputs=[self.bg_removed_image_path]
)
self.prompt = gr.Textbox(label="Prompt", lines=3)
self.negative_prompt = gr.Textbox(label="Negative prompt", lines=3, value="nose, asian, realistic, lowres, error, extra digit, fewer digits, cropped, worst quality,low quality, normal quality, jpeg artifacts, blurry")
prompt_analysis_button = gr.Button("Prompt analysis")
self.controlnet_scale = gr.Slider(minimum=0.4, maximum=1.0, value=0.55, step=0.01, label="Photo fidelity")
generate_button = gr.Button(value="Generate", variant="primary")
with gr.Column():
self.output_image = gr.Image(type="pil", label="Output image")
prompt_analysis_button.click(
fn=self.process_prompt_analysis,
inputs=[self.bg_removed_image_path],
outputs=self.prompt
)
generate_button.click(
fn=self.predict,
inputs=[self.lora_model, self.bg_removed_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
outputs=self.output_image
)
return demo
def auto_background_removal(self, input_image_path):
if input_image_path is not None:
bg_removed_image = background_removal(input_image_path)
return bg_removed_image
return None
img2img = Img2Img()
img2img.demo.queue()
img2img.demo.launch(share=True)
|