Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import spaces | |
import time | |
import gradio as gr | |
import torch | |
from torch import Tensor, nn | |
from PIL import Image | |
from torchvision import transforms | |
from dataclasses import dataclass | |
import math | |
from typing import Callable | |
import random | |
from tqdm import tqdm | |
import bitsandbytes as bnb | |
from bitsandbytes.nn.modules import Params4bit, QuantState | |
from transformers import ( | |
MarianTokenizer, | |
MarianMTModel, | |
CLIPTextModel, CLIPTokenizer, | |
T5EncoderModel, T5Tokenizer | |
) | |
from diffusers import AutoencoderKL | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
from einops import rearrange, repeat | |
# 1) ์ฅ์น ์ค์ | |
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# 2) ๋ฒ์ญ ๋ชจ๋ธ์ CPU์์, ๋ฐ๋์ PyTorch ์ฒดํฌํฌ์ธํธ๋ก ๋ก๋ | |
trans_tokenizer = MarianTokenizer.from_pretrained( | |
"Helsinki-NLP/opus-mt-ko-en" | |
) | |
trans_model = MarianMTModel.from_pretrained( | |
"Helsinki-NLP/opus-mt-ko-en", | |
from_tf=True, # TF ์ฒดํฌํฌ์ธํธ๋ผ๋ PyTorch ๋ก๋ | |
torch_dtype=torch.float32, | |
).to(torch.device("cpu")) | |
def translate_ko_to_en(text: str, max_length: int = 512) -> str: | |
"""ํ๊ธ โ ์์ด ๋ฒ์ญ (CPU)""" | |
batch = trans_tokenizer([text], return_tensors="pt", padding=True) | |
# ๋ชจ๋ธ์ CPU์ ์์ผ๋ฏ๋ก .to("cpu") ํด์ค ํ์ ์์ | |
gen = trans_model.generate( | |
**batch, max_length=max_length | |
) | |
return trans_tokenizer.batch_decode(gen, skip_special_tokens=True)[0] | |
# ---------------- Encoders ---------------- | |
class HFEmbedder(nn.Module): | |
def __init__(self, version: str, max_length: int, **hf_kwargs): | |
super().__init__() | |
self.is_clip = version.startswith("openai") | |
self.max_length = max_length | |
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" | |
if self.is_clip: | |
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( | |
version, max_length=max_length | |
) | |
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( | |
version, **hf_kwargs | |
) | |
else: | |
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( | |
version, max_length=max_length | |
) | |
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained( | |
version, **hf_kwargs | |
) | |
self.hf_module = self.hf_module.eval().requires_grad_(False) | |
def forward(self, text: list[str]) -> Tensor: | |
batch_encoding = self.tokenizer( | |
text, | |
truncation=True, | |
max_length=self.max_length, | |
padding="max_length", | |
return_tensors="pt", | |
) | |
outputs = self.hf_module( | |
input_ids=batch_encoding["input_ids"].to(self.hf_module.device), | |
attention_mask=None, | |
output_hidden_states=False, | |
) | |
return outputs[self.output_key] | |
# T5, CLIP, VAE ๋ชจ๋ GPU/CPU(device)๋ก ์ด๋ | |
t5 = HFEmbedder( | |
"DeepFloyd/t5-v1_1-xxl", | |
max_length=512, | |
torch_dtype=torch.bfloat16 | |
).to(torch_device) | |
clip = HFEmbedder( | |
"openai/clip-vit-large-patch14", | |
max_length=77, | |
torch_dtype=torch.bfloat16 | |
).to(torch_device) | |
ae = AutoencoderKL.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
subfolder="vae", | |
torch_dtype=torch.bfloat16 | |
).to(torch_device) | |
# ---------------- NF4 ์ง์ ์ฝ๋ ---------------- | |
def functional_linear_4bits(x, weight, bias): | |
out = bnb.matmul_4bit( | |
x, weight.t(), bias=bias, quant_state=weight.quant_state | |
) | |
return out.to(x) | |
def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState: | |
if state is None: | |
return None | |
device = device or state.absmax.device | |
state2 = ( | |
QuantState( | |
absmax=state.state2.absmax.to(device), | |
shape=state.state2.shape, | |
code=state.state2.code.to(device), | |
blocksize=state.state2.blocksize, | |
quant_type=state.state2.quant_type, | |
dtype=state.state2.dtype, | |
) | |
if state.nested | |
else None | |
) | |
return QuantState( | |
absmax=state.absmax.to(device), | |
shape=state.shape, | |
code=state.code.to(device), | |
blocksize=state.blocksize, | |
quant_type=state.quant_type, | |
dtype=state.dtype, | |
offset=state.offset.to(device) if state.nested else None, | |
state2=state2, | |
) | |
class ForgeParams4bit(Params4bit): | |
def to(self, *args, **kwargs): | |
device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs) | |
if device is not None and device.type == "cuda" and not self.bnb_quantized: | |
return self._quantize(device) | |
new = ForgeParams4bit( | |
torch.nn.Parameter.to( | |
self, device=device, dtype=dtype, non_blocking=non_blocking | |
), | |
requires_grad=self.requires_grad, | |
quant_state=copy_quant_state(self.quant_state, device), | |
compress_statistics=False, | |
blocksize=self.blocksize, | |
quant_type=self.quant_type, | |
quant_storage=self.quant_storage, | |
bnb_quantized=self.bnb_quantized, | |
module=self.module, | |
) | |
self.module.quant_state = new.quant_state | |
self.data = new.data | |
self.quant_state = new.quant_state | |
return new | |
class ForgeLoader4Bit(torch.nn.Module): | |
def __init__(self, *, device, dtype, quant_type, **kwargs): | |
super().__init__() | |
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype)) | |
self.weight = None | |
self.quant_state = None | |
self.bias = None | |
self.quant_type = quant_type | |
def _load_from_state_dict( | |
self, | |
state_dict, | |
prefix, | |
local_metadata, | |
strict, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
): | |
qs_keys = { | |
k[len(prefix + "weight.") :] | |
for k in state_dict | |
if k.startswith(prefix + "weight.") | |
} | |
if any("bitsandbytes" in k for k in qs_keys): | |
qs = { | |
k: state_dict[prefix + "weight." + k] for k in qs_keys | |
} | |
self.weight = ForgeParams4bit.from_prequantized( | |
data=state_dict[prefix + "weight"], | |
quantized_stats=qs, | |
requires_grad=False, | |
device=torch.device("cuda"), | |
module=self, | |
) | |
self.quant_state = self.weight.quant_state | |
if prefix + "bias" in state_dict: | |
self.bias = torch.nn.Parameter( | |
state_dict[prefix + "bias"].to(self.dummy) | |
) | |
del self.dummy | |
else: | |
super()._load_from_state_dict( | |
state_dict, | |
prefix, | |
local_metadata, | |
strict, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
) | |
class Linear(ForgeLoader4Bit): | |
def __init__(self, *args, device=None, dtype=None, **kwargs): | |
super().__init__(device=device, dtype=dtype, quant_type="nf4") | |
def forward(self, x): | |
self.weight.quant_state = self.quant_state | |
if self.bias is not None and self.bias.dtype != x.dtype: | |
self.bias.data = self.bias.data.to(x.dtype) | |
return functional_linear_4bits(x, self.weight, self.bias) | |
nn.Linear = Linear | |
# ---------------- Flux ๋ชจ๋ธ ์ ์ (์๋ณธ ๊ทธ๋๋ก) ---------------- | |
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: | |
# ... (์๋ต ์์ด ์๋ณธ ์ฝ๋ ๊ทธ๋๋ก) | |
q, k = apply_rope(q, k, pe) | |
x = torch.nn.functional.scaled_dot_product_attention(q, k, v) | |
x = x.permute(0, 2, 1, 3).reshape(x.size(0), x.size(2), -1) | |
return x | |
# apply_rope, rope, EmbedND, timestep_embedding, MLPEmbedder, RMSNorm, QKNorm, | |
# SelfAttention, Modulation, DoubleStreamBlock, SingleStreamBlock, | |
# LastLayer, FluxParams, Flux ํด๋์ค๊น์ง ์ ๋ถ ์๋ณธ๊ณผ ๋์ผํ๊ฒ ํฌํจํ์ธ์. | |
# ---------------- ๋ชจ๋ธ ๋ก๋ ---------------- | |
sd = load_file( | |
hf_hub_download( | |
repo_id="lllyasviel/flux1-dev-bnb-nf4", | |
filename="flux1-dev-bnb-nf4-v2.safetensors", | |
) | |
) | |
sd = { | |
k.replace("model.diffusion_model.", ""): v | |
for k, v in sd.items() | |
if "model.diffusion_model" in k | |
} | |
model = Flux().to(torch_device, dtype=torch.bfloat16) | |
model.load_state_dict(sd) | |
model_zero_init = False | |
# ---------------- ์ ํธ๋ฆฌํฐ ํจ์ ---------------- | |
def get_image(image) -> torch.Tensor | None: | |
if image is None: | |
return None | |
image = Image.fromarray(image).convert("RGB") | |
tfm = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Lambda(lambda x: 2.0 * x - 1.0), | |
] | |
) | |
return tfm(image)[None, ...] | |
def prepare(t5, clip, img, prompt): | |
bs, c, h, w = img.shape | |
img = rearrange( | |
img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2 | |
) | |
if bs == 1 and isinstance(prompt, list): | |
img = repeat(img, "1 ... -> bs ...", bs=len(prompt)) | |
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device) | |
img_ids[..., 1] = torch.arange(h // 2, device=img.device)[:, None] | |
img_ids[..., 2] = torch.arange(w // 2, device=img.device)[None, :] | |
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=img.shape[0]) | |
txt = t5([prompt] if isinstance(prompt, str) else prompt) | |
if txt.shape[0] == 1 and img.shape[0] > 1: | |
txt = repeat(txt, "1 ... -> bs ...", bs=img.shape[0]) | |
txt_ids = torch.zeros(txt.size(0), txt.size(1), 3, device=img.device) | |
vec = clip([prompt] if isinstance(prompt, str) else prompt) | |
if vec.shape[0] == 1 and img.shape[0] > 1: | |
vec = repeat(vec, "1 ... -> bs ...", bs=img.shape[0]) | |
return { | |
"img": img, | |
"img_ids": img_ids, | |
"txt": txt, | |
"txt_ids": txt_ids, | |
"vec": vec, | |
} | |
def get_schedule(num_steps, image_seq_len, base_shift=0.5, max_shift=1.15, shift=True): | |
timesteps = torch.linspace(1, 0, num_steps + 1) | |
if shift: | |
mu = ((max_shift - base_shift) / (4096 - 256)) * image_seq_len + ( | |
base_shift - (256 * (max_shift - base_shift) / (4096 - 256)) | |
) | |
timesteps = timesteps.exp().div((1 / timesteps - 1) ** 1 + mu) | |
return timesteps.tolist() | |
def denoise(model, img, img_ids, txt, txt_ids, vec, timesteps, guidance): | |
guidance_vec = torch.full( | |
(img.size(0),), guidance, device=img.device, dtype=img.dtype | |
) | |
for t_curr, t_prev in tqdm( | |
zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1 | |
): | |
t_vec = torch.full( | |
(img.size(0),), t_curr, device=img.device, dtype=img.dtype | |
) | |
pred = model( | |
img=img, | |
img_ids=img_ids, | |
txt=txt, | |
txt_ids=txt_ids, | |
y=vec, | |
timesteps=t_vec, | |
guidance=guidance_vec, | |
) | |
img = img + (t_prev - t_curr) * pred | |
return img | |
# ---------------- Gradio ๋ฐ๋ชจ ---------------- | |
def generate_image( | |
prompt, | |
width, | |
height, | |
guidance, | |
inference_steps, | |
seed, | |
do_img2img, | |
init_image, | |
image2image_strength, | |
resize_img, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
# ํ๊ธ ๊ฐ์ง ์ CPU ๋ฒ์ญ๊ธฐ ์ฌ์ฉ | |
if any("\u3131" <= c <= "\u318E" or "\uAC00" <= c <= "\uD7A3" for c in prompt): | |
prompt = translate_ko_to_en(prompt) | |
if seed == 0: | |
seed = random.randint(1, 1_000_000) | |
global model_zero_init, model | |
if not model_zero_init: | |
model = model.to(torch_device) | |
model_zero_init = True | |
if do_img2img and init_image is not None: | |
init_img = get_image(init_image) | |
if resize_img: | |
init_img = torch.nn.functional.interpolate( | |
init_img, (height, width) | |
) | |
else: | |
h0, w0 = init_img.shape[-2:] | |
init_img = init_img[..., : 16 * (h0 // 16), : 16 * (w0 // 16)] | |
height, width = init_img.shape[-2:] | |
init_img = ae.encode( | |
init_img.to(torch_device).to(torch.bfloat16) | |
).latent_dist.sample() | |
init_img = ( | |
init_img - ae.config.shift_factor | |
) * ae.config.scaling_factor | |
else: | |
init_img = None | |
generator = torch.Generator(device=str(torch_device)).manual_seed(seed) | |
x = torch.randn( | |
1, | |
16, | |
2 * math.ceil(height / 16), | |
2 * math.ceil(width / 16), | |
device=torch_device, | |
dtype=torch.bfloat16, | |
generator=generator, | |
) | |
timesteps = get_schedule( | |
inference_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True | |
) | |
if do_img2img and init_img is not None: | |
t_idx = int((1 - image2image_strength) * inference_steps) | |
t = timesteps[t_idx] | |
timesteps = timesteps[t_idx:] | |
x = t * x + (1 - t) * init_img.to(x.dtype) | |
inp = prepare(t5, clip, x, prompt) | |
x = denoise(model, **inp, timesteps=timesteps, guidance=guidance) | |
x = rearrange( | |
x[:, inp["txt"].shape[1] :, ...].float(), | |
"b (h w) (c ph pw) -> b c (h ph) (w pw)", | |
h=math.ceil(height / 16), | |
w=math.ceil(width / 16), | |
ph=2, | |
pw=2, | |
) | |
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): | |
x = (x / ae.config.scaling_factor) + ae.config.shift_factor | |
x = ae.decode(x).sample | |
x = x.clamp(-1, 1) | |
img = Image.fromarray( | |
(127.5 * (rearrange(x[0], "c h w -> h w c") + 1.0)) | |
.cpu() | |
.byte() | |
.numpy() | |
) | |
return img, seed | |
css = """ | |
footer { | |
visibility: hidden; | |
} | |
""" | |
def create_demo(): | |
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: | |
gr.Markdown( | |
"# News! Multilingual version " | |
"[https://huggingface.co/spaces/ginigen/FLUXllama-Multilingual]" | |
"(https://huggingface.co/spaces/ginigen/FLUXllama-Multilingual)" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox( | |
label="Prompt(ํ๊ธ ๊ฐ๋ฅ)", | |
value="A cute and fluffy golden retriever puppy sitting upright...", | |
) | |
width = gr.Slider(128, 2048, 64, label="Width", value=768) | |
height = gr.Slider(128, 2048, 64, label="Height", value=768) | |
guidance = gr.Slider(1.0, 5.0, 0.1, label="Guidance", value=3.5) | |
steps = gr.Slider(1, 30, 1, label="Inference steps", value=30) | |
seed = gr.Number(label="Seed", precision=0) | |
do_i2i = gr.Checkbox(label="Image to Image", value=False) | |
init_img = gr.Image(label="Input Image", visible=False) | |
strength = gr.Slider( | |
0.0, 1.0, 0.01, label="Noising strength", value=0.8, visible=False | |
) | |
resize = gr.Checkbox(label="Resize image", value=True, visible=False) | |
btn = gr.Button("Generate") | |
with gr.Column(): | |
out_img = gr.Image(label="Generated Image") | |
out_seed = gr.Text(label="Used Seed") | |
do_i2i.change( | |
fn=lambda x: [gr.update(visible=x)] * 3, | |
inputs=[do_i2i], | |
outputs=[init_img, strength, resize], | |
) | |
btn.click( | |
fn=generate_image, | |
inputs=[ | |
prompt, | |
width, | |
height, | |
guidance, | |
steps, | |
seed, | |
do_i2i, | |
init_img, | |
strength, | |
resize, | |
], | |
outputs=[out_img, out_seed], | |
) | |
return demo | |
if __name__ == "__main__": | |
create_demo().launch() | |