skip optimmum if not installed (#2)
Browse files- skip optimmum if not installed (e7968d5b22061a5a5a5f996b52bf284520722c94)
Co-authored-by: Joan Fontanals Martínez <[email protected]>
- configuration_bert.py +24 -15
configuration_bert.py
CHANGED
@@ -17,11 +17,18 @@
|
|
17 |
""" BERT model configuration"""
|
18 |
from collections import OrderedDict
|
19 |
from typing import Mapping
|
|
|
20 |
|
21 |
from transformers.configuration_utils import PretrainedConfig
|
22 |
-
from optimum.exporters.onnx.model_configs import OnnxConfig, BertOnnxConfig
|
23 |
from transformers.utils import logging
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
logger = logging.get_logger(__name__)
|
27 |
|
@@ -152,17 +159,19 @@ class JinaBertConfig(PretrainedConfig):
|
|
152 |
self.emb_pooler = emb_pooler
|
153 |
self.attn_implementation = attn_implementation
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
17 |
""" BERT model configuration"""
|
18 |
from collections import OrderedDict
|
19 |
from typing import Mapping
|
20 |
+
import warnings
|
21 |
|
22 |
from transformers.configuration_utils import PretrainedConfig
|
|
|
23 |
from transformers.utils import logging
|
24 |
|
25 |
+
try:
|
26 |
+
from optimum.exporters.onnx.model_configs import BertOnnxConfig
|
27 |
+
OPTIMUM_INSTALLED = True
|
28 |
+
except ImportError:
|
29 |
+
warnings.warn("optimum is not installed. To use OnnxConfig and BertOnnxConfig, make sure that `optimum` package is installed")
|
30 |
+
OPTIMUM_INSTALLED = False
|
31 |
+
|
32 |
|
33 |
logger = logging.get_logger(__name__)
|
34 |
|
|
|
159 |
self.emb_pooler = emb_pooler
|
160 |
self.attn_implementation = attn_implementation
|
161 |
|
162 |
+
if OPTIMUM_INSTALLED:
|
163 |
+
|
164 |
+
class JinaBertOnnxConfig(BertOnnxConfig):
|
165 |
+
|
166 |
+
@property
|
167 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
168 |
+
if self.task == "multiple-choice":
|
169 |
+
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
170 |
+
else:
|
171 |
+
dynamic_axis = {0: "batch", 1: "sequence"}
|
172 |
+
return OrderedDict(
|
173 |
+
[
|
174 |
+
("input_ids", dynamic_axis),
|
175 |
+
("attention_mask", dynamic_axis),
|
176 |
+
]
|
177 |
+
)
|