import random import gradio as gr import numpy as np import spaces import torch from diffusers import AutoencoderKL from mixture_tiling_sdxl import StableDiffusionXLTilingPipeline MAX_SEED = np.iinfo(np.int32).max SCHEDULERS = [ "LMSDiscreteScheduler", "DEISMultistepScheduler", "HeunDiscreteScheduler", "EulerAncestralDiscreteScheduler", "EulerDiscreteScheduler", "DPMSolverMultistepScheduler", "DPMSolverMultistepScheduler-Karras", "DPMSolverMultistepScheduler-Karras-SDE", "UniPCMultistepScheduler" ] # 모델 로딩: VAE 및 타일링 파이프라인 초기화 vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 ).to("cuda") model_id = "stablediffusionapi/yamermix-v8-vae" pipe = StableDiffusionXLTilingPipeline.from_pretrained( model_id, torch_dtype=torch.float16, vae=vae, use_safetensors=False, # for yammermix ).to("cuda") pipe.enable_model_cpu_offload() # VRAM이 제한된 경우 사용 pipe.enable_vae_tiling() pipe.enable_vae_slicing() #region functions def select_scheduler(scheduler_name): scheduler_parts = scheduler_name.split("-") scheduler_class_name = scheduler_parts[0] add_kwargs = { "beta_start": 0.00085, "beta_end": 0.012, "beta_schedule": "scaled_linear", "num_train_timesteps": 1000 } if len(scheduler_parts) > 1: add_kwargs["use_karras_sigmas"] = True if len(scheduler_parts) > 2: add_kwargs["algorithm_type"] = "sde-dpmsolver++" import diffusers scheduler_cls = getattr(diffusers, scheduler_class_name) scheduler = scheduler_cls.from_config(pipe.scheduler.config, **add_kwargs) return scheduler @spaces.GPU def predict(left_prompt, center_prompt, right_prompt, negative_prompt, left_gs, center_gs, right_gs, overlap_pixels, steps, generation_seed, scheduler, tile_height, tile_width, target_height, target_width): global pipe print(f"Using scheduler: {scheduler}...") pipe.scheduler = select_scheduler(scheduler) generator = torch.Generator("cuda").manual_seed(generation_seed) target_height = int(target_height) target_width = int(target_width) tile_height = int(tile_height) tile_width = int(tile_width) image = pipe( prompt=[[left_prompt, center_prompt, right_prompt]], negative_prompt=negative_prompt, tile_height=tile_height, tile_width=tile_width, tile_row_overlap=0, tile_col_overlap=overlap_pixels, guidance_scale_tiles=[[left_gs, center_gs, right_gs]], height=target_height, width=target_width, generator=generator, num_inference_steps=steps, )["images"][0] return image def calc_tile_size(target_height, target_width, overlap_pixels, max_tile_width_size=1280): num_cols = 3 num_rows = 1 min_tile_dimension = 8 reduction_step = 8 max_tile_height_size = 1024 best_tile_width = 0 best_tile_height = 0 best_adjusted_target_width = 0 best_adjusted_target_height = 0 found_valid_solution = False tile_width = max_tile_width_size tile_height = max_tile_height_size while tile_width >= min_tile_dimension: horizontal_borders = num_cols - 1 total_horizontal_overlap = overlap_pixels * horizontal_borders adjusted_target_width = tile_width * num_cols - total_horizontal_overlap vertical_borders = num_rows - 1 total_vertical_overlap = overlap_pixels * vertical_borders adjusted_target_height = tile_height * num_rows - total_vertical_overlap if tile_width <= max_tile_width_size and adjusted_target_width <= target_width: if adjusted_target_width > best_adjusted_target_width: best_tile_width = tile_width best_adjusted_target_width = adjusted_target_width found_valid_solution = True tile_width -= reduction_step if found_valid_solution: tile_width = best_tile_width tile_height = max_tile_height_size while tile_height >= min_tile_dimension: horizontal_borders = num_cols - 1 total_horizontal_overlap = overlap_pixels * horizontal_borders adjusted_target_width = tile_width * num_cols - total_horizontal_overlap vertical_borders = num_rows - 1 total_vertical_overlap = overlap_pixels * vertical_borders adjusted_target_height = tile_height * num_rows - total_vertical_overlap if tile_height <= max_tile_height_size and adjusted_target_height <= target_height: if adjusted_target_height > best_adjusted_target_height: best_tile_height = tile_height best_adjusted_target_height = adjusted_target_height tile_height -= reduction_step new_target_height = best_adjusted_target_height new_target_width = best_adjusted_target_width # ★ 새로 계산된 target 높이와 너비가 8로 나누어 떨어지도록 조정 (오류 방지) new_target_height = new_target_height - (new_target_height % 8) new_target_width = new_target_width - (new_target_width % 8) tile_width = best_tile_width tile_height = best_tile_height print("--- TILE SIZE CALCULATED VALUES ---") print(f"Requested Overlap Pixels: {overlap_pixels}") print(f"Tile Height (max {max_tile_height_size}, divisible by 8): {tile_height}") print(f"Tile Width (max {max_tile_width_size}, divisible by 8): {tile_width}") print(f"Columns: {num_cols} | Rows: {num_rows}") print(f"Original Target: {target_height} x {target_width}") print(f"Adjusted Target (divisible by 8): {new_target_height} x {new_target_width}\n") return new_target_height, new_target_width, tile_height, tile_width def do_calc_tile(target_height, target_width, overlap_pixels, max_tile_size): new_target_height, new_target_width, tile_height, tile_width = calc_tile_size(target_height, target_width, overlap_pixels, max_tile_size) return gr.update(value=tile_height), gr.update(value=tile_width), gr.update(value=new_target_height), gr.update(value=new_target_width) def clear_result(): return gr.update(value=None) def randomize_seed_fn(generation_seed: int, randomize_seed: bool) -> int: if randomize_seed: generation_seed = random.randint(0, MAX_SEED) return generation_seed #endregion # CSS 개선: 입체감 있는 배경, 반투명 컨테이너, 그림자, 애니메이션 효과 등 적용 css = """ body { background: linear-gradient(135deg, #667eea, #764ba2); font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; color: #333; margin: 0; padding: 0; } .gradio-container { background: rgba(255, 255, 255, 0.95); border-radius: 15px; padding: 30px 40px; box-shadow: 0 8px 30px rgba(0, 0, 0, 0.3); margin: 40px auto; max-width: 1200px; } .gradio-container h1 { color: #333; text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.2); } .fillable { width: 95% !important; max-width: unset !important; } #examples_container { margin: auto; width: 90%; } #examples_row { justify-content: center; } .sidebar { background: rgba(255, 255, 255, 0.98); border-radius: 10px; padding: 20px; box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2); } button, .btn { background: linear-gradient(90deg, #ff8a00, #e52e71); border: none; color: #fff; padding: 12px 24px; text-transform: uppercase; font-weight: bold; letter-spacing: 1px; border-radius: 5px; cursor: pointer; transition: transform 0.2s ease-in-out; } button:hover, .btn:hover { transform: scale(1.05); } """ title = """

Mixture-of-Diffusers for SDXL Tiling Pipeline 🤗

좌/중앙/우 각 영역에 다른 프롬프트를 적용하여 타일링 이미지를 생성합니다.
아래 예제를 클릭하면 입력창에 값이 채워집니다.

""" with gr.Blocks(css=css, title="SDXL Tiling Pipeline") as app: gr.Markdown(title) with gr.Row(): # 좌/중앙/우 프롬프트 및 결과 영역 with gr.Column(scale=7): generate_button = gr.Button("Generate", elem_id="generate_btn") with gr.Row(): with gr.Column(variant="panel"): gr.Markdown("### Left Region") left_prompt = gr.Textbox(lines=4, placeholder="예: 울창한 숲과 햇살이 비추는 나무...", label="Left Prompt") left_gs = gr.Slider(minimum=0, maximum=15, value=7, step=1, label="Left CFG Scale") with gr.Column(variant="panel"): gr.Markdown("### Center Region") center_prompt = gr.Textbox(lines=4, placeholder="예: 잔잔한 호수와 반짝이는 수면...", label="Center Prompt") center_gs = gr.Slider(minimum=0, maximum=15, value=7, step=1, label="Center CFG Scale") with gr.Column(variant="panel"): gr.Markdown("### Right Region") right_prompt = gr.Textbox(lines=4, placeholder="예: 웅장한 산맥과 하늘을 가르는 구름...", label="Right Prompt") right_gs = gr.Slider(minimum=0, maximum=15, value=7, step=1, label="Right CFG Scale") with gr.Row(): negative_prompt = gr.Textbox( lines=2, label="Negative Prompt", placeholder="예: blurry, low resolution, artifacts, poor details", value="blurry, low resolution, artifacts, poor details" ) with gr.Row(): result = gr.Image(label="Generated Image", show_label=True, format="png", interactive=False, scale=1) # 사이드바: 파라미터 및 타일 크기 계산 with gr.Sidebar(label="Parameters", open=True): gr.Markdown("### Generation Parameters") with gr.Row(): height = gr.Slider(label="Target Height", value=1024, step=8, minimum=512, maximum=1024) width = gr.Slider(label="Target Width", value=1280, step=8, minimum=512, maximum=3840) overlap = gr.Slider(minimum=0, maximum=512, value=128, step=8, label="Tile Overlap") max_tile_size = gr.Dropdown(label="Max Tile Size", choices=[1024, 1280], value=1280) calc_tile = gr.Button("Calculate Tile Size") with gr.Row(): tile_height = gr.Textbox(label="Tile Height", value=1024, interactive=False) tile_width = gr.Textbox(label="Tile Width", value=1024, interactive=False) with gr.Row(): new_target_height = gr.Textbox(label="New Image Height", value=1024, interactive=False) new_target_width = gr.Textbox(label="New Image Width", value=1280, interactive=False) with gr.Row(): steps = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Inference Steps") generation_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) randomize_seed = gr.Checkbox(label="Randomize Seed", value=False) with gr.Row(): scheduler = gr.Dropdown(label="Scheduler", choices=SCHEDULERS, value=SCHEDULERS[0]) # 중앙에 배치된 예제 영역 with gr.Row(elem_id="examples_row"): with gr.Column(scale=12, elem_id="examples_container"): gr.Markdown("### Example Prompts") gr.Examples( examples=[ [ "Lush green forest with sun rays filtering through the canopy", "Crystal clear lake reflecting a vibrant sky", "Majestic mountains with snowy peaks in the distance", "blurry, low resolution, artifacts, poor details", 7, 7, 7, 128, 30, 123456789, "UniPCMultistepScheduler", 1024, 1280, 1024, 1920, 1280 ], [ "Vibrant city street with neon signs and bustling crowds", "Sleek modern skyscrapers with digital billboards", "High-speed maglev train gliding over a futuristic urban landscape", "blurry, poorly rendered, low quality, disfigured", 8, 8, 8, 100, 35, 987654321, "EulerDiscreteScheduler", 1024, 1280, 1024, 1920, 1280 ], [ "Vibrant abstract strokes with fluid, swirling patterns in cool tones", "Interlocking geometric shapes bursting with color and texture", "Dynamic composition of splattered ink with smooth gradients", "text, watermark, signature, distorted", 6, 6, 6, 80, 25, 192837465, "DPMSolverMultistepScheduler-Karras", 1024, 1280, 1024, 1920, 1280 ], [ "Enchanted forest with glowing bioluminescent plants and mystical fog", "Ancient castle with towering spires bathed in moonlight", "Majestic dragon soaring above a starry night sky", "low quality, artifact, deformed, sketchy", 9, 9, 9, 150, 40, 1029384756, "DPMSolverMultistepScheduler-Karras-SDE", 1024, 1280, 1024, 1920, 1280 ] ], inputs=[left_prompt, center_prompt, right_prompt, negative_prompt, left_gs, center_gs, right_gs, overlap, steps, generation_seed, scheduler, tile_height, tile_width, height, width, max_tile_size], cache_examples=False ) # 이벤트 연결: 타일 사이즈 계산 및 이미지 생성 event_calc_tile_size = { "fn": do_calc_tile, "inputs": [height, width, overlap, max_tile_size], "outputs": [tile_height, tile_width, new_target_height, new_target_width] } calc_tile.click(**event_calc_tile_size) generate_button.click( fn=clear_result, inputs=None, outputs=result, ).then(**event_calc_tile_size).then( fn=randomize_seed_fn, inputs=[generation_seed, randomize_seed], outputs=generation_seed, queue=False, api_name=False, ).then( fn=predict, inputs=[left_prompt, center_prompt, right_prompt, negative_prompt, left_gs, center_gs, right_gs, overlap, steps, generation_seed, scheduler, tile_height, tile_width, new_target_height, new_target_width], outputs=result, ) app.launch(share=False)