Spaces:
Running
on
Zero
Running
on
Zero
import random | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
from diffusers import AutoencoderKL | |
from mixture_tiling_sdxl import StableDiffusionXLTilingPipeline | |
MAX_SEED = np.iinfo(np.int32).max | |
SCHEDULERS = [ | |
"LMSDiscreteScheduler", | |
"DEISMultistepScheduler", | |
"HeunDiscreteScheduler", | |
"EulerAncestralDiscreteScheduler", | |
"EulerDiscreteScheduler", | |
"DPMSolverMultistepScheduler", | |
"DPMSolverMultistepScheduler-Karras", | |
"DPMSolverMultistepScheduler-Karras-SDE", | |
"UniPCMultistepScheduler" | |
] | |
# ๋ก๋ฉ: VAE ๋ฐ ํ์ผ๋ง ํ์ดํ๋ผ์ธ ๋ชจ๋ธ ์ด๊ธฐํ | |
vae = AutoencoderKL.from_pretrained( | |
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 | |
).to("cuda") | |
model_id = "stablediffusionapi/yamermix-v8-vae" | |
pipe = StableDiffusionXLTilingPipeline.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
vae=vae, | |
use_safetensors=False, # for yammermix | |
).to("cuda") | |
pipe.enable_model_cpu_offload() # VRAM์ด ์ ํ๋ ๊ฒฝ์ฐ ์ฌ์ฉ | |
pipe.enable_vae_tiling() | |
pipe.enable_vae_slicing() | |
#region functions | |
def select_scheduler(scheduler_name): | |
scheduler_parts = scheduler_name.split("-") | |
scheduler_class_name = scheduler_parts[0] | |
add_kwargs = { | |
"beta_start": 0.00085, | |
"beta_end": 0.012, | |
"beta_schedule": "scaled_linear", | |
"num_train_timesteps": 1000 | |
} | |
if len(scheduler_parts) > 1: | |
add_kwargs["use_karras_sigmas"] = True | |
if len(scheduler_parts) > 2: | |
add_kwargs["algorithm_type"] = "sde-dpmsolver++" | |
import diffusers | |
scheduler_cls = getattr(diffusers, scheduler_class_name) | |
scheduler = scheduler_cls.from_config(pipe.scheduler.config, **add_kwargs) | |
return scheduler | |
def predict(left_prompt, center_prompt, right_prompt, negative_prompt, left_gs, center_gs, right_gs, | |
overlap_pixels, steps, generation_seed, scheduler, tile_height, tile_width, target_height, target_width): | |
global pipe | |
# ์ค์ผ์ค๋ฌ ์ ํ | |
print(f"Using scheduler: {scheduler}...") | |
pipe.scheduler = select_scheduler(scheduler) | |
# ์๋ ์ค์ | |
generator = torch.Generator("cuda").manual_seed(generation_seed) | |
target_height = int(target_height) | |
target_width = int(target_width) | |
tile_height = int(tile_height) | |
tile_width = int(tile_width) | |
# ํ์ผ๋ง๋ ์ด๋ฏธ์ง ์์ฑ (์ข/์ค์/์ฐ ์์ญ) | |
image = pipe( | |
prompt=[[left_prompt, center_prompt, right_prompt]], | |
negative_prompt=negative_prompt, | |
tile_height=tile_height, | |
tile_width=tile_width, | |
tile_row_overlap=0, | |
tile_col_overlap=overlap_pixels, | |
guidance_scale_tiles=[[left_gs, center_gs, right_gs]], | |
height=target_height, | |
width=target_width, | |
generator=generator, | |
num_inference_steps=steps, | |
)["images"][0] | |
return image | |
def calc_tile_size(target_height, target_width, overlap_pixels, max_tile_width_size=1280): | |
num_cols = 3 | |
num_rows = 1 | |
min_tile_dimension = 8 | |
reduction_step = 8 | |
max_tile_height_size = 1024 | |
best_tile_width = 0 | |
best_tile_height = 0 | |
best_adjusted_target_width = 0 | |
best_adjusted_target_height = 0 | |
found_valid_solution = False | |
# ํ์ผ ๋๋น ๊ฒฐ์ | |
tile_width = max_tile_width_size | |
tile_height = max_tile_height_size | |
while tile_width >= min_tile_dimension: | |
horizontal_borders = num_cols - 1 | |
total_horizontal_overlap = overlap_pixels * horizontal_borders | |
adjusted_target_width = tile_width * num_cols - total_horizontal_overlap | |
vertical_borders = num_rows - 1 | |
total_vertical_overlap = overlap_pixels * vertical_borders | |
adjusted_target_height = tile_height * num_rows - total_vertical_overlap | |
if tile_width <= max_tile_width_size and adjusted_target_width <= target_width: | |
if adjusted_target_width > best_adjusted_target_width: | |
best_tile_width = tile_width | |
best_adjusted_target_width = adjusted_target_width | |
found_valid_solution = True | |
tile_width -= reduction_step | |
# ํ์ผ ๋์ด ๊ฒฐ์ | |
if found_valid_solution: | |
tile_width = best_tile_width | |
tile_height = max_tile_height_size | |
while tile_height >= min_tile_dimension: | |
horizontal_borders = num_cols - 1 | |
total_horizontal_overlap = overlap_pixels * horizontal_borders | |
adjusted_target_width = tile_width * num_cols - total_horizontal_overlap | |
vertical_borders = num_rows - 1 | |
total_vertical_overlap = overlap_pixels * vertical_borders | |
adjusted_target_height = tile_height * num_rows - total_vertical_overlap | |
if tile_height <= max_tile_height_size and adjusted_target_height <= target_height: | |
if adjusted_target_height > best_adjusted_target_height: | |
best_tile_height = tile_height | |
best_adjusted_target_height = adjusted_target_height | |
tile_height -= reduction_step | |
new_target_height = best_adjusted_target_height | |
new_target_width = best_adjusted_target_width | |
tile_width = best_tile_width | |
tile_height = best_tile_height | |
print("--- TILE SIZE CALCULATED VALUES ---") | |
print(f"Requested Overlap Pixels: {overlap_pixels}") | |
print(f"Tile Height (max {max_tile_height_size}, divisible by 8): {tile_height}") | |
print(f"Tile Width (max {max_tile_width_size}, divisible by 8): {tile_width}") | |
print(f"Columns: {num_cols} | Rows: {num_rows}") | |
print(f"Original Target: {target_height} x {target_width}") | |
print(f"Adjusted Target: {new_target_height} x {new_target_width}\n") | |
return new_target_height, new_target_width, tile_height, tile_width | |
def do_calc_tile(target_height, target_width, overlap_pixels, max_tile_size): | |
new_target_height, new_target_width, tile_height, tile_width = calc_tile_size(target_height, target_width, overlap_pixels, max_tile_size) | |
return gr.update(value=tile_height), gr.update(value=tile_width), gr.update(value=new_target_height), gr.update(value=new_target_width) | |
def clear_result(): | |
return gr.update(value=None) | |
def run_for_examples(left_prompt, center_prompt, right_prompt, negative_prompt, left_gs, center_gs, right_gs, | |
overlap_pixels, steps, generation_seed, scheduler, tile_height, tile_width, target_height, target_width, max_tile_width): | |
return predict(left_prompt, center_prompt, right_prompt, negative_prompt, left_gs, center_gs, right_gs, | |
overlap_pixels, steps, generation_seed, scheduler, tile_height, tile_width, target_height, target_width) | |
def randomize_seed_fn(generation_seed: int, randomize_seed: bool) -> int: | |
if randomize_seed: | |
generation_seed = random.randint(0, MAX_SEED) | |
return generation_seed | |
#endregion | |
# ๊ฐ์ ๋ CSS: ๋ฐฐ๊ฒฝ์, ์ฌ๋ฐฑ, ๊ทธ๋ฆผ์ ๋ฑ์ ์ถ๊ฐํ์ฌ UI๋ฅผ ๊น๋ํ๊ฒ ํํ | |
css = """ | |
body { background-color: #f0f2f5; } | |
.gradio-container { | |
background: #ffffff; | |
border-radius: 15px; | |
padding: 20px; | |
box-shadow: 0 4px 10px rgba(0,0,0,0.1); | |
} | |
.gradio-container h1 { color: #333333; } | |
.fillable { width: 95% !important; max-width: unset !important; } | |
""" | |
# ์ ๋ชฉ ๋ฐ ๊ฐ๋จํ ์ค๋ช | |
title = """ | |
<h1 align="center" style="margin-bottom: 0.2em;">Mixture-of-Diffusers for SDXL Tiling Pipeline ๐ค</h1> | |
<p align="center" style="font-size:1.1em; color:#555;"> | |
์ข/์ค์/์ฐ ์ธ ์์ญ ๊ฐ๊ฐ์ ๋ค๋ฅธ ํ๋กฌํํธ๋ฅผ ์ ์ฉํ์ฌ ํ์ผ๋ง ์ด๋ฏธ์ง๋ฅผ ์์ฑํฉ๋๋ค.<br> | |
์๋์ ์์ ๋ ์ง์ ํ๋กฌํํธ๋ฅผ ์ ๋ ฅํ์ฌ ์ฐฝ์์ ์ธ ์ด๋ฏธ์ง๋ฅผ ๋ง๋ค์ด๋ณด์ธ์. | |
</p> | |
""" | |
with gr.Blocks(css=css, title="SDXL Tiling Pipeline") as app: | |
gr.Markdown(title) | |
with gr.Row(): | |
# ์ข/์ค์/์ฐ ์์ญ ํ๋กฌํํธ ๋ฐ ๊ฒฐ๊ณผ ์ถ๋ ฅ ์์ญ | |
with gr.Column(scale=7): | |
generate_button = gr.Button("Generate", elem_id="generate_btn") | |
with gr.Row(): | |
with gr.Column(variant="panel"): | |
gr.Markdown("### Left Region") | |
left_prompt = gr.Textbox(lines=4, placeholder="์: ์ธ์ฐฝํ ์ฒ๊ณผ ํ์ด์ด ๋น์ถ๋ ๋๋ฌด...", label="Left Prompt") | |
left_gs = gr.Slider(minimum=0, maximum=15, value=7, step=1, label="Left CFG Scale") | |
with gr.Column(variant="panel"): | |
gr.Markdown("### Center Region") | |
center_prompt = gr.Textbox(lines=4, placeholder="์: ์์ํ ํธ์์ ๋ฐ์ง์ด๋ ์๋ฉด...", label="Center Prompt") | |
center_gs = gr.Slider(minimum=0, maximum=15, value=7, step=1, label="Center CFG Scale") | |
with gr.Column(variant="panel"): | |
gr.Markdown("### Right Region") | |
right_prompt = gr.Textbox(lines=4, placeholder="์: ์ ์ฅํ ์ฐ๋งฅ๊ณผ ํ๋์ ๊ฐ๋ฅด๋ ๊ตฌ๋ฆ...", label="Right Prompt") | |
right_gs = gr.Slider(minimum=0, maximum=15, value=7, step=1, label="Right CFG Scale") | |
with gr.Row(): | |
negative_prompt = gr.Textbox( | |
lines=2, | |
label="Negative Prompt", | |
placeholder="์: blurry, low resolution, artifacts, poor details", | |
value="blurry, low resolution, artifacts, poor details" | |
) | |
with gr.Row(): | |
result = gr.Image(label="Generated Image", show_label=True, format="png", interactive=False, scale=1) | |
# ์ฐ์ธก ์ฌ์ด๋๋ฐ: ํ๋ผ๋ฏธํฐ ๋ฐ ํ์ผ ํฌ๊ธฐ ๊ณ์ฐ | |
with gr.Sidebar(label="Parameters", open=True): | |
gr.Markdown("### Generation Parameters") | |
with gr.Row(): | |
height = gr.Slider(label="Target Height", value=1024, step=8, minimum=512, maximum=1024) | |
width = gr.Slider(label="Target Width", value=1280, step=8, minimum=512, maximum=3840) | |
overlap = gr.Slider(minimum=0, maximum=512, value=128, step=8, label="Tile Overlap") | |
max_tile_size = gr.Dropdown(label="Max Tile Size", choices=[1024, 1280], value=1280) | |
calc_tile = gr.Button("Calculate Tile Size") | |
with gr.Row(): | |
tile_height = gr.Textbox(label="Tile Height", value=1024, interactive=False) | |
tile_width = gr.Textbox(label="Tile Width", value=1024, interactive=False) | |
with gr.Row(): | |
new_target_height = gr.Textbox(label="New Image Height", value=1024, interactive=False) | |
new_target_width = gr.Textbox(label="New Image Width", value=1280, interactive=False) | |
with gr.Row(): | |
steps = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Inference Steps") | |
generation_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=False) | |
with gr.Row(): | |
scheduler = gr.Dropdown(label="Scheduler", choices=SCHEDULERS, value=SCHEDULERS[0]) | |
# ์์ ํญ: ๋ชฉ์ ์ ๋ง๋ ์์ ํ๋กฌํํธ ์ ๊ณต | |
with gr.Row(): | |
gr.Markdown("### Example Prompts") | |
gr.Examples( | |
examples=[ | |
# Example 1: Serene Nature | |
[ | |
"Lush green forest with sun rays filtering through the canopy", | |
"Crystal clear lake reflecting a vibrant sky", | |
"Majestic mountains with snowy peaks in the distance", | |
"blurry, low resolution, artifacts, poor details", | |
7, 7, 7, | |
128, | |
30, | |
123456789, | |
"UniPCMultistepScheduler", | |
1024, 1280, | |
1024, 1920, | |
1280 | |
], | |
# Example 2: Futuristic Cityscape | |
[ | |
"Vibrant city street with neon signs and bustling crowds", | |
"Sleek modern skyscrapers with digital billboards", | |
"High-speed maglev train gliding over a futuristic urban landscape", | |
"blurry, poorly rendered, low quality, disfigured", | |
8, 8, 8, | |
100, | |
35, | |
987654321, | |
"EulerDiscreteScheduler", | |
1024, 1280, | |
1024, 1920, | |
1280 | |
], | |
# Example 3: Abstract Art | |
[ | |
"Vibrant abstract strokes with fluid, swirling patterns in cool tones", | |
"Interlocking geometric shapes bursting with color and texture", | |
"Dynamic composition of splattered ink with smooth gradients", | |
"text, watermark, signature, distorted", | |
6, 6, 6, | |
80, | |
25, | |
192837465, | |
"DPMSolverMultistepScheduler-Karras", | |
1024, 1280, | |
1024, 1920, | |
1280 | |
], | |
# Example 4: Fantasy Landscape | |
[ | |
"Enchanted forest with glowing bioluminescent plants and mystical fog", | |
"Ancient castle with towering spires bathed in moonlight", | |
"Majestic dragon soaring above a starry night sky", | |
"low quality, artifact, deformed, sketchy", | |
9, 9, 9, | |
150, | |
40, | |
1029384756, | |
"DPMSolverMultistepScheduler-Karras-SDE", | |
1024, 1280, | |
1024, 1920, | |
1280 | |
] | |
], | |
inputs=[left_prompt, center_prompt, right_prompt, negative_prompt, | |
left_gs, center_gs, right_gs, overlap, steps, generation_seed, | |
scheduler, tile_height, tile_width, height, width, max_tile_size], | |
fn=run_for_examples, | |
outputs=result, | |
cache_examples=True | |
) | |
# ์ด๋ฒคํธ ์ฐ๊ฒฐ: ํ์ผ ์ฌ์ด์ฆ ๊ณ์ฐ ๋ฐ ์ด๋ฏธ์ง ์์ฑ | |
event_calc_tile_size = { | |
"fn": do_calc_tile, | |
"inputs": [height, width, overlap, max_tile_size], | |
"outputs": [tile_height, tile_width, new_target_height, new_target_width] | |
} | |
calc_tile.click(**event_calc_tile_size) | |
generate_button.click( | |
fn=clear_result, | |
inputs=None, | |
outputs=result, | |
).then(**event_calc_tile_size).then( | |
fn=randomize_seed_fn, | |
inputs=[generation_seed, randomize_seed], | |
outputs=generation_seed, | |
queue=False, | |
api_name=False, | |
).then( | |
fn=predict, | |
inputs=[left_prompt, center_prompt, right_prompt, negative_prompt, | |
left_gs, center_gs, right_gs, overlap, steps, generation_seed, | |
scheduler, tile_height, tile_width, new_target_height, new_target_width], | |
outputs=result, | |
) | |
app.launch(share=False) | |