XFluxSpace / predict.py
stazizov's picture
init
e5e58b2
# Prediction interface for Cog ⚙️
# https://cog.run/python
from cog import BasePredictor, Input, Path
import os
import time
import torch
import subprocess
from PIL import Image
from typing import List
from image_datasets.canny_dataset import canny_processor, c_crop
from src.flux.util import load_ae, load_clip, load_t5, load_flow_model, load_controlnet, load_safetensors
OUTPUT_DIR = "controlnet_results"
MODEL_CACHE = "checkpoints"
CONTROLNET_URL = "https://huggingface.co/XLabs-AI/flux-controlnet-canny/resolve/main/controlnet.safetensors"
T5_URL = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-dev/t5-cache.tar"
CLIP_URL = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-dev/clip-cache.tar"
HF_TOKEN = "hf_..." # Your HuggingFace token
def download_weights(url, dest):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
subprocess.check_call(["pget", "-xf", url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
t5 = load_t5(device, max_length=256 if is_schnell else 512)
clip = load_clip(device)
model = load_flow_model(name, device="cpu" if offload else device)
ae = load_ae(name, device="cpu" if offload else device)
controlnet = load_controlnet(name, device).to(torch.bfloat16)
return model, ae, t5, clip, controlnet
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
t1 = time.time()
os.system(f"huggingface-cli login --token {HF_TOKEN}")
name = "flux-dev"
self.offload = False
checkpoint = "controlnet.safetensors"
print("Checking ControlNet weights")
checkpoint = "controlnet.safetensors"
if not os.path.exists(checkpoint):
os.system(f"wget {CONTROLNET_URL}")
print("Checking T5 weights")
if not os.path.exists(MODEL_CACHE+"/models--google--t5-v1_1-xxl"):
download_weights(T5_URL, MODEL_CACHE)
print("Checking CLIP weights")
if not os.path.exists(MODEL_CACHE+"/models--openai--clip-vit-large-patch14"):
download_weights(CLIP_URL, MODEL_CACHE)
self.is_schnell = False
device = "cuda"
self.torch_device = torch.device(device)
model, ae, t5, clip, controlnet = get_models(
name,
device=self.torch_device,
offload=self.offload,
is_schnell=self.is_schnell,
)
self.ae = ae
self.t5 = t5
self.clip = clip
self.controlnet = controlnet
self.model = model.to(self.torch_device)
if '.safetensors' in checkpoint:
checkpoint1 = load_safetensors(checkpoint)
else:
checkpoint1 = torch.load(checkpoint, map_location='cpu')
controlnet.load_state_dict(checkpoint1, strict=False)
t2 = time.time()
print(f"Setup time: {t2 - t1}")
def preprocess_canny_image(self, image_path: str, width: int = 512, height: int = 512):
image = Image.open(image_path)
image = c_crop(image)
image = image.resize((width, height))
image = canny_processor(image)
return image
def predict(
self,
prompt: str = Input(description="Input prompt", default="a handsome viking man with white hair, cinematic, MM full HD"),
image: Path = Input(description="Input image", default=None),
num_inference_steps: int = Input(description="Number of inference steps", ge=1, le=64, default=28),
cfg: float = Input(description="CFG", ge=0, le=10, default=3.5),
seed: int = Input(description="Random seed", default=None)
) -> List[Path]:
"""Run a single prediction on the model"""
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
# clean output dir
output_dir = "controlnet_results"
os.system(f"rm -rf {output_dir}")
input_image = str(image)
img = Image.open(input_image)
width, height = img.size
# Resize input image if it's too large
max_image_size = 1536
scale = min(max_image_size / width, max_image_size / height, 1)
if scale < 1:
width = int(width * scale)
height = int(height * scale)
print(f"Scaling image down to {width}x{height}")
img = img.resize((width, height), resample=Image.Resampling.LANCZOS)
input_image = "/tmp/resized_image.png"
img.save(input_image)
subprocess.check_call(
["python3", "main.py",
"--local_path", "controlnet.safetensors",
"--image", input_image,
"--use_controlnet",
"--control_type", "canny",
"--prompt", prompt,
"--width", str(width),
"--height", str(height),
"--num_steps", str(num_inference_steps),
"--guidance", str(cfg),
"--seed", str(seed)
], close_fds=False)
# Find the first file that begins with "controlnet_result_"
for file in os.listdir(output_dir):
if file.startswith("controlnet_result_"):
return [Path(os.path.join(output_dir, file))]