import random import gradio as gr import numpy as np import spaces import torch from diffusers import AutoencoderKL from mixture_tiling_sdxl import StableDiffusionXLTilingPipeline MAX_SEED = np.iinfo(np.int32).max SCHEDULERS = [ "LMSDiscreteScheduler", "DEISMultistepScheduler", "HeunDiscreteScheduler", "EulerAncestralDiscreteScheduler", "EulerDiscreteScheduler", "DPMSolverMultistepScheduler", "DPMSolverMultistepScheduler-Karras", "DPMSolverMultistepScheduler-Karras-SDE", "UniPCMultistepScheduler" ] # 모델 로딩: VAE 및 타일링 파이프라인 초기화 vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 ).to("cuda") model_id = "stablediffusionapi/yamermix-v8-vae" pipe = StableDiffusionXLTilingPipeline.from_pretrained( model_id, torch_dtype=torch.float16, vae=vae, use_safetensors=False, # for yammermix ).to("cuda") pipe.enable_model_cpu_offload() # VRAM이 제한된 경우 사용 pipe.enable_vae_tiling() pipe.enable_vae_slicing() #region functions def select_scheduler(scheduler_name): scheduler_parts = scheduler_name.split("-") scheduler_class_name = scheduler_parts[0] add_kwargs = { "beta_start": 0.00085, "beta_end": 0.012, "beta_schedule": "scaled_linear", "num_train_timesteps": 1000 } if len(scheduler_parts) > 1: add_kwargs["use_karras_sigmas"] = True if len(scheduler_parts) > 2: add_kwargs["algorithm_type"] = "sde-dpmsolver++" import diffusers scheduler_cls = getattr(diffusers, scheduler_class_name) scheduler = scheduler_cls.from_config(pipe.scheduler.config, **add_kwargs) return scheduler @spaces.GPU def predict(left_prompt, center_prompt, right_prompt, negative_prompt, left_gs, center_gs, right_gs, overlap_pixels, steps, generation_seed, scheduler, tile_height, tile_width, target_height, target_width): global pipe print(f"Using scheduler: {scheduler}...") pipe.scheduler = select_scheduler(scheduler) generator = torch.Generator("cuda").manual_seed(generation_seed) target_height = int(target_height) target_width = int(target_width) tile_height = int(tile_height) tile_width = int(tile_width) image = pipe( prompt=[[left_prompt, center_prompt, right_prompt]], negative_prompt=negative_prompt, tile_height=tile_height, tile_width=tile_width, tile_row_overlap=0, tile_col_overlap=overlap_pixels, guidance_scale_tiles=[[left_gs, center_gs, right_gs]], height=target_height, width=target_width, generator=generator, num_inference_steps=steps, )["images"][0] return image def calc_tile_size(target_height, target_width, overlap_pixels, max_tile_width_size=1280): num_cols = 3 num_rows = 1 min_tile_dimension = 8 reduction_step = 8 max_tile_height_size = 1024 best_tile_width = 0 best_tile_height = 0 best_adjusted_target_width = 0 best_adjusted_target_height = 0 found_valid_solution = False tile_width = max_tile_width_size tile_height = max_tile_height_size while tile_width >= min_tile_dimension: horizontal_borders = num_cols - 1 total_horizontal_overlap = overlap_pixels * horizontal_borders adjusted_target_width = tile_width * num_cols - total_horizontal_overlap vertical_borders = num_rows - 1 total_vertical_overlap = overlap_pixels * vertical_borders adjusted_target_height = tile_height * num_rows - total_vertical_overlap if tile_width <= max_tile_width_size and adjusted_target_width <= target_width: if adjusted_target_width > best_adjusted_target_width: best_tile_width = tile_width best_adjusted_target_width = adjusted_target_width found_valid_solution = True tile_width -= reduction_step if found_valid_solution: tile_width = best_tile_width tile_height = max_tile_height_size while tile_height >= min_tile_dimension: horizontal_borders = num_cols - 1 total_horizontal_overlap = overlap_pixels * horizontal_borders adjusted_target_width = tile_width * num_cols - total_horizontal_overlap vertical_borders = num_rows - 1 total_vertical_overlap = overlap_pixels * vertical_borders adjusted_target_height = tile_height * num_rows - total_vertical_overlap if tile_height <= max_tile_height_size and adjusted_target_height <= target_height: if adjusted_target_height > best_adjusted_target_height: best_tile_height = tile_height best_adjusted_target_height = adjusted_target_height tile_height -= reduction_step new_target_height = best_adjusted_target_height new_target_width = best_adjusted_target_width # ★ 새로 계산된 target 높이와 너비가 8로 나누어 떨어지도록 조정 (오류 방지) new_target_height = new_target_height - (new_target_height % 8) new_target_width = new_target_width - (new_target_width % 8) tile_width = best_tile_width tile_height = best_tile_height print("--- TILE SIZE CALCULATED VALUES ---") print(f"Requested Overlap Pixels: {overlap_pixels}") print(f"Tile Height (max {max_tile_height_size}, divisible by 8): {tile_height}") print(f"Tile Width (max {max_tile_width_size}, divisible by 8): {tile_width}") print(f"Columns: {num_cols} | Rows: {num_rows}") print(f"Original Target: {target_height} x {target_width}") print(f"Adjusted Target (divisible by 8): {new_target_height} x {new_target_width}\n") return new_target_height, new_target_width, tile_height, tile_width def do_calc_tile(target_height, target_width, overlap_pixels, max_tile_size): new_target_height, new_target_width, tile_height, tile_width = calc_tile_size(target_height, target_width, overlap_pixels, max_tile_size) return gr.update(value=tile_height), gr.update(value=tile_width), gr.update(value=new_target_height), gr.update(value=new_target_width) def clear_result(): return gr.update(value=None) def randomize_seed_fn(generation_seed: int, randomize_seed: bool) -> int: if randomize_seed: generation_seed = random.randint(0, MAX_SEED) return generation_seed #endregion # CSS 개선: 입체감 있는 배경, 반투명 컨테이너, 그림자, 애니메이션 효과 등 적용 css = """ body { background: linear-gradient(135deg, #667eea, #764ba2); font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; color: #333; margin: 0; padding: 0; } .gradio-container { background: rgba(255, 255, 255, 0.95); border-radius: 15px; padding: 30px 40px; box-shadow: 0 8px 30px rgba(0, 0, 0, 0.3); margin: 40px auto; max-width: 1200px; } .gradio-container h1 { color: #333; text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.2); } .fillable { width: 95% !important; max-width: unset !important; } #examples_container { margin: auto; width: 90%; } #examples_row { justify-content: center; } .sidebar { background: rgba(255, 255, 255, 0.98); border-radius: 10px; padding: 20px; box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2); } button, .btn { background: linear-gradient(90deg, #ff8a00, #e52e71); border: none; color: #fff; padding: 12px 24px; text-transform: uppercase; font-weight: bold; letter-spacing: 1px; border-radius: 5px; cursor: pointer; transition: transform 0.2s ease-in-out; } button:hover, .btn:hover { transform: scale(1.05); } """ title = """
좌/중앙/우 각 영역에 다른 프롬프트를 적용하여 타일링 이미지를 생성합니다.
아래 예제를 클릭하면 입력창에 값이 채워집니다.