Panorama / app.py
fantos's picture
Update app.py
8bc44c9 verified
import random
import re
import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import AutoencoderKL
from mixture_tiling_sdxl import StableDiffusionXLTilingPipeline
from transformers import pipeline
# ํ•œ๊ธ€ ์ž…๋ ฅ ์‹œ ์˜์–ด๋กœ ๋ฒˆ์—ญํ•˜๊ธฐ ์œ„ํ•œ ํŒŒ์ดํ”„๋ผ์ธ (CPU์—์„œ ๋™์ž‘)
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
MAX_SEED = np.iinfo(np.int32).max
SCHEDULERS = [
"LMSDiscreteScheduler",
"DEISMultistepScheduler",
"HeunDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"EulerDiscreteScheduler",
"DPMSolverMultistepScheduler",
"DPMSolverMultistepScheduler-Karras",
"DPMSolverMultistepScheduler-Karras-SDE",
"UniPCMultistepScheduler"
]
# ๋ชจ๋ธ ๋กœ๋”ฉ: VAE ๋ฐ ํƒ€์ผ๋ง ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
).to("cuda")
model_id = "stablediffusionapi/yamermix-v8-vae"
pipe = StableDiffusionXLTilingPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
vae=vae,
use_safetensors=False, # for yammermix
).to("cuda")
pipe.enable_model_cpu_offload() # VRAM์ด ์ œํ•œ๋œ ๊ฒฝ์šฐ ์‚ฌ์šฉ
pipe.enable_vae_tiling()
pipe.enable_vae_slicing()
#region helper functions
def translate_if_needed(text: str) -> str:
"""ํ…์ŠคํŠธ์— ํ•œ๊ธ€์ด ํฌํ•จ๋œ ๊ฒฝ์šฐ ์˜์–ด๋กœ ๋ฒˆ์—ญํ•˜์—ฌ ๋ฐ˜ํ™˜."""
if re.search(r'[\uac00-\ud7a3]', text):
# ๋ฒˆ์—ญ ๊ฒฐ๊ณผ๋Š” ๋ฆฌ์ŠคํŠธ์˜ ์ฒซ๋ฒˆ์งธ ์š”์†Œ์˜ "translation_text" ํ‚ค์— ์žˆ์Œ
translated = translator(text)[0]["translation_text"]
return translated
return text
def select_scheduler(scheduler_name):
scheduler_parts = scheduler_name.split("-")
scheduler_class_name = scheduler_parts[0]
add_kwargs = {
"beta_start": 0.00085,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"num_train_timesteps": 1000
}
if len(scheduler_parts) > 1:
add_kwargs["use_karras_sigmas"] = True
if len(scheduler_parts) > 2:
add_kwargs["algorithm_type"] = "sde-dpmsolver++"
import diffusers
scheduler_cls = getattr(diffusers, scheduler_class_name)
scheduler = scheduler_cls.from_config(pipe.scheduler.config, **add_kwargs)
return scheduler
@spaces.GPU
def predict(left_prompt, center_prompt, right_prompt, negative_prompt, left_gs, center_gs, right_gs,
overlap_pixels, steps, generation_seed, scheduler, tile_height, tile_width, target_height, target_width):
global pipe
print(f"Using scheduler: {scheduler}...")
pipe.scheduler = select_scheduler(scheduler)
generator = torch.Generator("cuda").manual_seed(generation_seed)
# ํ•œ๊ธ€ ํ”„๋กฌํ”„ํŠธ๊ฐ€ ์ž…๋ ฅ๋œ ๊ฒฝ์šฐ ์˜์–ด๋กœ ๋ฒˆ์—ญ
left_prompt = translate_if_needed(left_prompt)
center_prompt = translate_if_needed(center_prompt)
right_prompt = translate_if_needed(right_prompt)
target_height = int(target_height)
target_width = int(target_width)
tile_height = int(tile_height)
tile_width = int(tile_width)
image = pipe(
prompt=[[left_prompt, center_prompt, right_prompt]],
negative_prompt=negative_prompt,
tile_height=tile_height,
tile_width=tile_width,
tile_row_overlap=0,
tile_col_overlap=overlap_pixels,
guidance_scale_tiles=[[left_gs, center_gs, right_gs]],
height=target_height,
width=target_width,
generator=generator,
num_inference_steps=steps,
)["images"][0]
return image
def calc_tile_size(target_height, target_width, overlap_pixels, max_tile_width_size=1280):
num_cols = 3
num_rows = 1
min_tile_dimension = 8
reduction_step = 8
max_tile_height_size = 1024
best_tile_width = 0
best_tile_height = 0
best_adjusted_target_width = 0
best_adjusted_target_height = 0
found_valid_solution = False
tile_width = max_tile_width_size
tile_height = max_tile_height_size
while tile_width >= min_tile_dimension:
horizontal_borders = num_cols - 1
total_horizontal_overlap = overlap_pixels * horizontal_borders
adjusted_target_width = tile_width * num_cols - total_horizontal_overlap
vertical_borders = num_rows - 1
total_vertical_overlap = overlap_pixels * vertical_borders
adjusted_target_height = tile_height * num_rows - total_vertical_overlap
if tile_width <= max_tile_width_size and adjusted_target_width <= target_width:
if adjusted_target_width > best_adjusted_target_width:
best_tile_width = tile_width
best_adjusted_target_width = adjusted_target_width
found_valid_solution = True
tile_width -= reduction_step
if found_valid_solution:
tile_width = best_tile_width
tile_height = max_tile_height_size
while tile_height >= min_tile_dimension:
horizontal_borders = num_cols - 1
total_horizontal_overlap = overlap_pixels * horizontal_borders
adjusted_target_width = tile_width * num_cols - total_horizontal_overlap
vertical_borders = num_rows - 1
total_vertical_overlap = overlap_pixels * vertical_borders
adjusted_target_height = tile_height * num_rows - total_vertical_overlap
if tile_height <= max_tile_height_size and adjusted_target_height <= target_height:
if adjusted_target_height > best_adjusted_target_height:
best_tile_height = tile_height
best_adjusted_target_height = adjusted_target_height
tile_height -= reduction_step
new_target_height = best_adjusted_target_height
new_target_width = best_adjusted_target_width
# โ˜… ์ƒˆ๋กœ ๊ณ„์‚ฐ๋œ target ๋†’์ด์™€ ๋„ˆ๋น„๊ฐ€ 8๋กœ ๋‚˜๋ˆ„์–ด ๋–จ์–ด์ง€๋„๋ก ์กฐ์ • (์˜ค๋ฅ˜ ๋ฐฉ์ง€)
new_target_height = new_target_height - (new_target_height % 8)
new_target_width = new_target_width - (new_target_width % 8)
tile_width = best_tile_width
tile_height = best_tile_height
print("--- TILE SIZE CALCULATED VALUES ---")
print(f"Requested Overlap Pixels: {overlap_pixels}")
print(f"Tile Height (max {max_tile_height_size}, divisible by 8): {tile_height}")
print(f"Tile Width (max {max_tile_width_size}, divisible by 8): {tile_width}")
print(f"Columns: {num_cols} | Rows: {num_rows}")
print(f"Original Target: {target_height} x {target_width}")
print(f"Adjusted Target (divisible by 8): {new_target_height} x {new_target_width}\n")
return new_target_height, new_target_width, tile_height, tile_width
def do_calc_tile(target_height, target_width, overlap_pixels, max_tile_size):
new_target_height, new_target_width, tile_height, tile_width = calc_tile_size(target_height, target_width, overlap_pixels, max_tile_size)
return gr.update(value=tile_height), gr.update(value=tile_width), gr.update(value=new_target_height), gr.update(value=new_target_width)
def clear_result():
return gr.update(value=None)
def randomize_seed_fn(generation_seed: int, randomize_seed: bool) -> int:
if randomize_seed:
generation_seed = random.randint(0, MAX_SEED)
return generation_seed
#endregion
# CSS ๊ฐœ์„ : ์ž…์ฒด๊ฐ ์žˆ๋Š” ๋ฐฐ๊ฒฝ, ๋ฐ˜ํˆฌ๋ช… ์ปจํ…Œ์ด๋„ˆ, ๊ทธ๋ฆผ์ž, ์• ๋‹ˆ๋ฉ”์ด์…˜ ํšจ๊ณผ ๋“ฑ ์ ์šฉ
css = """
body {
background: linear-gradient(135deg, #667eea, #764ba2);
font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
color: #333;
margin: 0;
padding: 0;
}
.gradio-container {
background: rgba(255, 255, 255, 0.95);
border-radius: 15px;
padding: 30px 40px;
box-shadow: 0 8px 30px rgba(0, 0, 0, 0.3);
margin: 40px auto;
max-width: 1200px;
}
.gradio-container h1 {
color: #333;
text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.2);
}
.fillable {
width: 95% !important;
max-width: unset !important;
}
#examples_container {
margin: auto;
width: 90%;
}
#examples_row {
justify-content: center;
}
.sidebar {
background: rgba(255, 255, 255, 0.98);
border-radius: 10px;
padding: 20px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
}
button, .btn {
background: linear-gradient(90deg, #ff8a00, #e52e71);
border: none;
color: #fff;
padding: 12px 24px;
text-transform: uppercase;
font-weight: bold;
letter-spacing: 1px;
border-radius: 5px;
cursor: pointer;
transition: transform 0.2s ease-in-out;
}
button:hover, .btn:hover {
transform: scale(1.05);
}
"""
title = """
<h1 align="center" style="margin-bottom: 0.2em;"> ๐Ÿค— Panorama X3 Image </h1>
<p align="center" style="font-size:1.1em; color:#555;">
Left/Center/Right prompts (English/Korean allowed)
</p>
"""
with gr.Blocks(css=css, title="SDXL Tiling Pipeline") as app:
gr.Markdown(title)
with gr.Row():
# ์ขŒ/์ค‘์•™/์šฐ ํ”„๋กฌํ”„ํŠธ ๋ฐ ๊ฒฐ๊ณผ ์˜์—ญ
with gr.Column(scale=7):
generate_button = gr.Button("Generate", elem_id="generate_btn")
with gr.Row():
with gr.Column(variant="panel"):
gr.Markdown("### Left Region")
left_prompt = gr.Textbox(lines=4, placeholder="e.g., ์šธ์ฐฝํ•œ ์ˆฒ๊ณผ ํ–‡์‚ด์ด ๋น„์ถ”๋Š” ๋‚˜๋ฌด...", label="Left Prompt (English/Korean allowed)")
left_gs = gr.Slider(minimum=0, maximum=15, value=7, step=1, label="Left CFG Scale")
with gr.Column(variant="panel"):
gr.Markdown("### Center Region")
center_prompt = gr.Textbox(lines=4, placeholder="e.g., ์ž”์ž”ํ•œ ํ˜ธ์ˆ˜์™€ ๋ฐ˜์ง์ด๋Š” ์ˆ˜๋ฉด...", label="Center Prompt (English/Korean allowed)")
center_gs = gr.Slider(minimum=0, maximum=15, value=7, step=1, label="Center CFG Scale")
with gr.Column(variant="panel"):
gr.Markdown("### Right Region")
right_prompt = gr.Textbox(lines=4, placeholder="e.g., ์›…์žฅํ•œ ์‚ฐ๋งฅ๊ณผ ํ•˜๋Š˜์„ ๊ฐ€๋ฅด๋Š” ๊ตฌ๋ฆ„...", label="Right Prompt (English/Korean allowed)")
right_gs = gr.Slider(minimum=0, maximum=15, value=7, step=1, label="Right CFG Scale")
with gr.Row():
negative_prompt = gr.Textbox(
lines=2,
label="Negative Prompt",
placeholder="e.g., blurry, low resolution, artifacts, poor details",
value="blurry, low resolution, artifacts, poor details"
)
with gr.Row():
result = gr.Image(label="Generated Image", show_label=True, format="png", interactive=False, scale=1)
# ์‚ฌ์ด๋“œ๋ฐ”: ํŒŒ๋ผ๋ฏธํ„ฐ ๋ฐ ํƒ€์ผ ํฌ๊ธฐ ๊ณ„์‚ฐ
with gr.Sidebar(label="Parameters", open=True):
gr.Markdown("### Generation Parameters")
with gr.Row():
height = gr.Slider(label="Target Height", value=1024, step=8, minimum=512, maximum=1024)
width = gr.Slider(label="Target Width", value=1280, step=8, minimum=512, maximum=3840)
overlap = gr.Slider(minimum=0, maximum=512, value=128, step=8, label="Tile Overlap")
max_tile_size = gr.Dropdown(label="Max Tile Size", choices=[1024, 1280], value=1280)
calc_tile = gr.Button("Calculate Tile Size")
with gr.Row():
tile_height = gr.Textbox(label="Tile Height", value=1024, interactive=False)
tile_width = gr.Textbox(label="Tile Width", value=1024, interactive=False)
with gr.Row():
new_target_height = gr.Textbox(label="New Image Height", value=1024, interactive=False)
new_target_width = gr.Textbox(label="New Image Width", value=1280, interactive=False)
with gr.Row():
steps = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Inference Steps")
generation_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
with gr.Row():
scheduler = gr.Dropdown(label="Scheduler", choices=SCHEDULERS, value=SCHEDULERS[0])
# ์ค‘์•™์— ๋ฐฐ์น˜๋œ ์˜ˆ์ œ ์˜์—ญ
with gr.Row(elem_id="examples_row"):
with gr.Column(scale=12, elem_id="examples_container"):
gr.Markdown("### Example Prompts")
gr.Examples(
examples=[
[
"Lush green forest with sun rays filtering through the canopy",
"Crystal clear lake reflecting a vibrant sky",
"Majestic mountains with snowy peaks in the distance",
"blurry, low resolution, artifacts, poor details",
7, 7, 7,
128,
30,
123456789,
"UniPCMultistepScheduler",
1024, 1280,
1024, 1920,
1280
],
[
"Vibrant abstract strokes with fluid, swirling patterns in cool tones",
"Interlocking geometric shapes bursting with color and texture",
"Dynamic composition of splattered ink with smooth gradients",
"text, watermark, signature, distorted",
6, 6, 6,
80,
25,
192837465,
"DPMSolverMultistepScheduler-Karras",
1024, 1280,
1024, 1920,
1280
],
[
"Enchanted forest with glowing bioluminescent plants and mystical fog",
"Ancient castle with towering spires bathed in moonlight",
"Majestic dragon soaring above a starry night sky",
"low quality, artifact, deformed, sketchy",
9, 9, 9,
150,
40,
1029384756,
"DPMSolverMultistepScheduler-Karras-SDE",
1024, 1280,
1024, 1920,
1280
]
],
inputs=[left_prompt, center_prompt, right_prompt, negative_prompt,
left_gs, center_gs, right_gs, overlap, steps, generation_seed,
scheduler, tile_height, tile_width, height, width, max_tile_size],
cache_examples=False
)
# ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ: ํƒ€์ผ ์‚ฌ์ด์ฆˆ ๊ณ„์‚ฐ ๋ฐ ์ด๋ฏธ์ง€ ์ƒ์„ฑ
event_calc_tile_size = {
"fn": do_calc_tile,
"inputs": [height, width, overlap, max_tile_size],
"outputs": [tile_height, tile_width, new_target_height, new_target_width]
}
calc_tile.click(**event_calc_tile_size)
generate_button.click(
fn=clear_result,
inputs=None,
outputs=result,
).then(**event_calc_tile_size).then(
fn=randomize_seed_fn,
inputs=[generation_seed, randomize_seed],
outputs=generation_seed,
queue=False,
api_name=False,
).then(
fn=predict,
inputs=[left_prompt, center_prompt, right_prompt, negative_prompt,
left_gs, center_gs, right_gs, overlap, steps, generation_seed,
scheduler, tile_height, tile_width, new_target_height, new_target_width],
outputs=result,
)
app.launch(share=False)