Flux.1-Fill-dev / flux /modules /image_embedders.py
vilarin's picture
Upload 16 files
021dc80 verified
raw
history blame
3.58 kB
import os
import cv2
import numpy as np
import torch
from einops import rearrange, repeat
from PIL import Image
from safetensors.torch import load_file as load_sft
from torch import nn
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
from flux.util import print_load_warning
class DepthImageEncoder:
depth_model_name = "LiheYoung/depth-anything-large-hf"
def __init__(self, device):
self.device = device
self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
def __call__(self, img: torch.Tensor) -> torch.Tensor:
hw = img.shape[-2:]
img = torch.clamp(img, -1.0, 1.0)
img_byte = ((img + 1.0) * 127.5).byte()
img = self.processor(img_byte, return_tensors="pt")["pixel_values"]
depth = self.depth_model(img.to(self.device)).predicted_depth
depth = repeat(depth, "b h w -> b 3 h w")
depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True)
depth = depth / 127.5 - 1.0
return depth
class CannyImageEncoder:
def __init__(
self,
device,
min_t: int = 50,
max_t: int = 200,
):
self.device = device
self.min_t = min_t
self.max_t = max_t
def __call__(self, img: torch.Tensor) -> torch.Tensor:
assert img.shape[0] == 1, "Only batch size 1 is supported"
img = rearrange(img[0], "c h w -> h w c")
img = torch.clamp(img, -1.0, 1.0)
img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8)
# Apply Canny edge detection
canny = cv2.Canny(img_np, self.min_t, self.max_t)
# Convert back to torch tensor and reshape
canny = torch.from_numpy(canny).float() / 127.5 - 1.0
canny = rearrange(canny, "h w -> 1 1 h w")
canny = repeat(canny, "b 1 ... -> b 3 ...")
return canny.to(self.device)
class ReduxImageEncoder(nn.Module):
siglip_model_name = "google/siglip-so400m-patch14-384"
def __init__(
self,
device,
redux_dim: int = 1152,
txt_in_features: int = 4096,
redux_path: str | None = os.getenv("FLUX_REDUX"),
dtype=torch.bfloat16,
) -> None:
assert redux_path is not None, "Redux path must be provided"
super().__init__()
self.redux_dim = redux_dim
self.device = device if isinstance(device, torch.device) else torch.device(device)
self.dtype = dtype
with self.device:
self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
sd = load_sft(redux_path, device=str(device))
missing, unexpected = self.load_state_dict(sd, strict=False, assign=True)
print_load_warning(missing, unexpected)
self.siglip = SiglipVisionModel.from_pretrained(self.siglip_model_name).to(dtype=dtype)
self.normalize = SiglipImageProcessor.from_pretrained(self.siglip_model_name)
def __call__(self, x: Image.Image) -> torch.Tensor:
imgs = self.normalize.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True)
_encoded_x = self.siglip(**imgs.to(device=self.device, dtype=self.dtype)).last_hidden_state
projected_x = self.redux_down(nn.functional.silu(self.redux_up(_encoded_x)))
return projected_x