openfree's picture
Update app.py
d5beeda verified
raw
history blame
15.1 kB
import time
import gradio as gr
import torch
from einops import rearrange, repeat
from PIL import Image
import numpy as np
import spaces # Hugging Face Spaces ์ž„ํฌํŠธ ์ถ”๊ฐ€
import threading
import sys
import os
# ์ „์—ญ ๋ณ€์ˆ˜ ์ •์˜
model_initialized = False
flux_generator = None
initialization_message = "๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘... ์ž ์‹œ๋งŒ ๊ธฐ๋‹ค๋ ค์ฃผ์„ธ์š”."
# ๊ฐ„๋‹จํ•œ ์ธ์šฉ ์ •๋ณด ์ถ”๊ฐ€
_CITE_ = """PuLID: Person-under-Language Image Diffusion Model"""
# GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์—ฌ๋ถ€ ํ™•์ธ ๋ฐ ์žฅ์น˜ ์„ค์ • - ๋ฉ”์ธ ํ”„๋กœ์„ธ์Šค์—์„œ๋Š” ํ˜ธ์ถœํ•˜์ง€ ์•Š์Œ
def get_device():
if torch.cuda.is_available():
return torch.device('cuda')
else:
print("CUDA GPU๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. CPU๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.")
return torch.device('cpu')
def get_models(name: str, device, offload: bool):
try:
# ํ•„์š”ํ•œ ๋ชจ๋“ˆ๋งŒ ์ง€์—ฐ ์ž„ํฌํŠธ
from flux.util import load_ae, load_clip, load_flow_model, load_t5
print(f"๋ชจ๋ธ์„ {device}์— ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.")
t5 = load_t5(device, max_length=128)
clip_model = load_clip(device)
model = load_flow_model(name, device="cpu" if offload else device)
model.eval()
ae = load_ae(name, device="cpu" if offload else device)
return model, ae, t5, clip_model
except Exception as e:
print(f"๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
return None, None, None, None
class FluxGenerator:
def __init__(self):
# GPU ์ดˆ๊ธฐํ™”๋Š” Spaces GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ ์•ˆ์—์„œ๋งŒ ์ˆ˜ํ–‰
self.device = None # ์ดˆ๊ธฐํ™” ์‹œ์ ์—๋Š” device๋ฅผ ํ• ๋‹นํ•˜์ง€ ์•Š์Œ
self.offload = False
self.model_name = 'flux-dev'
self.initialized = False
self.model = None
self.ae = None
self.t5 = None
self.clip_model = None
self.pulid_model = None
def initialize(self):
global initialization_message
try:
# ํ•„์š”ํ•œ ๋ชจ๋“ˆ ์ง€์—ฐ ์ž„ํฌํŠธ
from pulid.pipeline_flux import PuLIDPipeline
from flux.sampling import prepare
# ์ด ์‹œ์ ์—์„œ ์žฅ์น˜ ์„ค์ • (GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ ๋‚ด์—์„œ๋งŒ ํ˜ธ์ถœ๋จ)
self.device = get_device()
print("๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹œ์ž‘...")
self.model, self.ae, self.t5, self.clip_model = get_models(
self.model_name,
device=self.device,
offload=self.offload,
)
if None in [self.model, self.ae, self.t5, self.clip_model]:
print("๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹คํŒจ: ํ•˜๋‚˜ ์ด์ƒ์˜ ๋ชจ๋ธ ์ปดํฌ๋„ŒํŠธ๋ฅผ ๋กœ๋“œํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
self.initialized = False
initialization_message = "๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: ์ผ๋ถ€ ์ปดํฌ๋„ŒํŠธ๋ฅผ ๋กœ๋“œํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
return
self.pulid_model = PuLIDPipeline(
self.model,
'cuda' if torch.cuda.is_available() else 'cpu',
weight_dtype=torch.bfloat16 if self.device.type == 'cuda' else torch.float32
)
self.pulid_model.load_pretrain()
self.initialized = True
print("๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์™„๋ฃŒ!")
# UI ๋ฉ”์‹œ์ง€ ์—…๋ฐ์ดํŠธ
initialization_message = "๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ! ์ด์ œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค."
except Exception as e:
import traceback
error_msg = f"๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
self.initialized = False
# UI ๋ฉ”์‹œ์ง€ ์—…๋ฐ์ดํŠธ
initialization_message = f"๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {str(e)}"
# ์ง€์—ฐ ๋กœ๋”ฉ์„ ์œ„ํ•œ ๋ฐฑ๊ทธ๋ผ์šด๋“œ ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜ - GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ๋กœ ๋ณ€๊ฒฝ
@spaces.GPU(duration=60)
def initialize_models():
global flux_generator, model_initialized, initialization_message
print("GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ ๋‚ด์—์„œ ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹œ์ž‘...")
try:
# ์ง€์—ฐ ์ž„ํฌํŠธ
from flux.sampling import denoise, get_noise, get_schedule, prepare, rf_denoise, rf_inversion, unpack
from flux.util import SamplingOptions
from pulid.utils import resize_numpy_image_long, seed_everything
# ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
flux_generator = FluxGenerator()
flux_generator.initialize()
model_initialized = flux_generator.initialized
except Exception as e:
import traceback
error_msg = f"์ดˆ๊ธฐํ™” ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
model_initialized = False
initialization_message = f"๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์˜ค๋ฅ˜: {str(e)}"
return initialization_message
# ๋ชจ๋ธ ์ƒํƒœ ํ™•์ธ ํ•จ์ˆ˜
def check_model_status():
return initialization_message
# Spaces GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ ์ถ”๊ฐ€ (120์ดˆ GPU ์‚ฌ์šฉ)
@spaces.GPU(duration=120)
@torch.inference_mode()
def generate_image(
prompt: str,
id_image,
num_steps: int,
guidance: float,
seed,
id_weight: float,
neg_prompt: str,
true_cfg: float,
gamma: float,
eta: float,
):
global flux_generator, model_initialized
# ๋ชจ๋ธ์ด ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์•˜์œผ๋ฉด ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€ ๋ฐ˜ํ™˜
if not model_initialized:
return None, "๋ชจ๋ธ ์ดˆ๊ธฐํ™”๊ฐ€ ์™„๋ฃŒ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ฒ„ํŠผ์„ ๋ˆŒ๋Ÿฌ์ฃผ์„ธ์š”."
# ID ์ด๋ฏธ์ง€๊ฐ€ ์—†์œผ๋ฉด ์‹คํ–‰ ๋ถˆ๊ฐ€
if id_image is None:
return None, "์˜ค๋ฅ˜: ID ์ด๋ฏธ์ง€๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค."
try:
# ํ•„์š”ํ•œ ๋ชจ๋“ˆ ์ง€์—ฐ ์ž„ํฌํŠธ
from flux.sampling import denoise, get_noise, get_schedule, prepare, rf_denoise, rf_inversion, unpack
from flux.util import SamplingOptions
from pulid.utils import resize_numpy_image_long, seed_everything
# ๊ณ ์ • ๋งค๊ฐœ๋ณ€์ˆ˜
width = 512
height = 512
start_step = 0
timestep_to_start_cfg = 1
max_sequence_length = 128
s = 0
tau = 5
flux_generator.t5.max_length = max_sequence_length
# ์‹œ๋“œ ์„ค์ •
try:
seed = int(seed)
except:
seed = -1
if seed == -1:
seed = None
opts = SamplingOptions(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
)
if opts.seed is None:
opts.seed = torch.Generator(device="cpu").seed()
seed_everything(opts.seed)
print(f"Generating prompt: '{opts.prompt}' (seed={opts.seed})...")
t0 = time.perf_counter()
use_true_cfg = abs(true_cfg - 1.0) > 1e-6
# 1) ์ž…๋ ฅ ๋…ธ์ด์ฆˆ ์ค€๋น„
noise = get_noise(
num_samples=1,
height=opts.height,
width=opts.width,
device=flux_generator.device,
dtype=torch.bfloat16 if flux_generator.device.type == 'cuda' else torch.float32,
seed=opts.seed,
)
bs, c, h, w = noise.shape
noise = rearrange(noise, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if noise.shape[0] == 1 and bs > 1:
noise = repeat(noise, "1 ... -> bs ...", bs=bs)
# ID ์ด๋ฏธ์ง€ ์ธ์ฝ”๋”ฉ
encode_t0 = time.perf_counter()
id_image = id_image.resize((opts.width, opts.height), resample=Image.LANCZOS)
x = torch.from_numpy(np.array(id_image).astype(np.float32))
x = (x / 127.5) - 1.0
x = rearrange(x, "h w c -> 1 c h w")
x = x.to(flux_generator.device)
dtype = torch.bfloat16 if flux_generator.device.type == 'cuda' else torch.float32
with torch.autocast(device_type=flux_generator.device.type, dtype=dtype):
x = flux_generator.ae.encode(x)
x = x.to(dtype)
encode_t1 = time.perf_counter()
print(f"Encoded in {encode_t1 - encode_t0:.2f} seconds.")
timesteps = get_schedule(opts.num_steps, x.shape[-1] * x.shape[-2] // 4, shift=False)
# 2) ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ์ค€๋น„
inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=opts.prompt)
inp_inversion = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt="")
inp_neg = None
if use_true_cfg:
inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=neg_prompt)
# 3) ID ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ
id_embeddings = None
uncond_id_embeddings = None
if id_image is not None:
id_image = np.array(id_image)
id_image = resize_numpy_image_long(id_image, 1024)
id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
y_0 = inp["img"].clone().detach()
# ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ๊ณผ์ •
inverted = rf_inversion(
flux_generator.model,
**inp_inversion,
timesteps=timesteps,
guidance=opts.guidance,
id=id_embeddings,
id_weight=id_weight,
start_step=start_step,
uncond_id=uncond_id_embeddings,
true_cfg=true_cfg,
timestep_to_start_cfg=timestep_to_start_cfg,
neg_txt=inp_neg["txt"] if use_true_cfg else None,
neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
neg_vec=inp_neg["vec"] if use_true_cfg else None,
aggressive_offload=False,
y_1=noise,
gamma=gamma
)
inp["img"] = inverted
inp_inversion["img"] = inverted
edited = rf_denoise(
flux_generator.model,
**inp,
timesteps=timesteps,
guidance=opts.guidance,
id=id_embeddings,
id_weight=id_weight,
start_step=start_step,
uncond_id=uncond_id_embeddings,
true_cfg=true_cfg,
timestep_to_start_cfg=timestep_to_start_cfg,
neg_txt=inp_neg["txt"] if use_true_cfg else None,
neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
neg_vec=inp_neg["vec"] if use_true_cfg else None,
aggressive_offload=False,
y_0=y_0,
eta=eta,
s=s,
tau=tau,
)
# ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ ๋””์ฝ”๋”ฉ
edited = unpack(edited.float(), opts.height, opts.width)
with torch.autocast(device_type=flux_generator.device.type, dtype=dtype):
edited = flux_generator.ae.decode(edited)
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.2f} seconds.")
# PIL ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜
edited = edited.clamp(-1, 1)
edited = rearrange(edited[0], "c h w -> h w c")
edited = Image.fromarray((127.5 * (edited + 1.0)).cpu().byte().numpy())
return edited, str(opts.seed)
except Exception as e:
import traceback
error_msg = f"์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return None, error_msg
def create_demo():
with gr.Blocks() as demo:
gr.Markdown("# PuLID: ์ธ๋ฌผ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜ ๋„๊ตฌ")
# ๋ชจ๋ธ ์ƒํƒœ ํ‘œ์‹œ
status_box = gr.Textbox(label="๋ชจ๋ธ ์ƒํƒœ", value=initialization_message)
# ์ดˆ๊ธฐํ™” ๋ฒ„ํŠผ ์ถ”๊ฐ€ (๋ฐฑ๊ทธ๋ผ์šด๋“œ ์ดˆ๊ธฐํ™” ๋Œ€์‹  ๋ช…์‹œ์  ์ดˆ๊ธฐํ™” ๋ฒ„ํŠผ ์‚ฌ์šฉ)
init_btn = gr.Button("๋ชจ๋ธ ์ดˆ๊ธฐํ™”")
init_btn.click(fn=initialize_models, inputs=[], outputs=[status_box])
refresh_btn = gr.Button("์ƒํƒœ ์ƒˆ๋กœ๊ณ ์นจ")
refresh_btn.click(fn=check_model_status, inputs=[], outputs=[status_box])
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="ํ”„๋กฌํ”„ํŠธ", value="portrait, color, cinematic")
id_image = gr.Image(label="ID ์ด๋ฏธ์ง€", type="pil")
id_weight = gr.Slider(0.0, 1.0, 0.4, step=0.05, label="ID ๊ฐ€์ค‘์น˜")
num_steps = gr.Slider(1, 24, 16, step=1, label="๋‹จ๊ณ„ ์ˆ˜")
guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="๊ฐ€์ด๋˜์Šค")
with gr.Accordion("๊ณ ๊ธ‰ ์˜ต์…˜", open=False):
neg_prompt = gr.Textbox(label="๋„ค๊ฑฐํ‹ฐ๋ธŒ ํ”„๋กฌํ”„ํŠธ", value="")
true_cfg = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="CFG ์Šค์ผ€์ผ")
seed = gr.Textbox(value="-1", label="์‹œ๋“œ (-1: ๋žœ๋ค)")
gr.Markdown("### ๊ธฐํƒ€ ์˜ต์…˜")
gamma = gr.Slider(0.0, 1.0, 0.5, step=0.1, label="๊ฐ๋งˆ")
eta = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="์—ํƒ€")
generate_btn = gr.Button("์ด๋ฏธ์ง€ ์ƒ์„ฑ")
with gr.Column():
output_image = gr.Image(label="์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€")
seed_output = gr.Textbox(label="๊ฒฐ๊ณผ/์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€")
gr.Markdown(_CITE_)
# ์˜ˆ์ œ ์ถ”๊ฐ€
with gr.Row():
gr.Markdown("## ์˜ˆ์ œ")
example_inps = [
[
'a portrait of a clown',
'example_inputs/unsplash/lhon-karwan-11tbHtK5STE-unsplash.jpg',
16, 3.5, "-1", 0.4, "", 3.5, 0.5, 0.8
],
[
'a portrait of a zombie',
'example_inputs/unsplash/baruk-granda-cfLL_jHQ-Iw-unsplash.jpg',
16, 3.5, "42", 0.4, "", 3.5, 0.5, 0.8
]
]
gr.Examples(
examples=example_inps,
inputs=[prompt, id_image, num_steps, guidance, seed,
id_weight, neg_prompt, true_cfg, gamma, eta]
)
# Gradio ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
generate_btn.click(
fn=generate_image,
inputs=[
prompt, id_image, num_steps, guidance, seed,
id_weight, neg_prompt, true_cfg, gamma, eta
],
outputs=[output_image, seed_output],
)
return demo
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev")
parser.add_argument('--version', type=str, default='v0.9.1')
parser.add_argument("--name", type=str, default="flux-dev")
parser.add_argument("--port", type=int, default=8080)
args = parser.parse_args()
print("Hugging Face Spaces ํ™˜๊ฒฝ์—์„œ ์‹คํ–‰ ์ค‘์ž…๋‹ˆ๋‹ค. GPU ํ• ๋‹น์„ ์š”์ฒญํ•ฉ๋‹ˆ๋‹ค.")
# ๋ฉ”์ธ ํ”„๋กœ์„ธ์Šค์—์„œ๋Š” CUDA ์ดˆ๊ธฐํ™”ํ•˜์ง€ ์•Š์Œ
# ๋ฐฑ๊ทธ๋ผ์šด๋“œ ์Šค๋ ˆ๋“œ ๋Œ€์‹  ๋ช…์‹œ์  ๋ฒ„ํŠผ์œผ๋กœ ์ดˆ๊ธฐํ™”
demo = create_demo()
# ์ˆ˜์ •๋œ ๋ถ€๋ถ„: create_demo.launch() -> demo.launch()
demo.launch(server_name="0.0.0.0", server_port=args.port)