import spaces import torch from diffusers import FluxPipeline import gradio as gr import random import numpy as np import os # GPU 사용 가능 여부 확인 if torch.cuda.is_available(): device = "cuda" print("GPU를 사용합니다") else: device = "cpu" print("CPU를 사용합니다") # HuggingFace 토큰 로그인 HF_TOKEN = os.getenv("HF_TOKEN") MAX_SEED = np.iinfo(np.int32).max CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1" # 파이프라인 초기화 및 모델 다운로드 pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) pipe.to(device) # 이미지 생성 함수 정의 @spaces.GPU(duration=160) def generate_image(prompt, num_inference_steps, height, width, guidance_scale, seed, num_images_per_prompt, progress=gr.Progress(track_tqdm=True)): if seed == 0: seed = random.randint(1, MAX_SEED) generator = torch.Generator().manual_seed(seed) with torch.inference_mode(): output = pipe( prompt=prompt, num_inference_steps=num_inference_steps, height=height, width=width, guidance_scale=guidance_scale, generator=generator, num_images_per_prompt=num_images_per_prompt ).images return output # 예제 프롬프트 examples = [ ["A cat holding a sign that says hello world"], ["a tiny astronaut hatching from an egg on the moon"], ["An astronaut on mars in a futuristic cyborg suit"], ] # 커스텀 CSS css = ''' .gradio-container { max-width: 1000px !important; margin: auto; } h1 { text-align: center; font-family: 'Pretendard', sans-serif; color: #EA580C; } .gr-button-primary { background-color: #F97316 !important; } .gr-button-primary:hover { background-color: #EA580C !important; } .footer-content { text-align: center; margin-top: 2rem; padding: 2rem 1rem; font-family: 'Pretendard', sans-serif; background-color: #F9FAFB; border-radius: 0.5rem; } .visit-button { background-color: #EA580C; color: white; padding: 0.75rem 1.5rem; border-radius: 0.5rem; font-weight: 500; text-decoration: none; display: inline-block; margin-top: 1rem; transition: all 0.2s ease-in-out; } .visit-button:hover { background-color: #C2410C; transform: translateY(-1px); box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); } ''' # Gradio 인터페이스 생성 with gr.Blocks( theme=gr.themes.Soft( primary_hue=gr.themes.Color( c50="#FFF7ED", c100="#FFEDD5", c200="#FED7AA", c300="#FDBA74", c400="#FB923C", c500="#F97316", c600="#EA580C", c700="#C2410C", c800="#9A3412", c900="#7C2D12", c950="#431407", ), secondary_hue="zinc", neutral_hue="zinc", font=("Pretendard", "sans-serif") ), css=css ) as demo: with gr.Row(): with gr.Column(): gr.HTML( """