|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This module add quantization support to all Deberta architecture based models. |
|
For now, Deberta export to ONNX doesn't work well. |
|
This PR may help: https://github.com/microsoft/DeBERTa/pull/6 |
|
""" |
|
|
|
import torch |
|
|
|
from transformer_deploy.QDQModels.ast_utils import PatchModule |
|
|
|
|
|
def get_attention_mask(self, attention_mask): |
|
""" |
|
Override existing get_attention_mask method in DebertaV2Encoder class. |
|
This one uses signed integers instead of unsigned one. |
|
""" |
|
if attention_mask.dim() <= 2: |
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) |
|
|
|
|
|
elif attention_mask.dim() == 3: |
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
return attention_mask |
|
|
|
|
|
def symbolic(g, self, mask, dim): |
|
""" |
|
Override existing symbolic static function of Xsoftmax class. |
|
This one uses signed integers instead of unsigned one. |
|
Symbolic function are used during ONNX conversion instead of Pytorch code. |
|
""" |
|
import torch.onnx.symbolic_helper as sym_help |
|
from torch.onnx.symbolic_opset9 import masked_fill, softmax |
|
|
|
mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"]) |
|
|
|
|
|
r_mask = g.op( |
|
"Cast", |
|
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), |
|
to_i=sym_help.cast_pytorch_to_onnx["Char"], |
|
) |
|
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf")))) |
|
output = softmax(g, output, dim) |
|
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int8))) |
|
|
|
|
|
qdq_deberta_mapping: PatchModule = PatchModule( |
|
module="transformers.models.deberta.modeling_deberta", |
|
monkey_patch={ |
|
"XSoftmax.symbolic": (symbolic, "symbolic"), |
|
"DebertaEncoder.get_attention_mask": (get_attention_mask, "get_attention_mask"), |
|
}, |
|
) |
|
|
|
|
|
qdq_deberta_v2_mapping: PatchModule = PatchModule( |
|
module="transformers.models.deberta_v2.modeling_deberta_v2", |
|
monkey_patch={ |
|
"XSoftmax.symbolic": (symbolic, "symbolic"), |
|
"DebertaV2Encoder.get_attention_mask": (get_attention_mask, "get_attention_mask"), |
|
}, |
|
) |
|
|