FLUXllama / app.py
ginipick's picture
Update app.py
26b3498 verified
raw
history blame
32.2 kB
import os
import sys
# Disable bitsandbytes triton integration to avoid conflicts
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
# Try to handle spaces import gracefully
try:
import spaces
SPACES_AVAILABLE = True
except Exception as e:
print(f"Warning: Could not import spaces: {e}")
SPACES_AVAILABLE = False
# Create a dummy decorator if spaces is not available
class spaces:
@staticmethod
def GPU(duration=None):
def decorator(func):
return func
return decorator
import time
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
from dataclasses import dataclass, field
import math
from typing import Callable
from tqdm import tqdm
import random
from einops import rearrange, repeat
from diffusers import AutoencoderKL
from torch import Tensor, nn
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import T5EncoderModel, T5Tokenizer
# Import bitsandbytes after spaces to avoid conflicts
try:
import bitsandbytes as bnb
from bitsandbytes.nn.modules import Params4bit, QuantState
BNB_AVAILABLE = True
except Exception as e:
print(f"Warning: Could not import bitsandbytes: {e}")
BNB_AVAILABLE = False
# Store original Linear class before any modifications
original_linear = nn.Linear
# Disable BNB for now due to compatibility issues
BNB_AVAILABLE = False
print("Note: BitsAndBytes quantization disabled for compatibility")
# ---------------- Encoders ----------------
class HFEmbedder(nn.Module):
def __init__(self, version: str, max_length: int, **hf_kwargs):
super().__init__()
self.is_clip = version.startswith("openai")
self.max_length = max_length
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
if self.is_clip:
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
else:
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str]) -> Tensor:
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
return outputs[self.output_key]
# Initialize models without GPU decorator first
t5 = None
clip = None
ae = None
model = None
model_initialized = False
def initialize_models():
global t5, clip, ae, model, model_initialized
if not model_initialized:
print("Initializing models...")
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load standard models
print("Loading T5 encoder...")
t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
t5 = t5.to(device)
print("Loading CLIP encoder...")
clip = HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
clip = clip.to(device)
print("Loading VAE...")
ae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
ae = ae.to(device)
print("Loading Flux model...")
# Use the standard Flux model instead of quantized version
# This will use more memory but avoid compatibility issues
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
try:
# Try to load from the standard Flux checkpoint
print("Loading standard Flux model (this may take a while)...")
model = Flux()
model = model.to(dtype=torch.bfloat16, device=device)
# You would need to download the standard Flux weights
# For now, let's create a randomly initialized model for testing
print("Warning: Using randomly initialized Flux model for testing")
print("To use a pretrained model, you need to load proper Flux weights")
except Exception as e:
print(f"Error initializing Flux model: {e}")
raise
model_initialized = True
print("Models initialized successfully!")
# ---------------- NF4 ----------------
if BNB_AVAILABLE:
def functional_linear_4bits(x, weight, bias):
import bitsandbytes as bnb
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
out = out.to(x)
return out
class ForgeParams4bit(Params4bit):
"""Subclass to force re-quantization to GPU if needed."""
def to(self, *args, **kwargs):
import torch
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == "cuda" and not self.bnb_quantized:
return self._quantize(device)
else:
n = ForgeParams4bit(
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad,
quant_state=self.quant_state,
compress_statistics=False,
blocksize=64,
quant_type=self.quant_type,
quant_storage=self.quant_storage,
bnb_quantized=self.bnb_quantized,
module=self.module
)
self.module.quant_state = n.quant_state
self.data = n.data
self.quant_state = n.quant_state
return n
class ForgeLoader4Bit(nn.Module):
def __init__(self, *, device, dtype, quant_type, **kwargs):
super().__init__()
self.dummy = nn.Parameter(torch.empty(1, device=device, dtype=dtype))
self.weight = None
self.quant_state = None
self.bias = None
self.quant_type = quant_type
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
from bitsandbytes.nn.modules import QuantState
quant_state = getattr(self.weight, "quant_state", None)
if quant_state is not None:
for k, v in quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
return
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
from bitsandbytes.nn.modules import Params4bit
import torch
quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
if any('bitsandbytes' in k for k in quant_state_keys):
quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
self.weight = ForgeParams4bit.from_prequantized(
data=state_dict[prefix + 'weight'],
quantized_stats=quant_state_dict,
requires_grad=False,
device=torch.device('cuda'),
module=self
)
self.quant_state = self.weight.quant_state
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
elif hasattr(self, 'dummy'):
if prefix + 'weight' in state_dict:
self.weight = ForgeParams4bit(
state_dict[prefix + 'weight'].to(self.dummy),
requires_grad=False,
compress_statistics=True,
quant_type=self.quant_type,
quant_storage=torch.uint8,
module=self,
)
self.quant_state = self.weight.quant_state
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
else:
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
class Linear(ForgeLoader4Bit):
def __init__(self, *args, device=None, dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, quant_type='nf4')
def forward(self, x):
self.weight.quant_state = self.quant_state
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
return functional_linear_4bits(x, self.weight, self.bias)
# Don't override Linear globally - we'll only use it for Flux model
pass
else:
print("Warning: BitsAndBytes not available, using standard Linear layers")
# ---------------- Model ----------------
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = x.permute(0, 2, 1, 3).reshape(x.size(0), x.size(2), -1)
return x
def rope(pos, dim, theta):
import torch
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta ** scale)
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
b, n, d, _ = out.shape
out = out.view(b, n, d, 2, 2)
return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: Tensor) -> Tensor:
import torch
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
import torch, math
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor):
import torch
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * self.scale
class QKNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
head_dim = dim // num_heads
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
B, L, _ = qkv.shape
qkv = qkv.view(B, L, 3, self.num_heads, -1)
q, k, v = qkv.permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
from dataclasses import dataclass
@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def forward(self, vec: Tensor):
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
first = ModulationOut(*out[:3])
second = ModulationOut(*out[3:]) if self.is_double else None
return first, second
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# Image attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
B, L, _ = img_qkv.shape
H = self.num_heads
D = img_qkv.shape[-1] // (3 * H)
img_q, img_k, img_v = img_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# Text attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
B, L, _ = txt_qkv.shape
txt_q, txt_k, txt_v = txt_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# Combined attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# Img final
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# Text final
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
return img, txt
class SingleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float | None = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.hidden_size // self.num_heads)
q, k, v = qkv.permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
attn = attention(q, k, v, pe=pe)
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x
from dataclasses import dataclass, field
@dataclass
class FluxParams:
in_channels: int = 64
vec_in_dim: int = 768
context_in_dim: int = 4096
hidden_size: int = 3072
mlp_ratio: float = 4.0
num_heads: int = 24
depth: int = 19
depth_single_blocks: int = 38
axes_dim: list[int] = field(default_factory=lambda: [16, 56, 56])
theta: int = 10000
qkv_bias: bool = True
guidance_embed: bool = True
class Flux(nn.Module):
def __init__(self, params = FluxParams()):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("No guidance strength provided for guidance-distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec)
return img
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
import torch
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def time_shift(mu: float, sigma: float, t: Tensor):
import math
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
import math
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
import torch
import math
timesteps = torch.linspace(1, 0, num_steps + 1)
if shift:
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Flux,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
timesteps: list[float],
guidance: float = 4.0,
):
import torch
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)
@dataclass
class SamplingOptions:
prompt: str
width: int
height: int
guidance: float
seed: int | None
def get_image(image) -> torch.Tensor | None:
if image is None:
return None
image = Image.fromarray(image).convert("RGB")
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: 2.0 * x - 1.0),
])
img: torch.Tensor = transform(image)
return img[None, ...]
@spaces.GPU(duration=120)
@torch.no_grad()
def generate_image(
prompt, width, height, guidance, inference_steps, seed,
do_img2img, init_image, image2image_strength, resize_img,
progress=gr.Progress(track_tqdm=True),
):
# Initialize models on first run
initialize_models()
if seed == 0:
seed = int(random.random() * 1_000_000)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_device = torch.device(device)
if do_img2img and init_image is not None:
init_image = get_image(init_image)
if resize_img:
init_image = torch.nn.functional.interpolate(init_image, (height, width))
else:
h, w = init_image.shape[-2:]
init_image = init_image[..., : 16 * (h // 16), : 16 * (w // 16)]
height = init_image.shape[-2]
width = init_image.shape[-1]
init_image = ae.encode(init_image.to(torch_device).to(torch.bfloat16)).latent_dist.sample()
init_image = (init_image - ae.config.shift_factor) * ae.config.scaling_factor
generator = torch.Generator(device=device).manual_seed(seed)
x = torch.randn(
1,
16,
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=device,
dtype=torch.bfloat16,
generator=generator
)
timesteps = get_schedule(inference_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
if do_img2img and init_image is not None:
t_idx = int((1 - image2image_strength) * inference_steps)
t = timesteps[t_idx]
timesteps = timesteps[t_idx:]
x = t * x + (1.0 - t) * init_image.to(x.dtype)
inp = prepare(t5=t5, clip=clip, img=x, prompt=prompt)
x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
x = unpack(x.float(), height, width)
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
x = (x / ae.config.scaling_factor) + ae.config.shift_factor
x = ae.decode(x).sample
x = x.clamp(-1, 1)
x = rearrange(x[0], "c h w -> h w c")
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
return img, seed
def create_demo():
with gr.Blocks(css=".gradio-container {background-color: #282828 !important;}") as demo:
gr.HTML(
"""
<div style="text-align: center; margin: 0 auto;">
<h1 style="color: #ffffff; font-weight: 900;">
FluxLLama
</h1>
</div>
"""
)
gr.HTML(
"""
<div class='container' style='display:flex; justify-content:center; gap:12px;'>
<a href="https://huggingface.co/spaces/openfree/Best-AI" target="_blank">
<img src="https://img.shields.io/static/v1?label=OpenFree&message=BEST%20AI%20Services&color=%230000ff&labelColor=%23000080&logo=huggingface&logoColor=%23ffa500&style=for-the-badge" alt="OpenFree badge">
</a>
<a href="https://discord.gg/openfreeai" target="_blank">
<img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="Discord badge">
</a>
</div>
"""
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="A majestic castle on top of a floating island")
width = gr.Slider(minimum=128, maximum=2048, step=64, label="Width", value=640)
height = gr.Slider(minimum=128, maximum=2048, step=64, label="Height", value=640)
guidance = gr.Slider(minimum=1.0, maximum=5.0, step=0.1, label="Guidance", value=3.5)
inference_steps = gr.Slider(
label="Inference steps",
minimum=1,
maximum=30,
step=1,
value=16,
)
seed = gr.Number(label="Seed", precision=-1)
do_img2img = gr.Checkbox(label="Image to Image", value=False)
init_image = gr.Image(label="Initial Image", visible=False)
image2image_strength = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.01,
label="Noising Strength",
value=0.8,
visible=False
)
resize_img = gr.Checkbox(label="Resize Initial Image", value=True, visible=False)
generate_button = gr.Button("Generate", variant="primary")
with gr.Column():
output_image = gr.Image(label="Result")
output_seed = gr.Text(label="Seed Used")
do_img2img.change(
fn=lambda x: [gr.update(visible=x), gr.update(visible=x), gr.update(visible=x)],
inputs=[do_img2img],
outputs=[init_image, image2image_strength, resize_img]
)
generate_button.click(
fn=generate_image,
inputs=[
prompt, width, height, guidance,
inference_steps, seed, do_img2img,
init_image, image2image_strength, resize_img
],
outputs=[output_image, output_seed]
)
return demo
if __name__ == "__main__":
# Create the demo
demo = create_demo()
# Enable the queue to handle concurrency
demo.queue()
# Launch with appropriate settings
demo.launch(show_api=False, share=True)