wwqgtxx commited on
Commit
49504ee
·
1 Parent(s): de572cd

support setting max batch/workspace size when convert to trt (#692)

Browse files
Files changed (1) hide show
  1. tools/trt.py +4 -1
tools/trt.py CHANGED
@@ -27,6 +27,8 @@ def make_parser():
27
  help="pls input your expriment description file",
28
  )
29
  parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
 
 
30
  return parser
31
 
32
 
@@ -59,7 +61,8 @@ def main():
59
  [x],
60
  fp16_mode=True,
61
  log_level=trt.Logger.INFO,
62
- max_workspace_size=(1 << 32),
 
63
  )
64
  torch.save(model_trt.state_dict(), os.path.join(file_name, "model_trt.pth"))
65
  logger.info("Converted TensorRT model done.")
 
27
  help="pls input your expriment description file",
28
  )
29
  parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
30
+ parser.add_argument("-w", '--workspace', type=int, default=32, help='max workspace size in detect')
31
+ parser.add_argument("-b", '--batch', type=int, default=1, help='max batch size in detect')
32
  return parser
33
 
34
 
 
61
  [x],
62
  fp16_mode=True,
63
  log_level=trt.Logger.INFO,
64
+ max_workspace_size=(1 << args.workspace),
65
+ max_batch_size=args.batch,
66
  )
67
  torch.save(model_trt.state_dict(), os.path.join(file_name, "model_trt.pth"))
68
  logger.info("Converted TensorRT model done.")