|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import multiprocessing as mp |
|
mp.set_start_method("spawn", force=True) |
|
|
|
import spaces |
|
|
|
import tempfile |
|
from PIL import Image |
|
import gradio as gr |
|
import string |
|
import random, time, math |
|
import os |
|
|
|
|
|
os.environ["NCCL_P2P_DISABLE"]="1" |
|
os.environ["NCCL_IB_DISABLE"]="1" |
|
|
|
import src.flux.generate |
|
from src.flux.generate import generate_from_test_sample, seed_everything |
|
from src.flux.pipeline_tools import CustomFluxPipeline, load_modulation_adapter, load_dit_lora |
|
from src.utils.data_utils import get_train_config, image_grid, pil2tensor, json_dump, pad_to_square, cv2pil, merge_bboxes |
|
from eval.tools.face_id import FaceID |
|
from eval.tools.florence_sam import ObjectDetector |
|
import shutil |
|
import yaml |
|
import numpy as np |
|
from huggingface_hub import snapshot_download, hf_hub_download |
|
import torch |
|
|
|
|
|
snapshot_download( |
|
repo_id="black-forest-labs/FLUX.1-dev", |
|
local_dir="./checkpoints/FLUX.1-dev", |
|
local_dir_use_symlinks=False |
|
) |
|
|
|
|
|
snapshot_download( |
|
repo_id="microsoft/Florence-2-large", |
|
local_dir="./checkpoints/Florence-2-large", |
|
local_dir_use_symlinks=False |
|
) |
|
|
|
|
|
snapshot_download( |
|
repo_id="openai/clip-vit-large-patch14", |
|
local_dir="./checkpoints/clip-vit-large-patch14", |
|
local_dir_use_symlinks=False |
|
) |
|
|
|
|
|
snapshot_download( |
|
repo_id="facebook/dino-vits16", |
|
local_dir="./checkpoints/dino-vits16", |
|
local_dir_use_symlinks=False |
|
) |
|
|
|
|
|
snapshot_download( |
|
repo_id="xingjianleng/mplug_visual-question-answering_coco_large_en", |
|
local_dir="./checkpoints/mplug_visual-question-answering_coco_large_en", |
|
local_dir_use_symlinks=False |
|
) |
|
|
|
|
|
snapshot_download( |
|
repo_id="ByteDance/XVerse", |
|
local_dir="./checkpoints/XVerse", |
|
local_dir_use_symlinks=False |
|
) |
|
|
|
hf_hub_download( |
|
repo_id="facebook/sam2.1-hiera-large", |
|
local_dir="./checkpoints/", |
|
filename="sam2.1_hiera_large.pt", |
|
) |
|
|
|
|
|
|
|
os.environ["FLORENCE2_MODEL_PATH"] = "./checkpoints/Florence-2-large" |
|
os.environ["SAM2_MODEL_PATH"] = "./checkpoints/sam2.1_hiera_large.pt" |
|
os.environ["FACE_ID_MODEL_PATH"] = "./checkpoints/model_ir_se50.pth" |
|
os.environ["CLIP_MODEL_PATH"] = "./checkpoints/clip-vit-large-patch14" |
|
os.environ["FLUX_MODEL_PATH"] = "./checkpoints/FLUX.1-dev" |
|
os.environ["DPG_VQA_MODEL_PATH"] = "./checkpoints/mplug_visual-question-answering_coco_large_en" |
|
os.environ["DINO_MODEL_PATH"] = "./checkpoints/dino-vits16" |
|
|
|
dtype = torch.bfloat16 |
|
device = "cuda" |
|
|
|
config_path = "train/config/XVerse_config_demo.yaml" |
|
|
|
config = config_train = get_train_config(config_path) |
|
|
|
config["model"]["use_dit_lora"] = False |
|
model = CustomFluxPipeline( |
|
config, device, torch_dtype=dtype, |
|
) |
|
model.pipe.set_progress_bar_config(leave=False) |
|
|
|
face_model = FaceID(device) |
|
detector = ObjectDetector(device) |
|
|
|
config = get_train_config(config_path) |
|
model.config = config |
|
|
|
run_mode = "mod_only" |
|
store_attn_map = False |
|
run_name = time.strftime("%m%d-%H%M") |
|
|
|
num_inputs = 6 |
|
|
|
ckpt_root = "./checkpoints/XVerse" |
|
model.clear_modulation_adapters() |
|
model.pipe.unload_lora_weights() |
|
if not os.path.exists(ckpt_root): |
|
print("Checkpoint root does not exist.") |
|
|
|
modulation_adapter = load_modulation_adapter(model, config, dtype, device, f"{ckpt_root}/modulation_adapter", is_training=False) |
|
model.add_modulation_adapter(modulation_adapter) |
|
if config["model"]["use_dit_lora"]: |
|
load_dit_lora(model, model.pipe, config, dtype, device, f"{ckpt_root}", is_training=False) |
|
|
|
vae_skip_iter = None |
|
attn_skip_iter = 0 |
|
|
|
|
|
def clear_images(): |
|
return [None, ]*num_inputs |
|
|
|
@spaces.GPU() |
|
def det_seg_img(image, label): |
|
if isinstance(image, str): |
|
image = Image.open(image).convert("RGB") |
|
instance_result_dict = detector.get_multiple_instances(image, label, min_size=image.size[0]//20) |
|
indices = list(range(len(instance_result_dict["instance_images"]))) |
|
ins, bbox = merge_instances(image, indices, instance_result_dict["instance_bboxes"], instance_result_dict["instance_images"]) |
|
return ins |
|
|
|
@spaces.GPU() |
|
def crop_face_img(image): |
|
if isinstance(image, str): |
|
image = Image.open(image).convert("RGB") |
|
|
|
|
|
image = pad_to_square(image).resize((2048, 2048)) |
|
|
|
face_bbox = face_model.detect( |
|
(pil2tensor(image).unsqueeze(0) * 255).to(torch.uint8).to(device), 1.4 |
|
)[0] |
|
face = image.crop(face_bbox) |
|
return face |
|
|
|
@spaces.GPU() |
|
def vlm_img_caption(image): |
|
if isinstance(image, str): |
|
image = Image.open(image).convert("RGB") |
|
|
|
try: |
|
caption = detector.detector.caption(image, "<CAPTION>").strip() |
|
if caption.endswith("."): |
|
caption = caption[:-1] |
|
|
|
except Exception as e: |
|
print(e) |
|
caption = "" |
|
|
|
caption = caption.lower() |
|
return caption |
|
|
|
|
|
def generate_random_string(length=4): |
|
letters = string.ascii_letters |
|
result_str = ''.join(random.choice(letters) for i in range(length)) |
|
return result_str |
|
|
|
def resize_keep_aspect_ratio(pil_image, target_size=1024): |
|
H, W = pil_image.height, pil_image.width |
|
target_area = target_size * target_size |
|
current_area = H * W |
|
scaling_factor = (target_area / current_area) ** 0.5 |
|
new_H = int(round(H * scaling_factor)) |
|
new_W = int(round(W * scaling_factor)) |
|
return pil_image.resize((new_W, new_H)) |
|
|
|
@spaces.GPU() |
|
def generate_image( |
|
prompt, |
|
cond_size, target_height, target_width, |
|
seed, |
|
vae_skip_iter, control_weight_lambda, |
|
double_attention, |
|
single_attention, |
|
ip_scale, |
|
latent_sblora_scale_str, vae_lora_scale, |
|
indexs, |
|
|
|
): |
|
torch.cuda.empty_cache() |
|
num_images = 1 |
|
|
|
|
|
images = list(images_captions_faces[:num_inputs]) |
|
captions = list(images_captions_faces[num_inputs:2 * num_inputs]) |
|
idips_checkboxes = list(images_captions_faces[2 * num_inputs:3 * num_inputs]) |
|
images = [images[i] for i in indexs] |
|
captions = [captions[i] for i in indexs] |
|
idips_checkboxes = [idips_checkboxes[i] for i in indexs] |
|
|
|
print(f"Length of images: {len(images)}") |
|
print(f"Length of captions: {len(captions)}") |
|
print(f"Indexs: {indexs}") |
|
|
|
print(f"Control weight lambda: {control_weight_lambda}") |
|
if control_weight_lambda != "no": |
|
parts = control_weight_lambda.split(',') |
|
new_parts = [] |
|
for part in parts: |
|
if ':' in part: |
|
left, right = part.split(':') |
|
values = right.split('/') |
|
|
|
global_value = values[0] |
|
id_value = values[1] |
|
ip_value = values[2] |
|
new_values = [global_value] |
|
for is_id in idips_checkboxes: |
|
if is_id: |
|
new_values.append(id_value) |
|
else: |
|
new_values.append(ip_value) |
|
new_part = f"{left}:{('/'.join(new_values))}" |
|
new_parts.append(new_part) |
|
else: |
|
new_parts.append(part) |
|
control_weight_lambda = ','.join(new_parts) |
|
|
|
print(f"Control weight lambda: {control_weight_lambda}") |
|
|
|
src_inputs = [] |
|
use_words = [] |
|
cur_run_time = time.strftime("%m%d-%H%M%S") |
|
tmp_dir_root = f"tmp/gradio_demo/{run_name}" |
|
temp_dir = f"{tmp_dir_root}/{cur_run_time}_{generate_random_string(4)}" |
|
os.makedirs(temp_dir, exist_ok=True) |
|
print(f"Temporary directory created: {temp_dir}") |
|
for i, (image_path, caption) in enumerate(zip(images, captions)): |
|
if image_path: |
|
if caption.startswith("a ") or caption.startswith("A "): |
|
word = caption[2:] |
|
else: |
|
word = caption |
|
|
|
if f"ENT{i+1}" in prompt: |
|
prompt = prompt.replace(f"ENT{i+1}", caption) |
|
|
|
image = resize_keep_aspect_ratio(Image.open(image_path), 768) |
|
save_path = f"{temp_dir}/tmp_resized_input_{i}.png" |
|
image.save(save_path) |
|
|
|
input_image_path = save_path |
|
|
|
src_inputs.append( |
|
{ |
|
"image_path": input_image_path, |
|
"caption": caption |
|
} |
|
) |
|
use_words.append((i, word, word)) |
|
|
|
|
|
test_sample = dict( |
|
input_images=[], position_delta=[0, -32], |
|
prompt=prompt, |
|
target_height=target_height, |
|
target_width=target_width, |
|
seed=seed, |
|
cond_size=cond_size, |
|
vae_skip_iter=vae_skip_iter, |
|
lora_scale=ip_scale, |
|
control_weight_lambda=control_weight_lambda, |
|
latent_sblora_scale=latent_sblora_scale_str, |
|
condition_sblora_scale=vae_lora_scale, |
|
double_attention=double_attention, |
|
single_attention=single_attention, |
|
) |
|
if len(src_inputs) > 0: |
|
test_sample["modulation"] = [ |
|
dict( |
|
type="adapter", |
|
src_inputs=src_inputs, |
|
use_words=use_words, |
|
), |
|
] |
|
|
|
json_dump(test_sample, f"{temp_dir}/test_sample.json", 'utf-8') |
|
assert single_attention == True |
|
target_size = int(round((target_width * target_height) ** 0.5) // 16 * 16) |
|
print(test_sample) |
|
|
|
model.config["train"]["dataset"]["val_condition_size"] = cond_size |
|
model.config["train"]["dataset"]["val_target_size"] = target_size |
|
|
|
if control_weight_lambda == "no": |
|
control_weight_lambda = None |
|
if vae_skip_iter == "no": |
|
vae_skip_iter = None |
|
use_condition_sblora_control = True |
|
use_latent_sblora_control = True |
|
image = generate_from_test_sample( |
|
test_sample, model.pipe, model.config, |
|
num_images=num_images, |
|
target_height=target_height, |
|
target_width=target_width, |
|
seed=seed, |
|
store_attn_map=store_attn_map, |
|
vae_skip_iter=vae_skip_iter, |
|
control_weight_lambda=control_weight_lambda, |
|
double_attention=double_attention, |
|
single_attention=single_attention, |
|
ip_scale=ip_scale, |
|
use_latent_sblora_control=use_latent_sblora_control, |
|
latent_sblora_scale=latent_sblora_scale_str, |
|
use_condition_sblora_control=use_condition_sblora_control, |
|
condition_sblora_scale=vae_lora_scale, |
|
) |
|
if isinstance(image, list): |
|
num_cols = 2 |
|
num_rows = int(math.ceil(num_images / num_cols)) |
|
image = image_grid(image, num_rows, num_cols) |
|
|
|
save_path = f"{temp_dir}/tmp_result.png" |
|
image.save(save_path) |
|
|
|
return image |
|
|
|
|
|
|
|
def merge_instances(orig_img, indices, ins_bboxes, ins_images): |
|
orig_image_width, orig_image_height = orig_img.width, orig_img.height |
|
final_img = Image.new("RGB", (orig_image_width, orig_image_height), color=(255, 255, 255)) |
|
bboxes = [] |
|
for i in indices: |
|
bbox = np.array(ins_bboxes[i], dtype=int).tolist() |
|
bboxes.append(bbox) |
|
|
|
img = cv2pil(ins_images[i]) |
|
mask = (np.array(img)[..., :3] != 255).any(axis=-1) |
|
mask = Image.fromarray(mask.astype(np.uint8) * 255, mode='L') |
|
final_img.paste(img, (bbox[0], bbox[1]), mask) |
|
|
|
bbox = merge_bboxes(bboxes) |
|
img = final_img.crop(bbox) |
|
return img, bbox |
|
|
|
|
|
def change_accordion(at: bool, index: int, state: list): |
|
print(at, state) |
|
indexs = state |
|
if at: |
|
if index not in indexs: |
|
indexs.append(index) |
|
else: |
|
if index in indexs: |
|
indexs.remove(index) |
|
|
|
|
|
indexs.sort() |
|
print(indexs) |
|
return gr.Accordion(open=at), indexs |
|
|
|
def update_inputs(is_open, index, state: list): |
|
indexs = state |
|
if is_open: |
|
if index not in indexs: |
|
indexs.append(index) |
|
else: |
|
if index in indexs: |
|
indexs.remove(index) |
|
|
|
|
|
indexs.sort() |
|
print(indexs) |
|
return indexs, is_open |
|
|
|
if __name__ == "__main__": |
|
|
|
with gr.Blocks() as demo: |
|
|
|
indexs_state = gr.State([0, 1]) |
|
|
|
gr.Markdown("### XVerse Demo") |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox(label="Prompt", value="") |
|
with gr.Accordion("Open for More!", open=False): |
|
|
|
with gr.Row(): |
|
target_height = gr.Slider(512, 1024, step=128, value=768, label="Generated Height", info="") |
|
target_width = gr.Slider(512, 1024, step=128, value=768, label="Generated Width", info="") |
|
cond_size = gr.Slider(256, 384, step=128, value=256, label="Condition Size", info="") |
|
with gr.Row(): |
|
weight_id = gr.Slider(0.1, 5, step=0.1, value=3, label="weight_id") |
|
weight_ip = gr.Slider(0.1, 5, step=0.1, value=5, label="weight_ip") |
|
with gr.Row(): |
|
ip_scale_str = gr.Slider(0.5, 1.5, step=0.01, value=0.85, label="latent_lora_scale") |
|
vae_lora_scale = gr.Slider(0.5, 1.5, step=0.01, value=1.3, label="vae_lora_scale") |
|
with gr.Row(): |
|
vae_skip_iter_s1 = gr.Slider(0, 1, step=0.01, value=0.05, label="vae_skip_iter_before") |
|
vae_skip_iter_s2 = gr.Slider(0, 1, step=0.01, value=0.8, label="vae_skip_iter_after") |
|
|
|
|
|
with gr.Row(): |
|
weight_id_ip_str = gr.Textbox( |
|
value="0-1:1/3/5", |
|
label="weight_id_ip_str", |
|
interactive=False, visible=False |
|
) |
|
weight_id.change( |
|
lambda s1, s2: f"0-1:1/{s1}/{s2}", |
|
inputs=[weight_id, weight_ip], |
|
outputs=weight_id_ip_str |
|
) |
|
weight_ip.change( |
|
lambda s1, s2: f"0-1:1/{s1}/{s2}", |
|
inputs=[weight_id, weight_ip], |
|
outputs=weight_id_ip_str |
|
) |
|
vae_skip_iter = gr.Textbox( |
|
value="0-0.05:1,0.8-1:1", |
|
label="vae_skip_iter", |
|
interactive=False, visible=False |
|
) |
|
vae_skip_iter_s1.change( |
|
lambda s1, s2: f"0-{s1}:1,{s2}-1:1", |
|
inputs=[vae_skip_iter_s1, vae_skip_iter_s2], |
|
outputs=vae_skip_iter |
|
) |
|
vae_skip_iter_s2.change( |
|
lambda s1, s2: f"0-{s1}:1,{s2}-1:1", |
|
inputs=[vae_skip_iter_s1, vae_skip_iter_s2], |
|
outputs=vae_skip_iter |
|
) |
|
|
|
|
|
with gr.Row(): |
|
db_latent_lora_scale_str = gr.Textbox( |
|
value="0-1:0.85", |
|
label="db_latent_lora_scale_str", |
|
interactive=False, visible=False |
|
) |
|
sb_latent_lora_scale_str = gr.Textbox( |
|
value="0-1:0.85", |
|
label="sb_latent_lora_scale_str", |
|
interactive=False, visible=False |
|
) |
|
vae_lora_scale_str = gr.Textbox( |
|
value="0-1:1.3", |
|
label="vae_lora_scale_str", |
|
interactive=False, visible=False |
|
) |
|
vae_lora_scale.change( |
|
lambda s: f"0-1:{s}", |
|
inputs=vae_lora_scale, |
|
outputs=vae_lora_scale_str |
|
) |
|
ip_scale_str.change( |
|
lambda s: [f"0-1:{s}", f"0-1:{s}"], |
|
inputs=ip_scale_str, |
|
outputs=[db_latent_lora_scale_str, sb_latent_lora_scale_str] |
|
) |
|
|
|
with gr.Row(): |
|
double_attention = gr.Checkbox(value=False, label="Double Attention", visible=False) |
|
single_attention = gr.Checkbox(value=True, label="Single Attention", visible=False) |
|
|
|
clear_btn = gr.Button("清空输入图像") |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_1 = gr.Image(type="filepath", label=f"Image 1") |
|
caption_1 = gr.Textbox(label=f"Caption 1", value="") |
|
id_ip_checkbox_1 = gr.Checkbox(value=False, label=f"ID or not 1", visible=True) |
|
with gr.Row(): |
|
vlm_btn_1 = gr.Button("Auto Caption") |
|
det_btn_1 = gr.Button("Det & Seg") |
|
face_btn_1 = gr.Button("Crop Face") |
|
|
|
with gr.Column(): |
|
image_2 = gr.Image(type="filepath", label=f"Image 2") |
|
caption_2 = gr.Textbox(label=f"Caption 2", value="") |
|
id_ip_checkbox_2 = gr.Checkbox(value=False, label=f"ID or not 2", visible=True) |
|
with gr.Row(): |
|
vlm_btn_2 = gr.Button("Auto Caption") |
|
det_btn_2 = gr.Button("Det & Seg") |
|
face_btn_2 = gr.Button("Crop Face") |
|
|
|
with gr.Column(): |
|
output = gr.Image(label="生成的图像") |
|
seed = gr.Number(value=42, label="Seed", info="") |
|
gen_btn = gr.Button("生成图像") |
|
|
|
gr.Markdown("### Examples") |
|
gen_btn.click( |
|
generate_image, |
|
inputs=[ |
|
prompt, cond_size, target_height, target_width, seed, |
|
vae_skip_iter, weight_id_ip_str, |
|
double_attention, single_attention, |
|
db_latent_lora_scale_str, sb_latent_lora_scale_str, vae_lora_scale_str, |
|
indexs_state, |
|
|
|
|
|
|
|
], |
|
outputs=output |
|
) |
|
|
|
|
|
clear_btn.click(clear_images, outputs=images) |
|
|
|
face_btn_1.click(crop_face_img, inputs=[image_1], outputs=[image_1]) |
|
det_btn_1.click(det_seg_img, inputs=[image_1, caption_1], outputs=[image_1]) |
|
vlm_btn_1.click(vlm_img_caption, inputs=[image_1], outputs=[caption_1]) |
|
|
|
face_btn_2.click(crop_face_img, inputs=[image_2], outputs=[image_2]) |
|
det_btn_2.click(det_seg_img, inputs=[image_2, caption_2], outputs=[image_2]) |
|
vlm_btn_2.click(vlm_img_caption, inputs=[image_2], outputs=[caption_2]) |
|
|
|
|
|
demo.queue() |
|
demo.launch() |