from typing import TypedDict import diffusers.image_processor import gradio as gr import pillow_heif # pyright: ignore[reportMissingTypeStubs] import spaces # pyright: ignore[reportMissingTypeStubs] import torch from PIL import Image from pipeline import TryOffAnyone pillow_heif.register_heif_opener() # pyright: ignore[reportUnknownMemberType] pillow_heif.register_avif_opener() # pyright: ignore[reportUnknownMemberType] torch.set_float32_matmul_precision("high") torch.backends.cuda.matmul.allow_tf32 = True TITLE = """ # Try Off Anyone ## ⚠️ Important 1. Choose an example image or upload your own 2. Use the Pen tool to draw a mask over the clothing area you want to extract [[arxiv:2412.08573]](https://arxiv.org/abs/2412.08573) [[github:ixarchakos/try-off-anyone]](https://github.com/ixarchakos/try-off-anyone) """ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 pipeline_tryoff = TryOffAnyone( device=DEVICE, dtype=DTYPE, ) mask_processor = diffusers.image_processor.VaeImageProcessor( vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True, ) vae_processor = diffusers.image_processor.VaeImageProcessor( vae_scale_factor=8, ) class ImageData(TypedDict): background: Image.Image composite: Image.Image layers: list[Image.Image] @spaces.GPU def process( image_data: ImageData, image_width: int, image_height: int, num_inference_steps: int, condition_scale: float, seed: int, ) -> Image.Image: assert image_width > 0 assert image_height > 0 assert num_inference_steps > 0 assert condition_scale > 0 assert seed >= 0 # extract image and mask from image_data image = image_data["background"] mask = image_data["layers"][0] # preprocess image image = image.convert("RGB").resize((image_width, image_height)) image_preprocessed = vae_processor.preprocess( # pyright: ignore[reportUnknownMemberType,reportAssignmentType] image=image, width=image_width, height=image_height, )[0] # preprocess mask mask = mask.getchannel("A").resize((image_width, image_height)) mask_preprocessed = mask_processor.preprocess( # pyright: ignore[reportUnknownMemberType] image=mask, width=image_width, height=image_height, )[0] # generate the TryOff image gen = torch.Generator(device=DEVICE).manual_seed(seed) tryoff_image = pipeline_tryoff( image_preprocessed, mask_preprocessed, inference_steps=num_inference_steps, scale=condition_scale, generator=gen, )[0] return tryoff_image with gr.Blocks() as demo: gr.Markdown(TITLE) with gr.Row(): with gr.Column(): input_image = gr.ImageMask( label="Input Image", height=1024, # https://github.com/gradio-app/gradio/issues/10236 type="pil", interactive=True, ) run_button = gr.Button( value="Extract Clothing", ) gr.Examples( examples=[ ["examples/model_1.jpg"], ["examples/model_2.jpg"], ["examples/model_3.jpg"], ["examples/model_4.jpg"], ["examples/model_5.jpg"], ["examples/model_6.jpg"], ["examples/model_7.jpg"], ["examples/model_8.jpg"], ["examples/model_9.jpg"], ], inputs=[input_image], ) with gr.Column(): output_image = gr.Image( label="TryOff result", height=1024, image_mode="RGB", type="pil", ) with gr.Accordion("Advanced Settings", open=True): seed = gr.Slider( label="Seed", minimum=0, maximum=100_000, value=69_420, step=1, ) scale = gr.Slider( label="Scale", minimum=0.5, maximum=5, value=2.5, step=0.05, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, value=25, step=1, ) with gr.Row(): image_width = gr.Slider( label="Image Width", minimum=64, maximum=1024, value=384, step=8, ) image_height = gr.Slider( label="Image Height", minimum=64, maximum=1024, value=512, step=8, ) run_button.click( fn=process, inputs=[ input_image, image_width, image_height, num_inference_steps, scale, seed, ], outputs=output_image, ) demo.launch()