File size: 1,980 Bytes
e0c2d04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
# Copyright 2022, Lefebvre Dalloz Services
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
"""
This module add quantization support to all Roberta architecture based models.
"""
import torch
import torch.utils.checkpoint
from transformer_deploy.QDQModels.ast_utils import PatchModule
def qdq_create_position_tensorrt(input_ids, padding_idx, past_key_values_length=0):
"""
Override qdq_create_position_tensorrt function.
It appeared that cumsum operator in TensorRT doesn't support integer type.
see https://github.com/onnx/onnx-tensorrt/blob/master/docs/operators.md
This override uses float instead.
"""
# QDQ change below
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
# int() -> float() because of a limitations in cumsum operator implementation in TensorRT
mask = input_ids.ne(padding_idx).float()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx
qdq_roberta_mapping: PatchModule = PatchModule(
module="transformers.models.roberta.modeling_roberta",
monkey_patch={
"create_position_ids_from_input_ids": (qdq_create_position_tensorrt, "qdq_create_position_tensorrt"),
},
)
|