import os import sys # Disable bitsandbytes triton integration to avoid conflicts os.environ["BITSANDBYTES_NOWELCOME"] = "1" os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Try to handle spaces import gracefully try: import spaces SPACES_AVAILABLE = True except Exception as e: print(f"Warning: Could not import spaces: {e}") SPACES_AVAILABLE = False # Create a dummy decorator if spaces is not available class spaces: @staticmethod def GPU(duration=None): def decorator(func): return func return decorator import time import gradio as gr import torch from PIL import Image from torchvision import transforms from dataclasses import dataclass, field import math from typing import Callable from tqdm import tqdm import random from einops import rearrange, repeat from diffusers import AutoencoderKL from torch import Tensor, nn from transformers import CLIPTextModel, CLIPTokenizer from transformers import T5EncoderModel, T5Tokenizer # Import bitsandbytes after spaces to avoid conflicts try: import bitsandbytes as bnb from bitsandbytes.nn.modules import Params4bit, QuantState BNB_AVAILABLE = True except Exception as e: print(f"Warning: Could not import bitsandbytes: {e}") BNB_AVAILABLE = False # Store original Linear class before any modifications original_linear = nn.Linear # Disable BNB for now due to compatibility issues BNB_AVAILABLE = False print("Note: BitsAndBytes quantization disabled for compatibility") # ---------------- Encoders ---------------- class HFEmbedder(nn.Module): def __init__(self, version: str, max_length: int, **hf_kwargs): super().__init__() self.is_clip = version.startswith("openai") self.max_length = max_length self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" if self.is_clip: self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) else: self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) self.hf_module = self.hf_module.eval().requires_grad_(False) def forward(self, text: list[str]) -> Tensor: batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=False, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) outputs = self.hf_module( input_ids=batch_encoding["input_ids"].to(self.hf_module.device), attention_mask=None, output_hidden_states=False, ) return outputs[self.output_key] # Initialize models without GPU decorator first t5 = None clip = None ae = None model = None model_initialized = False def initialize_models(): global t5, clip, ae, model, model_initialized if not model_initialized: print("Initializing models...") device = "cuda" if torch.cuda.is_available() else "cpu" # Load standard models print("Loading T5 encoder...") t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) t5 = t5.to(device) print("Loading CLIP encoder...") clip = HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) clip = clip.to(device) print("Loading VAE...") ae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) ae = ae.to(device) print("Loading Flux model...") # Use the standard Flux model instead of quantized version # This will use more memory but avoid compatibility issues from huggingface_hub import hf_hub_download from safetensors.torch import load_file try: # Try to load from the standard Flux checkpoint print("Loading standard Flux model (this may take a while)...") model = Flux() model = model.to(dtype=torch.bfloat16, device=device) # You would need to download the standard Flux weights # For now, let's create a randomly initialized model for testing print("Warning: Using randomly initialized Flux model for testing") print("To use a pretrained model, you need to load proper Flux weights") except Exception as e: print(f"Error initializing Flux model: {e}") raise model_initialized = True print("Models initialized successfully!") # ---------------- NF4 ---------------- if BNB_AVAILABLE: def functional_linear_4bits(x, weight, bias): import bitsandbytes as bnb out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state) out = out.to(x) return out class ForgeParams4bit(Params4bit): """Subclass to force re-quantization to GPU if needed.""" def to(self, *args, **kwargs): import torch device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) if device is not None and device.type == "cuda" and not self.bnb_quantized: return self._quantize(device) else: n = ForgeParams4bit( torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, quant_state=self.quant_state, compress_statistics=False, blocksize=64, quant_type=self.quant_type, quant_storage=self.quant_storage, bnb_quantized=self.bnb_quantized, module=self.module ) self.module.quant_state = n.quant_state self.data = n.data self.quant_state = n.quant_state return n class ForgeLoader4Bit(nn.Module): def __init__(self, *, device, dtype, quant_type, **kwargs): super().__init__() self.dummy = nn.Parameter(torch.empty(1, device=device, dtype=dtype)) self.weight = None self.quant_state = None self.bias = None self.quant_type = quant_type def _save_to_state_dict(self, destination, prefix, keep_vars): super()._save_to_state_dict(destination, prefix, keep_vars) from bitsandbytes.nn.modules import QuantState quant_state = getattr(self.weight, "quant_state", None) if quant_state is not None: for k, v in quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach() return def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): from bitsandbytes.nn.modules import Params4bit import torch quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")} if any('bitsandbytes' in k for k in quant_state_keys): quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys} self.weight = ForgeParams4bit.from_prequantized( data=state_dict[prefix + 'weight'], quantized_stats=quant_state_dict, requires_grad=False, device=torch.device('cuda'), module=self ) self.quant_state = self.weight.quant_state if prefix + 'bias' in state_dict: self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy)) del self.dummy elif hasattr(self, 'dummy'): if prefix + 'weight' in state_dict: self.weight = ForgeParams4bit( state_dict[prefix + 'weight'].to(self.dummy), requires_grad=False, compress_statistics=True, quant_type=self.quant_type, quant_storage=torch.uint8, module=self, ) self.quant_state = self.weight.quant_state if prefix + 'bias' in state_dict: self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy)) del self.dummy else: super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) class Linear(ForgeLoader4Bit): def __init__(self, *args, device=None, dtype=None, **kwargs): super().__init__(device=device, dtype=dtype, quant_type='nf4') def forward(self, x): self.weight.quant_state = self.quant_state if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) return functional_linear_4bits(x, self.weight, self.bias) # Don't override Linear globally - we'll only use it for Flux model pass else: print("Warning: BitsAndBytes not available, using standard Linear layers") # ---------------- Model ---------------- def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: q, k = apply_rope(q, k, pe) x = torch.nn.functional.scaled_dot_product_attention(q, k, v) x = x.permute(0, 2, 1, 3).reshape(x.size(0), x.size(2), -1) return x def rope(pos, dim, theta): import torch scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim omega = 1.0 / (theta ** scale) out = pos.unsqueeze(-1) * omega.unsqueeze(0) cos_out = torch.cos(out) sin_out = torch.sin(out) out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) b, n, d, _ = out.shape out = out.view(b, n, d, 2, 2) return out.float() def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) class EmbedND(nn.Module): def __init__(self, dim: int, theta: int, axes_dim: list[int]): super().__init__() self.dim = dim self.theta = theta self.axes_dim = axes_dim def forward(self, ids: Tensor) -> Tensor: import torch n_axes = ids.shape[-1] emb = torch.cat( [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3, ) return emb.unsqueeze(1) def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): import torch, math t = time_factor * t half = dim // 2 freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) if torch.is_floating_point(t): embedding = embedding.to(t) return embedding class MLPEmbedder(nn.Module): def __init__(self, in_dim: int, hidden_dim: int): super().__init__() self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) self.silu = nn.SiLU() self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) class RMSNorm(torch.nn.Module): def __init__(self, dim: int): super().__init__() self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x: Tensor): import torch x_dtype = x.dtype x = x.float() rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) return (x * rrms).to(dtype=x_dtype) * self.scale class QKNorm(torch.nn.Module): def __init__(self, dim: int): super().__init__() self.query_norm = RMSNorm(dim) self.key_norm = RMSNorm(dim) def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: q = self.query_norm(q) k = self.key_norm(k) return q.to(v), k.to(v) class SelfAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): super().__init__() self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) head_dim = dim // num_heads self.norm = QKNorm(head_dim) self.proj = nn.Linear(dim, dim) def forward(self, x: Tensor, pe: Tensor) -> Tensor: qkv = self.qkv(x) B, L, _ = qkv.shape qkv = qkv.view(B, L, 3, self.num_heads, -1) q, k, v = qkv.permute(2, 0, 3, 1, 4) q, k = self.norm(q, k, v) x = attention(q, k, v, pe=pe) x = self.proj(x) return x from dataclasses import dataclass @dataclass class ModulationOut: shift: Tensor scale: Tensor gate: Tensor class Modulation(nn.Module): def __init__(self, dim: int, double: bool): super().__init__() self.is_double = double self.multiplier = 6 if double else 3 self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) def forward(self, vec: Tensor): out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) first = ModulationOut(*out[:3]) second = ModulationOut(*out[3:]) if self.is_double else None return first, second class DoubleStreamBlock(nn.Module): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) self.num_heads = num_heads self.hidden_size = hidden_size self.img_mod = Modulation(hidden_size, double=True) self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_mlp = nn.Sequential( nn.Linear(hidden_size, mlp_hidden_dim, bias=True), nn.GELU(approximate="tanh"), nn.Linear(mlp_hidden_dim, hidden_size, bias=True), ) self.txt_mod = Modulation(hidden_size, double=True) self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.txt_mlp = nn.Sequential( nn.Linear(hidden_size, mlp_hidden_dim, bias=True), nn.GELU(approximate="tanh"), nn.Linear(mlp_hidden_dim, hidden_size, bias=True), ) def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) # Image attention img_modulated = self.img_norm1(img) img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_qkv = self.img_attn.qkv(img_modulated) B, L, _ = img_qkv.shape H = self.num_heads D = img_qkv.shape[-1] // (3 * H) img_q, img_k, img_v = img_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # Text attention txt_modulated = self.txt_norm1(txt) txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_qkv = self.txt_attn.qkv(txt_modulated) B, L, _ = txt_qkv.shape txt_q, txt_k, txt_v = txt_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) # Combined attention q = torch.cat((txt_q, img_q), dim=2) k = torch.cat((txt_k, img_k), dim=2) v = torch.cat((txt_v, img_v), dim=2) attn = attention(q, k, v, pe=pe) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # Img final img = img + img_mod1.gate * self.img_attn.proj(img_attn) img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) # Text final txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) return img, txt class SingleStreamBlock(nn.Module): def __init__( self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, qk_scale: float | None = None, ): super().__init__() self.hidden_dim = hidden_size self.num_heads = num_heads head_dim = hidden_size // num_heads self.scale = qk_scale or head_dim**-0.5 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) self.norm = QKNorm(head_dim) self.hidden_size = hidden_size self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False) def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: mod, _ = self.modulation(vec) x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.hidden_size // self.num_heads) q, k, v = qkv.permute(2, 0, 3, 1, 4) q, k = self.norm(q, k, v) attn = attention(q, k, v, pe=pe) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) return x + mod.gate * output class LastLayer(nn.Module): def __init__(self, hidden_size: int, patch_size: int, out_channels: int): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) def forward(self, x: Tensor, vec: Tensor) -> Tensor: shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] x = self.linear(x) return x from dataclasses import dataclass, field @dataclass class FluxParams: in_channels: int = 64 vec_in_dim: int = 768 context_in_dim: int = 4096 hidden_size: int = 3072 mlp_ratio: float = 4.0 num_heads: int = 24 depth: int = 19 depth_single_blocks: int = 38 axes_dim: list[int] = field(default_factory=lambda: [16, 56, 56]) theta: int = 10000 qkv_bias: bool = True guidance_embed: bool = True class Flux(nn.Module): def __init__(self, params = FluxParams()): super().__init__() self.params = params self.in_channels = params.in_channels self.out_channels = self.in_channels if params.hidden_size % params.num_heads != 0: raise ValueError( f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" ) pe_dim = params.hidden_size // params.num_heads if sum(params.axes_dim) != pe_dim: raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") self.hidden_size = params.hidden_size self.num_heads = params.num_heads self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) self.guidance_in = ( MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() ) self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) self.double_blocks = nn.ModuleList( [ DoubleStreamBlock( self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, ) for _ in range(params.depth) ] ) self.single_blocks = nn.ModuleList( [ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for _ in range(params.depth_single_blocks) ] ) self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) def forward( self, img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, timesteps: Tensor, y: Tensor, guidance: Tensor | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256)) if self.params.guidance_embed: if guidance is None: raise ValueError("No guidance strength provided for guidance-distilled model.") vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) vec = vec + self.vector_in(y) txt = self.txt_in(txt) ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) for block in self.double_blocks: img, txt = block(img=img, txt=txt, vec=vec, pe=pe) img = torch.cat((txt, img), 1) for block in self.single_blocks: img = block(img, vec=vec, pe=pe) img = img[:, txt.shape[1] :, ...] img = self.final_layer(img, vec) return img def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: import torch bs, c, h, w = img.shape if bs == 1 and not isinstance(prompt, str): bs = len(prompt) img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img.shape[0] == 1 and bs > 1: img = repeat(img, "1 ... -> bs ...", bs=bs) img_ids = torch.zeros(h // 2, w // 2, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) if isinstance(prompt, str): prompt = [prompt] txt = t5(prompt) if txt.shape[0] == 1 and bs > 1: txt = repeat(txt, "1 ... -> bs ...", bs=bs) txt_ids = torch.zeros(bs, txt.shape[1], 3) vec = clip(prompt) if vec.shape[0] == 1 and bs > 1: vec = repeat(vec, "1 ... -> bs ...", bs=bs) return { "img": img, "img_ids": img_ids.to(img.device), "txt": txt.to(img.device), "txt_ids": txt_ids.to(img.device), "vec": vec.to(img.device), } def time_shift(mu: float, sigma: float, t: Tensor): import math return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) def get_lin_function( x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 ) -> Callable[[float], float]: import math m = (y2 - y1) / (x2 - x1) b = y1 - m * x1 return lambda x: m * x + b def get_schedule( num_steps: int, image_seq_len: int, base_shift: float = 0.5, max_shift: float = 1.15, shift: bool = True, ) -> list[float]: import torch import math timesteps = torch.linspace(1, 0, num_steps + 1) if shift: mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) timesteps = time_shift(mu, 1.0, timesteps) return timesteps.tolist() def denoise( model: Flux, img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, vec: Tensor, timesteps: list[float], guidance: float = 4.0, ): import torch guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) pred = model( img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec, ) img = img + (t_prev - t_curr) * pred return img def unpack(x: Tensor, height: int, width: int) -> Tensor: return rearrange( x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=math.ceil(height / 16), w=math.ceil(width / 16), ph=2, pw=2, ) @dataclass class SamplingOptions: prompt: str width: int height: int guidance: float seed: int | None def get_image(image) -> torch.Tensor | None: if image is None: return None image = Image.fromarray(image).convert("RGB") transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: 2.0 * x - 1.0), ]) img: torch.Tensor = transform(image) return img[None, ...] @spaces.GPU(duration=120) @torch.no_grad() def generate_image( prompt, width, height, guidance, inference_steps, seed, do_img2img, init_image, image2image_strength, resize_img, progress=gr.Progress(track_tqdm=True), ): # Initialize models on first run initialize_models() if seed == 0: seed = int(random.random() * 1_000_000) device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = torch.device(device) if do_img2img and init_image is not None: init_image = get_image(init_image) if resize_img: init_image = torch.nn.functional.interpolate(init_image, (height, width)) else: h, w = init_image.shape[-2:] init_image = init_image[..., : 16 * (h // 16), : 16 * (w // 16)] height = init_image.shape[-2] width = init_image.shape[-1] init_image = ae.encode(init_image.to(torch_device).to(torch.bfloat16)).latent_dist.sample() init_image = (init_image - ae.config.shift_factor) * ae.config.scaling_factor generator = torch.Generator(device=device).manual_seed(seed) x = torch.randn( 1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), device=device, dtype=torch.bfloat16, generator=generator ) timesteps = get_schedule(inference_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True) if do_img2img and init_image is not None: t_idx = int((1 - image2image_strength) * inference_steps) t = timesteps[t_idx] timesteps = timesteps[t_idx:] x = t * x + (1.0 - t) * init_image.to(x.dtype) inp = prepare(t5=t5, clip=clip, img=x, prompt=prompt) x = denoise(model, **inp, timesteps=timesteps, guidance=guidance) x = unpack(x.float(), height, width) with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): x = (x / ae.config.scaling_factor) + ae.config.shift_factor x = ae.decode(x).sample x = x.clamp(-1, 1) x = rearrange(x[0], "c h w -> h w c") img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) return img, seed def create_demo(): with gr.Blocks(css=".gradio-container {background-color: #282828 !important;}") as demo: gr.HTML( """