export to onnx
Browse files- export_to_torchscript.py +13 -3
- model.onnx +3 -0
export_to_torchscript.py
CHANGED
@@ -84,10 +84,20 @@ def ctranspath():
|
|
84 |
|
85 |
model = ctranspath()
|
86 |
model.head = torch.nn.Identity()
|
87 |
-
td = torch.load(
|
88 |
model.load_state_dict(td["model"], strict=True)
|
89 |
|
90 |
-
|
91 |
jitted = torch.jit.script(model)
|
92 |
-
|
93 |
torch.jit.save(jitted, "torchscript_model.pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
model = ctranspath()
|
86 |
model.head = torch.nn.Identity()
|
87 |
+
td = torch.load("ctranspath.pth")
|
88 |
model.load_state_dict(td["model"], strict=True)
|
89 |
|
|
|
90 |
jitted = torch.jit.script(model)
|
|
|
91 |
torch.jit.save(jitted, "torchscript_model.pt")
|
92 |
+
|
93 |
+
torch.onnx.export(
|
94 |
+
model,
|
95 |
+
args=torch.ones(1, 3, 224, 224),
|
96 |
+
f="model.onnx",
|
97 |
+
input_names=["image"],
|
98 |
+
output_names=["embedding"],
|
99 |
+
dynamic_axes={
|
100 |
+
"input": {0: "batch_size"},
|
101 |
+
"output": {0: "batch_size"},
|
102 |
+
},
|
103 |
+
)
|
model.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eec3f58a2b57142c48411467858847f75b2726f7b22db877428fcf06aaac2958
|
3 |
+
size 112312109
|