Daisuke Nishimatsu
commited on
Commit
·
61b2559
1
Parent(s):
e578635
feat(tools): add option for decode output in export_onnx (#1113)
Browse files- tools/export_onnx.py +6 -1
tools/export_onnx.py
CHANGED
@@ -49,6 +49,11 @@ def make_parser():
|
|
49 |
default=None,
|
50 |
nargs=argparse.REMAINDER,
|
51 |
)
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
return parser
|
54 |
|
@@ -78,7 +83,7 @@ def main():
|
|
78 |
ckpt = ckpt["model"]
|
79 |
model.load_state_dict(ckpt)
|
80 |
model = replace_module(model, nn.SiLU, SiLU)
|
81 |
-
model.head.decode_in_inference =
|
82 |
|
83 |
logger.info("loading checkpoint done.")
|
84 |
dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])
|
|
|
49 |
default=None,
|
50 |
nargs=argparse.REMAINDER,
|
51 |
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--decode_in_inference",
|
54 |
+
action="store_true",
|
55 |
+
help="decode in inference or not"
|
56 |
+
)
|
57 |
|
58 |
return parser
|
59 |
|
|
|
83 |
ckpt = ckpt["model"]
|
84 |
model.load_state_dict(ckpt)
|
85 |
model = replace_module(model, nn.SiLU, SiLU)
|
86 |
+
model.head.decode_in_inference = args.decode_in_inference
|
87 |
|
88 |
logger.info("loading checkpoint done.")
|
89 |
dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])
|