|
import random |
|
import re |
|
import gradio as gr |
|
import numpy as np |
|
import spaces |
|
import torch |
|
from diffusers import AutoencoderKL |
|
from mixture_tiling_sdxl import StableDiffusionXLTilingPipeline |
|
from transformers import pipeline |
|
|
|
|
|
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") |
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
SCHEDULERS = [ |
|
"LMSDiscreteScheduler", |
|
"DEISMultistepScheduler", |
|
"HeunDiscreteScheduler", |
|
"EulerAncestralDiscreteScheduler", |
|
"EulerDiscreteScheduler", |
|
"DPMSolverMultistepScheduler", |
|
"DPMSolverMultistepScheduler-Karras", |
|
"DPMSolverMultistepScheduler-Karras-SDE", |
|
"UniPCMultistepScheduler" |
|
] |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained( |
|
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 |
|
).to("cuda") |
|
|
|
model_id = "stablediffusionapi/yamermix-v8-vae" |
|
pipe = StableDiffusionXLTilingPipeline.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float16, |
|
vae=vae, |
|
use_safetensors=False, |
|
).to("cuda") |
|
|
|
pipe.enable_model_cpu_offload() |
|
pipe.enable_vae_tiling() |
|
pipe.enable_vae_slicing() |
|
|
|
|
|
def translate_if_needed(text: str) -> str: |
|
"""ํ
์คํธ์ ํ๊ธ์ด ํฌํจ๋ ๊ฒฝ์ฐ ์์ด๋ก ๋ฒ์ญํ์ฌ ๋ฐํ.""" |
|
if re.search(r'[\uac00-\ud7a3]', text): |
|
|
|
translated = translator(text)[0]["translation_text"] |
|
return translated |
|
return text |
|
|
|
def select_scheduler(scheduler_name): |
|
scheduler_parts = scheduler_name.split("-") |
|
scheduler_class_name = scheduler_parts[0] |
|
add_kwargs = { |
|
"beta_start": 0.00085, |
|
"beta_end": 0.012, |
|
"beta_schedule": "scaled_linear", |
|
"num_train_timesteps": 1000 |
|
} |
|
if len(scheduler_parts) > 1: |
|
add_kwargs["use_karras_sigmas"] = True |
|
if len(scheduler_parts) > 2: |
|
add_kwargs["algorithm_type"] = "sde-dpmsolver++" |
|
import diffusers |
|
scheduler_cls = getattr(diffusers, scheduler_class_name) |
|
scheduler = scheduler_cls.from_config(pipe.scheduler.config, **add_kwargs) |
|
return scheduler |
|
|
|
@spaces.GPU |
|
def predict(left_prompt, center_prompt, right_prompt, negative_prompt, left_gs, center_gs, right_gs, |
|
overlap_pixels, steps, generation_seed, scheduler, tile_height, tile_width, target_height, target_width): |
|
global pipe |
|
print(f"Using scheduler: {scheduler}...") |
|
pipe.scheduler = select_scheduler(scheduler) |
|
generator = torch.Generator("cuda").manual_seed(generation_seed) |
|
|
|
|
|
left_prompt = translate_if_needed(left_prompt) |
|
center_prompt = translate_if_needed(center_prompt) |
|
right_prompt = translate_if_needed(right_prompt) |
|
|
|
target_height = int(target_height) |
|
target_width = int(target_width) |
|
tile_height = int(tile_height) |
|
tile_width = int(tile_width) |
|
|
|
image = pipe( |
|
prompt=[[left_prompt, center_prompt, right_prompt]], |
|
negative_prompt=negative_prompt, |
|
tile_height=tile_height, |
|
tile_width=tile_width, |
|
tile_row_overlap=0, |
|
tile_col_overlap=overlap_pixels, |
|
guidance_scale_tiles=[[left_gs, center_gs, right_gs]], |
|
height=target_height, |
|
width=target_width, |
|
generator=generator, |
|
num_inference_steps=steps, |
|
)["images"][0] |
|
return image |
|
|
|
def calc_tile_size(target_height, target_width, overlap_pixels, max_tile_width_size=1280): |
|
num_cols = 3 |
|
num_rows = 1 |
|
min_tile_dimension = 8 |
|
reduction_step = 8 |
|
max_tile_height_size = 1024 |
|
best_tile_width = 0 |
|
best_tile_height = 0 |
|
best_adjusted_target_width = 0 |
|
best_adjusted_target_height = 0 |
|
found_valid_solution = False |
|
|
|
tile_width = max_tile_width_size |
|
tile_height = max_tile_height_size |
|
while tile_width >= min_tile_dimension: |
|
horizontal_borders = num_cols - 1 |
|
total_horizontal_overlap = overlap_pixels * horizontal_borders |
|
adjusted_target_width = tile_width * num_cols - total_horizontal_overlap |
|
|
|
vertical_borders = num_rows - 1 |
|
total_vertical_overlap = overlap_pixels * vertical_borders |
|
adjusted_target_height = tile_height * num_rows - total_vertical_overlap |
|
|
|
if tile_width <= max_tile_width_size and adjusted_target_width <= target_width: |
|
if adjusted_target_width > best_adjusted_target_width: |
|
best_tile_width = tile_width |
|
best_adjusted_target_width = adjusted_target_width |
|
found_valid_solution = True |
|
tile_width -= reduction_step |
|
|
|
if found_valid_solution: |
|
tile_width = best_tile_width |
|
tile_height = max_tile_height_size |
|
while tile_height >= min_tile_dimension: |
|
horizontal_borders = num_cols - 1 |
|
total_horizontal_overlap = overlap_pixels * horizontal_borders |
|
adjusted_target_width = tile_width * num_cols - total_horizontal_overlap |
|
|
|
vertical_borders = num_rows - 1 |
|
total_vertical_overlap = overlap_pixels * vertical_borders |
|
adjusted_target_height = tile_height * num_rows - total_vertical_overlap |
|
|
|
if tile_height <= max_tile_height_size and adjusted_target_height <= target_height: |
|
if adjusted_target_height > best_adjusted_target_height: |
|
best_tile_height = tile_height |
|
best_adjusted_target_height = adjusted_target_height |
|
tile_height -= reduction_step |
|
|
|
new_target_height = best_adjusted_target_height |
|
new_target_width = best_adjusted_target_width |
|
|
|
|
|
new_target_height = new_target_height - (new_target_height % 8) |
|
new_target_width = new_target_width - (new_target_width % 8) |
|
|
|
tile_width = best_tile_width |
|
tile_height = best_tile_height |
|
|
|
print("--- TILE SIZE CALCULATED VALUES ---") |
|
print(f"Requested Overlap Pixels: {overlap_pixels}") |
|
print(f"Tile Height (max {max_tile_height_size}, divisible by 8): {tile_height}") |
|
print(f"Tile Width (max {max_tile_width_size}, divisible by 8): {tile_width}") |
|
print(f"Columns: {num_cols} | Rows: {num_rows}") |
|
print(f"Original Target: {target_height} x {target_width}") |
|
print(f"Adjusted Target (divisible by 8): {new_target_height} x {new_target_width}\n") |
|
|
|
return new_target_height, new_target_width, tile_height, tile_width |
|
|
|
def do_calc_tile(target_height, target_width, overlap_pixels, max_tile_size): |
|
new_target_height, new_target_width, tile_height, tile_width = calc_tile_size(target_height, target_width, overlap_pixels, max_tile_size) |
|
return gr.update(value=tile_height), gr.update(value=tile_width), gr.update(value=new_target_height), gr.update(value=new_target_width) |
|
|
|
def clear_result(): |
|
return gr.update(value=None) |
|
|
|
def randomize_seed_fn(generation_seed: int, randomize_seed: bool) -> int: |
|
if randomize_seed: |
|
generation_seed = random.randint(0, MAX_SEED) |
|
return generation_seed |
|
|
|
|
|
|
|
css = """ |
|
body { |
|
background: linear-gradient(135deg, #667eea, #764ba2); |
|
font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; |
|
color: #333; |
|
margin: 0; |
|
padding: 0; |
|
} |
|
.gradio-container { |
|
background: rgba(255, 255, 255, 0.95); |
|
border-radius: 15px; |
|
padding: 30px 40px; |
|
box-shadow: 0 8px 30px rgba(0, 0, 0, 0.3); |
|
margin: 40px auto; |
|
max-width: 1200px; |
|
} |
|
.gradio-container h1 { |
|
color: #333; |
|
text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.2); |
|
} |
|
.fillable { |
|
width: 95% !important; |
|
max-width: unset !important; |
|
} |
|
#examples_container { |
|
margin: auto; |
|
width: 90%; |
|
} |
|
#examples_row { |
|
justify-content: center; |
|
} |
|
.sidebar { |
|
background: rgba(255, 255, 255, 0.98); |
|
border-radius: 10px; |
|
padding: 20px; |
|
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2); |
|
} |
|
button, .btn { |
|
background: linear-gradient(90deg, #ff8a00, #e52e71); |
|
border: none; |
|
color: #fff; |
|
padding: 12px 24px; |
|
text-transform: uppercase; |
|
font-weight: bold; |
|
letter-spacing: 1px; |
|
border-radius: 5px; |
|
cursor: pointer; |
|
transition: transform 0.2s ease-in-out; |
|
} |
|
button:hover, .btn:hover { |
|
transform: scale(1.05); |
|
} |
|
""" |
|
|
|
title = """ |
|
<h1 align="center" style="margin-bottom: 0.2em;"> ๐ค Panorama X3 Image </h1> |
|
<p align="center" style="font-size:1.1em; color:#555;"> |
|
Left/Center/Right prompts (English/Korean allowed) |
|
</p> |
|
""" |
|
|
|
with gr.Blocks(css=css, title="SDXL Tiling Pipeline") as app: |
|
gr.Markdown(title) |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=7): |
|
generate_button = gr.Button("Generate", elem_id="generate_btn") |
|
with gr.Row(): |
|
with gr.Column(variant="panel"): |
|
gr.Markdown("### Left Region") |
|
left_prompt = gr.Textbox(lines=4, placeholder="e.g., ์ธ์ฐฝํ ์ฒ๊ณผ ํ์ด์ด ๋น์ถ๋ ๋๋ฌด...", label="Left Prompt (English/Korean allowed)") |
|
left_gs = gr.Slider(minimum=0, maximum=15, value=7, step=1, label="Left CFG Scale") |
|
with gr.Column(variant="panel"): |
|
gr.Markdown("### Center Region") |
|
center_prompt = gr.Textbox(lines=4, placeholder="e.g., ์์ํ ํธ์์ ๋ฐ์ง์ด๋ ์๋ฉด...", label="Center Prompt (English/Korean allowed)") |
|
center_gs = gr.Slider(minimum=0, maximum=15, value=7, step=1, label="Center CFG Scale") |
|
with gr.Column(variant="panel"): |
|
gr.Markdown("### Right Region") |
|
right_prompt = gr.Textbox(lines=4, placeholder="e.g., ์
์ฅํ ์ฐ๋งฅ๊ณผ ํ๋์ ๊ฐ๋ฅด๋ ๊ตฌ๋ฆ...", label="Right Prompt (English/Korean allowed)") |
|
right_gs = gr.Slider(minimum=0, maximum=15, value=7, step=1, label="Right CFG Scale") |
|
with gr.Row(): |
|
negative_prompt = gr.Textbox( |
|
lines=2, |
|
label="Negative Prompt", |
|
placeholder="e.g., blurry, low resolution, artifacts, poor details", |
|
value="blurry, low resolution, artifacts, poor details" |
|
) |
|
with gr.Row(): |
|
result = gr.Image(label="Generated Image", show_label=True, format="png", interactive=False, scale=1) |
|
|
|
|
|
with gr.Sidebar(label="Parameters", open=True): |
|
gr.Markdown("### Generation Parameters") |
|
with gr.Row(): |
|
height = gr.Slider(label="Target Height", value=1024, step=8, minimum=512, maximum=1024) |
|
width = gr.Slider(label="Target Width", value=1280, step=8, minimum=512, maximum=3840) |
|
overlap = gr.Slider(minimum=0, maximum=512, value=128, step=8, label="Tile Overlap") |
|
max_tile_size = gr.Dropdown(label="Max Tile Size", choices=[1024, 1280], value=1280) |
|
calc_tile = gr.Button("Calculate Tile Size") |
|
with gr.Row(): |
|
tile_height = gr.Textbox(label="Tile Height", value=1024, interactive=False) |
|
tile_width = gr.Textbox(label="Tile Width", value=1024, interactive=False) |
|
with gr.Row(): |
|
new_target_height = gr.Textbox(label="New Image Height", value=1024, interactive=False) |
|
new_target_width = gr.Textbox(label="New Image Width", value=1280, interactive=False) |
|
with gr.Row(): |
|
steps = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Inference Steps") |
|
generation_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) |
|
randomize_seed = gr.Checkbox(label="Randomize Seed", value=False) |
|
with gr.Row(): |
|
scheduler = gr.Dropdown(label="Scheduler", choices=SCHEDULERS, value=SCHEDULERS[0]) |
|
|
|
|
|
with gr.Row(elem_id="examples_row"): |
|
with gr.Column(scale=12, elem_id="examples_container"): |
|
gr.Markdown("### Example Prompts") |
|
gr.Examples( |
|
examples=[ |
|
[ |
|
"Lush green forest with sun rays filtering through the canopy", |
|
"Crystal clear lake reflecting a vibrant sky", |
|
"Majestic mountains with snowy peaks in the distance", |
|
"blurry, low resolution, artifacts, poor details", |
|
7, 7, 7, |
|
128, |
|
30, |
|
123456789, |
|
"UniPCMultistepScheduler", |
|
1024, 1280, |
|
1024, 1920, |
|
1280 |
|
], |
|
[ |
|
"Vibrant abstract strokes with fluid, swirling patterns in cool tones", |
|
"Interlocking geometric shapes bursting with color and texture", |
|
"Dynamic composition of splattered ink with smooth gradients", |
|
"text, watermark, signature, distorted", |
|
6, 6, 6, |
|
80, |
|
25, |
|
192837465, |
|
"DPMSolverMultistepScheduler-Karras", |
|
1024, 1280, |
|
1024, 1920, |
|
1280 |
|
], |
|
[ |
|
"Enchanted forest with glowing bioluminescent plants and mystical fog", |
|
"Ancient castle with towering spires bathed in moonlight", |
|
"Majestic dragon soaring above a starry night sky", |
|
"low quality, artifact, deformed, sketchy", |
|
9, 9, 9, |
|
150, |
|
40, |
|
1029384756, |
|
"DPMSolverMultistepScheduler-Karras-SDE", |
|
1024, 1280, |
|
1024, 1920, |
|
1280 |
|
] |
|
], |
|
inputs=[left_prompt, center_prompt, right_prompt, negative_prompt, |
|
left_gs, center_gs, right_gs, overlap, steps, generation_seed, |
|
scheduler, tile_height, tile_width, height, width, max_tile_size], |
|
cache_examples=False |
|
) |
|
|
|
|
|
event_calc_tile_size = { |
|
"fn": do_calc_tile, |
|
"inputs": [height, width, overlap, max_tile_size], |
|
"outputs": [tile_height, tile_width, new_target_height, new_target_width] |
|
} |
|
calc_tile.click(**event_calc_tile_size) |
|
|
|
generate_button.click( |
|
fn=clear_result, |
|
inputs=None, |
|
outputs=result, |
|
).then(**event_calc_tile_size).then( |
|
fn=randomize_seed_fn, |
|
inputs=[generation_seed, randomize_seed], |
|
outputs=generation_seed, |
|
queue=False, |
|
api_name=False, |
|
).then( |
|
fn=predict, |
|
inputs=[left_prompt, center_prompt, right_prompt, negative_prompt, |
|
left_gs, center_gs, right_gs, overlap, steps, generation_seed, |
|
scheduler, tile_height, tile_width, new_target_height, new_target_width], |
|
outputs=result, |
|
) |
|
|
|
app.launch(share=False) |
|
|