Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,064 Bytes
3542be4 d472855 3542be4 31877a7 25db6c9 00c86c6 0c23f7f 27419c1 3542be4 953a099 402fe71 25db6c9 31877a7 402fe71 2bd18a2 953a099 c7df0ab 25db6c9 953a099 c7df0ab 953a099 85949b2 6e57b9c 953a099 402fe71 25db6c9 402fe71 25db6c9 402fe71 0c23f7f 3515ae5 402fe71 25db6c9 402fe71 d472855 402fe71 d472855 402fe71 a50b44f 402fe71 6dbbd62 06642da 25db6c9 4948a0e 6dbbd62 776de3e 6dbbd62 27419c1 25db6c9 27419c1 402fe71 2b32e3d 5826348 25db6c9 4469e93 402fe71 72dca33 402fe71 72dca33 402fe71 72dca33 2b32e3d 72dca33 2b32e3d 25db6c9 f091221 6dbbd62 f091221 25db6c9 2b32e3d 2098e9c 25db6c9 2b32e3d 402fe71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import spaces
import gradio as gr
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, ControlNetModel, AutoencoderKL
from PIL import Image
import os
import time
from utils.dl_utils import dl_cn_model, dl_cn_config, dl_tagger_model, dl_lora_model
from utils.image_utils import resize_image_aspect_ratio, base_generation, canny_process
from utils.prompt_utils import execute_prompt, remove_color, remove_duplicates
from utils.tagger import modelLoad, analysis
path = os.getcwd()
cn_dir = f"{path}/controlnet"
tagger_dir = f"{path}/tagger"
lora_dir = f"{path}/lora"
os.makedirs(cn_dir, exist_ok=True)
os.makedirs(tagger_dir, exist_ok=True)
os.makedirs(lora_dir, exist_ok=True)
# dl_cn_model(cn_dir)
# dl_cn_config(cn_dir)
dl_tagger_model(tagger_dir)
dl_lora_model(lora_dir)
def load_model(lora_dir, cn_dir):
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
# controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-canny-sdxl-1.0",
torch_dtype=torch.float16
)
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
"cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
)
pipe.load_lora_weights(lora_dir, weight_name="sdxl_BWLine.safetensors")
pipe.set_adapters(["sdxl_BWLine"], adapter_weights=[1.4])
pipe.fuse_lora()
pipe = pipe.to(device)
return pipe
@spaces.GPU
def predict(input_image_path, canny_image, prompt, negative_prompt, controlnet_scale):
pipe = load_model(lora_dir, cn_dir)
input_image_pil = Image.open(input_image_path)
base_size = input_image_pil.size
resize_image = resize_image_aspect_ratio(input_image_pil)
white_base_pil = base_generation(resize_image.size, (255, 255, 255, 255)).convert("RGB")
canny_image = canny_image.resize(resize_image.size, Image.LANCZOS)
generator = torch.manual_seed(0)
last_time = time.time()
prompt = "masterpiece, best quality, monochrome, lineart, white background, " + prompt
execute_tags = ["sketch", "transparent background"]
prompt = execute_prompt(execute_tags, prompt)
prompt = remove_duplicates(prompt)
prompt = remove_color(prompt)
print(prompt)
output_image = pipe(
image=white_base_pil,
control_image=canny_image,
strength=1.0,
prompt=prompt,
negative_prompt = negative_prompt,
controlnet_conditioning_scale=float(controlnet_scale),
generator=generator,
num_inference_steps=30,
eta=1.0,
).images[0]
print(f"Time taken: {time.time() - last_time}")
output_image = output_image.resize(base_size, Image.LANCZOS)
return output_image
class Img2Img:
def __init__(self):
self.demo = self.layout()
self.post_filter = True
self.tagger_model = None
self.input_image_path = None
self.canny_image = None
def process_prompt_analysis(self, input_image_path):
if self.tagger_model is None:
self.tagger_model = modelLoad(tagger_dir)
tags = analysis(input_image_path, tagger_dir, self.tagger_model)
tags_list = tags
if self.post_filter:
tags_list = remove_color(tags)
return tags_list
def _make_canny(self, img_path, canny_threshold1, canny_threshold2):
threshold1 = int(canny_threshold1)
threshold2 = int(canny_threshold2)
return canny_process(img_path, threshold1, threshold2)
def layout(self):
css = """
#intro{
max-width: 32rem;
text-align: center;
margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column():
self.input_image_path = gr.Image(label="input_image", type='filepath')
self.canny_image = gr.Image(label="canny_image", type='pil')
with gr.Row():
canny_threshold1 = gr.Slider(minimum=0, value=20, maximum=253, show_label=False)
gr.HTML(value="<span>/</span>", show_label=False)
canny_threshold2 = gr.Slider(minimum=0, value=120, maximum=254, show_label=False)
canny_generate_button = gr.Button("canny_generate", interactive=False)
self.prompt = gr.Textbox(label="prompt", lines=3)
self.negative_prompt = gr.Textbox(label="negative_prompt", lines=3, value="lowres, error, extra digit, fewer digits, cropped, worst quality,low quality, normal quality, jpeg artifacts, blurry")
prompt_analysis_button = gr.Button("prompt_analysis")
self.controlnet_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, label="controlnet_scale")
generate_button = gr.Button("generate")
with gr.Column():
self.output_image = gr.Image(type="pil", label="output_image")
canny_generate_button.click(
self.process_prompt_analysis,
inputs=[self.input_image, canny_threshold1, canny_threshold2],
outputs=self.canny_image
)
prompt_analysis_button.click(
self.process_prompt_analysis,
inputs=[self.input_image_path],
outputs=self.prompt
)
generate_button.click(
fn=predict,
inputs=[self.input_image_path, self.canny_image, self.prompt, self.negative_prompt, self.controlnet_scale],
outputs=self.output_image
)
return demo
img2img = Img2Img()
img2img.demo.launch(share=True)
|