jupyterjazz JoanFM commited on
Commit
3baf9e3
·
verified ·
1 Parent(s): 2122141

skip optimmum if not installed (#2)

Browse files

- skip optimmum if not installed (e7968d5b22061a5a5a5f996b52bf284520722c94)


Co-authored-by: Joan Fontanals Martínez <[email protected]>

Files changed (1) hide show
  1. configuration_bert.py +24 -15
configuration_bert.py CHANGED
@@ -17,11 +17,18 @@
17
  """ BERT model configuration"""
18
  from collections import OrderedDict
19
  from typing import Mapping
 
20
 
21
  from transformers.configuration_utils import PretrainedConfig
22
- from optimum.exporters.onnx.model_configs import OnnxConfig, BertOnnxConfig
23
  from transformers.utils import logging
24
 
 
 
 
 
 
 
 
25
 
26
  logger = logging.get_logger(__name__)
27
 
@@ -152,17 +159,19 @@ class JinaBertConfig(PretrainedConfig):
152
  self.emb_pooler = emb_pooler
153
  self.attn_implementation = attn_implementation
154
 
155
- class JinaBertOnnxConfig(BertOnnxConfig):
156
-
157
- @property
158
- def inputs(self) -> Mapping[str, Mapping[int, str]]:
159
- if self.task == "multiple-choice":
160
- dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
161
- else:
162
- dynamic_axis = {0: "batch", 1: "sequence"}
163
- return OrderedDict(
164
- [
165
- ("input_ids", dynamic_axis),
166
- ("attention_mask", dynamic_axis),
167
- ]
168
- )
 
 
 
17
  """ BERT model configuration"""
18
  from collections import OrderedDict
19
  from typing import Mapping
20
+ import warnings
21
 
22
  from transformers.configuration_utils import PretrainedConfig
 
23
  from transformers.utils import logging
24
 
25
+ try:
26
+ from optimum.exporters.onnx.model_configs import BertOnnxConfig
27
+ OPTIMUM_INSTALLED = True
28
+ except ImportError:
29
+ warnings.warn("optimum is not installed. To use OnnxConfig and BertOnnxConfig, make sure that `optimum` package is installed")
30
+ OPTIMUM_INSTALLED = False
31
+
32
 
33
  logger = logging.get_logger(__name__)
34
 
 
159
  self.emb_pooler = emb_pooler
160
  self.attn_implementation = attn_implementation
161
 
162
+ if OPTIMUM_INSTALLED:
163
+
164
+ class JinaBertOnnxConfig(BertOnnxConfig):
165
+
166
+ @property
167
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
168
+ if self.task == "multiple-choice":
169
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
170
+ else:
171
+ dynamic_axis = {0: "batch", 1: "sequence"}
172
+ return OrderedDict(
173
+ [
174
+ ("input_ids", dynamic_axis),
175
+ ("attention_mask", dynamic_axis),
176
+ ]
177
+ )