|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections.abc import Callable |
|
from typing import Dict, Optional |
|
|
|
import numpy as np |
|
import triton_python_backend_utils as pb_utils |
|
from lib.decode import * |
|
from typing_extensions import override |
|
|
|
|
|
class TritonDecoder(Decoder): |
|
|
|
def __init__(self, |
|
streaming=False, |
|
accumulate=False, |
|
preproc_model_name="preprocessing", |
|
postproc_model_name="postprocessing", |
|
llm_model_name="tensorrt_llm", |
|
draft_llm_model_name: Optional[str] = None): |
|
super().__init__(streaming=streaming, accumulate=accumulate) |
|
self.preproc_model_name = preproc_model_name |
|
self.postproc_model_name = postproc_model_name |
|
self.llm_model_name = llm_model_name |
|
self.draft_llm_model_name = draft_llm_model_name |
|
|
|
self._preproc_outputs = [ |
|
"INPUT_ID", |
|
"DECODER_INPUT_ID", |
|
"REQUEST_INPUT_LEN", |
|
"REQUEST_DECODER_INPUT_LEN", |
|
"BAD_WORDS_IDS", |
|
"STOP_WORDS_IDS", |
|
"EMBEDDING_BIAS", |
|
"OUT_PAD_ID", |
|
"OUT_END_ID", |
|
] |
|
|
|
self._llm_outputs = [ |
|
"output_ids", |
|
"sequence_length", |
|
"cum_log_probs", |
|
"output_log_probs", |
|
"context_logits", |
|
"generation_logits", |
|
] |
|
|
|
self._postproc_outputs = [ |
|
"OUTPUT", |
|
] |
|
|
|
self.input_names = [ |
|
"text_input", |
|
"decoder_text_input", |
|
"max_tokens", |
|
"bad_words", |
|
"stop_words", |
|
"end_id", |
|
"pad_id", |
|
"top_k", |
|
"top_p", |
|
"temperature", |
|
"length_penalty", |
|
"repetition_penalty", |
|
"min_length", |
|
"presence_penalty", |
|
"frequency_penalty", |
|
"random_seed", |
|
"return_log_probs", |
|
"return_context_logits", |
|
"return_generation_logits", |
|
"beam_width", |
|
"stream", |
|
"prompt_embedding_table", |
|
"prompt_vocab_size", |
|
"embedding_bias_words", |
|
"embedding_bias_weights", |
|
"num_draft_tokens", |
|
"use_draft_logits", |
|
] |
|
|
|
self.__undo_reshape_whitelist = { |
|
"max_tokens", |
|
"end_id", |
|
"pad_id", |
|
"top_k", |
|
"top_p", |
|
"temperature", |
|
"length_penalty", |
|
"repetition_penalty", |
|
"min_length", |
|
"presence_penalty", |
|
"frequency_penalty", |
|
"random_seed", |
|
"return_log_probs", |
|
"return_context_logits", |
|
"return_generation_logits", |
|
"beam_width", |
|
"stream", |
|
"prompt_vocab_size", |
|
"num_draft_tokens", |
|
"use_draft_logits", |
|
} |
|
|
|
def _exec_triton_request(self, request): |
|
responses = request.exec(decoupled=True) |
|
for r in responses: |
|
if r.has_error(): |
|
raise pb_utils.TritonModelException(r.error().message()) |
|
yield r |
|
|
|
def _exec_triton_request_single(self, request): |
|
responses = request.exec(decoupled=False) |
|
if responses.has_error(): |
|
raise pb_utils.TritonModelException(responses.error().message()) |
|
return responses |
|
|
|
def create_triton_response(self, response: Response): |
|
name_map = { |
|
"text_output": "text_output", |
|
"cum_log_probs": "cum_log_probs", |
|
"output_log_probs": "output_log_probs", |
|
"context_logits": "context_logits", |
|
"generation_logits": "generation_logits" |
|
} |
|
tensors = self.create_triton_tensors(response, name_map) |
|
return pb_utils.InferenceResponse(output_tensors=tensors) |
|
|
|
def convert_triton_request(self, triton_request) -> Request: |
|
request = Request() |
|
for triton_name in self.input_names: |
|
tensor = pb_utils.get_input_tensor_by_name(triton_request, |
|
triton_name) |
|
target_name = triton_name |
|
if tensor is None: |
|
continue |
|
if not hasattr(request, target_name): |
|
raise AttributeError( |
|
f"Request has no attribute '{target_name}'") |
|
setattr(request, target_name, tensor.as_numpy()) |
|
return request |
|
|
|
def convert_triton_response(self, |
|
triton_response, |
|
response_factory: Callable, |
|
name_map=None): |
|
response = response_factory() |
|
for tensor in triton_response.output_tensors(): |
|
if tensor is None: |
|
continue |
|
triton_name = tensor.name() |
|
value = tensor.as_numpy() |
|
target_name = triton_name |
|
if name_map and triton_name in name_map: |
|
target_name = name_map[triton_name] |
|
if name_map and not triton_name in name_map: |
|
continue |
|
if target_name is None: |
|
|
|
continue |
|
if not hasattr(response, target_name): |
|
raise AttributeError( |
|
f"response object has not attribute '{target_name}'") |
|
setattr(response, target_name, value) |
|
return response |
|
|
|
def __undo_reshape(self, x, name): |
|
if name in self.__undo_reshape_whitelist and len(x.shape) == 1: |
|
|
|
return np.expand_dims(x, 0) |
|
else: |
|
return x |
|
|
|
def create_triton_tensors(self, obj, name_map: dict): |
|
tensors = [] |
|
for name, triton_name in name_map.items(): |
|
if triton_name is None: |
|
continue |
|
value = getattr(obj, name) |
|
if value is None: |
|
continue |
|
t = pb_utils.Tensor(triton_name, self.__undo_reshape(value, name)) |
|
tensors.append(t) |
|
return tensors |
|
|
|
@override |
|
def preprocess(self, request: Request) -> PreprocResponse: |
|
input_tensors = self._get_preproc_tensors(request) |
|
triton_req = pb_utils.InferenceRequest( |
|
model_name=self.preproc_model_name, |
|
inputs=input_tensors, |
|
requested_output_names=self._preproc_outputs) |
|
triton_output = self._exec_triton_request_single(triton_req) |
|
return self._get_preproc_response(triton_output) |
|
|
|
def _get_preproc_tensors(self, request: Request): |
|
name_map = { |
|
"text_input": "QUERY", |
|
"decoder_text_input": "DECODER_QUERY", |
|
"max_tokens": "REQUEST_OUTPUT_LEN", |
|
"bad_words": "BAD_WORDS_DICT", |
|
"stop_words": "STOP_WORDS_DICT", |
|
"embedding_bias_words": "EMBEDDING_BIAS_WORDS", |
|
"embedding_bias_weights": "EMBEDDING_BIAS_WEIGHTS", |
|
"pad_id": "PAD_ID", |
|
"end_id": "END_ID", |
|
} |
|
return self.create_triton_tensors(request, name_map) |
|
|
|
def _get_preproc_response(self, triton_output): |
|
name_map = { |
|
"INPUT_ID": "input_ids", |
|
"DECODER_INPUT_ID": "decoder_input_ids", |
|
"REQUEST_INPUT_LEN": "input_lengths", |
|
"REQUEST_DECODER_INPUT_LEN": "decoder_input_lengths", |
|
"BAD_WORDS_IDS": "bad_words_list", |
|
"STOP_WORDS_IDS": "stop_words_list", |
|
"EMBEDDING_BIAS": "embedding_bias", |
|
"OUT_PAD_ID": "pad_id", |
|
"OUT_END_ID": "end_id", |
|
} |
|
return self.convert_triton_response(triton_output, PreprocResponse, |
|
name_map) |
|
|
|
@override |
|
def _draft_generate_non_streaming( |
|
self, preproc: PreprocResponse, request: Request, |
|
num_draft_tokens: int) -> GenerationResponse: |
|
input_tensors = self._get_llm_tensors(preproc, request, |
|
num_draft_tokens, None, True) |
|
triton_req = pb_utils.InferenceRequest( |
|
model_name=self.draft_llm_model_name, |
|
inputs=input_tensors, |
|
requested_output_names=self._llm_outputs) |
|
triton_response = self._exec_triton_request_single(triton_req) |
|
llm_response = self._get_llm_response(triton_response) |
|
return llm_response |
|
|
|
@override |
|
def _generate( |
|
self, |
|
preproc: PreprocResponse, |
|
request: Request, |
|
draft_request: Optional[DraftRequest] = None |
|
) -> Generator[GenerationResponse, None, None]: |
|
input_tensors = self._get_llm_tensors(preproc, request, None, |
|
draft_request) |
|
triton_req = pb_utils.InferenceRequest( |
|
model_name=self.llm_model_name, |
|
inputs=input_tensors, |
|
requested_output_names=self._llm_outputs) |
|
for r in self._exec_triton_request(triton_req): |
|
yield self._get_llm_response(r) |
|
|
|
@override |
|
def _generate_non_streaming( |
|
self, |
|
preproc: PreprocResponse, |
|
request: Request, |
|
draft_request: Optional[DraftRequest] = None |
|
) -> GenerationResponse: |
|
input_tensors = self._get_llm_tensors(preproc, request, None, |
|
draft_request) |
|
triton_req = pb_utils.InferenceRequest( |
|
model_name=self.llm_model_name, |
|
inputs=input_tensors, |
|
requested_output_names=self._llm_outputs) |
|
r = self._exec_triton_request_single(triton_req) |
|
return self._get_llm_response(r) |
|
|
|
def _get_llm_tensors(self, |
|
preproc: PreprocResponse, |
|
request: Request, |
|
num_output_tokens: Optional[int] = None, |
|
draft_request: Optional[DraftRequest] = None, |
|
is_draft_model_request: bool = False): |
|
tensors = [] |
|
tensors.extend(self._get_tensors_from_preproc(preproc)) |
|
tensors.extend( |
|
self._get_llm_tensors_from_request(request, num_output_tokens, |
|
draft_request, |
|
is_draft_model_request)) |
|
return tensors |
|
|
|
def _get_tensors_from_preproc(self, preproc: PreprocResponse): |
|
name_map = { |
|
"input_ids": "input_ids", |
|
"decoder_input_ids": "decoder_input_ids", |
|
"input_lengths": "input_lengths", |
|
"bad_words_list": "bad_words_list", |
|
"stop_words_list": "stop_words_list", |
|
"embedding_bias": "embedding_bias", |
|
"pad_id": "pad_id", |
|
"end_id": "end_id", |
|
} |
|
return self.create_triton_tensors(preproc, name_map) |
|
|
|
def _get_llm_tensors_from_request( |
|
self, |
|
request: Request, |
|
num_output_tokens: Optional[int] = None, |
|
draft_request: Optional[DraftRequest] = None, |
|
is_draft_model_request: bool = False): |
|
name_map: Dict[str, Optional[str]] = { |
|
"beam_width": "beam_width", |
|
"top_k": "runtime_top_k", |
|
"top_p": "runtime_top_p", |
|
"length_penalty": "len_penalty", |
|
"repetition_penalty": "repetition_penalty", |
|
"min_length": "min_length", |
|
"presence_penalty": "presence_penalty", |
|
"frequency_penalty": "frequency_penalty", |
|
"random_seed": "random_seed", |
|
"return_log_probs": "return_log_probs", |
|
"stream": "streaming", |
|
"prompt_embedding_table": "prompt_embedding_table", |
|
"prompt_vocab_size": "prompt_vocab_size", |
|
} |
|
tensors = self.create_triton_tensors(request, name_map) |
|
|
|
out_len = request.max_tokens[0][0] if request.max_tokens else None |
|
if num_output_tokens is not None: |
|
out_len = num_output_tokens |
|
elif draft_request: |
|
if draft_request.draft_input_ids is not None: |
|
out_len = len(draft_request.draft_input_ids[0]) + 1 |
|
else: |
|
out_len = 1 |
|
|
|
if out_len is None: |
|
raise Exception("Could not determine request_output_len") |
|
else: |
|
tensors.append( |
|
pb_utils.Tensor("request_output_len", |
|
np.array([[out_len]], dtype=np.int32))) |
|
|
|
if draft_request: |
|
if draft_request.draft_input_ids is not None: |
|
tensors.append( |
|
pb_utils.Tensor("draft_input_ids", |
|
draft_request.draft_input_ids)) |
|
if draft_request.draft_logits is not None and request.use_draft_logits is not None and request.use_draft_logits[ |
|
0]: |
|
tensors.append( |
|
pb_utils.Tensor("draft_logits", |
|
draft_request.draft_logits)) |
|
|
|
return_context_logits = False |
|
return_generation_logits = False |
|
if draft_request is None: |
|
if is_draft_model_request: |
|
return_generation_logits = request.use_draft_logits[ |
|
0] if request.use_draft_logits is not None else False |
|
else: |
|
return_context_logits = request.return_context_logits[ |
|
0] if request.return_context_logits is not None else False |
|
return_generation_logits = request.return_generation_logits[ |
|
0] if request.return_generation_logits is not None else False |
|
|
|
tensors.append( |
|
pb_utils.Tensor("return_context_logits", |
|
np.array([[return_context_logits]]))) |
|
tensors.append( |
|
pb_utils.Tensor("return_generation_logits", |
|
np.array([[return_generation_logits]]))) |
|
return tensors |
|
|
|
def _get_llm_response(self, triton_output): |
|
name_map = { |
|
"output_ids": "output_ids", |
|
"sequence_length": "sequence_length", |
|
"cum_log_probs": "cum_log_probs", |
|
"output_log_probs": "output_log_probs", |
|
"context_logits": "context_logits", |
|
"generation_logits": "generation_logits", |
|
} |
|
return self.convert_triton_response(triton_output, GenerationResponse, |
|
name_map) |
|
|
|
def _postprocess(self, tokens: np.ndarray, |
|
sequence_lengths: Optional[np.ndarray], |
|
gen_response: GenerationResponse) -> Response: |
|
input_tensors = self._get_postproc_tensors(tokens, sequence_lengths, |
|
gen_response) |
|
triton_req = pb_utils.InferenceRequest( |
|
model_name=self.postproc_model_name, |
|
inputs=input_tensors, |
|
requested_output_names=self._postproc_outputs) |
|
r = self._exec_triton_request_single(triton_req) |
|
response = self._get_response(r, gen_response) |
|
return response |
|
|
|
def _get_postproc_tensors(self, tokens: np.ndarray, |
|
sequence_lengths: Optional[np.ndarray], |
|
gen_response: GenerationResponse): |
|
tensors = [ |
|
pb_utils.Tensor("TOKENS_BATCH", tokens), |
|
pb_utils.Tensor( |
|
"SEQUENCE_LENGTH", sequence_lengths |
|
if sequence_lengths else gen_response.sequence_length) |
|
] |
|
return tensors |
|
|
|
def _get_response(self, triton_output, gen_res: GenerationResponse): |
|
tensors = triton_output.output_tensors() |
|
t_map = {} |
|
for named_t in tensors: |
|
name = named_t.name() |
|
t = named_t.as_numpy() |
|
t_map[name] = t |
|
response = Response(text_output=t_map["OUTPUT"], |
|
cum_log_probs=gen_res.cum_log_probs, |
|
output_log_probs=gen_res.output_log_probs, |
|
context_logits=gen_res.context_logits, |
|
generation_logits=gen_res.generation_logits) |
|
return response |
|
|