use safetensors
Browse files
app.py
CHANGED
@@ -228,7 +228,7 @@ def calculate_iou(box1, box2):
|
|
228 |
iou = intersection_area / union_area
|
229 |
return iou
|
230 |
|
231 |
-
# @spaces.GPU(enable_queue=True, duration=
|
232 |
def buildmodel(**kwargs):
|
233 |
from modeling_crello import CrelloModel, CrelloModelConfig
|
234 |
from quantizer import get_quantizer
|
@@ -337,7 +337,7 @@ def construction_layout():
|
|
337 |
return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
|
338 |
|
339 |
@torch.no_grad()
|
340 |
-
@spaces.GPU(duration=
|
341 |
def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, device, do_sample=False, temperature=1.0, top_p=1.0, top_k=50):
|
342 |
json_example = inputs
|
343 |
input_intension = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":'
|
@@ -379,13 +379,15 @@ def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, device, do_s
|
|
379 |
return pred_json_example
|
380 |
|
381 |
def inference(generate_method, intention, model, quantizer, tokenizer, width, height, device, do_sample=True, temperature=1.0, top_p=1.0, top_k=50):
|
382 |
-
def FormulateInput(intension: str):
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
rawdata =
|
|
|
|
|
389 |
|
390 |
if generate_method == 'v1':
|
391 |
max_try_time = 5
|
@@ -399,7 +401,7 @@ def inference(generate_method, intention, model, quantizer, tokenizer, width, he
|
|
399 |
|
400 |
return preddata
|
401 |
|
402 |
-
# @spaces.GPU(enable_queue=True, duration=
|
403 |
def construction():
|
404 |
global pipeline
|
405 |
global transp_vae
|
@@ -418,6 +420,7 @@ def construction():
|
|
418 |
"WYBar/ART_test_weights",
|
419 |
subfolder="custom_vae",
|
420 |
torch_dtype=torch.float32,
|
|
|
421 |
# cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
|
422 |
)
|
423 |
|
@@ -433,7 +436,7 @@ def construction():
|
|
433 |
|
434 |
# return pipeline, transp_vae
|
435 |
|
436 |
-
@spaces.GPU(duration=
|
437 |
def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae):
|
438 |
print(validation_box)
|
439 |
output, rgba_output, _, _ = pipeline(
|
|
|
228 |
iou = intersection_area / union_area
|
229 |
return iou
|
230 |
|
231 |
+
# @spaces.GPU(enable_queue=True, duration=120)
|
232 |
def buildmodel(**kwargs):
|
233 |
from modeling_crello import CrelloModel, CrelloModelConfig
|
234 |
from quantizer import get_quantizer
|
|
|
337 |
return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
|
338 |
|
339 |
@torch.no_grad()
|
340 |
+
@spaces.GPU(duration=120)
|
341 |
def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, device, do_sample=False, temperature=1.0, top_p=1.0, top_k=50):
|
342 |
json_example = inputs
|
343 |
input_intension = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":'
|
|
|
379 |
return pred_json_example
|
380 |
|
381 |
def inference(generate_method, intention, model, quantizer, tokenizer, width, height, device, do_sample=True, temperature=1.0, top_p=1.0, top_k=50):
|
382 |
+
# def FormulateInput(intension: str):
|
383 |
+
# resdict = {}
|
384 |
+
# resdict["wholecaption"] = intension
|
385 |
+
# resdict["layout"] = []
|
386 |
+
# return resdict
|
387 |
+
# rawdata = FormulateInput(intention)
|
388 |
+
rawdata = {}
|
389 |
+
rawdata["wholecaption"] = intention
|
390 |
+
rawdata["layout"] = []
|
391 |
|
392 |
if generate_method == 'v1':
|
393 |
max_try_time = 5
|
|
|
401 |
|
402 |
return preddata
|
403 |
|
404 |
+
# @spaces.GPU(enable_queue=True, duration=120)
|
405 |
def construction():
|
406 |
global pipeline
|
407 |
global transp_vae
|
|
|
420 |
"WYBar/ART_test_weights",
|
421 |
subfolder="custom_vae",
|
422 |
torch_dtype=torch.float32,
|
423 |
+
use_safetensors=True,
|
424 |
# cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
|
425 |
)
|
426 |
|
|
|
436 |
|
437 |
# return pipeline, transp_vae
|
438 |
|
439 |
+
@spaces.GPU(duration=120)
|
440 |
def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae):
|
441 |
print(validation_box)
|
442 |
output, rgba_output, _, _ = pipeline(
|