File size: 27,140 Bytes
3e0f6bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
883c203
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
import psutil
from transformers import (
    AutoConfig,
    T5ForConditionalGeneration,
    MT5ForConditionalGeneration,
)
import torch
import time
import gradio as gr
from transformers import AutoTokenizer
import onnxruntime as ort
from transformers.modeling_outputs import (
    Seq2SeqLMOutput,
    BaseModelOutput,
)
import os
from pathlib import Path
from progress.bar import Bar
import operator
import functools
from onnxruntime import (
    GraphOptimizationLevel,
    InferenceSession,
    SessionOptions,
    ExecutionMode,
)
_auth_token = None


def set_auth_token(token):
    """Set the token which allows the user to authenticate to hugginface.co for downloading private models

    Args:
        token (Union[str, bool]): The token value to store. One of:
            - an API key (from https://huggingface.co/organizations/ORGNAME/settings/token),
            - a login token obtained by running `$ transformers-cli login`
            - `True`, which tells transformers to use the login token stored in ~/.huggingface/token

    Returns:
        None
    """
    global _auth_token
    _auth_token = token


def get_auth_token():
    """Get the user-configurable auth token, which defaults to None

    Returns:
        auth_token (Optional[Union[str, bool]]) for authenticating with huggingface.co
    """
    global _auth_token
    return _auth_token


os.environ["OMP_NUM_THREADS"] = str(psutil.cpu_count(logical=True))
os.environ["OMP_WAIT_POLICY"] = "ACTIVE"


def get_onnx_runtime_sessions(
    model_paths,
    default: bool = True,
    opt_level: int = 99,
    parallel_exe_mode: bool = True,
    n_threads: int = 0,
    provider=[
        "CPUExecutionProvider",
    ],
) -> InferenceSession:
    """
            Optimizes the model

    Args:
        model_paths (List or Tuple of str) : the path to, in order:
            path_to_encoder (str) : the path of input onnx encoder model.
            path_to_decoder (str) : the path of input onnx decoder model.
            path_to_initial_decoder (str) :  the path of input initial onnx decoder model.
        default : set this to true, ort will choose the best settings for your hardware.
                  (you can test out different settings for better results.)
        opt_level (int) : sess_options.GraphOptimizationLevel param if set 1 uses 'ORT_ENABLE_BASIC',
                          2 for 'ORT_ENABLE_EXTENDED' and 99 for 'ORT_ENABLE_ALL',
                          default value is set to 99.
        parallel_exe_mode (bool) :  Sets the execution mode. Default is True (parallel).
        n_threads (int) :  Sets the number of threads used to parallelize the execution within nodes. Default is 0 to let onnxruntime choose
        provider : execution providers list.

    Returns:
        encoder_session : encoder onnx InferenceSession
        decoder_session : decoder onnx InferenceSession
        decoder_sess_init : initial decoder onnx InferenceSession

    """
    path_to_encoder, path_to_decoder, path_to_initial_decoder = model_paths

    if default:

        encoder_sess = InferenceSession(str(path_to_encoder))

        decoder_sess = InferenceSession(str(path_to_decoder))

        decoder_sess_init = InferenceSession(str(path_to_initial_decoder))

    else:

        # Few properties that might have an impact on performances
        options = SessionOptions()

        if opt_level == 1:
            options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
        elif opt_level == 2:
            options.graph_optimization_level = (
                GraphOptimizationLevel.ORT_ENABLE_EXTENDED
            )
        else:
            assert opt_level == 99
            options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL

        # set this true for better performance
        if parallel_exe_mode == True:
            options.execution_mode = ExecutionMode.ORT_PARALLEL
        else:
            options.execution_mode = ExecutionMode.ORT_SEQUENTIAL

        options.intra_op_num_threads = n_threads
        # options.inter_op_num_threads = 10

        # options.enable_profiling = True

        encoder_sess = InferenceSession(
            str(path_to_encoder), options, providers=provider
        )

        decoder_sess = InferenceSession(
            str(path_to_decoder), options, providers=provider
        )

        decoder_sess_init = InferenceSession(
            str(path_to_initial_decoder), options, providers=provider
        )

    return encoder_sess, decoder_sess, decoder_sess_init


class DecoderWithLMhead(torch.nn.Module):
    """ Creation of a class to combine the decoder and the lm head """

    def __init__(self, decoder, lm_head, config):
        super().__init__()
        self.decoder = decoder
        self.lm_head = lm_head
        self.config = config

    def forward(self, *inputs):

        input_ids, attention_mask, encoder_hidden_states = inputs[:3]

        list_pkv = inputs[3:]
        past_key_values = tuple(list_pkv[i: i + 4]
                                for i in range(0, len(list_pkv), 4))

        decoder_output = self.decoder(
            input_ids=input_ids,  # decoder_input_ids
            encoder_attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            past_key_values=past_key_values,
        )

        lm_head_out = self.lm_head(
            decoder_output[0] * (self.config.d_model ** -0.5))

        return lm_head_out, decoder_output[1]


class T5Encoder(torch.nn.Module):
    """ Creation of a class to output only the last hidden state from the encoder """

    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder

    def forward(self, *input, **kwargs):
        return self.encoder(*input, **kwargs)[0]


class DecoderWithLMheadInitial(torch.nn.Module):
    """ Creation of a class to combine the decoder and the lm head """

    def __init__(self, decoder, lm_head, config):
        super().__init__()
        self.decoder = decoder
        self.lm_head = lm_head
        self.config = config

    def forward(self, input_ids, attention_mask, encoder_hidden_states):
        decoder_output = self.decoder(
            input_ids=input_ids,
            encoder_attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
        )

        return (
            self.lm_head(decoder_output[0] * (self.config.d_model ** -0.5)),
            decoder_output[1],
        )


_folder = Path.cwd()
saved_models_path = _folder.joinpath("models")

Bar.check_tty = False


def create_t5_encoder_decoder(pretrained_version="t5-base"):
    """Generates an encoder and a decoder model with a language model head from a pretrained huggingface model

    Args:
        pretrained_version (str): Name of a pretrained model, or path to a pretrained / finetuned version of T5

    Returns:
        simplified_encoder: pytorch t5 encoder with a wrapper to output only the hidden states
        decoder_with_lm_head: pytorch t5 decoder with a language modeling head
    """

    if 'mt5' in pretrained_version:
        model = MT5ForConditionalGeneration.from_pretrained(
            pretrained_version, use_auth_token=get_auth_token())
    else:
        model = T5ForConditionalGeneration.from_pretrained(
            pretrained_version, use_auth_token=get_auth_token())

    return turn_model_into_encoder_decoder(model)


def turn_model_into_encoder_decoder(model):
    encoder = model.encoder
    decoder = model.decoder
    lm_head = model.lm_head

    decoder_with_lm_head = DecoderWithLMhead(decoder, lm_head, model.config)
    simplified_encoder = T5Encoder(encoder)
    decoder_with_lm_head_init = DecoderWithLMheadInitial(
        decoder, lm_head, model.config)

    return simplified_encoder, decoder_with_lm_head, decoder_with_lm_head_init


def generate_onnx_representation(
    pretrained_version=None,
    model=None,
    output_path=None,
    input_sequence_length=256,
    onnx_opset_version=12,  # no other opset versions are tested, change at your own risk
):
    """Exports a given huggingface pretrained model, or a given model and tokenizer, to onnx

    Args:
        pretrained_version (str): Name of a pretrained model, or path to a pretrained / finetuned version of T5
        output_path (Optional[str]): if missing then use ./models
        input_sequence_length (Optional[int]): typical input sequence length, for use by the ORT for possible optimization
        onnx_opset_version (Optional[int]): ONNX Operator Set Version, default 12 is the only tested version
    """
    if (pretrained_version is None) and model is None:
        print(
            "You need to specify pretrained_version (the pretrained model you wish to export). Alternatively you can export a model you have in memory."
        )
        return

    if model is not None:
        (
            simplified_encoder,
            decoder_with_lm_head,
            decoder_with_lm_head_init,
        ) = turn_model_into_encoder_decoder(model)
    else:
        (
            simplified_encoder,
            decoder_with_lm_head,
            decoder_with_lm_head_init,
        ) = create_t5_encoder_decoder(pretrained_version)

    # model paths for enc, dec and dec_init
    output_path = saved_models_path if output_path is None else Path(
        output_path)
    encoder_path, decoder_path, init_decoder_path = get_model_paths(
        pretrained_version, output_path, quantized=False
    )

    model_config = AutoConfig.from_pretrained(
        pretrained_version, use_auth_token=get_auth_token())

    # Though these are dummy inputs, ORT optimizations do reference these values,
    # so it is worth using values as close to production as possible
    batch_size = 1  # not configurable since only CPU
    enc_seq_length = input_sequence_length
    # a decoder sequence length is always one because it's just the last generated token
    dec_seq_length = 1
    input_ids = torch.ones(batch_size, enc_seq_length, dtype=torch.int64)
    attention_mask = torch.ones(batch_size, enc_seq_length, dtype=torch.int64)

    n_heads = model_config.num_heads
    d_kv = model_config.d_kv

    input_ids_dec = torch.ones(batch_size, dec_seq_length, dtype=torch.int64)
    attention_mask_dec = torch.ones(
        batch_size, dec_seq_length, dtype=torch.int64)
    enc_out = torch.ones(
        (batch_size, enc_seq_length, model_config.d_model), dtype=torch.float32
    )

    # self_attention_past_key_values = torch.ones(
    #     (model_config.num_decoder_layers, 2, batch_size, n_heads, seq_length_a, d_kv), dtype=torch.float32)
    # cross_attention_past_key_values = torch.ones(
    #     (model_config.num_decoder_layers, 2, batch_size, n_heads, seq_length_b, d_kv), dtype=torch.float32)

    sa = torch.ones(
        (batch_size, n_heads, dec_seq_length, d_kv), dtype=torch.float32
    )  # 1, 8, 1, 64
    ca = torch.ones(
        (batch_size, n_heads, enc_seq_length, d_kv), dtype=torch.float32
    )  # 1, 8, variable, 64
    t5_block = (sa, sa, ca, ca)
    past_key_values = (t5_block,) * model_config.num_decoder_layers

    flat_past_key_values = functools.reduce(
        operator.iconcat, past_key_values, [])

    decoder_all_inputs = tuple(
        [input_ids_dec, attention_mask_dec, enc_out] + flat_past_key_values
    )

    # for progress bars
    bar = Bar("Exporting to onnx...", max=3)

    import warnings

    # ignores all the warnings during conversion
    warnings.filterwarnings("ignore")

    # Exports to ONNX
    with torch.no_grad():

        decoder_inputs = [
            "input_ids",
            "encoder_attention_mask",
            "encoder_hidden_states",
        ]

        pkv_input_names = ["pkv_{}".format(
            i) for i in range(len(flat_past_key_values))]

        decoder_input_names = decoder_inputs + pkv_input_names

        decoder_output_names = ["logits", "output_past_key_values"]

        dyn_axis_general = {0: "batch", 1: "sequence"}
        dyn_axis_pkv = {0: "batch", 2: "seq_length"}

        dyn_axis = {
            "input_ids": dyn_axis_general,
            "encoder_attention_mask": dyn_axis_general,
            "encoder_hidden_states": dyn_axis_general,
            "logits": dyn_axis_general,
            "output_past_key_values": dyn_axis_general,
        }

        dyn_pkv = {
            "pkv_{}".format(i): dyn_axis_pkv
            for i in range(len(flat_past_key_values))
        }

        dyn_axis_params = {**dyn_axis, **dyn_pkv}

        # decoder to utilize past key values:
        torch.onnx.export(
            decoder_with_lm_head,
            decoder_all_inputs,
            decoder_path.as_posix(),
            export_params=True,
            do_constant_folding=True,
            opset_version=onnx_opset_version,
            input_names=decoder_input_names,
            output_names=decoder_output_names,
            dynamic_axes=dyn_axis_params,
        )
        bar.next()

        torch.onnx.export(
            simplified_encoder,
            args=(input_ids, attention_mask),
            f=encoder_path.as_posix(),
            export_params=True,
            opset_version=onnx_opset_version,
            do_constant_folding=True,
            input_names=["input_ids", "attention_mask"],
            output_names=["hidden_states"],
            dynamic_axes={
                "input_ids": dyn_axis_general,
                "attention_mask": dyn_axis_general,
                "hidden_states": dyn_axis_general,
            },
        )
        bar.next()
        # initial decoder to produce past key values
        torch.onnx.export(
            decoder_with_lm_head_init,
            (input_ids_dec, attention_mask_dec, enc_out),
            init_decoder_path.as_posix(),
            export_params=True,
            opset_version=onnx_opset_version,
            input_names=[
                "input_ids",
                "encoder_attention_mask",
                "encoder_hidden_states",
            ],
            output_names=["logits", "past_key_values"],
            dynamic_axes={
                # batch_size, seq_length = input_shape
                "input_ids": dyn_axis_general,
                "encoder_attention_mask": dyn_axis_general,
                "encoder_hidden_states": dyn_axis_general,
                "logits": dyn_axis_general,
                "past_key_values": dyn_axis_general,
            },
        )
        bar.next()
        bar.finish()

    return encoder_path, decoder_path, init_decoder_path


def get_model_paths(pretrained_model, model_path, quantized):

    model_path.mkdir(parents=True, exist_ok=True)

    # gets only the filename
    pretrained_model_name = Path(pretrained_model).stem

    if not quantized:
        encoder_path = model_path.joinpath(
            f"{pretrained_model_name}-encoder.onnx")
        decoder_path = model_path.joinpath(
            f"{pretrained_model_name}-decoder.onnx")
        init_decoder_path = model_path.joinpath(
            f"{pretrained_model_name}-init-decoder.onnx"
        )
    else:
        encoder_path = model_path.joinpath(
            f"{pretrained_model_name}-encoder-quantized.onnx"
        )
        decoder_path = model_path.joinpath(
            f"{pretrained_model_name}-decoder-quantized.onnx"
        )
        init_decoder_path = model_path.joinpath(
            f"{pretrained_model_name}-init-decoder-quantized.onnx"
        )

    return encoder_path, decoder_path, init_decoder_path


def quantize(models_name_or_path):
    """
    Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU

    Uses unsigned ints for activation values, signed ints for weights, per
    https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection
    it is faster on most CPU architectures
    Args:
        onnx_model_path: Path to location the exported ONNX model is stored
    Returns: The Path generated for the quantized
    """
    from onnxruntime.quantization import quantize_dynamic, QuantType

    bar = Bar("Quantizing...", max=3)

    quant_model_paths = []
    for model in models_name_or_path:
        model_name = model.as_posix()
        output_model_name = f"{model_name[:-5]}-quantized.onnx"
        quantize_dynamic(
            model_input=model_name,
            model_output=output_model_name,
            per_channel=True,
            reduce_range=True,  # should be the same as per_channel
            activation_type=QuantType.QUInt8,
            weight_type=QuantType.QInt8,  # per docs, signed is faster on most CPUs
            optimize_model=False,
        )  # op_types_to_quantize=['MatMul', 'Relu', 'Add', 'Mul' ],
        quant_model_paths.append(output_model_name)
        bar.next()

    bar.finish()

    return tuple(quant_model_paths)


class T5Encoder(torch.nn.Module):
    def __init__(self, encoder_sess):
        super().__init__()
        self.encoder = encoder_sess
        self.main_input_name = "input_ids"

    def forward(
        self,
        input_ids,
        attention_mask,
        inputs_embeds=None,
        head_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):

        encoder_hidden_state = torch.from_numpy(
            self.encoder.run(
                None,
                {
                    "input_ids": input_ids.cpu().numpy(),
                    "attention_mask": attention_mask.cpu().numpy(),
                },
            )[0]
        )

        return BaseModelOutput(encoder_hidden_state)


class T5DecoderInit(torch.nn.Module):
    def __init__(self, decoder_sess):
        super().__init__()
        self.decoder = decoder_sess

    def forward(self, input_ids, encoder_attention_mask, encoder_hidden_states):

        decoder_outputs = self.decoder.run(
            None,
            {
                "input_ids": input_ids.cpu().numpy(),
                "encoder_attention_mask": encoder_attention_mask.cpu().numpy(),
                "encoder_hidden_states": encoder_hidden_states.cpu().numpy(),
            },
        )

        list_pkv = tuple(torch.from_numpy(x) for x in decoder_outputs[1:])

        out_past_key_values = tuple(
            list_pkv[i: i + 4] for i in range(0, len(list_pkv), 4)
        )

        return torch.from_numpy(decoder_outputs[0]), out_past_key_values


class T5Decoder(torch.nn.Module):
    def __init__(self, decoder_sess):
        super().__init__()
        self.decoder = decoder_sess

    def forward(self, input_ids, attention_mask, encoder_output, past_key_values):

        decoder_inputs = {
            "input_ids": input_ids.cpu().numpy(),
            "encoder_attention_mask": attention_mask.cpu().numpy(),
            "encoder_hidden_states": encoder_output.cpu().numpy(),
        }

        flat_past_key_values = functools.reduce(
            operator.iconcat, past_key_values, [])

        past_key_values = {
            f"pkv_{i}": pkv.cpu().numpy() for i, pkv in enumerate(flat_past_key_values)
        }

        decoder_outputs = self.decoder.run(
            None, {**decoder_inputs, **past_key_values})
        # converts each value of the list to tensor from numpy
        list_pkv = tuple(torch.from_numpy(x) for x in decoder_outputs[1:])

        # creates a tuple of tuples of shape 6x4 from the above tuple
        out_past_key_values = tuple(
            list_pkv[i: i + 4] for i in range(0, len(list_pkv), 4)
        )

        return torch.from_numpy(decoder_outputs[0]), out_past_key_values


class OnnxT5(T5ForConditionalGeneration):
    """creates a T5 model using onnx sessions (encode, decoder & init_decoder)"""

    def __init__(self, model_or_model_path, onnx_model_sessions):
        config = AutoConfig.from_pretrained(
            model_or_model_path, use_auth_token=get_auth_token()
        )
        super().__init__(config)

        # monkeypatch to work for MT5
        if (
            isinstance(model_or_model_path, str)
            and "mt5" in model_or_model_path.lower()
        ) or (
            hasattr(model_or_model_path, "name_or_path")
            and "mt5" in model_or_model_path.name_or_path
        ):
            self.model_type = "mt5"
            self.config_class = MT5Config
            self._keys_to_ignore_on_load_missing = [
                r"encoder\.embed_tokens\.weight",
            ]
            self._keys_to_ignore_on_save = [
                r"encoder\.embed_tokens\.weight",
            ]

        assert len(onnx_model_sessions) == 3, "all three models should be given"

        encoder_sess, decoder_sess, decoder_sess_init = onnx_model_sessions

        self.encoder = T5Encoder(encoder_sess)
        self.decoder = T5Decoder(decoder_sess)
        self.decoder_init = T5DecoderInit(decoder_sess_init)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):

        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids, attention_mask=attention_mask
            )

        encoder_hidden_states = encoder_outputs[0]

        if past_key_values is not None:
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        if past_key_values is None:

            # runs only for the first time:
            init_onnx_outputs = self.decoder_init(
                decoder_input_ids, attention_mask, encoder_hidden_states
            )

            logits, past_key_values = init_onnx_outputs

        else:

            onnx_outputs = self.decoder(
                decoder_input_ids,
                attention_mask,
                encoder_hidden_states,
                past_key_values,
            )

            logits, past_key_values = onnx_outputs

        return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values)


def export_and_get_onnx_model(
    model_or_model_path, custom_output_path=saved_models_path, quantized=True
):
    """
                          Method for whole pipeline,
    converts from pytorch to onnx --> quantizes model --> sets onnx runtime
                --> builds whole onnx model with all sessions

    """

    # Step 1. convert huggingfaces t5 model to onnx
    onnx_model_paths = generate_onnx_representation(
        model_or_model_path, output_path=custom_output_path
    )

    if quantized:
        # Step 2. (recommended) quantize the converted model for fast inference and to reduce model size.
        quant_model_paths = quantize(onnx_model_paths)

        # step 3. setup onnx runtime
        print("Setting up onnx model...")
        model_sessions = get_onnx_runtime_sessions(quant_model_paths)
    else:
        print("Setting up onnx model...")
        model_sessions = get_onnx_runtime_sessions(onnx_model_paths)

    # step 4. get the onnx model
    model = OnnxT5(model_or_model_path, model_sessions)
    print("Done!")

    return model


def get_onnx_model(model_name, onnx_models_path=saved_models_path, quantized=True):
    """
    method gets the onnx model, if already converted models exists
    Example:
    >> get_onnx_model(model_name="t5-finetuned", onnx_models_path="../models/onnx/quantized/")

    """

    encoder_path, decoder_path, init_decoder_path = get_model_paths(
        model_name, Path(onnx_models_path), quantized
    )

    if quantized:
        assert (
            encoder_path.exists()
            and decoder_path.exists()
            and init_decoder_path.exists()
        ), "quantized model don't exist in the model folder, first quantize the model!"
    else:
        assert (
            encoder_path.exists()
            and decoder_path.exists()
            and init_decoder_path.exists()
        ), "all or some models don't exists in the model folder, first convert the model! "

    model_paths = encoder_path, decoder_path, init_decoder_path

    model_sessions = get_onnx_runtime_sessions(model_paths)

    model = OnnxT5(model_name, model_sessions)

    return model


trained_model_path = './t5_squad_v1/'

pretrained_model_name = Path(trained_model_path).stem

encoder_path = os.path.join(
    trained_model_path, f"{pretrained_model_name}-encoder_quantized.onnx")
decoder_path = os.path.join(
    trained_model_path, f"{pretrained_model_name}-decoder_quantized.onnx")
init_decoder_path = os.path.join(
    trained_model_path, f"{pretrained_model_name}-init-decoder_quantized.onnx")

model_paths = encoder_path, decoder_path, init_decoder_path
model_sessions = get_onnx_runtime_sessions(model_paths)
model = OnnxT5(trained_model_path, model_sessions)

tokenizer = AutoTokenizer.from_pretrained(trained_model_path)


def get_question(sentence, answer, mdl, tknizer):
    text = "context: {} answer: {}".format(sentence, answer)
    print(text)
    max_len = 256
    encoding = tknizer.encode_plus(
        text, max_length=max_len, pad_to_max_length=False, truncation=True, return_tensors="pt")
    input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
    outs = mdl.generate(input_ids=input_ids,
                        attention_mask=attention_mask,
                        early_stopping=True,
                        num_beams=5,
                        num_return_sequences=1,
                        no_repeat_ngram_size=2,
                        max_length=300)

    dec = [tknizer.decode(ids, skip_special_tokens=True) for ids in outs]

    Question = dec[0].replace("question:", "")
    Ouestion = Question.strip()
    return Question


# context = "Ramsri loves to watch cricket during his free time"
# answer = "cricket"
context = "Donald Trump is an American media personality and businessman who served as the 45th president of the United States."
answer = "Donald Trump"
ques = get_question(context, answer, model, tokenizer)
print("question: ", ques)


context = gr.components.Textbox(
    lines=5, placeholder="Enter paragraph/context here...")
answer = gr.components.Textbox(
    lines=3, placeholder="Enter answer/keyword here...")
question = gr.components.Textbox(type="text", label="Question")


def generate_question(context, answer):
    start_time = time.time()  # Record the start time
    result = get_question(context, answer, model, tokenizer)
    end_time = time.time()    # Record the end time
    latency = end_time - start_time  # Calculate latency
    print(f"Latency: {latency} seconds")
    return result


iface = gr.Interface(
    fn=generate_question,
    inputs=[context, answer],
    outputs=question
)

iface.launch()