Update app.py
Browse files
app.py
CHANGED
@@ -31,117 +31,117 @@ os.environ["NCCL_P2P_DISABLE"]="1"
|
|
31 |
os.environ["NCCL_IB_DISABLE"]="1"
|
32 |
|
33 |
import src.flux.generate
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
import torch
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
#
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
#
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
#
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
#
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
#
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
#
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
|
146 |
@spaces.GPU()
|
147 |
def det_seg_img(image, label):
|
@@ -211,145 +211,143 @@ def generate_image(
|
|
211 |
indexs, # 新增参数
|
212 |
# *images_captions_faces, # Combine all unpacked arguments into one tuple
|
213 |
):
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
#
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
|
269 |
-
|
270 |
-
|
271 |
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
|
316 |
-
|
317 |
-
|
318 |
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
return None
|
353 |
|
354 |
|
355 |
|
@@ -533,7 +531,7 @@ if __name__ == "__main__":
|
|
533 |
)
|
534 |
|
535 |
# # 修改清空函数的输出参数
|
536 |
-
|
537 |
|
538 |
face_btn_1.click(crop_face_img, inputs=[image_1], outputs=[image_1])
|
539 |
det_btn_1.click(det_seg_img, inputs=[image_1, caption_1], outputs=[image_1])
|
|
|
31 |
os.environ["NCCL_IB_DISABLE"]="1"
|
32 |
|
33 |
import src.flux.generate
|
34 |
+
from src.flux.generate import generate_from_test_sample, seed_everything
|
35 |
+
from src.flux.pipeline_tools import CustomFluxPipeline, load_modulation_adapter, load_dit_lora
|
36 |
+
from src.utils.data_utils import get_train_config, image_grid, pil2tensor, json_dump, pad_to_square, cv2pil, merge_bboxes
|
37 |
+
from eval.tools.face_id import FaceID
|
38 |
+
from eval.tools.florence_sam import ObjectDetector
|
39 |
+
import shutil
|
40 |
+
import yaml
|
41 |
+
import numpy as np
|
42 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
43 |
import torch
|
44 |
|
45 |
+
# FLUX.1-dev
|
46 |
+
snapshot_download(
|
47 |
+
repo_id="black-forest-labs/FLUX.1-dev",
|
48 |
+
local_dir="./checkpoints/FLUX.1-dev",
|
49 |
+
local_dir_use_symlinks=False
|
50 |
+
)
|
51 |
+
|
52 |
+
# Florence-2-large
|
53 |
+
snapshot_download(
|
54 |
+
repo_id="microsoft/Florence-2-large",
|
55 |
+
local_dir="./checkpoints/Florence-2-large",
|
56 |
+
local_dir_use_symlinks=False
|
57 |
+
)
|
58 |
+
|
59 |
+
# CLIP ViT Large
|
60 |
+
snapshot_download(
|
61 |
+
repo_id="openai/clip-vit-large-patch14",
|
62 |
+
local_dir="./checkpoints/clip-vit-large-patch14",
|
63 |
+
local_dir_use_symlinks=False
|
64 |
+
)
|
65 |
+
|
66 |
+
# DINO ViT-s16
|
67 |
+
snapshot_download(
|
68 |
+
repo_id="facebook/dino-vits16",
|
69 |
+
local_dir="./checkpoints/dino-vits16",
|
70 |
+
local_dir_use_symlinks=False
|
71 |
+
)
|
72 |
+
|
73 |
+
# mPLUG Visual Question Answering
|
74 |
+
snapshot_download(
|
75 |
+
repo_id="xingjianleng/mplug_visual-question-answering_coco_large_en",
|
76 |
+
local_dir="./checkpoints/mplug_visual-question-answering_coco_large_en",
|
77 |
+
local_dir_use_symlinks=False
|
78 |
+
)
|
79 |
+
|
80 |
+
# XVerse
|
81 |
+
snapshot_download(
|
82 |
+
repo_id="ByteDance/XVerse",
|
83 |
+
local_dir="./checkpoints/XVerse",
|
84 |
+
local_dir_use_symlinks=False
|
85 |
+
)
|
86 |
+
|
87 |
+
hf_hub_download(
|
88 |
+
repo_id="facebook/sam2.1-hiera-large",
|
89 |
+
local_dir="./checkpoints/",
|
90 |
+
filename="sam2.1_hiera_large.pt",
|
91 |
+
)
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
os.environ["FLORENCE2_MODEL_PATH"] = "./checkpoints/Florence-2-large"
|
96 |
+
os.environ["SAM2_MODEL_PATH"] = "./checkpoints/sam2.1_hiera_large.pt"
|
97 |
+
os.environ["FACE_ID_MODEL_PATH"] = "./checkpoints/model_ir_se50.pth"
|
98 |
+
os.environ["CLIP_MODEL_PATH"] = "./checkpoints/clip-vit-large-patch14"
|
99 |
+
os.environ["FLUX_MODEL_PATH"] = "./checkpoints/FLUX.1-dev"
|
100 |
+
os.environ["DPG_VQA_MODEL_PATH"] = "./checkpoints/mplug_visual-question-answering_coco_large_en"
|
101 |
+
os.environ["DINO_MODEL_PATH"] = "./checkpoints/dino-vits16"
|
102 |
+
|
103 |
+
dtype = torch.bfloat16
|
104 |
+
device = "cuda"
|
105 |
+
|
106 |
+
config_path = "train/config/XVerse_config_demo.yaml"
|
107 |
+
|
108 |
+
config = config_train = get_train_config(config_path)
|
109 |
+
# config["model"]["dit_quant"] = "int8-quanto"
|
110 |
+
config["model"]["use_dit_lora"] = False
|
111 |
+
model = CustomFluxPipeline(
|
112 |
+
config, device, torch_dtype=dtype,
|
113 |
+
)
|
114 |
+
model.pipe.set_progress_bar_config(leave=False)
|
115 |
+
|
116 |
+
face_model = FaceID(device)
|
117 |
+
detector = ObjectDetector(device)
|
118 |
+
|
119 |
+
config = get_train_config(config_path)
|
120 |
+
model.config = config
|
121 |
+
|
122 |
+
run_mode = "mod_only" # orig_only, mod_only, both
|
123 |
+
store_attn_map = False
|
124 |
+
run_name = time.strftime("%m%d-%H%M")
|
125 |
+
|
126 |
+
num_inputs = 6
|
127 |
+
|
128 |
+
ckpt_root = "./checkpoints/XVerse"
|
129 |
+
model.clear_modulation_adapters()
|
130 |
+
model.pipe.unload_lora_weights()
|
131 |
+
if not os.path.exists(ckpt_root):
|
132 |
+
print("Checkpoint root does not exist.")
|
133 |
+
|
134 |
+
modulation_adapter = load_modulation_adapter(model, config, dtype, device, f"{ckpt_root}/modulation_adapter", is_training=False)
|
135 |
+
model.add_modulation_adapter(modulation_adapter)
|
136 |
+
if config["model"]["use_dit_lora"]:
|
137 |
+
load_dit_lora(model, model.pipe, config, dtype, device, f"{ckpt_root}", is_training=False)
|
138 |
+
|
139 |
+
vae_skip_iter = None
|
140 |
+
attn_skip_iter = 0
|
141 |
+
|
142 |
+
|
143 |
+
def clear_images():
|
144 |
+
return [None, ]*num_inputs
|
145 |
|
146 |
@spaces.GPU()
|
147 |
def det_seg_img(image, label):
|
|
|
211 |
indexs, # 新增参数
|
212 |
# *images_captions_faces, # Combine all unpacked arguments into one tuple
|
213 |
):
|
214 |
+
torch.cuda.empty_cache()
|
215 |
+
num_images = 1
|
216 |
+
|
217 |
+
# Determine the number of images, captions, and faces based on the indexs length
|
218 |
+
images = list(images_captions_faces[:num_inputs])
|
219 |
+
captions = list(images_captions_faces[num_inputs:2 * num_inputs])
|
220 |
+
idips_checkboxes = list(images_captions_faces[2 * num_inputs:3 * num_inputs])
|
221 |
+
images = [images[i] for i in indexs]
|
222 |
+
captions = [captions[i] for i in indexs]
|
223 |
+
idips_checkboxes = [idips_checkboxes[i] for i in indexs]
|
224 |
+
|
225 |
+
print(f"Length of images: {len(images)}")
|
226 |
+
print(f"Length of captions: {len(captions)}")
|
227 |
+
print(f"Indexs: {indexs}")
|
228 |
|
229 |
+
print(f"Control weight lambda: {control_weight_lambda}")
|
230 |
+
if control_weight_lambda != "no":
|
231 |
+
parts = control_weight_lambda.split(',')
|
232 |
+
new_parts = []
|
233 |
+
for part in parts:
|
234 |
+
if ':' in part:
|
235 |
+
left, right = part.split(':')
|
236 |
+
values = right.split('/')
|
237 |
+
# 保存整体值
|
238 |
+
global_value = values[0]
|
239 |
+
id_value = values[1]
|
240 |
+
ip_value = values[2]
|
241 |
+
new_values = [global_value]
|
242 |
+
for is_id in idips_checkboxes:
|
243 |
+
if is_id:
|
244 |
+
new_values.append(id_value)
|
245 |
+
else:
|
246 |
+
new_values.append(ip_value)
|
247 |
+
new_part = f"{left}:{('/'.join(new_values))}"
|
248 |
+
new_parts.append(new_part)
|
249 |
+
else:
|
250 |
+
new_parts.append(part)
|
251 |
+
control_weight_lambda = ','.join(new_parts)
|
252 |
|
253 |
+
print(f"Control weight lambda: {control_weight_lambda}")
|
254 |
+
|
255 |
+
src_inputs = []
|
256 |
+
use_words = []
|
257 |
+
cur_run_time = time.strftime("%m%d-%H%M%S")
|
258 |
+
tmp_dir_root = f"tmp/gradio_demo/{run_name}"
|
259 |
+
temp_dir = f"{tmp_dir_root}/{cur_run_time}_{generate_random_string(4)}"
|
260 |
+
os.makedirs(temp_dir, exist_ok=True)
|
261 |
+
print(f"Temporary directory created: {temp_dir}")
|
262 |
+
for i, (image_path, caption) in enumerate(zip(images, captions)):
|
263 |
+
if image_path:
|
264 |
+
if caption.startswith("a ") or caption.startswith("A "):
|
265 |
+
word = caption[2:]
|
266 |
+
else:
|
267 |
+
word = caption
|
268 |
|
269 |
+
if f"ENT{i+1}" in prompt:
|
270 |
+
prompt = prompt.replace(f"ENT{i+1}", caption)
|
271 |
|
272 |
+
image = resize_keep_aspect_ratio(Image.open(image_path), 768)
|
273 |
+
save_path = f"{temp_dir}/tmp_resized_input_{i}.png"
|
274 |
+
image.save(save_path)
|
275 |
|
276 |
+
input_image_path = save_path
|
277 |
+
|
278 |
+
src_inputs.append(
|
279 |
+
{
|
280 |
+
"image_path": input_image_path,
|
281 |
+
"caption": caption
|
282 |
+
}
|
283 |
+
)
|
284 |
+
use_words.append((i, word, word))
|
285 |
+
|
286 |
+
|
287 |
+
test_sample = dict(
|
288 |
+
input_images=[], position_delta=[0, -32],
|
289 |
+
prompt=prompt,
|
290 |
+
target_height=target_height,
|
291 |
+
target_width=target_width,
|
292 |
+
seed=seed,
|
293 |
+
cond_size=cond_size,
|
294 |
+
vae_skip_iter=vae_skip_iter,
|
295 |
+
lora_scale=ip_scale,
|
296 |
+
control_weight_lambda=control_weight_lambda,
|
297 |
+
latent_sblora_scale=latent_sblora_scale_str,
|
298 |
+
condition_sblora_scale=vae_lora_scale,
|
299 |
+
double_attention=double_attention,
|
300 |
+
single_attention=single_attention,
|
301 |
+
)
|
302 |
+
if len(src_inputs) > 0:
|
303 |
+
test_sample["modulation"] = [
|
304 |
+
dict(
|
305 |
+
type="adapter",
|
306 |
+
src_inputs=src_inputs,
|
307 |
+
use_words=use_words,
|
308 |
+
),
|
309 |
+
]
|
310 |
|
311 |
+
json_dump(test_sample, f"{temp_dir}/test_sample.json", 'utf-8')
|
312 |
+
assert single_attention == True
|
313 |
+
target_size = int(round((target_width * target_height) ** 0.5) // 16 * 16)
|
314 |
+
print(test_sample)
|
315 |
|
316 |
+
model.config["train"]["dataset"]["val_condition_size"] = cond_size
|
317 |
+
model.config["train"]["dataset"]["val_target_size"] = target_size
|
318 |
|
319 |
+
if control_weight_lambda == "no":
|
320 |
+
control_weight_lambda = None
|
321 |
+
if vae_skip_iter == "no":
|
322 |
+
vae_skip_iter = None
|
323 |
+
use_condition_sblora_control = True
|
324 |
+
use_latent_sblora_control = True
|
325 |
+
image = generate_from_test_sample(
|
326 |
+
test_sample, model.pipe, model.config,
|
327 |
+
num_images=num_images,
|
328 |
+
target_height=target_height,
|
329 |
+
target_width=target_width,
|
330 |
+
seed=seed,
|
331 |
+
store_attn_map=store_attn_map,
|
332 |
+
vae_skip_iter=vae_skip_iter, # 使用新的参数
|
333 |
+
control_weight_lambda=control_weight_lambda, # 传递新的参数
|
334 |
+
double_attention=double_attention, # 新增参数
|
335 |
+
single_attention=single_attention, # 新增参数
|
336 |
+
ip_scale=ip_scale,
|
337 |
+
use_latent_sblora_control=use_latent_sblora_control,
|
338 |
+
latent_sblora_scale=latent_sblora_scale_str,
|
339 |
+
use_condition_sblora_control=use_condition_sblora_control,
|
340 |
+
condition_sblora_scale=vae_lora_scale,
|
341 |
+
)
|
342 |
+
if isinstance(image, list):
|
343 |
+
num_cols = 2
|
344 |
+
num_rows = int(math.ceil(num_images / num_cols))
|
345 |
+
image = image_grid(image, num_rows, num_cols)
|
346 |
+
|
347 |
+
save_path = f"{temp_dir}/tmp_result.png"
|
348 |
+
image.save(save_path)
|
349 |
+
|
350 |
+
return image
|
|
|
|
|
351 |
|
352 |
|
353 |
|
|
|
531 |
)
|
532 |
|
533 |
# # 修改清空函数的输出参数
|
534 |
+
clear_btn.click(clear_images, outputs=images)
|
535 |
|
536 |
face_btn_1.click(crop_face_img, inputs=[image_1], outputs=[image_1])
|
537 |
det_btn_1.click(det_seg_img, inputs=[image_1, caption_1], outputs=[image_1])
|