# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import spaces import tempfile from PIL import Image import gradio as gr import string import random, time, math import os import uuid import src.flux.generate from src.flux.generate import generate_from_test_sample, seed_everything from src.flux.pipeline_tools import CustomFluxPipeline, load_modulation_adapter, load_dit_lora from src.utils.data_utils import get_train_config, image_grid, pil2tensor, json_dump, pad_to_square, cv2pil, merge_bboxes from eval.tools.face_id import FaceID from eval.tools.florence_sam import ObjectDetector import shutil import yaml import numpy as np from huggingface_hub import snapshot_download, hf_hub_download import torch import re os.environ["XVERSE_PREPROCESSED_DATA"] = f"{os.getcwd()}/proprocess_data" # FLUX.1-schnell snapshot_download( repo_id="black-forest-labs/FLUX.1-schnell", local_dir="./checkpoints/FLUX.1-schnell", local_dir_use_symlinks=False ) # Florence-2-large snapshot_download( repo_id="microsoft/Florence-2-large", local_dir="./checkpoints/Florence-2-large", local_dir_use_symlinks=False ) # CLIP ViT Large snapshot_download( repo_id="openai/clip-vit-large-patch14", local_dir="./checkpoints/clip-vit-large-patch14", local_dir_use_symlinks=False ) # DINO ViT-s16 snapshot_download( repo_id="facebook/dino-vits16", local_dir="./checkpoints/dino-vits16", local_dir_use_symlinks=False ) # mPLUG Visual Question Answering snapshot_download( repo_id="xingjianleng/mplug_visual-question-answering_coco_large_en", local_dir="./checkpoints/mplug_visual-question-answering_coco_large_en", local_dir_use_symlinks=False ) # XVerse snapshot_download( repo_id="ByteDance/XVerse", local_dir="./checkpoints/XVerse", local_dir_use_symlinks=False ) hf_hub_download( repo_id="facebook/sam2.1-hiera-large", local_dir="./checkpoints/", filename="sam2.1_hiera_large.pt", ) os.environ["FLORENCE2_MODEL_PATH"] = "./checkpoints/Florence-2-large" os.environ["SAM2_MODEL_PATH"] = "./checkpoints/sam2.1_hiera_large.pt" os.environ["FACE_ID_MODEL_PATH"] = "./checkpoints/model_ir_se50.pth" os.environ["CLIP_MODEL_PATH"] = "./checkpoints/clip-vit-large-patch14" os.environ["FLUX_MODEL_PATH"] = "./checkpoints/FLUX.1-schnell" os.environ["DPG_VQA_MODEL_PATH"] = "./checkpoints/mplug_visual-question-answering_coco_large_en" os.environ["DINO_MODEL_PATH"] = "./checkpoints/dino-vits16" dtype = torch.bfloat16 device = "cuda" config_path = "train/config/XVerse_config_demo.yaml" config = config_train = get_train_config(config_path) # config["model"]["dit_quant"] = "int8-quanto" config["model"]["use_dit_lora"] = False model = CustomFluxPipeline( config, device, torch_dtype=dtype, ) model.pipe.set_progress_bar_config(leave=False) face_model = FaceID(device) detector = ObjectDetector(device) config = get_train_config(config_path) model.config = config run_mode = "mod_only" store_attn_map = False run_name = time.strftime("%m%d-%H%M") num_inputs = 3 images = [] captions = [] face_btns = [] det_btns = [] vlm_btns = [] idip_checkboxes = [] ckpt_root = "./checkpoints/XVerse" model.clear_modulation_adapters() model.pipe.unload_lora_weights() if not os.path.exists(ckpt_root): print("Checkpoint root does not exist.") modulation_adapter = load_modulation_adapter(model, config, dtype, device, f"{ckpt_root}/modulation_adapter", is_training=False) model.add_modulation_adapter(modulation_adapter) if config["model"]["use_dit_lora"]: load_dit_lora(model, model.pipe, config, dtype, device, f"{ckpt_root}", is_training=False) vae_skip_iter = None attn_skip_iter = 0 def clear_images(): return [None, ]*num_inputs + ["", ]*num_inputs @spaces.GPU() def det_seg_img(image, label): if isinstance(image, str): image = Image.open(image).convert("RGB") instance_result_dict = detector.get_multiple_instances(image, label, min_size=image.size[0]//20) indices = list(range(len(instance_result_dict["instance_images"]))) ins, bbox = merge_instances(image, indices, instance_result_dict["instance_bboxes"], instance_result_dict["instance_images"]) return ins @spaces.GPU() def crop_face_img(image): if isinstance(image, str): image = Image.open(image).convert("RGB") image = pad_to_square(image).resize((2048, 2048)) face_bbox = face_model.detect( (pil2tensor(image).unsqueeze(0) * 255).to(torch.uint8).to(device), 1.4 )[0] face = image.crop(face_bbox) return face @spaces.GPU() def vlm_img_caption(image): if isinstance(image, str): image = Image.open(image).convert("RGB") try: caption = detector.detector.caption(image, "").strip() if caption.endswith("."): caption = caption[:-1] except Exception as e: print(e) caption = "" caption = caption.lower() return caption def generate_random_string(length=4): letters = string.ascii_letters result_str = ''.join(random.choice(letters) for i in range(length)) return result_str def resize_keep_aspect_ratio(pil_image, target_size=1024): H, W = pil_image.height, pil_image.width target_area = target_size * target_size current_area = H * W scaling_factor = (target_area / current_area) ** 0.5 # sqrt(target_area / current_area) new_H = int(round(H * scaling_factor)) new_W = int(round(W * scaling_factor)) return pil_image.resize((new_W, new_H)) @spaces.GPU() def generate_image( prompt, image_1, caption_1, image_2 = None, caption_2 = None, image_3 = None, caption_3 = None, num_inference_steps = 8, cond_size = 256, target_height = 768, target_width = 768, seed = 42, vae_skip_iter = "0-0.05:1,0.8-1:1", control_weight_lambda = "0-1:1/1.6/5", double_attention = False, single_attention = True, ip_scale = "0-1:0.85", latent_sblora_scale_str = "0-1:0.85", vae_lora_scale = "0-1:1.3", session_id = None, ): if session_id is None: session_id = uuid.uuid4().hex torch.cuda.empty_cache() num_images = 1 images = [image_1, image_2, image_3] captions = [caption_1, caption_2, caption_3] idips_checkboxes = [True, True, True] # ——— Fallback to VLM caption if any caption is empty ——— for idx, (img, cap) in enumerate(zip(images, captions)): if img is not None: if not cap or cap.strip() == "": try: captions[idx] = vlm_img_caption(img) gr.Info(f"ENT{idx+1} Prompt Missing. Auto Generating Caption") except Exception as e: # if captioning fails, leave it empty or log as you wish print(f"Failed to generate caption for image {idx}: {e}") # ———————————————————————————————————————————————————————— # ——— Build a default prompt if none was provided ——— if not prompt or prompt.strip() == "": # collect ENT1, ENT2, … for every non-None image ents = [f"ENT{i+1}" for i, img in enumerate(images) if img is not None] # join them with spaces (or “ and ” if you prefer) prompt = " and ".join(ents) # e.g. prompt becomes "ENT1 ENT2" or "ENT1 ENT2 ENT3" gr.Info(f"Prompt Not Provided, Defaulting to {prompt}") # ———————————————————————————————————————— print(f"Length of images: {len(images)}") print(f"Length of captions: {len(captions)}") print(f"Control weight lambda: {control_weight_lambda}") if control_weight_lambda != "no": parts = control_weight_lambda.split(',') new_parts = [] for part in parts: if ':' in part: left, right = part.split(':') values = right.split('/') global_value = values[0] id_value = values[1] ip_value = values[2] new_values = [global_value] for is_id in idips_checkboxes: if is_id: new_values.append(id_value) else: new_values.append(ip_value) new_part = f"{left}:{('/'.join(new_values))}" new_parts.append(new_part) else: new_parts.append(part) control_weight_lambda = ','.join(new_parts) print(f"Control weight lambda: {control_weight_lambda}") src_inputs = [] use_words = [] cur_run_time = time.strftime("%m%d-%H%M%S") processed_directory = os.environ["XVERSE_PREPROCESSED_DATA"] tmp_dir_root = f'{processed_directory}' temp_dir = f"{tmp_dir_root}/{session_id}/{cur_run_time}_{generate_random_string(4)}" os.makedirs(temp_dir, exist_ok=True) print(f"Temporary directory created: {temp_dir}") for i, (image_path, caption) in enumerate(zip(images, captions)): if image_path: if caption.startswith("a ") or caption.startswith("A "): word = caption[2:] else: word = caption # case-insensitive replace of the ENT token prompt = re.sub( rf"ent{i+1}", # match "ent1", "ENT1", "Ent1", etc. caption, prompt, flags=re.IGNORECASE ) image = resize_keep_aspect_ratio(Image.open(image_path), 768) save_path = f"{temp_dir}/tmp_resized_input_{i}.png" image.save(save_path) input_image_path = save_path src_inputs.append( { "image_path": input_image_path, "caption": caption } ) use_words.append((i, word, word)) test_sample = dict( input_images=[], position_delta=[0, -32], prompt=prompt, target_height=target_height, target_width=target_width, seed=seed, cond_size=cond_size, vae_skip_iter=vae_skip_iter, lora_scale=ip_scale, control_weight_lambda=control_weight_lambda, latent_sblora_scale=latent_sblora_scale_str, condition_sblora_scale=vae_lora_scale, double_attention=double_attention, single_attention=single_attention, ) if len(src_inputs) > 0: test_sample["modulation"] = [ dict( type="adapter", src_inputs=src_inputs, use_words=use_words, ), ] json_dump(test_sample, f"{temp_dir}/test_sample.json", 'utf-8') assert single_attention == True target_size = int(round((target_width * target_height) ** 0.5) // 16 * 16) model.config["train"]["dataset"]["val_condition_size"] = cond_size model.config["train"]["dataset"]["val_target_size"] = target_size if control_weight_lambda == "no": control_weight_lambda = None if vae_skip_iter == "no": vae_skip_iter = None use_condition_sblora_control = True use_latent_sblora_control = True image = generate_from_test_sample( test_sample, model.pipe, model.config, num_inference_steps=num_inference_steps, num_images=num_images, target_height=target_height, target_width=target_width, seed=seed, store_attn_map=store_attn_map, vae_skip_iter=vae_skip_iter, control_weight_lambda=control_weight_lambda, double_attention=double_attention, single_attention=single_attention, ip_scale=ip_scale, use_latent_sblora_control=use_latent_sblora_control, latent_sblora_scale=latent_sblora_scale_str, use_condition_sblora_control=use_condition_sblora_control, condition_sblora_scale=vae_lora_scale, ) if isinstance(image, list): num_cols = 2 num_rows = int(math.ceil(num_images / num_cols)) image = image_grid(image, num_rows, num_cols) return image, prompt def create_image_input(index, open=True, indices_state=None): accordion_state = gr.State(open) with gr.Column(): with gr.Accordion(f"Input Image {index + 1}", open=accordion_state.value) as accordion: image = gr.Image(type="filepath", label=f"Image {index + 1}") caption = gr.Textbox(label=f"ENT{index + 1}", value="") id_ip_checkbox = gr.Checkbox(value=False, label=f"ID or not {index + 1}", visible=True) with gr.Row(): vlm_btn = gr.Button("Generate Caption") face_btn = gr.Button("Crop Face") det_btn = gr.Button("Crop to Prompt") accordion.expand( inputs=[indices_state], fn = lambda x: update_inputs(True, index, x), outputs=[indices_state, accordion_state], ) accordion.collapse( inputs=[indices_state], fn = lambda x: update_inputs(False, index, x), outputs=[indices_state, accordion_state], ) return image, caption, face_btn, det_btn, vlm_btn, accordion_state, accordion, id_ip_checkbox def create_min_image_input(index, open=True, indices_state=None): with gr.Column(min_width=128): image = gr.Image(type="filepath", label=f"Image {index + 1}") caption = gr.Textbox(label=f"ENT{index + 1} Prompt", value="") face_btn = gr.Button("Crop to Face") det_btn = gr.Button("Crop to Prompt") id_ip_checkbox = gr.Checkbox(value=True, label=f"ID or not {index + 1}", visible=False) with gr.Row(): vlm_btn = gr.Button("Generate Caption", visible=False) return image, caption, face_btn, det_btn, vlm_btn, id_ip_checkbox def merge_instances(orig_img, indices, ins_bboxes, ins_images): orig_image_width, orig_image_height = orig_img.width, orig_img.height final_img = Image.new("RGB", (orig_image_width, orig_image_height), color=(255, 255, 255)) bboxes = [] for i in indices: bbox = np.array(ins_bboxes[i], dtype=int).tolist() bboxes.append(bbox) img = cv2pil(ins_images[i]) mask = (np.array(img)[..., :3] != 255).any(axis=-1) mask = Image.fromarray(mask.astype(np.uint8) * 255, mode='L') final_img.paste(img, (bbox[0], bbox[1]), mask) bbox = merge_bboxes(bboxes) img = final_img.crop(bbox) return img, bbox def change_accordion(at: bool, index: int, state: list): print(at, state) indices = state if at: if index not in indices: indices.append(index) else: if index in indices: indices.remove(index) # 确保 indices 是有序的 indices.sort() print(indices) return gr.Accordion(open=at), indices def update_inputs(is_open, index, state: list): indices = state if is_open: if index not in indices: indices.append(index) else: if index in indices: indices.remove(index) indices.sort() print(indices) return indices, is_open def start_session(request: gr.Request): """ Initialize a new user session and return the session identifier. This function is triggered when the Gradio demo loads and creates a unique session hash that will be used to organize outputs and temporary files for this specific user session. Args: request (gr.Request): Gradio request object containing session information Returns: str: Unique session hash identifier """ return request.session_hash # Cleanup on unload def cleanup(request: gr.Request): """ Clean up session-specific directories and temporary files when the user session ends. This function is triggered when the Gradio demo is unloaded (e.g., when the user closes the browser tab or navigates away). It removes all temporary files and directories created during the user's session to free up storage space. Args: request (gr.Request): Gradio request object containing session information """ sid = request.session_hash if sid: d1 = os.path.join(os.environ["XVERSE_PREPROCESSED_DATA"], sid) shutil.rmtree(d1, ignore_errors=True) css = """ #col-container { margin: 0 auto; max-width: 1400px; } """ if __name__ == "__main__": with gr.Blocks(css=css) as demo: session_state = gr.State() demo.load(start_session, outputs=[session_state]) indices_state = gr.State([0, 1]) with gr.Column(elem_id="col-container"): gr.Markdown( """ # XVerse – Consistent Multi-Subject Control of Identity and Semantic Attributes via DiT Modulation • Source: [Github](https://github.com/bytedance/XVerse) • HF Space by : [@alexandernasa](https://twitter.com/alexandernasa/) """ ) with gr.Row(): with gr.Column(): with gr.Row(): for i in range(num_inputs): image, caption, face_btn, det_btn, vlm_btn, id_ip_checkbox = create_min_image_input(i, open=i<2, indices_state=indices_state) images.append(image) idip_checkboxes.append(id_ip_checkbox) captions.append(caption) face_btns.append(face_btn) det_btns.append(det_btn) vlm_btns.append(vlm_btn) prompt = gr.Textbox(label="Prompt", placeholder="e.g., ENT1 and ENT2") gen_btn = gr.Button("Generate", variant="primary") clear_btn = gr.Button("Clear") steps_slider = gr.Slider(minimum=4, maximum=40, step=1, value=8, label="inference steps") with gr.Accordion("Advanced Settings", open=False): seed = gr.Number(value=42, label="Seed", info="") with gr.Row(): target_height = gr.Slider(512, 1024, step=128, value=768, label="Generated Height", info="") target_width = gr.Slider(512, 1024, step=128, value=768, label="Generated Width", info="") cond_size = gr.Slider(256, 384, step=128, value=256, label="Condition Size", info="") with gr.Row(): weight_id = gr.Slider(0.1, 5, step=0.1, value=1.6, label="weight_id") weight_ip = gr.Slider(0.1, 5, step=0.1, value=5, label="weight_ip") with gr.Row(): ip_scale_str = gr.Slider(0.5, 1.5, step=0.01, value=0.85, label="latent_lora_scale") vae_lora_scale = gr.Slider(0.5, 1.5, step=0.01, value=1.3, label="vae_lora_scale") with gr.Row(): vae_skip_iter_s1 = gr.Slider(0, 1, step=0.01, value=0.05, label="vae_skip_iter_before") vae_skip_iter_s2 = gr.Slider(0, 1, step=0.01, value=0.8, label="vae_skip_iter_after") with gr.Row(): weight_id_ip_str = gr.Textbox( value="0-1:1/1.6/5", label="weight_id_ip_str", interactive=False, visible=False ) weight_id.change( lambda s1, s2: f"0-1:1/{s1}/{s2}", inputs=[weight_id, weight_ip], outputs=weight_id_ip_str ) weight_ip.change( lambda s1, s2: f"0-1:1/{s1}/{s2}", inputs=[weight_id, weight_ip], outputs=weight_id_ip_str ) vae_skip_iter = gr.Textbox( value="0-0.05:1,0.8-1:1", label="vae_skip_iter", interactive=False, visible=False ) vae_skip_iter_s1.change( lambda s1, s2: f"0-{s1}:1,{s2}-1:1", inputs=[vae_skip_iter_s1, vae_skip_iter_s2], outputs=vae_skip_iter ) vae_skip_iter_s2.change( lambda s1, s2: f"0-{s1}:1,{s2}-1:1", inputs=[vae_skip_iter_s1, vae_skip_iter_s2], outputs=vae_skip_iter ) with gr.Row(): db_latent_lora_scale_str = gr.Textbox( value="0-1:0.85", label="db_latent_lora_scale_str", interactive=False, visible=False ) sb_latent_lora_scale_str = gr.Textbox( value="0-1:0.85", label="sb_latent_lora_scale_str", interactive=False, visible=False ) vae_lora_scale_str = gr.Textbox( value="0-1:1.3", label="vae_lora_scale_str", interactive=False, visible=False ) vae_lora_scale.change( lambda s: f"0-1:{s}", inputs=vae_lora_scale, outputs=vae_lora_scale_str ) ip_scale_str.change( lambda s: [f"0-1:{s}", f"0-1:{s}"], inputs=ip_scale_str, outputs=[db_latent_lora_scale_str, sb_latent_lora_scale_str] ) with gr.Row(): double_attention = gr.Checkbox(value=False, label="Double Attention", visible=False) single_attention = gr.Checkbox(value=True, label="Single Attention", visible=False) with gr.Column(): output = gr.Image(label="Result") final_text = gr.Textbox(show_label=False, interactive=False) examples = gr.Examples( examples=[ [ "ENT1 with ENT2 holding ENT3", "sample/sam.jpg", "a man", "sample/hair.jpg", "curly hair", "sample/can.jpg", "a can", 24, ], [ "ENT1 with long curly hair wearing ENT2 at Met Gala", "sample/woman2.jpg", "a woman", "sample/dress.jpg", "a dress", None, None, 8, ], [ "ENT1 wearing a tiny hat", "sample/hamster.jpg", "a hamster", None, None, None, None, 8, ], [ "a drawing of ENT1 and ENT2 that the ENT1 is running alongside of a giant ENT2, in style of a comic book", "sample/woman.jpg", "a woman", "sample/hamster.jpg", "a hamster", None, None, 8, ], [ "a drawing of ENT1 and ENT2 that the ENT1 is running alongside of a giant ENT2, in style of a comic book", "sample/woman3.jpg", "a woman", "sample/hamster.jpg", "a hamster", None, None, 8, ], ], inputs=[ prompt, images[0], captions[0], images[1], captions[1], images[2], captions[2], steps_slider, ], outputs=[output, final_text], fn=generate_image, cache_examples=True, ) gen_btn.click( generate_image, inputs=[ prompt, images[0], captions[0], images[1], captions[1], images[2], captions[2], steps_slider, cond_size, target_height, target_width, seed, vae_skip_iter, weight_id_ip_str, double_attention, single_attention, db_latent_lora_scale_str, sb_latent_lora_scale_str, vae_lora_scale_str, session_state, ], outputs=[output, final_text] ) clear_btn.click(clear_images, outputs=images + captions) for i in range(num_inputs): face_btns[i].click(det_seg_img, inputs=[images[i], gr.State("A face")], outputs=[images[i]]) det_btns[i].click(det_seg_img, inputs=[images[i], captions[i]], outputs=[images[i]]) vlm_btns[i].click(vlm_img_caption, inputs=[images[i]], outputs=[captions[i]]) images[i].change( lambda img: "" if img is None else gr.update(), inputs=[images[i]], outputs=[captions[i]] ) demo.unload(cleanup) demo.queue() demo.launch()