WYBar commited on
Commit
13ab714
·
1 Parent(s): 8ba1e94

construction all

Browse files
Files changed (1) hide show
  1. app.py +110 -161
app.py CHANGED
@@ -228,18 +228,36 @@ def calculate_iou(box1, box2):
228
  iou = intersection_area / union_area
229
  return iou
230
 
231
- # @spaces.GPU(enable_queue=True, duration=120)
232
- def buildmodel(**kwargs):
233
  global model
234
  global quantizer
235
  global tokenizer
 
 
236
  from modeling_crello import CrelloModel, CrelloModelConfig
237
  from quantizer import get_quantizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  # seed / input model / resume
239
- resume = kwargs.get('resume', None)
240
- seed = kwargs.get('seed', None)
241
- input_model = kwargs.get('input_model', None)
242
- quantizer_version = kwargs.get('quantizer_version', 'v4')
243
 
244
  set_seed(seed)
245
  # old_tokenizer = AutoTokenizer.from_pretrained(input_model, trust_remote_code=True)
@@ -261,13 +279,13 @@ def buildmodel(**kwargs):
261
  quantizer = get_quantizer(
262
  quantizer_version,
263
  update_vocab = False,
264
- decimal_quantize_types = kwargs.get('decimal_quantize_types'),
265
- mask_values = kwargs['mask_values'],
266
- width = kwargs['width'],
267
- height = kwargs['height'],
268
  simplify_json = False,
269
  num_mask_tokens = 0,
270
- mask_type = kwargs.get('mask_type'),
271
  )
272
  quantizer.setup_tokenizer(tokenizer)
273
 
@@ -280,11 +298,7 @@ def buildmodel(**kwargs):
280
  model_args.freeze_lm = False
281
  model_args.opt_version = input_model
282
  model_args.use_lora = False
283
- model_args.load_in_4bit = kwargs.get('load_in_4bit', False)
284
- # model = CrelloModel.from_pretrained(
285
- # resume,
286
- # config=model_args
287
- # ).to(device)
288
 
289
  model = CrelloModel.from_pretrained(
290
  "WYBar/LLM_For_Layout_Planning",
@@ -300,63 +314,46 @@ def buildmodel(**kwargs):
300
  for token in added_special_tokens_list:
301
  quantizer.additional_special_tokens.add(token)
302
 
303
- print(f"before .to(device):{model.device} {model.lm.device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  model = model.to("cuda")
305
- print(f"after .to(device):{model.device} {model.lm.device}")
 
306
  model = model.bfloat16()
307
  model.eval()
308
- # quantizer = quantizer.to("cuda")
309
- # tokenizer = tokenizer.to("cuda")
310
- # model.lm = model.lm.to("cuda")
311
- print(model.lm.device)
312
-
313
- # return model, quantizer, tokenizer
314
-
315
- def construction_layout():
316
- params_dict = {
317
- # 需要修改
318
- "input_model": "/openseg_blob/v-sirui/temporary/2024-02-21/Layout_train/COLEv2/Design_LLM/checkpoint/Meta-Llama-3-8B",
319
- "resume": "/openseg_blob/v-sirui/temporary/2024-02-21/SVD/Int2lay_1016/checkpoint/int2lay_1031/1031_test/checkpoint-26000/",
320
-
321
- "seed": 0,
322
- "mask_values": False,
323
- "quantizer_version": 'v4',
324
- "mask_type": 'cm3',
325
- "decimal_quantize_types": [],
326
- "num_mask_tokens": 0,
327
- "width": 512,
328
- "height": 512,
329
- "device": 0,
330
- }
331
- device = "cuda"
332
- # Init model
333
- buildmodel(**params_dict)
334
- # model, quantizer, tokenizer = buildmodel(**params_dict)
335
-
336
- # print('resize token embeddings to match the tokenizer', 129423)
337
- # model.lm.resize_token_embeddings(129423)
338
- # model.input_embeddings = model.lm.get_input_embeddings()
339
- # print('after token embeddings to match the tokenizer', 129423)
340
-
341
- # print("before .to(device)")
342
- # model = model.to("cuda")
343
- # print("after .to(device)")
344
- # model = model.bfloat16()
345
- # model.eval()
346
- # # quantizer = quantizer.to("cuda")
347
- # # tokenizer = tokenizer.to("cuda")
348
- # # model.lm = model.lm.to("cuda")
349
- # print(model.lm.device)
350
-
351
- return params_dict["width"], params_dict["height"], device
352
- # return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
353
 
354
  @torch.no_grad()
355
  @spaces.GPU(duration=120)
356
- def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, device, do_sample=False, temperature=1.0, top_p=1.0, top_k=50):
357
- print(model.lm.device)
358
  json_example = inputs
359
  input_intension = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":'
 
360
  print("tokenizer1")
361
  inputs = tokenizer(
362
  input_intension, return_tensors="pt"
@@ -395,7 +392,7 @@ def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, device, do_s
395
  pred_json_example = None
396
  return pred_json_example
397
 
398
- def inference(generate_method, intention, model, quantizer, tokenizer, width, height, device, do_sample=True, temperature=1.0, top_p=1.0, top_k=50):
399
  rawdata = {}
400
  rawdata["wholecaption"] = intention
401
  rawdata["layout"] = []
@@ -404,7 +401,7 @@ def inference(generate_method, intention, model, quantizer, tokenizer, width, he
404
  max_try_time = 5
405
  preddata = None
406
  while preddata is None and max_try_time > 0:
407
- preddata = evaluate_v1(rawdata, model, quantizer, tokenizer, width, height, device, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k)
408
  max_try_time -= 1
409
  else:
410
  print("Please input correct generate method")
@@ -412,41 +409,6 @@ def inference(generate_method, intention, model, quantizer, tokenizer, width, he
412
 
413
  return preddata
414
 
415
- # @spaces.GPU(enable_queue=True, duration=120)
416
- def construction():
417
- global pipeline
418
- global transp_vae
419
- from custom_model_mmdit import CustomFluxTransformer2DModel
420
- from custom_model_transp_vae import AutoencoderKLTransformerTraining as CustomVAE
421
- from custom_pipeline import CustomFluxPipelineCfg
422
-
423
- transformer = CustomFluxTransformer2DModel.from_pretrained(
424
- "WYBar/ART_test_weights",
425
- subfolder="fused_transformer",
426
- torch_dtype=torch.bfloat16,
427
- # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
428
- )
429
-
430
- transp_vae = CustomVAE.from_pretrained(
431
- "WYBar/ART_test_weights",
432
- subfolder="custom_vae",
433
- torch_dtype=torch.float32,
434
- use_safetensors=True,
435
- # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
436
- )
437
-
438
- token = os.environ.get("HF_TOKEN")
439
- pipeline = CustomFluxPipelineCfg.from_pretrained(
440
- "black-forest-labs/FLUX.1-dev",
441
- transformer=transformer,
442
- torch_dtype=torch.bfloat16,
443
- token=token,
444
- # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
445
- ).to("cuda")
446
- pipeline.enable_model_cpu_offload(gpu_id=0) # Save GPU memory
447
-
448
- # return pipeline, transp_vae
449
-
450
  @spaces.GPU(duration=120)
451
  def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae):
452
  print(validation_box)
@@ -477,7 +439,7 @@ def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps,
477
  return output_gradio
478
 
479
  def svg_test_one_sample(validation_prompt, validation_box_str, seed, true_gs, inference_steps, pipeline, transp_vae):
480
- print("svg_test_one_sample")
481
  generator = torch.Generator().manual_seed(seed)
482
  try:
483
  validation_box = ast.literal_eval(validation_box_str)
@@ -511,7 +473,7 @@ def svg_test_one_sample(validation_prompt, validation_box_str, seed, true_gs, in
511
  return result_images, svg_file_path
512
 
513
  def process_svg(text_input, tuple_input, seed, true_gs, inference_steps):
514
- print("precess_svg")
515
  result_images = []
516
  result_images, svg_file_path = svg_test_one_sample(text_input, tuple_input, seed, true_gs, inference_steps, pipeline=pipeline, transp_vae=transp_vae)
517
  # result_images, svg_file_path = gradio_test_one_sample_partial(text_input, tuple_input, seed, true_gs, inference_steps)
@@ -534,64 +496,52 @@ def process_svg(text_input, tuple_input, seed, true_gs, inference_steps):
534
  """
535
 
536
  return result_images, svg_file_path, svg_editor
537
-
538
- def main():
539
- # model, quantizer, tokenizer, width, height, device = construction_layout()
540
- width, height, device = construction_layout()
541
-
542
- # inference_partial = partial(
543
- # inference,
544
- # model=model,
545
- # quantizer=quantizer,
546
- # tokenizer=tokenizer,
547
- # width=width,
548
- # height=height,
549
- # device=device
550
- # )
551
-
552
- def process_preddate(intention, temperature, top_p, generate_method='v1'):
553
- intention = intention.replace('\n', '').replace('\r', '').replace('\\', '')
554
- intention = ensure_space_after_period(intention)
555
- print(f"process_preddate: {model.lm.device}")
556
- model.lm.to("cuda")
557
- print(f"after process_preddate: {model.lm.device}")
558
- if temperature == 0.0:
559
- # print("looking for greedy decoding strategies, set `do_sample=False`.")
560
- # preddata = inference_partial(generate_method, intention, do_sample=False)
561
- preddata = inference(generate_method, intention, model=model, quantizer=quantizer, tokenizer=tokenizer, width=width, height=height, device=device, do_sample=False)
562
  else:
563
- # preddata = inference_partial(generate_method, intention, temperature=temperature, top_p=top_p)
564
- preddata = inference(generate_method, intention, model=model, quantizer=quantizer, tokenizer=tokenizer, width=width, height=height, device=device, temperature=temperature, top_p=top_p)
565
-
566
- layouts = preddata["layout"]
567
- list_box = []
568
- for i, layout in enumerate(layouts):
569
- x, y = layout["x"], layout["y"]
570
- width, height = layout["width"], layout["height"]
571
- if i == 0:
572
- list_box.append((0, 0, width, height))
573
- list_box.append((0, 0, width, height))
574
- else:
575
- left = x - width // 2
576
- top = y - height // 2
577
- right = x + width // 2
578
- bottom = y + height // 2
579
- list_box.append((left, top, right, bottom))
580
-
581
- # print(list_box)
582
- filtered_boxes = list_box[:2]
583
- for i in range(2, len(list_box)):
584
- keep = True
585
- for j in range(1, len(filtered_boxes)):
586
- iou = calculate_iou(list_box[i], filtered_boxes[j])
587
- if iou > 0.65:
588
- print(list_box[i], filtered_boxes[j])
589
- keep = False
590
- break
591
- if keep:
592
- filtered_boxes.append(list_box[i])
593
 
594
- return str(filtered_boxes), intention, str(filtered_boxes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
 
596
  # def process_preddate(intention, generate_method='v1'):
597
  # list_box = [(0, 0, 512, 512), (0, 0, 512, 512), (136, 184, 512, 512), (144, 0, 512, 512), (0, 0, 328, 136), (160, 112, 512, 360), (168, 112, 512, 360), (40, 232, 112, 296), (32, 88, 248, 176), (48, 424, 144, 448), (48, 464, 144, 488), (240, 464, 352, 488), (384, 464, 488, 488), (48, 480, 144, 504), (240, 480, 360, 504), (456, 0, 512, 56), (0, 0, 56, 40), (440, 0, 512, 40), (0, 24, 48, 88), (48, 168, 168, 240)]
@@ -600,7 +550,6 @@ def main():
600
  # return wholecaption, str(list_box), json_file
601
 
602
  # pipeline, transp_vae = construction()
603
- construction()
604
 
605
  # gradio_test_one_sample_partial = partial(
606
  # svg_test_one_sample,
 
228
  iou = intersection_area / union_area
229
  return iou
230
 
231
+ def construction_all():
 
232
  global model
233
  global quantizer
234
  global tokenizer
235
+ global pipeline
236
+ global transp_vae
237
  from modeling_crello import CrelloModel, CrelloModelConfig
238
  from quantizer import get_quantizer
239
+ from custom_model_mmdit import CustomFluxTransformer2DModel
240
+ from custom_model_transp_vae import AutoencoderKLTransformerTraining as CustomVAE
241
+ from custom_pipeline import CustomFluxPipelineCfg
242
+
243
+ params_dict = {
244
+ "input_model": "/openseg_blob/v-sirui/temporary/2024-02-21/Layout_train/COLEv2/Design_LLM/checkpoint/Meta-Llama-3-8B",
245
+ "resume": "/openseg_blob/v-sirui/temporary/2024-02-21/SVD/Int2lay_1016/checkpoint/int2lay_1031/1031_test/checkpoint-26000/",
246
+ "seed": 0,
247
+ "mask_values": False,
248
+ "quantizer_version": 'v4',
249
+ "mask_type": 'cm3',
250
+ "decimal_quantize_types": [],
251
+ "num_mask_tokens": 0,
252
+ "width": 512,
253
+ "height": 512,
254
+ "device": 0,
255
+ }
256
+
257
  # seed / input model / resume
258
+ seed = params_dict.get('seed', None)
259
+ input_model = params_dict.get('input_model', None)
260
+ quantizer_version = params_dict.get('quantizer_version', 'v4')
 
261
 
262
  set_seed(seed)
263
  # old_tokenizer = AutoTokenizer.from_pretrained(input_model, trust_remote_code=True)
 
279
  quantizer = get_quantizer(
280
  quantizer_version,
281
  update_vocab = False,
282
+ decimal_quantize_types = params_dict.get('decimal_quantize_types'),
283
+ mask_values = params_dict['mask_values'],
284
+ width = params_dict['width'],
285
+ height = params_dict['height'],
286
  simplify_json = False,
287
  num_mask_tokens = 0,
288
+ mask_type = params_dict.get('mask_type'),
289
  )
290
  quantizer.setup_tokenizer(tokenizer)
291
 
 
298
  model_args.freeze_lm = False
299
  model_args.opt_version = input_model
300
  model_args.use_lora = False
301
+ model_args.load_in_4bit = params_dict.get('load_in_4bit', False)
 
 
 
 
302
 
303
  model = CrelloModel.from_pretrained(
304
  "WYBar/LLM_For_Layout_Planning",
 
314
  for token in added_special_tokens_list:
315
  quantizer.additional_special_tokens.add(token)
316
 
317
+ transformer = CustomFluxTransformer2DModel.from_pretrained(
318
+ "WYBar/ART_test_weights",
319
+ subfolder="fused_transformer",
320
+ torch_dtype=torch.bfloat16,
321
+ # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
322
+ )
323
+
324
+ transp_vae = CustomVAE.from_pretrained(
325
+ "WYBar/ART_test_weights",
326
+ subfolder="custom_vae",
327
+ torch_dtype=torch.float32,
328
+ use_safetensors=True,
329
+ # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
330
+ )
331
+
332
+ token = os.environ.get("HF_TOKEN")
333
+ pipeline = CustomFluxPipelineCfg.from_pretrained(
334
+ "black-forest-labs/FLUX.1-dev",
335
+ transformer=transformer,
336
+ torch_dtype=torch.bfloat16,
337
+ token=token,
338
+ # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
339
+ ).to("cuda")
340
+ pipeline.enable_model_cpu_offload(gpu_id=0) # Save GPU memory
341
+
342
+ print(f"before .to(device):{model.device} {model.lm.device} {pipeline.device}")
343
  model = model.to("cuda")
344
+ pipeline = pipeline.to("cuda")
345
+ print(f"after .to(device):{model.device} {model.lm.device} {pipeline.device}")
346
  model = model.bfloat16()
347
  model.eval()
348
+ print(f"after bf16 & eval .to(device):{model.device} {model.lm.device} {pipeline.device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
 
350
  @torch.no_grad()
351
  @spaces.GPU(duration=120)
352
+ def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, do_sample=False, temperature=1.0, top_p=1.0, top_k=50):
353
+ print(f"evaluate_v1 {model.device} {model.lm.device} {pipeline.device}")
354
  json_example = inputs
355
  input_intension = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":'
356
+
357
  print("tokenizer1")
358
  inputs = tokenizer(
359
  input_intension, return_tensors="pt"
 
392
  pred_json_example = None
393
  return pred_json_example
394
 
395
+ def inference(generate_method, intention, model, quantizer, tokenizer, width, height, do_sample=True, temperature=1.0, top_p=1.0, top_k=50):
396
  rawdata = {}
397
  rawdata["wholecaption"] = intention
398
  rawdata["layout"] = []
 
401
  max_try_time = 5
402
  preddata = None
403
  while preddata is None and max_try_time > 0:
404
+ preddata = evaluate_v1(rawdata, model, quantizer, tokenizer, width, height, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k)
405
  max_try_time -= 1
406
  else:
407
  print("Please input correct generate method")
 
409
 
410
  return preddata
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  @spaces.GPU(duration=120)
413
  def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae):
414
  print(validation_box)
 
439
  return output_gradio
440
 
441
  def svg_test_one_sample(validation_prompt, validation_box_str, seed, true_gs, inference_steps, pipeline, transp_vae):
442
+ print(f"svg_test_one_sample {model.device} {model.lm.device} {pipeline.device}")
443
  generator = torch.Generator().manual_seed(seed)
444
  try:
445
  validation_box = ast.literal_eval(validation_box_str)
 
473
  return result_images, svg_file_path
474
 
475
  def process_svg(text_input, tuple_input, seed, true_gs, inference_steps):
476
+ print(f"precess_svg {model.device} {model.lm.device} {pipeline.device}")
477
  result_images = []
478
  result_images, svg_file_path = svg_test_one_sample(text_input, tuple_input, seed, true_gs, inference_steps, pipeline=pipeline, transp_vae=transp_vae)
479
  # result_images, svg_file_path = gradio_test_one_sample_partial(text_input, tuple_input, seed, true_gs, inference_steps)
 
496
  """
497
 
498
  return result_images, svg_file_path, svg_editor
499
+
500
+ def process_preddate(intention, temperature, top_p, generate_method='v1'):
501
+ intention = intention.replace('\n', '').replace('\r', '').replace('\\', '')
502
+ intention = ensure_space_after_period(intention)
503
+ print(f"process_preddate: {model.lm.device}")
504
+ if temperature == 0.0:
505
+ # print("looking for greedy decoding strategies, set `do_sample=False`.")
506
+ # preddata = inference_partial(generate_method, intention, do_sample=False)
507
+ preddata = inference(generate_method, intention, model=model, quantizer=quantizer, tokenizer=tokenizer, width=512, height=512, do_sample=False)
508
+ else:
509
+ # preddata = inference_partial(generate_method, intention, temperature=temperature, top_p=top_p)
510
+ preddata = inference(generate_method, intention, model=model, quantizer=quantizer, tokenizer=tokenizer, width=512, height=512, temperature=temperature, top_p=top_p)
511
+
512
+ layouts = preddata["layout"]
513
+ list_box = []
514
+ for i, layout in enumerate(layouts):
515
+ x, y = layout["x"], layout["y"]
516
+ width, height = layout["width"], layout["height"]
517
+ if i == 0:
518
+ list_box.append((0, 0, width, height))
519
+ list_box.append((0, 0, width, height))
 
 
 
 
520
  else:
521
+ left = x - width // 2
522
+ top = y - height // 2
523
+ right = x + width // 2
524
+ bottom = y + height // 2
525
+ list_box.append((left, top, right, bottom))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
 
527
+ # print(list_box)
528
+ filtered_boxes = list_box[:2]
529
+ for i in range(2, len(list_box)):
530
+ keep = True
531
+ for j in range(1, len(filtered_boxes)):
532
+ iou = calculate_iou(list_box[i], filtered_boxes[j])
533
+ if iou > 0.65:
534
+ print(list_box[i], filtered_boxes[j])
535
+ keep = False
536
+ break
537
+ if keep:
538
+ filtered_boxes.append(list_box[i])
539
+
540
+ return str(filtered_boxes), intention, str(filtered_boxes)
541
+
542
+ def main():
543
+ construction_all()
544
+ print(f"after construction_all:{model.device} {model.lm.device} {pipeline.device}")
545
 
546
  # def process_preddate(intention, generate_method='v1'):
547
  # list_box = [(0, 0, 512, 512), (0, 0, 512, 512), (136, 184, 512, 512), (144, 0, 512, 512), (0, 0, 328, 136), (160, 112, 512, 360), (168, 112, 512, 360), (40, 232, 112, 296), (32, 88, 248, 176), (48, 424, 144, 448), (48, 464, 144, 488), (240, 464, 352, 488), (384, 464, 488, 488), (48, 480, 144, 504), (240, 480, 360, 504), (456, 0, 512, 56), (0, 0, 56, 40), (440, 0, 512, 40), (0, 24, 48, 88), (48, 168, 168, 240)]
 
550
  # return wholecaption, str(list_box), json_file
551
 
552
  # pipeline, transp_vae = construction()
 
553
 
554
  # gradio_test_one_sample_partial = partial(
555
  # svg_test_one_sample,