kaczmarj commited on
Commit
c6e4a1c
·
verified ·
1 Parent(s): 01a2351

export to onnx

Browse files
Files changed (2) hide show
  1. export_to_torchscript.py +13 -3
  2. model.onnx +3 -0
export_to_torchscript.py CHANGED
@@ -84,10 +84,20 @@ def ctranspath():
84
 
85
  model = ctranspath()
86
  model.head = torch.nn.Identity()
87
- td = torch.load(r"./ctranspath.pth")
88
  model.load_state_dict(td["model"], strict=True)
89
 
90
-
91
  jitted = torch.jit.script(model)
92
-
93
  torch.jit.save(jitted, "torchscript_model.pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  model = ctranspath()
86
  model.head = torch.nn.Identity()
87
+ td = torch.load("ctranspath.pth")
88
  model.load_state_dict(td["model"], strict=True)
89
 
 
90
  jitted = torch.jit.script(model)
 
91
  torch.jit.save(jitted, "torchscript_model.pt")
92
+
93
+ torch.onnx.export(
94
+ model,
95
+ args=torch.ones(1, 3, 224, 224),
96
+ f="model.onnx",
97
+ input_names=["image"],
98
+ output_names=["embedding"],
99
+ dynamic_axes={
100
+ "input": {0: "batch_size"},
101
+ "output": {0: "batch_size"},
102
+ },
103
+ )
model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eec3f58a2b57142c48411467858847f75b2726f7b22db877428fcf06aaac2958
3
+ size 112312109