Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
os.system("pip install -r requirements.txt") | |
from diffsynth import download_models | |
download_models(["Kolors", "FLUX.1-dev"]) | |
import gradio as gr | |
from diffsynth import ModelManager, SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline | |
import os, torch | |
from PIL import Image | |
import numpy as np | |
config = { | |
"model_config": { | |
"Stable Diffusion": { | |
"model_folder": "models/stable_diffusion", | |
"pipeline_class": SDImagePipeline, | |
"default_parameters": { | |
"cfg_scale": 7.0, | |
"height": 512, | |
"width": 512, | |
} | |
}, | |
"Stable Diffusion XL": { | |
"model_folder": "models/stable_diffusion_xl", | |
"pipeline_class": SDXLImagePipeline, | |
"default_parameters": { | |
"cfg_scale": 7.0, | |
} | |
}, | |
"Stable Diffusion 3": { | |
"model_folder": "models/stable_diffusion_3", | |
"pipeline_class": SD3ImagePipeline, | |
"default_parameters": { | |
"cfg_scale": 7.0, | |
} | |
}, | |
"Stable Diffusion XL Turbo": { | |
"model_folder": "models/stable_diffusion_xl_turbo", | |
"pipeline_class": SDXLImagePipeline, | |
"default_parameters": { | |
"negative_prompt": "", | |
"cfg_scale": 1.0, | |
"num_inference_steps": 1, | |
"height": 512, | |
"width": 512, | |
} | |
}, | |
"Kolors": { | |
"model_folder": "models/kolors", | |
"pipeline_class": SDXLImagePipeline, | |
"default_parameters": { | |
"cfg_scale": 7.0, | |
} | |
}, | |
"HunyuanDiT": { | |
"model_folder": "models/HunyuanDiT", | |
"pipeline_class": HunyuanDiTImagePipeline, | |
"default_parameters": { | |
"cfg_scale": 7.0, | |
} | |
}, | |
"FLUX": { | |
"model_folder": "models/FLUX", | |
"pipeline_class": FluxImagePipeline, | |
"default_parameters": { | |
"cfg_scale": 1.0, | |
} | |
} | |
}, | |
"max_num_painter_layers": 3, | |
"max_num_model_cache": 2, | |
} | |
def load_model_list(model_type): | |
if model_type is None: | |
return [] | |
folder = config["model_config"][model_type]["model_folder"] | |
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")] | |
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]: | |
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))] | |
file_list = sorted(file_list) | |
return file_list | |
def load_model(model_type, model_path): | |
global model_dict | |
model_key = f"{model_type}:{model_path}" | |
if model_key in model_dict: | |
return model_dict[model_key] | |
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path) | |
model_manager = ModelManager() | |
if model_type == "HunyuanDiT": | |
model_manager.load_models([ | |
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"), | |
os.path.join(model_path, "mt5/pytorch_model.bin"), | |
os.path.join(model_path, "model/pytorch_model_ema.pt"), | |
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"), | |
]) | |
elif model_type == "Kolors": | |
model_manager.load_models([ | |
os.path.join(model_path, "text_encoder"), | |
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"), | |
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"), | |
]) | |
elif model_type == "FLUX": | |
model_manager.torch_dtype = torch.bfloat16 | |
file_list = [ | |
os.path.join(model_path, "text_encoder/model.safetensors"), | |
os.path.join(model_path, "text_encoder_2"), | |
] | |
for file_name in os.listdir(model_path): | |
if file_name.endswith(".safetensors"): | |
file_list.append(os.path.join(model_path, file_name)) | |
model_manager.load_models(file_list) | |
else: | |
model_manager.load_model(model_path) | |
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager) | |
while len(model_dict) + 1 > config["max_num_model_cache"]: | |
key = next(iter(model_dict.keys())) | |
model_manager_to_release, _ = model_dict[key] | |
model_manager_to_release.to("cpu") | |
del model_dict[key] | |
torch.cuda.empty_cache() | |
model_dict[model_key] = model_manager, pipe | |
return model_manager, pipe | |
model_dict = {} | |
with gr.Blocks() as app: | |
gr.Markdown("# DiffSynth-Studio Painter") | |
with gr.Row(): | |
with gr.Column(scale=382, min_width=100): | |
with gr.Accordion(label="Model"): | |
model_type = gr.Dropdown(choices=["Kolors", "FLUX"], label="Model type", value="Kolors") | |
model_path = gr.Dropdown(choices=["Kolors"], interactive=True, label="Model path", value="Kolors") | |
def model_type_to_model_path(model_type): | |
return gr.Dropdown(choices=load_model_list(model_type)) | |
with gr.Accordion(label="Prompt"): | |
prompt = gr.Textbox(label="Prompt", lines=3) | |
negative_prompt = gr.Textbox(label="Negative prompt", lines=1) | |
cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale") | |
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale (only for FLUX)") | |
with gr.Accordion(label="Image"): | |
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps") | |
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height") | |
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width") | |
with gr.Column(): | |
use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed") | |
seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False) | |
def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width): | |
load_model(model_type, model_path) | |
cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale) | |
embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance) | |
num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps) | |
height = config["model_config"][model_type]["default_parameters"].get("height", height) | |
width = config["model_config"][model_type]["default_parameters"].get("width", width) | |
return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width | |
with gr.Column(scale=618, min_width=100): | |
with gr.Accordion(label="Painter"): | |
enable_local_prompt_list = [] | |
local_prompt_list = [] | |
mask_scale_list = [] | |
canvas_list = [] | |
for painter_layer_id in range(config["max_num_painter_layers"]): | |
with gr.Tab(label=f"Layer {painter_layer_id}"): | |
enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}") | |
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}") | |
mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}") | |
canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA", | |
brush=gr.Brush(default_size=100, default_color="#000000", colors=["#000000"]), | |
label="Painter", key=f"canvas_{painter_layer_id}") | |
def resize_canvas(height, width, canvas): | |
h, w = canvas["background"].shape[:2] | |
if h != height or width != w: | |
return np.ones((height, width, 3), dtype=np.uint8) * 255 | |
else: | |
return canvas | |
enable_local_prompt_list.append(enable_local_prompt) | |
local_prompt_list.append(local_prompt) | |
mask_scale_list.append(mask_scale) | |
canvas_list.append(canvas) | |
with gr.Accordion(label="Results"): | |
run_button = gr.Button(value="Generate", variant="primary") | |
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil") | |
output_to_painter_button = gr.Button(value="Set as painter's background") | |
painter_background = gr.State(None) | |
input_background = gr.State(None) | |
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()): | |
_, pipe = load_model(model_type, model_path) | |
input_params = { | |
"prompt": prompt, | |
"negative_prompt": negative_prompt, | |
"cfg_scale": cfg_scale, | |
"num_inference_steps": num_inference_steps, | |
"height": height, | |
"width": width, | |
"progress_bar_cmd": progress.tqdm, | |
} | |
if isinstance(pipe, FluxImagePipeline): | |
input_params["embedded_guidance"] = embedded_guidance | |
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = ( | |
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]], | |
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]], | |
args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]], | |
args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]] | |
) | |
local_prompts, masks, mask_scales = [], [], [] | |
for enable_local_prompt, local_prompt, mask_scale, canvas in zip( | |
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list | |
): | |
if enable_local_prompt: | |
local_prompts.append(local_prompt) | |
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB")) | |
mask_scales.append(mask_scale) | |
input_params.update({ | |
"local_prompts": local_prompts, | |
"masks": masks, | |
"mask_scales": mask_scales, | |
}) | |
torch.manual_seed(seed) | |
image = pipe(**input_params) | |
return image | |
def send_output_to_painter_background(output_image, *canvas_list): | |
for canvas in canvas_list: | |
h, w = canvas["background"].shape[:2] | |
canvas["background"] = output_image.resize((w, h)) | |
return tuple(canvas_list) | |
app.launch() | |