File size: 4,563 Bytes
33473a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from pathlib import Path

import tensorrt_llm
import torch
from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp

from modules import shared
from modules.logging_colors import logger
from modules.text_generation import (
    get_max_prompt_length,
    get_reply_from_output_ids
)


class TensorRTLLMModel:
    def __init__(self):
        pass

    @classmethod
    def from_pretrained(self, path_to_model):

        path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
        runtime_rank = tensorrt_llm.mpi_rank()

        # Define model settings
        runner_kwargs = dict(
            engine_dir=str(path_to_model),
            lora_dir=None,
            rank=runtime_rank,
            debug_mode=False,
            lora_ckpt_source="hf",
        )

        if shared.args.cpp_runner:
            logger.info("TensorRT-LLM: Using \"ModelRunnerCpp\"")
            runner_kwargs.update(
                max_batch_size=1,
                max_input_len=shared.args.max_seq_len - 512,
                max_output_len=512,
                max_beam_width=1,
                max_attention_window_size=None,
                sink_token_length=None,
            )
        else:
            logger.info("TensorRT-LLM: Using \"ModelRunner\"")

        # Load the model
        runner_cls = ModelRunnerCpp if shared.args.cpp_runner else ModelRunner
        runner = runner_cls.from_dir(**runner_kwargs)

        result = self()
        result.model = runner
        result.runtime_rank = runtime_rank

        return result

    def generate_with_streaming(self, prompt, state):
        batch_input_ids = []
        input_ids = shared.tokenizer.encode(
            prompt,
            add_special_tokens=True,
            truncation=False,
        )
        input_ids = torch.tensor(input_ids, dtype=torch.int32)
        input_ids = input_ids[-get_max_prompt_length(state):]  # Apply truncation_length
        batch_input_ids.append(input_ids)

        if shared.args.cpp_runner:
            max_new_tokens = min(512, state['max_new_tokens'])
        elif state['auto_max_new_tokens']:
            max_new_tokens = state['truncation_length'] - input_ids.shape[-1]
        else:
            max_new_tokens = state['max_new_tokens']

        with torch.no_grad():
            generator = self.model.generate(
                batch_input_ids,
                max_new_tokens=max_new_tokens,
                max_attention_window_size=None,
                sink_token_length=None,
                end_id=shared.tokenizer.eos_token_id if not state['ban_eos_token'] else -1,
                pad_id=shared.tokenizer.pad_token_id or shared.tokenizer.eos_token_id,
                temperature=state['temperature'],
                top_k=state['top_k'],
                top_p=state['top_p'],
                num_beams=1,
                length_penalty=1.0,
                repetition_penalty=state['repetition_penalty'],
                presence_penalty=state['presence_penalty'],
                frequency_penalty=state['frequency_penalty'],
                stop_words_list=None,
                bad_words_list=None,
                lora_uids=None,
                prompt_table_path=None,
                prompt_tasks=None,
                streaming=not shared.args.cpp_runner,
                output_sequence_lengths=True,
                return_dict=True,
                medusa_choices=None
            )

        torch.cuda.synchronize()

        cumulative_reply = ''
        starting_from = batch_input_ids[0].shape[-1]

        if shared.args.cpp_runner:
            sequence_length = generator['sequence_lengths'][0].item()
            output_ids = generator['output_ids'][0][0][:sequence_length].tolist()

            cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from)
            starting_from = sequence_length
            yield cumulative_reply
        else:
            for curr_outputs in generator:
                if shared.stop_everything:
                    break

                sequence_length = curr_outputs['sequence_lengths'][0].item()
                output_ids = curr_outputs['output_ids'][0][0][:sequence_length].tolist()

                cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from)
                starting_from = sequence_length
                yield cumulative_reply

    def generate(self, prompt, state):
        output = ''
        for output in self.generate_with_streaming(prompt, state):
            pass

        return output