Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import subprocess | |
import time | |
from cog import BasePredictor, Input, Path, Secret | |
from diffusers.utils import load_image, check_min_version | |
from diffusers import FluxFillPipeline | |
from diffusers import FluxTransformer2DModel | |
import numpy as np | |
import torch | |
from torchvision import transforms | |
class Predictor(BasePredictor): | |
def setup(self) -> None: | |
"""Load part of the model into memory to make running multiple predictions efficient""" | |
self.dtype = torch.bloat16 | |
self.try_on_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/catvton-flux-beta", | |
torch_dtype=self.dtype) | |
self.try_off_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/cat-tryoff-flux", | |
torch_dtype=self.dtype) | |
def predict(self, | |
hf_token: Secret(description="Hugging Face API token. Create a write token at https://huggingface.co/settings/token. You also need to approve the Flux Dev terms."), | |
image: Path = Input(description="Image file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/person/1.jpg"), | |
mask: Path = Input(description="Mask file path", default="https://github.com/nftblackmagic/catvton-flux/blob/main/example/person/1_mask.png?raw=true"), | |
try_on: bool = Input(True, description="Try on or try off"), | |
garment: Path = Input(description="Garment file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/garment/00035_00.jpg"), | |
num_steps: int = Input(50, description="Number of steps to run the model for"), | |
guidance_scale: float = Input(30, description="Guidance scale for the model"), | |
seed: int = Input(0, description="Seed for the model"), | |
width: int = Input(576, description="Width of the output image"), | |
height: int = Input(768, description="Height of the output image")): | |
size = (width, height) | |
if try_on: | |
self.transformer = self.try_on_transformer | |
else: | |
self.transformer = self.try_off_transformer | |
self.pipe = FluxFillPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
transformer=self.transformer, | |
torch_dtype=self.dtype, | |
token=hf_token | |
).to("cuda") | |
self.pipe.transformer.to(self.dtype) | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]) # For RGB images | |
]) | |
mask_transform = transforms.Compose([ | |
transforms.ToTensor() | |
]) | |
i = load_image(image).convert("RGB").resize(size) | |
m = load_image(mask).convert("RGB").resize(size) | |
g = load_image(garment).convert("RGB").resize(size) | |
# Transform images using the new preprocessing | |
image_tensor = transform(i) | |
mask_tensor = mask_transform(m)[:1] # Take only first channel | |
garment_tensor = transform(g) | |
# Create concatenated images | |
inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width | |
garment_mask = torch.zeros_like(mask_tensor) | |
if try_on: | |
extended_mask = torch.cat([garment_mask, mask_tensor], dim=2) | |
else: | |
extended_mask = torch.cat([1 - garment_mask, mask_tensor], dim=2) | |
prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \ | |
f"[IMAGE1] Detailed product shot of a clothing" \ | |
f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting." | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
result = self.pipe( | |
height=size[1], | |
width=size[0] * 2, | |
image=inpaint_image, | |
mask_image=extended_mask, | |
num_inference_steps=num_steps, | |
generator=generator, | |
max_sequence_length=512, | |
guidance_scale=guidance_scale, | |
prompt=prompt, | |
).images[0] | |
# Split and save results | |
width = size[0] | |
garment_result = result.crop((0, 0, width, size[1])) | |
try_result = result.crop((width, 0, width * 2, size[1])) | |
return garment_result, try_result | |