File size: 7,326 Bytes
de4ade4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Basic HuggingFace -> ONNX export script.

This scripts show a basic HuggingFace -> ONNX export workflow. This works for a MPT model
that has been saved using `MPT.save_pretrained`. For more details and examples
of exporting and working with HuggingFace models with ONNX, see https://huggingface.co/docs/transformers/serialization#export-to-onnx.

Example usage:

    1) Local export

    python convert_hf_to_onnx.py --pretrained_model_name_or_path local/path/to/huggingface/folder --output_folder local/folder

    2) Remote export

    python convert_hf_to_onnx.py --pretrained_model_name_or_path local/path/to/huggingface/folder --output_folder s3://bucket/remote/folder

    3) Verify the exported model

    python convert_hf_to_onnx.py --pretrained_model_name_or_path local/path/to/huggingface/folder --output_folder local/folder --verify_export

    4) Change the batch size or max sequence length

    python convert_hf_to_onnx.py --pretrained_model_name_or_path local/path/to/huggingface/folder --output_folder local/folder --export_batch_size 1 --max_seq_len 32000
"""

import argparse
import os
from argparse import ArgumentTypeError
from pathlib import Path
from typing import Any, Dict, Optional, Union

import torch
from composer.utils import (maybe_create_object_store_from_uri, parse_uri,
                            reproducibility)
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer


def str2bool(v: Union[str, bool]):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise ArgumentTypeError('Boolean value expected.')


def str_or_bool(v: Union[str, bool]):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        return v


def gen_random_batch(batch_size: int, vocab_size: int, max_seq_len: int):
    # generate input batch of random data
    batch = {
        'input_ids':
            torch.randint(
                low=0,
                high=vocab_size,
                size=(batch_size, max_seq_len),
                dtype=torch.int64,
            ),
        'attention_mask':
            torch.ones(size=(batch_size, max_seq_len), dtype=torch.bool)
    }
    return batch


def export_to_onnx(
    pretrained_model_name_or_path: str,
    output_folder: str,
    export_batch_size: int,
    max_seq_len: Optional[int],
    verify_export: bool,
    from_pretrained_kwargs: Dict[str, Any],
):
    reproducibility.seed_all(42)
    save_object_store = maybe_create_object_store_from_uri(output_folder)
    _, _, parsed_save_path = parse_uri(output_folder)

    print('Loading HF config/model/tokenizer...')
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
                                              **from_pretrained_kwargs)
    config = AutoConfig.from_pretrained(pretrained_model_name_or_path,
                                        **from_pretrained_kwargs)

    # specifically for MPT, switch to the torch version of attention for ONNX export
    if hasattr(config, 'attn_config'):
        config.attn_config['attn_impl'] = 'torch'

    model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path,
                                                 config=config,
                                                 **from_pretrained_kwargs)
    model.eval()

    if max_seq_len is None and not hasattr(model.config, 'max_seq_len'):
        raise ValueError(
            'max_seq_len must be specified in either the model config or as an argument to this function.'
        )
    elif max_seq_len is None:
        max_seq_len = model.config.max_seq_len

    assert isinstance(max_seq_len, int)  # pyright

    print('Creating random batch...')
    sample_input = gen_random_batch(
        export_batch_size,
        len(tokenizer),
        max_seq_len,
    )

    with torch.no_grad():
        model(**sample_input)

    output_file = Path(parsed_save_path) / 'model.onnx'
    os.makedirs(parsed_save_path, exist_ok=True)
    print('Exporting the model with ONNX...')
    torch.onnx.export(
        model,
        (sample_input,),
        str(output_file),
        input_names=['input_ids', 'attention_mask'],
        output_names=['output'],
        opset_version=16,
    )

    if verify_export:
        with torch.no_grad():
            orig_out = model(**sample_input)

        import onnx
        import onnx.checker
        import onnxruntime as ort

        _ = onnx.load(str(output_file))

        onnx.checker.check_model(str(output_file))

        ort_session = ort.InferenceSession(str(output_file))

        for key, value in sample_input.items():
            sample_input[key] = value.cpu().numpy()

        loaded_model_out = ort_session.run(None, sample_input)

        torch.testing.assert_close(
            orig_out.logits.detach().numpy(),
            loaded_model_out[0],
            rtol=1e-2,
            atol=1e-2,
            msg=f'output mismatch between the orig and onnx exported model',
        )
        print('exported model ouptut matches with unexported model!!')

    if save_object_store is not None:
        print('Uploading files to object storage...')
        for filename in os.listdir(parsed_save_path):
            full_path = str(Path(parsed_save_path) / filename)
            save_object_store.upload_object(full_path, full_path)


def parse_args():
    parser = argparse.ArgumentParser(description='Convert HF model to ONNX',)
    parser.add_argument(
        '--pretrained_model_name_or_path',
        type=str,
        required=True,
    )
    parser.add_argument(
        '--output_folder',
        type=str,
        required=True,
    )
    parser.add_argument(
        '--export_batch_size',
        type=int,
        default=8,
    )
    parser.add_argument(
        '--max_seq_len',
        type=int,
        default=None,
    )
    parser.add_argument(
        '--verify_export',
        action='store_true',
    )
    parser.add_argument('--trust_remote_code',
                        type=str2bool,
                        nargs='?',
                        const=True,
                        default=True)
    parser.add_argument('--use_auth_token',
                        type=str_or_bool,
                        nargs='?',
                        const=True,
                        default=None)
    parser.add_argument('--revision', type=str, default=None)
    return parser.parse_args()


def main(args: argparse.Namespace):
    from_pretrained_kwargs = {
        'use_auth_token': args.use_auth_token,
        'trust_remote_code': args.trust_remote_code,
        'revision': args.revision,
    }

    export_to_onnx(
        pretrained_model_name_or_path=args.pretrained_model_name_or_path,
        output_folder=args.output_folder,
        export_batch_size=args.export_batch_size,
        max_seq_len=args.max_seq_len,
        verify_export=args.verify_export,
        from_pretrained_kwargs=from_pretrained_kwargs)


if __name__ == '__main__':
    main(parse_args())