WYBar commited on
Commit
dac0408
·
1 Parent(s): 3a4bde2

delete local args for global models

Browse files
Files changed (1) hide show
  1. app.py +13 -14
app.py CHANGED
@@ -360,11 +360,11 @@ def construction_all():
360
 
361
  @torch.no_grad()
362
  @spaces.GPU(duration=120)
363
- def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, do_sample=False, temperature=1.0, top_p=1.0, top_k=50):
364
  print(f"evaluate_v1")
365
  print(f"evaluate_v1 {model.device} {model.lm.device} {pipeline.device}")
366
- model = model.to("cuda")
367
- print(f"after evaluate_v1 {model.device} {model.lm.device} {pipeline.device}")
368
 
369
  json_example = inputs
370
  input_intension = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":'
@@ -407,7 +407,7 @@ def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, do_sample=Fa
407
  pred_json_example = None
408
  return pred_json_example
409
 
410
- def inference(generate_method, intention, model, quantizer, tokenizer, width, height, do_sample=True, temperature=1.0, top_p=1.0, top_k=50):
411
  print(f"start inference")
412
  rawdata = {}
413
  rawdata["wholecaption"] = intention
@@ -418,7 +418,7 @@ def inference(generate_method, intention, model, quantizer, tokenizer, width, he
418
  preddata = None
419
  while preddata is None and max_try_time > 0:
420
  print(f"preddata = evaluate_v1")
421
- preddata = evaluate_v1(rawdata, model, quantizer, tokenizer, width, height, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k)
422
  max_try_time -= 1
423
  else:
424
  print("Please input correct generate method")
@@ -434,11 +434,11 @@ def process_preddate(intention, temperature, top_p, generate_method='v1'):
434
  # print("looking for greedy decoding strategies, set `do_sample=False`.")
435
  # preddata = inference_partial(generate_method, intention, do_sample=False)
436
  print(f"preddata = inference temperatrue = 0.0")
437
- preddata = inference(generate_method, intention, model=model, quantizer=quantizer, tokenizer=tokenizer, width=512, height=512, do_sample=False)
438
  else:
439
  # preddata = inference_partial(generate_method, intention, temperature=temperature, top_p=top_p)
440
  print(f"preddata = inference temperatrue != 0.0")
441
- preddata = inference(generate_method, intention, model=model, quantizer=quantizer, tokenizer=tokenizer, width=512, height=512, temperature=temperature, top_p=top_p)
442
 
443
  layouts = preddata["layout"]
444
  list_box = []
@@ -471,8 +471,9 @@ def process_preddate(intention, temperature, top_p, generate_method='v1'):
471
  return str(filtered_boxes), intention, str(filtered_boxes)
472
 
473
  @spaces.GPU(duration=120)
474
- def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae):
475
  print(f"test_one_sample: {validation_box}")
 
476
  output, rgba_output, _, _ = pipeline(
477
  prompt=validation_prompt,
478
  validation_box=validation_box,
@@ -500,11 +501,9 @@ def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps,
500
 
501
  return output_gradio
502
 
503
- def gradio_test_one_sample(validation_prompt, validation_box_str, seed, true_gs, inference_steps, pipeline, transp_vae):
504
  print(f"svg_test_one_sample")
505
- # print(f"svg_test_one_sample {model.device} {model.lm.device} {pipeline.device}")
506
- # generator = torch.Generator().manual_seed(seed)
507
- generator = torch.Generator(device=torch.device("cuda", index=0)).manual_seed(seed)
508
  try:
509
  if isinstance(validation_box_str, (list, tuple)):
510
  validation_box = validation_box_str
@@ -518,7 +517,7 @@ def gradio_test_one_sample(validation_prompt, validation_box_str, seed, true_gs,
518
  validation_box = adjust_validation_box(validation_box)
519
 
520
  print("result_images = test_one_sample")
521
- result_images = test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae)
522
  print("after result_images = test_one_sample")
523
  svg_img = pngs_to_svg(result_images[1:])
524
 
@@ -543,7 +542,7 @@ def process_svg(text_input, tuple_input, seed, true_gs, inference_steps):
543
  print(f"precess_svg")
544
  # print(f"precess_svg {model.device} {model.lm.device} {pipeline.device}")
545
  result_images = []
546
- result_images, svg_file_path = gradio_test_one_sample(text_input, tuple_input, seed, true_gs, inference_steps, pipeline=pipeline, transp_vae=transp_vae)
547
  # result_images, svg_file_path = gradio_test_one_sample_partial(text_input, tuple_input, seed, true_gs, inference_steps)
548
 
549
  url, unique_filename = upload_to_github(file_path=svg_file_path)
 
360
 
361
  @torch.no_grad()
362
  @spaces.GPU(duration=120)
363
+ def evaluate_v1(inputs, width, height, do_sample=False, temperature=1.0, top_p=1.0, top_k=50):
364
  print(f"evaluate_v1")
365
  print(f"evaluate_v1 {model.device} {model.lm.device} {pipeline.device}")
366
+ # model = model.to("cuda")
367
+ # print(f"after evaluate_v1 {model.device} {model.lm.device} {pipeline.device}")
368
 
369
  json_example = inputs
370
  input_intension = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":'
 
407
  pred_json_example = None
408
  return pred_json_example
409
 
410
+ def inference(generate_method, intention, width, height, do_sample=True, temperature=1.0, top_p=1.0, top_k=50):
411
  print(f"start inference")
412
  rawdata = {}
413
  rawdata["wholecaption"] = intention
 
418
  preddata = None
419
  while preddata is None and max_try_time > 0:
420
  print(f"preddata = evaluate_v1")
421
+ preddata = evaluate_v1(rawdata, width, height, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k)
422
  max_try_time -= 1
423
  else:
424
  print("Please input correct generate method")
 
434
  # print("looking for greedy decoding strategies, set `do_sample=False`.")
435
  # preddata = inference_partial(generate_method, intention, do_sample=False)
436
  print(f"preddata = inference temperatrue = 0.0")
437
+ preddata = inference(generate_method, intention, width=512, height=512, do_sample=False)
438
  else:
439
  # preddata = inference_partial(generate_method, intention, temperature=temperature, top_p=top_p)
440
  print(f"preddata = inference temperatrue != 0.0")
441
+ preddata = inference(generate_method, intention, width=512, height=512, temperature=temperature, top_p=top_p)
442
 
443
  layouts = preddata["layout"]
444
  list_box = []
 
471
  return str(filtered_boxes), intention, str(filtered_boxes)
472
 
473
  @spaces.GPU(duration=120)
474
+ def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, seed):
475
  print(f"test_one_sample: {validation_box}")
476
+ generator = torch.Generator(device=torch.device("cuda", index=0)).manual_seed(seed)
477
  output, rgba_output, _, _ = pipeline(
478
  prompt=validation_prompt,
479
  validation_box=validation_box,
 
501
 
502
  return output_gradio
503
 
504
+ def gradio_test_one_sample(validation_prompt, validation_box_str, seed, true_gs, inference_steps):
505
  print(f"svg_test_one_sample")
506
+
 
 
507
  try:
508
  if isinstance(validation_box_str, (list, tuple)):
509
  validation_box = validation_box_str
 
517
  validation_box = adjust_validation_box(validation_box)
518
 
519
  print("result_images = test_one_sample")
520
+ result_images = test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, seed)
521
  print("after result_images = test_one_sample")
522
  svg_img = pngs_to_svg(result_images[1:])
523
 
 
542
  print(f"precess_svg")
543
  # print(f"precess_svg {model.device} {model.lm.device} {pipeline.device}")
544
  result_images = []
545
+ result_images, svg_file_path = gradio_test_one_sample(text_input, tuple_input, seed, true_gs, inference_steps)
546
  # result_images, svg_file_path = gradio_test_one_sample_partial(text_input, tuple_input, seed, true_gs, inference_steps)
547
 
548
  url, unique_filename = upload_to_github(file_path=svg_file_path)