WYBar commited on
Commit
fcb494b
·
1 Parent(s): a162d7b

use safetensors

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -228,7 +228,7 @@ def calculate_iou(box1, box2):
228
  iou = intersection_area / union_area
229
  return iou
230
 
231
- # @spaces.GPU(enable_queue=True, duration=60)
232
  def buildmodel(**kwargs):
233
  from modeling_crello import CrelloModel, CrelloModelConfig
234
  from quantizer import get_quantizer
@@ -337,7 +337,7 @@ def construction_layout():
337
  return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
338
 
339
  @torch.no_grad()
340
- @spaces.GPU(duration=60)
341
  def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, device, do_sample=False, temperature=1.0, top_p=1.0, top_k=50):
342
  json_example = inputs
343
  input_intension = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":'
@@ -379,13 +379,15 @@ def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, device, do_s
379
  return pred_json_example
380
 
381
  def inference(generate_method, intention, model, quantizer, tokenizer, width, height, device, do_sample=True, temperature=1.0, top_p=1.0, top_k=50):
382
- def FormulateInput(intension: str):
383
- resdict = {}
384
- resdict["wholecaption"] = intension
385
- resdict["layout"] = []
386
- return resdict
387
-
388
- rawdata = FormulateInput(intention)
 
 
389
 
390
  if generate_method == 'v1':
391
  max_try_time = 5
@@ -399,7 +401,7 @@ def inference(generate_method, intention, model, quantizer, tokenizer, width, he
399
 
400
  return preddata
401
 
402
- # @spaces.GPU(enable_queue=True, duration=60)
403
  def construction():
404
  global pipeline
405
  global transp_vae
@@ -418,6 +420,7 @@ def construction():
418
  "WYBar/ART_test_weights",
419
  subfolder="custom_vae",
420
  torch_dtype=torch.float32,
 
421
  # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
422
  )
423
 
@@ -433,7 +436,7 @@ def construction():
433
 
434
  # return pipeline, transp_vae
435
 
436
- @spaces.GPU(duration=60)
437
  def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae):
438
  print(validation_box)
439
  output, rgba_output, _, _ = pipeline(
 
228
  iou = intersection_area / union_area
229
  return iou
230
 
231
+ # @spaces.GPU(enable_queue=True, duration=120)
232
  def buildmodel(**kwargs):
233
  from modeling_crello import CrelloModel, CrelloModelConfig
234
  from quantizer import get_quantizer
 
337
  return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
338
 
339
  @torch.no_grad()
340
+ @spaces.GPU(duration=120)
341
  def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, device, do_sample=False, temperature=1.0, top_p=1.0, top_k=50):
342
  json_example = inputs
343
  input_intension = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":'
 
379
  return pred_json_example
380
 
381
  def inference(generate_method, intention, model, quantizer, tokenizer, width, height, device, do_sample=True, temperature=1.0, top_p=1.0, top_k=50):
382
+ # def FormulateInput(intension: str):
383
+ # resdict = {}
384
+ # resdict["wholecaption"] = intension
385
+ # resdict["layout"] = []
386
+ # return resdict
387
+ # rawdata = FormulateInput(intention)
388
+ rawdata = {}
389
+ rawdata["wholecaption"] = intention
390
+ rawdata["layout"] = []
391
 
392
  if generate_method == 'v1':
393
  max_try_time = 5
 
401
 
402
  return preddata
403
 
404
+ # @spaces.GPU(enable_queue=True, duration=120)
405
  def construction():
406
  global pipeline
407
  global transp_vae
 
420
  "WYBar/ART_test_weights",
421
  subfolder="custom_vae",
422
  torch_dtype=torch.float32,
423
+ use_safetensors=True,
424
  # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
425
  )
426
 
 
436
 
437
  # return pipeline, transp_vae
438
 
439
+ @spaces.GPU(duration=120)
440
  def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae):
441
  print(validation_box)
442
  output, rgba_output, _, _ = pipeline(