support setting max batch/workspace size when convert to trt (#692)
Browse files- tools/trt.py +4 -1
tools/trt.py
CHANGED
@@ -27,6 +27,8 @@ def make_parser():
|
|
27 |
help="pls input your expriment description file",
|
28 |
)
|
29 |
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
|
|
|
|
|
30 |
return parser
|
31 |
|
32 |
|
@@ -59,7 +61,8 @@ def main():
|
|
59 |
[x],
|
60 |
fp16_mode=True,
|
61 |
log_level=trt.Logger.INFO,
|
62 |
-
max_workspace_size=(1 <<
|
|
|
63 |
)
|
64 |
torch.save(model_trt.state_dict(), os.path.join(file_name, "model_trt.pth"))
|
65 |
logger.info("Converted TensorRT model done.")
|
|
|
27 |
help="pls input your expriment description file",
|
28 |
)
|
29 |
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
|
30 |
+
parser.add_argument("-w", '--workspace', type=int, default=32, help='max workspace size in detect')
|
31 |
+
parser.add_argument("-b", '--batch', type=int, default=1, help='max batch size in detect')
|
32 |
return parser
|
33 |
|
34 |
|
|
|
61 |
[x],
|
62 |
fp16_mode=True,
|
63 |
log_level=trt.Logger.INFO,
|
64 |
+
max_workspace_size=(1 << args.workspace),
|
65 |
+
max_batch_size=args.batch,
|
66 |
)
|
67 |
torch.save(model_trt.state_dict(), os.path.join(file_name, "model_trt.pth"))
|
68 |
logger.info("Converted TensorRT model done.")
|