File size: 1,980 Bytes
e0c2d04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#  Copyright 2022, Lefebvre Dalloz Services
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
"""
This module add quantization support to all Roberta architecture based models.
"""


import torch
import torch.utils.checkpoint

from transformer_deploy.QDQModels.ast_utils import PatchModule


def qdq_create_position_tensorrt(input_ids, padding_idx, past_key_values_length=0):
    """
    Override qdq_create_position_tensorrt function.
    It appeared that cumsum operator in TensorRT doesn't support integer type.
    see https://github.com/onnx/onnx-tensorrt/blob/master/docs/operators.md
    This override uses float instead.
    """
    # QDQ change below
    # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
    # int() -> float() because of a limitations in cumsum operator implementation in TensorRT
    mask = input_ids.ne(padding_idx).float()
    incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
    return incremental_indices.long() + padding_idx


qdq_roberta_mapping: PatchModule = PatchModule(
    module="transformers.models.roberta.modeling_roberta",
    monkey_patch={
        "create_position_ids_from_input_ids": (qdq_create_position_tensorrt, "qdq_create_position_tensorrt"),
    },
)