# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # STEP 1: Very first thing in the file: force spawn import multiprocessing as mp mp.set_start_method("spawn", force=True) import spaces import tempfile from PIL import Image import gradio as gr import string import random, time, math import os os.environ["NCCL_P2P_DISABLE"]="1" os.environ["NCCL_IB_DISABLE"]="1" import src.flux.generate # from src.flux.generate import generate_from_test_sample, seed_everything # from src.flux.pipeline_tools import CustomFluxPipeline, load_modulation_adapter, load_dit_lora # from src.utils.data_utils import get_train_config, image_grid, pil2tensor, json_dump, pad_to_square, cv2pil, merge_bboxes # from eval.tools.face_id import FaceID # from eval.tools.florence_sam import ObjectDetector # import shutil # import yaml # import numpy as np # from huggingface_hub import snapshot_download, hf_hub_download import torch # # FLUX.1-dev # snapshot_download( # repo_id="black-forest-labs/FLUX.1-dev", # local_dir="./checkpoints/FLUX.1-dev", # local_dir_use_symlinks=False # ) # # Florence-2-large # snapshot_download( # repo_id="microsoft/Florence-2-large", # local_dir="./checkpoints/Florence-2-large", # local_dir_use_symlinks=False # ) # # CLIP ViT Large # snapshot_download( # repo_id="openai/clip-vit-large-patch14", # local_dir="./checkpoints/clip-vit-large-patch14", # local_dir_use_symlinks=False # ) # # DINO ViT-s16 # snapshot_download( # repo_id="facebook/dino-vits16", # local_dir="./checkpoints/dino-vits16", # local_dir_use_symlinks=False # ) # # mPLUG Visual Question Answering # snapshot_download( # repo_id="xingjianleng/mplug_visual-question-answering_coco_large_en", # local_dir="./checkpoints/mplug_visual-question-answering_coco_large_en", # local_dir_use_symlinks=False # ) # # XVerse # snapshot_download( # repo_id="ByteDance/XVerse", # local_dir="./checkpoints/XVerse", # local_dir_use_symlinks=False # ) # hf_hub_download( # repo_id="facebook/sam2.1-hiera-large", # local_dir="./checkpoints/", # filename="sam2.1_hiera_large.pt", # ) # os.environ["FLORENCE2_MODEL_PATH"] = "./checkpoints/Florence-2-large" # os.environ["SAM2_MODEL_PATH"] = "./checkpoints/sam2.1_hiera_large.pt" # os.environ["FACE_ID_MODEL_PATH"] = "./checkpoints/model_ir_se50.pth" # os.environ["CLIP_MODEL_PATH"] = "./checkpoints/clip-vit-large-patch14" # os.environ["FLUX_MODEL_PATH"] = "./checkpoints/FLUX.1-dev" # os.environ["DPG_VQA_MODEL_PATH"] = "./checkpoints/mplug_visual-question-answering_coco_large_en" # os.environ["DINO_MODEL_PATH"] = "./checkpoints/dino-vits16" # dtype = torch.bfloat16 # device = "cuda" # config_path = "train/config/XVerse_config_demo.yaml" # config = config_train = get_train_config(config_path) # # config["model"]["dit_quant"] = "int8-quanto" # config["model"]["use_dit_lora"] = False # model = CustomFluxPipeline( # config, device, torch_dtype=dtype, # ) # model.pipe.set_progress_bar_config(leave=False) # face_model = FaceID(device) # detector = ObjectDetector(device) # config = get_train_config(config_path) # model.config = config # run_mode = "mod_only" # orig_only, mod_only, both # store_attn_map = False # run_name = time.strftime("%m%d-%H%M") # num_inputs = 6 # ckpt_root = "./checkpoints/XVerse" # model.clear_modulation_adapters() # model.pipe.unload_lora_weights() # if not os.path.exists(ckpt_root): # print("Checkpoint root does not exist.") # modulation_adapter = load_modulation_adapter(model, config, dtype, device, f"{ckpt_root}/modulation_adapter", is_training=False) # model.add_modulation_adapter(modulation_adapter) # if config["model"]["use_dit_lora"]: # load_dit_lora(model, model.pipe, config, dtype, device, f"{ckpt_root}", is_training=False) # vae_skip_iter = None # attn_skip_iter = 0 # def clear_images(): # return [None, ]*num_inputs @spaces.GPU() def det_seg_img(image, label): if isinstance(image, str): image = Image.open(image).convert("RGB") instance_result_dict = detector.get_multiple_instances(image, label, min_size=image.size[0]//20) indices = list(range(len(instance_result_dict["instance_images"]))) ins, bbox = merge_instances(image, indices, instance_result_dict["instance_bboxes"], instance_result_dict["instance_images"]) return ins @spaces.GPU() def crop_face_img(image): if isinstance(image, str): image = Image.open(image).convert("RGB") # image = resize_keep_aspect_ratio(image, 1024) image = pad_to_square(image).resize((2048, 2048)) face_bbox = face_model.detect( (pil2tensor(image).unsqueeze(0) * 255).to(torch.uint8).to(device), 1.4 )[0] face = image.crop(face_bbox) return face @spaces.GPU() def vlm_img_caption(image): if isinstance(image, str): image = Image.open(image).convert("RGB") try: caption = detector.detector.caption(image, "").strip() if caption.endswith("."): caption = caption[:-1] except Exception as e: print(e) caption = "" caption = caption.lower() return caption def generate_random_string(length=4): letters = string.ascii_letters # 包含大小写字母的字符串 result_str = ''.join(random.choice(letters) for i in range(length)) return result_str def resize_keep_aspect_ratio(pil_image, target_size=1024): H, W = pil_image.height, pil_image.width target_area = target_size * target_size current_area = H * W scaling_factor = (target_area / current_area) ** 0.5 # sqrt(target_area / current_area) new_H = int(round(H * scaling_factor)) new_W = int(round(W * scaling_factor)) return pil_image.resize((new_W, new_H)) @spaces.GPU() def generate_image( prompt, cond_size, target_height, target_width, seed, vae_skip_iter, control_weight_lambda, double_attention, # 新增参数 single_attention, # 新增参数 ip_scale, latent_sblora_scale_str, vae_lora_scale, indexs, # 新增参数 # *images_captions_faces, # Combine all unpacked arguments into one tuple ): # torch.cuda.empty_cache() # num_images = 1 # # Determine the number of images, captions, and faces based on the indexs length # images = list(images_captions_faces[:num_inputs]) # captions = list(images_captions_faces[num_inputs:2 * num_inputs]) # idips_checkboxes = list(images_captions_faces[2 * num_inputs:3 * num_inputs]) # images = [images[i] for i in indexs] # captions = [captions[i] for i in indexs] # idips_checkboxes = [idips_checkboxes[i] for i in indexs] # print(f"Length of images: {len(images)}") # print(f"Length of captions: {len(captions)}") # print(f"Indexs: {indexs}") # print(f"Control weight lambda: {control_weight_lambda}") # if control_weight_lambda != "no": # parts = control_weight_lambda.split(',') # new_parts = [] # for part in parts: # if ':' in part: # left, right = part.split(':') # values = right.split('/') # # 保存整体值 # global_value = values[0] # id_value = values[1] # ip_value = values[2] # new_values = [global_value] # for is_id in idips_checkboxes: # if is_id: # new_values.append(id_value) # else: # new_values.append(ip_value) # new_part = f"{left}:{('/'.join(new_values))}" # new_parts.append(new_part) # else: # new_parts.append(part) # control_weight_lambda = ','.join(new_parts) # print(f"Control weight lambda: {control_weight_lambda}") # src_inputs = [] # use_words = [] # cur_run_time = time.strftime("%m%d-%H%M%S") # tmp_dir_root = f"tmp/gradio_demo/{run_name}" # temp_dir = f"{tmp_dir_root}/{cur_run_time}_{generate_random_string(4)}" # os.makedirs(temp_dir, exist_ok=True) # print(f"Temporary directory created: {temp_dir}") # for i, (image_path, caption) in enumerate(zip(images, captions)): # if image_path: # if caption.startswith("a ") or caption.startswith("A "): # word = caption[2:] # else: # word = caption # if f"ENT{i+1}" in prompt: # prompt = prompt.replace(f"ENT{i+1}", caption) # image = resize_keep_aspect_ratio(Image.open(image_path), 768) # save_path = f"{temp_dir}/tmp_resized_input_{i}.png" # image.save(save_path) # input_image_path = save_path # src_inputs.append( # { # "image_path": input_image_path, # "caption": caption # } # ) # use_words.append((i, word, word)) # test_sample = dict( # input_images=[], position_delta=[0, -32], # prompt=prompt, # target_height=target_height, # target_width=target_width, # seed=seed, # cond_size=cond_size, # vae_skip_iter=vae_skip_iter, # lora_scale=ip_scale, # control_weight_lambda=control_weight_lambda, # latent_sblora_scale=latent_sblora_scale_str, # condition_sblora_scale=vae_lora_scale, # double_attention=double_attention, # single_attention=single_attention, # ) # if len(src_inputs) > 0: # test_sample["modulation"] = [ # dict( # type="adapter", # src_inputs=src_inputs, # use_words=use_words, # ), # ] # json_dump(test_sample, f"{temp_dir}/test_sample.json", 'utf-8') # assert single_attention == True # target_size = int(round((target_width * target_height) ** 0.5) // 16 * 16) # print(test_sample) # model.config["train"]["dataset"]["val_condition_size"] = cond_size # model.config["train"]["dataset"]["val_target_size"] = target_size # if control_weight_lambda == "no": # control_weight_lambda = None # if vae_skip_iter == "no": # vae_skip_iter = None # use_condition_sblora_control = True # use_latent_sblora_control = True # image = generate_from_test_sample( # test_sample, model.pipe, model.config, # num_images=num_images, # target_height=target_height, # target_width=target_width, # seed=seed, # store_attn_map=store_attn_map, # vae_skip_iter=vae_skip_iter, # 使用新的参数 # control_weight_lambda=control_weight_lambda, # 传递新的参数 # double_attention=double_attention, # 新增参数 # single_attention=single_attention, # 新增参数 # ip_scale=ip_scale, # use_latent_sblora_control=use_latent_sblora_control, # latent_sblora_scale=latent_sblora_scale_str, # use_condition_sblora_control=use_condition_sblora_control, # condition_sblora_scale=vae_lora_scale, # ) # if isinstance(image, list): # num_cols = 2 # num_rows = int(math.ceil(num_images / num_cols)) # image = image_grid(image, num_rows, num_cols) # save_path = f"{temp_dir}/tmp_result.png" # image.save(save_path) # return image return None def merge_instances(orig_img, indices, ins_bboxes, ins_images): orig_image_width, orig_image_height = orig_img.width, orig_img.height final_img = Image.new("RGB", (orig_image_width, orig_image_height), color=(255, 255, 255)) bboxes = [] for i in indices: bbox = np.array(ins_bboxes[i], dtype=int).tolist() bboxes.append(bbox) img = cv2pil(ins_images[i]) mask = (np.array(img)[..., :3] != 255).any(axis=-1) mask = Image.fromarray(mask.astype(np.uint8) * 255, mode='L') final_img.paste(img, (bbox[0], bbox[1]), mask) bbox = merge_bboxes(bboxes) img = final_img.crop(bbox) return img, bbox def change_accordion(at: bool, index: int, state: list): print(at, state) indexs = state if at: if index not in indexs: indexs.append(index) else: if index in indexs: indexs.remove(index) # 确保 indexs 是有序的 indexs.sort() print(indexs) return gr.Accordion(open=at), indexs def update_inputs(is_open, index, state: list): indexs = state if is_open: if index not in indexs: indexs.append(index) else: if index in indexs: indexs.remove(index) # 确保 indexs 是有序的 indexs.sort() print(indexs) return indexs, is_open if __name__ == "__main__": with gr.Blocks() as demo: indexs_state = gr.State([0, 1]) # 添加状态来存储 indexs gr.Markdown("### XVerse Demo") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", value="") with gr.Accordion("Open for More!", open=False): with gr.Row(): target_height = gr.Slider(512, 1024, step=128, value=768, label="Generated Height", info="") target_width = gr.Slider(512, 1024, step=128, value=768, label="Generated Width", info="") cond_size = gr.Slider(256, 384, step=128, value=256, label="Condition Size", info="") with gr.Row(): weight_id = gr.Slider(0.1, 5, step=0.1, value=3, label="weight_id") weight_ip = gr.Slider(0.1, 5, step=0.1, value=5, label="weight_ip") with gr.Row(): ip_scale_str = gr.Slider(0.5, 1.5, step=0.01, value=0.85, label="latent_lora_scale") vae_lora_scale = gr.Slider(0.5, 1.5, step=0.01, value=1.3, label="vae_lora_scale") with gr.Row(): vae_skip_iter_s1 = gr.Slider(0, 1, step=0.01, value=0.05, label="vae_skip_iter_before") vae_skip_iter_s2 = gr.Slider(0, 1, step=0.01, value=0.8, label="vae_skip_iter_after") with gr.Row(): weight_id_ip_str = gr.Textbox( value="0-1:1/3/5", label="weight_id_ip_str", interactive=False, visible=False ) weight_id.change( lambda s1, s2: f"0-1:1/{s1}/{s2}", inputs=[weight_id, weight_ip], outputs=weight_id_ip_str ) weight_ip.change( lambda s1, s2: f"0-1:1/{s1}/{s2}", inputs=[weight_id, weight_ip], outputs=weight_id_ip_str ) vae_skip_iter = gr.Textbox( value="0-0.05:1,0.8-1:1", label="vae_skip_iter", interactive=False, visible=False ) vae_skip_iter_s1.change( lambda s1, s2: f"0-{s1}:1,{s2}-1:1", inputs=[vae_skip_iter_s1, vae_skip_iter_s2], outputs=vae_skip_iter ) vae_skip_iter_s2.change( lambda s1, s2: f"0-{s1}:1,{s2}-1:1", inputs=[vae_skip_iter_s1, vae_skip_iter_s2], outputs=vae_skip_iter ) with gr.Row(): db_latent_lora_scale_str = gr.Textbox( value="0-1:0.85", label="db_latent_lora_scale_str", interactive=False, visible=False ) sb_latent_lora_scale_str = gr.Textbox( value="0-1:0.85", label="sb_latent_lora_scale_str", interactive=False, visible=False ) vae_lora_scale_str = gr.Textbox( value="0-1:1.3", label="vae_lora_scale_str", interactive=False, visible=False ) vae_lora_scale.change( lambda s: f"0-1:{s}", inputs=vae_lora_scale, outputs=vae_lora_scale_str ) ip_scale_str.change( lambda s: [f"0-1:{s}", f"0-1:{s}"], inputs=ip_scale_str, outputs=[db_latent_lora_scale_str, sb_latent_lora_scale_str] ) with gr.Row(): double_attention = gr.Checkbox(value=False, label="Double Attention", visible=False) single_attention = gr.Checkbox(value=True, label="Single Attention", visible=False) clear_btn = gr.Button("清空输入图像") with gr.Row(): with gr.Column(): image_1 = gr.Image(type="filepath", label=f"Image 1") caption_1 = gr.Textbox(label=f"Caption 1", value="") id_ip_checkbox_1 = gr.Checkbox(value=False, label=f"ID or not 1", visible=True) with gr.Row(): vlm_btn_1 = gr.Button("Auto Caption") det_btn_1 = gr.Button("Det & Seg") face_btn_1 = gr.Button("Crop Face") with gr.Column(): image_2 = gr.Image(type="filepath", label=f"Image 2") caption_2 = gr.Textbox(label=f"Caption 2", value="") id_ip_checkbox_2 = gr.Checkbox(value=False, label=f"ID or not 2", visible=True) with gr.Row(): vlm_btn_2 = gr.Button("Auto Caption") det_btn_2 = gr.Button("Det & Seg") face_btn_2 = gr.Button("Crop Face") with gr.Column(): output = gr.Image(label="生成的图像") seed = gr.Number(value=42, label="Seed", info="") gen_btn = gr.Button("生成图像") gr.Markdown("### Examples") gen_btn.click( generate_image, inputs=[ prompt, cond_size, target_height, target_width, seed, vae_skip_iter, weight_id_ip_str, double_attention, single_attention, db_latent_lora_scale_str, sb_latent_lora_scale_str, vae_lora_scale_str, indexs_state, # 传递 indexs 状态 # *images, # *captions, # *idip_checkboxes, ], outputs=output ) # # 修改清空函数的输出参数 # clear_btn.click(clear_images, outputs=images) face_btn_1.click(crop_face_img, inputs=[image_1], outputs=[image_1]) det_btn_1.click(det_seg_img, inputs=[image_1, caption_1], outputs=[image_1]) vlm_btn_1.click(vlm_img_caption, inputs=[image_1], outputs=[caption_1]) face_btn_2.click(crop_face_img, inputs=[image_2], outputs=[image_2]) det_btn_2.click(det_seg_img, inputs=[image_2, caption_2], outputs=[image_2]) vlm_btn_2.click(vlm_img_caption, inputs=[image_2], outputs=[caption_2]) demo.queue() demo.launch()