Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,579 Bytes
021dc80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import os
import cv2
import numpy as np
import torch
from einops import rearrange, repeat
from PIL import Image
from safetensors.torch import load_file as load_sft
from torch import nn
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
from flux.util import print_load_warning
class DepthImageEncoder:
depth_model_name = "LiheYoung/depth-anything-large-hf"
def __init__(self, device):
self.device = device
self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
def __call__(self, img: torch.Tensor) -> torch.Tensor:
hw = img.shape[-2:]
img = torch.clamp(img, -1.0, 1.0)
img_byte = ((img + 1.0) * 127.5).byte()
img = self.processor(img_byte, return_tensors="pt")["pixel_values"]
depth = self.depth_model(img.to(self.device)).predicted_depth
depth = repeat(depth, "b h w -> b 3 h w")
depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True)
depth = depth / 127.5 - 1.0
return depth
class CannyImageEncoder:
def __init__(
self,
device,
min_t: int = 50,
max_t: int = 200,
):
self.device = device
self.min_t = min_t
self.max_t = max_t
def __call__(self, img: torch.Tensor) -> torch.Tensor:
assert img.shape[0] == 1, "Only batch size 1 is supported"
img = rearrange(img[0], "c h w -> h w c")
img = torch.clamp(img, -1.0, 1.0)
img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8)
# Apply Canny edge detection
canny = cv2.Canny(img_np, self.min_t, self.max_t)
# Convert back to torch tensor and reshape
canny = torch.from_numpy(canny).float() / 127.5 - 1.0
canny = rearrange(canny, "h w -> 1 1 h w")
canny = repeat(canny, "b 1 ... -> b 3 ...")
return canny.to(self.device)
class ReduxImageEncoder(nn.Module):
siglip_model_name = "google/siglip-so400m-patch14-384"
def __init__(
self,
device,
redux_dim: int = 1152,
txt_in_features: int = 4096,
redux_path: str | None = os.getenv("FLUX_REDUX"),
dtype=torch.bfloat16,
) -> None:
assert redux_path is not None, "Redux path must be provided"
super().__init__()
self.redux_dim = redux_dim
self.device = device if isinstance(device, torch.device) else torch.device(device)
self.dtype = dtype
with self.device:
self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
sd = load_sft(redux_path, device=str(device))
missing, unexpected = self.load_state_dict(sd, strict=False, assign=True)
print_load_warning(missing, unexpected)
self.siglip = SiglipVisionModel.from_pretrained(self.siglip_model_name).to(dtype=dtype)
self.normalize = SiglipImageProcessor.from_pretrained(self.siglip_model_name)
def __call__(self, x: Image.Image) -> torch.Tensor:
imgs = self.normalize.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True)
_encoded_x = self.siglip(**imgs.to(device=self.device, dtype=self.dtype)).last_hidden_state
projected_x = self.redux_down(nn.functional.silu(self.redux_up(_encoded_x)))
return projected_x
|