fix-onnx-export (#3)
Browse files- Fix ONNX export (16686c5c06ad6a49f2aac8d22414b2554533c3b5)
- onnx_export.py +6 -7
onnx_export.py
CHANGED
@@ -41,7 +41,11 @@ def convert_onnx(model_id: str, task: str, folder: str, opset: int) -> List:
|
|
41 |
model_name = getattr(model, "name", None)
|
42 |
|
43 |
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
|
44 |
-
|
|
|
|
|
|
|
|
|
45 |
)
|
46 |
onnx_config = onnx_config_constructor(model.config)
|
47 |
|
@@ -66,12 +70,7 @@ def convert_onnx(model_id: str, task: str, folder: str, opset: int) -> List:
|
|
66 |
opset = onnx_config.DEFAULT_ONNX_OPSET
|
67 |
|
68 |
output = Path(folder).joinpath("model.onnx")
|
69 |
-
onnx_inputs, onnx_outputs = export(
|
70 |
-
model,
|
71 |
-
onnx_config,
|
72 |
-
opset,
|
73 |
-
output,
|
74 |
-
)
|
75 |
|
76 |
atol = onnx_config.ATOL_FOR_VALIDATION
|
77 |
if isinstance(atol, dict):
|
|
|
41 |
model_name = getattr(model, "name", None)
|
42 |
|
43 |
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
|
44 |
+
exporter="onnx",
|
45 |
+
model=model,
|
46 |
+
task=task,
|
47 |
+
model_name=model_name,
|
48 |
+
model_type=model_type,
|
49 |
)
|
50 |
onnx_config = onnx_config_constructor(model.config)
|
51 |
|
|
|
70 |
opset = onnx_config.DEFAULT_ONNX_OPSET
|
71 |
|
72 |
output = Path(folder).joinpath("model.onnx")
|
73 |
+
onnx_inputs, onnx_outputs = export(model, onnx_config, output, opset)
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
atol = onnx_config.ATOL_FOR_VALIDATION
|
76 |
if isinstance(atol, dict):
|