import datetime import json import os import sys import time from dataclasses import dataclass from random import randint from threading import Lock, Thread from typing import Any, List import numpy as np import torch import triton_python_backend_utils as pb_utils from torch import from_numpy from torch.utils.dlpack import from_dlpack import tensorrt_llm.bindings.executor as trtllm from tensorrt_llm.llmapi.tokenizer import _xgrammar_tokenizer_info METRIC_TOTAL_OUTPUT_TOKENS = "total_output_tokens" METRIC_TOTAL_INPUT_TOKENS = "total_input_tokens" import tensorrt_llm.logger as logger # From https://github.com/pytorch/pytorch/blob/39425feac799905402abe4d15667fa47c344f2d7/torch/testing/_internal/common_utils.py#L1761 # Dict of NumPy dtype -> torch dtype (when the correspondence exists) numpy_to_torch_dtype_dict = { np.bool_: torch.bool, np.uint8: torch.uint8, np.uint16: torch.uint16, np.uint32: torch.uint32, np.uint64: torch.uint64, np.int8: torch.int8, np.int16: torch.int16, np.int32: torch.int32, np.int64: torch.int64, np.float16: torch.float16, np.float32: torch.float32, np.float64: torch.float64, np.complex64: torch.complex64, np.complex128: torch.complex128 } # Dict of torch dtype -> NumPy dtype torch_to_numpy_dtype_dict = { value: key for (key, value) in numpy_to_torch_dtype_dict.items() } torch_to_numpy_dtype_dict.update({ torch.bfloat16: np.float32, torch.complex32: np.complex64 }) @dataclass class RequestData: triton_req_id: int triton_user_id: str batch_index: int batch_size: int num_return_sequences: int num_input_tokens: int num_output_tokens: int response_sender: Any def mpi_comm(): from mpi4py import MPI return MPI.COMM_WORLD def mpi_rank(): return mpi_comm().Get_rank() def get_input_tensor_by_name(request, name, expected_batch_size=None, batch_index=None, force_on_torch=False): tensor = pb_utils.get_input_tensor_by_name(request, name) if tensor is None: return None if tensor.is_cpu() and not force_on_torch: tensor = tensor.as_numpy() else: tensor = from_dlpack(tensor.to_dlpack()) if expected_batch_size is not None and tensor.shape[ 0] != expected_batch_size: raise pb_utils.TritonModelException( f"Expected batch size doesn't match batch size for tensor {name}. Expected {expected_batch_size} got {tensor.shape[0]}" ) if batch_index is not None and expected_batch_size is not None and batch_index >= expected_batch_size: raise pb_utils.TritonModelException( f"Invalid batch index in get_input_tensor_by_name for {name}") if batch_index is not None: # Add leading 1 batch dimension if isinstance(tensor, np.ndarray): return np.expand_dims(tensor[batch_index], axis=0) elif isinstance(tensor, torch.Tensor): return torch.unsqueeze(tensor[batch_index], dim=0) else: return tensor def get_input_scalar_by_name(request, name, expected_batch_size=1, batch_index=0): tensor = pb_utils.get_input_tensor_by_name(request, name) if tensor is None: return None tensor = tensor.as_numpy() if tensor.size != expected_batch_size: raise pb_utils.TritonModelException( f"Expected a scalar tensor for tensor {name}") return tensor.item(batch_index) def read_parameter_as_type(value, name, pytype=str): if value == "": return None if value.startswith("${") and value.endswith("}"): return None if pytype is bool: return value.lower() in ["1", "true"] try: result = pytype(value) return result except: pb_utils.Logger.log_warning( f"Could not read parameter '{name}' with value '{value}', will use default." ) return None def get_parameter(model_config, name, pytype=str): if name not in model_config['parameters']: return None return read_parameter_as_type( model_config['parameters'][name]['string_value'], name, pytype) def convert_word_list(word_list): if word_list is None: return None word_list = word_list.tolist() if len(word_list) == 0 or len(word_list[0]) != 2: raise pb_utils.TritonModelException(f"Invalid format for word list.") words, indices = word_list[0] result = [] current_index = 0 for i in indices: if i == -1: continue if i > len(words): raise pb_utils.TritonModelException( f"Invalid format for word list.") current_word = [] while current_index < i: current_word.append(words[current_index]) current_index += 1 result.append(current_word) return result def parse_medusa_choices(medusa_choices): if medusa_choices is None: return None try: result = json.loads( "[" + medusa_choices.replace("{", "[").replace("}", "]") + "]") assert isinstance(result, list) and len(result) > 0 assert all([isinstance(x, list) for x in result]) assert all([isinstance(y, int) for x in result for y in x]) except Exception: raise pb_utils.TritonModelException( "Invalid format for medusa_choices") return result def parse_eagle_choices(eagle_choices): return parse_medusa_choices(eagle_choices) def get_sampling_config_from_request(request, batch_size=1, batch_index=0): kwargs = {} kwargs['beam_width'] = get_input_scalar_by_name( request, 'beam_width', batch_size, batch_index) or 1 kwargs['top_k'] = get_input_scalar_by_name(request, 'runtime_top_k', batch_size, batch_index) kwargs['top_p'] = get_input_scalar_by_name(request, 'runtime_top_p', batch_size, batch_index) kwargs['top_p'] = None if kwargs['top_p'] is None or kwargs[ 'top_p'] <= 0 else kwargs['top_p'] kwargs['random_seed'] = get_input_scalar_by_name(request, 'random_seed', batch_size, batch_index) kwargs['temperature'] = get_input_scalar_by_name(request, 'temperature', batch_size, batch_index) kwargs['min_length'] = get_input_scalar_by_name(request, 'min_length', batch_size, batch_index) kwargs['repetition_penalty'] = get_input_scalar_by_name( request, 'repetition_penalty', batch_size, batch_index) kwargs['presence_penalty'] = get_input_scalar_by_name( request, 'presence_penalty', batch_size, batch_index) kwargs['frequency_penalty'] = get_input_scalar_by_name( request, 'frequency_penalty', batch_size, batch_index) kwargs['length_penalty'] = get_input_scalar_by_name( request, 'len_penalty', batch_size, batch_index) kwargs['top_p_min'] = get_input_scalar_by_name(request, 'runtime_top_p_min', batch_size, batch_index) kwargs['top_p_reset_ids'] = get_input_scalar_by_name( request, 'runtime_top_p_reset_ids', batch_size, batch_index) kwargs['top_p_decay'] = get_input_scalar_by_name(request, 'runtime_top_p_decay', batch_size, batch_index) kwargs['beam_search_diversity_rate'] = get_input_scalar_by_name( request, 'beam_search_diversity_rate', batch_size, batch_index) kwargs['early_stopping'] = get_input_scalar_by_name( request, 'early_stopping', batch_size, batch_index) kwargs['num_return_sequences'] = get_input_scalar_by_name( request, 'num_return_sequences', batch_size, batch_index) or 1 kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.SamplingConfig(**kwargs) def get_output_config_from_request(request, batch_size=1, batch_index=0): kwargs = {} kwargs["return_log_probs"] = get_input_scalar_by_name( request, 'return_log_probs', batch_size, batch_index) kwargs["return_context_logits"] = get_input_scalar_by_name( request, 'return_context_logits', batch_size, batch_index) kwargs["return_generation_logits"] = get_input_scalar_by_name( request, 'return_generation_logits', batch_size, batch_index) kwargs["return_perf_metrics"] = get_input_scalar_by_name( request, 'return_kv_cache_reuse_stats', batch_size, batch_index) kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.OutputConfig(**kwargs) def get_external_draft_tokens_config_from_request(request, batch_size=1, batch_index=0): kwargs = {} draft_input_ids = get_input_tensor_by_name(request, 'draft_input_ids', batch_size, batch_index) if draft_input_ids is not None: kwargs['tokens'] = draft_input_ids[0].tolist() draft_logits = get_input_tensor_by_name(request, 'draft_logits', batch_size, batch_index) if draft_logits is not None: kwargs['logits'] = from_numpy(draft_logits).squeeze(dim=0) kwargs['acceptance_threshold'] = get_input_scalar_by_name( request, 'draft_acceptance_threshold', batch_size, batch_index) kwargs = {k: v for k, v in kwargs.items() if v is not None} if len(kwargs) > 0: return trtllm.ExternalDraftTokensConfig(**kwargs) return None def get_prompt_tuning_config_from_request(request, batch_size=1, batch_index=0, input_length=0): # prompt_vocab_size is unused by executor. kwargs = {} prompt_embedding_table = get_input_tensor_by_name( request, 'prompt_embedding_table', batch_size, batch_index) prompt_table_extra_ids = get_input_tensor_by_name( request, 'prompt_table_extra_ids', batch_size, batch_index) if prompt_embedding_table is not None: if isinstance(prompt_embedding_table, np.ndarray): kwargs["embedding_table"] = from_numpy( prompt_embedding_table).squeeze(dim=0) elif isinstance(prompt_embedding_table, torch.Tensor): kwargs["embedding_table"] = prompt_embedding_table.squeeze(dim=0) if prompt_table_extra_ids is not None: prompt_table_extra_ids = prompt_table_extra_ids[0].tolist() if len(prompt_table_extra_ids) != 0: kwargs["input_token_extra_ids"] = prompt_table_extra_ids[ 0:input_length] kwargs = {k: v for k, v in kwargs.items() if v is not None} if len(kwargs) > 0: return trtllm.PromptTuningConfig(**kwargs) return None def get_lora_config_from_request(request, batch_size=1, batch_index=0): kwargs = {} kwargs["task_id"] = get_input_scalar_by_name(request, 'lora_task_id', batch_size, batch_index) lora_weights = get_input_tensor_by_name(request, 'lora_weights', batch_size, batch_index) if lora_weights is not None: kwargs["weights"] = from_numpy(lora_weights).squeeze(dim=0) lora_config = get_input_tensor_by_name(request, 'lora_config', batch_size, batch_index) if lora_config is not None: kwargs["config"] = from_numpy(lora_config).squeeze(dim=0) kwargs = {k: v for k, v in kwargs.items() if v is not None} if len(kwargs) > 0: return trtllm.LoraConfig(**kwargs) return None def get_guided_decoding_params_from_request(request, batch_size=1, batch_index=0): kwargs = {} guided_decoding_guide_type = get_input_tensor_by_name( request, 'guided_decoding_guide_type', batch_size, batch_index) if guided_decoding_guide_type is not None: guided_decoding_guide_type = guided_decoding_guide_type.squeeze( axis=0)[0].decode() guided_decoding_guide_type_mapping = { "json": trtllm.GuidedDecodingParams.GuideType.JSON, "json_schema": trtllm.GuidedDecodingParams.GuideType.JSON_SCHEMA, "regex": trtllm.GuidedDecodingParams.GuideType.REGEX, "ebnf_grammar": trtllm.GuidedDecodingParams.GuideType.EBNF_GRAMMAR } guided_decoding_guide_type = guided_decoding_guide_type_mapping.get( guided_decoding_guide_type) kwargs['guide_type'] = guided_decoding_guide_type guided_decoding_guide = get_input_tensor_by_name(request, 'guided_decoding_guide', batch_size, batch_index) if guided_decoding_guide is not None: kwargs['guide'] = guided_decoding_guide.squeeze(axis=0)[0].decode() kwargs = {k: v for k, v in kwargs.items() if v is not None} if len(kwargs) > 0: return trtllm.GuidedDecodingParams(**kwargs) return None def get_kv_cache_retention_config_from_request(request, batch_size=1, batch_index=0): def get_tensor_and_check_length(name: str, expected_length: int): tensor = get_input_tensor_by_name(request, name, batch_size, batch_index) if tensor is None: raise RuntimeError(f"{name} must be provided.") tensor = np.squeeze(tensor, axis=0) if len(tensor) != expected_length: raise RuntimeError( f"Invalid {name} length. Expected length {expected_length}, got length {len(tensor)}" ) return tensor token_range_starts = get_input_tensor_by_name( request, "retention_token_range_starts", batch_size, batch_index) if token_range_starts is not None: token_range_starts = np.squeeze(token_range_starts, axis=0) token_range_ends = get_tensor_and_check_length( "retention_token_range_ends", len(token_range_starts)) token_range_ends = [ None if end == -1 else end for end in token_range_ends ] token_range_priorities = get_tensor_and_check_length( "retention_token_range_priorities", len(token_range_starts)) token_range_durations_ms = get_input_tensor_by_name( request, "retention_token_range_durations_ms", batch_size, batch_index) if token_range_durations_ms is None: token_range_durations_ms = [None] * len(token_range_starts) else: token_range_durations_ms = np.squeeze(token_range_durations_ms, axis=0) token_range_durations_ms = [ None if duration == -1 else duration for duration in token_range_durations_ms ] if len(token_range_durations_ms) != len(token_range_starts): raise RuntimeError( f"Invalid retention_token_range_durations length. Expected length {len(token_range_starts)}, got length {len(token_range_durations_ms)}" ) ranges = [] for start, end, priority, duration_ms in zip(token_range_starts, token_range_ends, token_range_priorities, token_range_durations_ms): ranges.append( trtllm.KvCacheRetentionConfig.TokenRangeRetentionConfig( token_start=start, token_end=end, priority=priority.item(), duration_ms=None if duration_ms is None else datetime.timedelta(milliseconds=duration_ms.item()))) decode_args = {} decode_priority = get_input_scalar_by_name( request, "retention_decode_priority", batch_size, batch_index) if decode_priority is not None: decode_args['decode_retention_priority'] = decode_priority decode_duration_ms = get_input_scalar_by_name( request, "retention_decode_duration_ms", batch_size, batch_index) if decode_duration_ms is not None: decode_args[ 'decode_duration_ms'] = decode_duration_ms if decode_duration_ms != -1 else None return trtllm.KvCacheRetentionConfig( token_range_retention_configs=ranges, **decode_args) return None def build_1_2_5_buckets(max_value: int) -> List[int]: """ Builds a list of buckets with increasing powers of 10 multiplied by mantissa values (1, 5), starting from 10 until the value exceeds the specified maximum. Example: >>> build_1_2_5_buckets(1000) [10, 50, 100, 500, 1000] """ mantissa_lst = [1, 5] exponent = 1 # Start from exponent 1 instead of 0 buckets: List[int] = [] while True: for m in mantissa_lst: value = m * 10**exponent if value <= max_value: buckets.append(value) else: return buckets exponent += 1 def convert_request(request, exclude_input_from_output, decoupled): inputs = {} input_token_ids = get_input_tensor_by_name(request, 'input_ids') if input_token_ids is None: raise pb_utils.TritonModelException( "A value is required for input_ids") if len(input_token_ids.shape) != 2: raise pb_utils.TritonModelException(f"Invalid format for input_ids") batch_size = input_token_ids.shape[0] requests = [] for batch_index in range(0, batch_size): input_token_ids = get_input_tensor_by_name(request, 'input_ids', batch_size, batch_index)[0] if input_token_ids is None: raise pb_utils.TritonModelException( "A value is required for input_ids") input_token_ids = input_token_ids.tolist() if len(input_token_ids) == 0: raise pb_utils.TritonModelException( f"Invalid format for input_ids") input_length = get_input_scalar_by_name(request, 'input_lengths', batch_size, batch_index) if input_length is None: input_length = len(input_token_ids) # Trim input token ids with input_lengths inputs['input_token_ids'] = input_token_ids[0:input_length] inputs['max_new_tokens'] = get_input_scalar_by_name( request, 'request_output_len', batch_size, batch_index) if inputs['max_new_tokens'] is None: raise pb_utils.TritonModelException( "A value is required for request_output_len") inputs['streaming'] = get_input_scalar_by_name(request, 'streaming', batch_size, batch_index) if inputs['streaming'] and not decoupled: raise pb_utils.TritonModelException( "Streaming is only supported in decoupled mode.") inputs['end_id'] = get_input_scalar_by_name(request, 'end_id', batch_size, batch_index) inputs['pad_id'] = get_input_scalar_by_name(request, 'pad_id', batch_size, batch_index) inputs['stop_words'] = convert_word_list( get_input_tensor_by_name(request, 'stop_words_list', batch_size, batch_index)) inputs['bad_words'] = convert_word_list( get_input_tensor_by_name(request, 'bad_words_list', batch_size, batch_index)) embedding_bias = get_input_tensor_by_name(request, 'embedding_bias', batch_size, batch_index) if embedding_bias is not None and embedding_bias.size != 0: inputs['embedding_bias'] = from_numpy(embedding_bias).squeeze( dim=0) sampling_config = get_sampling_config_from_request( request, batch_size, batch_index) output_config = get_output_config_from_request(request, batch_size, batch_index) req_exclude_input_from_output = get_input_scalar_by_name( request, 'exclude_input_in_output', batch_size, batch_index) if req_exclude_input_from_output is None: # if request doesn't specify exclude_input_from_output, try to use the parameter output_config.exclude_input_from_output = ( exclude_input_from_output if exclude_input_from_output is not None else False) else: output_config.exclude_input_from_output = req_exclude_input_from_output external_draft_tokens_config = get_external_draft_tokens_config_from_request( request, batch_size, batch_index) prompt_tuning_config = get_prompt_tuning_config_from_request( request, batch_size, batch_index, input_length) lora_config = get_lora_config_from_request(request, batch_size, batch_index) kv_cache_retention_config = get_kv_cache_retention_config_from_request( request, batch_size, batch_index) # Inputs for mllama support encoder_input_features = get_input_tensor_by_name( request, 'encoder_input_features', batch_size, batch_index) if encoder_input_features is not None: if isinstance(encoder_input_features, np.ndarray): encoder_input_features = from_numpy( encoder_input_features).squeeze(dim=0) elif isinstance(encoder_input_features, torch.Tensor): encoder_input_features = encoder_input_features.squeeze(dim=0) inputs['encoder_input_features'] = encoder_input_features logger.debug( f"inputs to llm: encoder_input_features ({encoder_input_features.shape}" ) encoder_output_length = get_input_tensor_by_name( request, 'encoder_output_lengths', batch_size, batch_index) if encoder_output_length is not None: inputs['encoder_output_length'] = np.squeeze( encoder_output_length, axis=0) cross_attention_mask = get_input_tensor_by_name( request, 'cross_attention_mask', batch_size, batch_index) if cross_attention_mask is not None: inputs['cross_attention_mask'] = cross_attention_mask[0] logger.debug( f"inputs to llm: cross_attention_mask ({ cross_attention_mask.shape})" ) skip_cross_attn_blocks = get_input_tensor_by_name( request, 'skip_cross_attn_blocks', batch_size, batch_index, force_on_torch=True) if skip_cross_attn_blocks is not None: inputs['skip_cross_attn_blocks'] = skip_cross_attn_blocks[0] logger.debug( f"inputs to llm: skip_cross_attn_blocks ({ skip_cross_attn_blocks.shape})" ) guided_decoding_params = get_guided_decoding_params_from_request( request, batch_size, batch_index) requests.append( trtllm.Request( **inputs, sampling_config=sampling_config, output_config=output_config, external_draft_tokens_config=external_draft_tokens_config, prompt_tuning_config=prompt_tuning_config, lora_config=lora_config, guided_decoding_params=guided_decoding_params, kv_cache_retention_config=kv_cache_retention_config)) return requests def convert_response(response, batch_index, batch_size, num_return_sequences, expected_logits_dtype=torch.float32): if response.has_error(): return pb_utils.InferenceResponse(output_tensors=[], error=pb_utils.TritonError( response.error_msg)), True, 0 result = response.result beam_lengths = np.expand_dims( np.array([len(beam) for beam in result.output_token_ids], np.int32), 0) max_beam_length = max([len(beam) for beam in result.output_token_ids]) output_ids = np.full((1, len(result.output_token_ids), max_beam_length), -1, np.int32) for idx, beam in enumerate(result.output_token_ids): output_ids[0, idx, :len(beam)] = beam output_lengths = output_ids.size output_tensors = [ pb_utils.Tensor("output_ids", output_ids), pb_utils.Tensor("sequence_length", beam_lengths), ] if result.cum_log_probs is not None: output_tensors.append( pb_utils.Tensor( "cum_log_probs", np.expand_dims(np.array(result.cum_log_probs, np.float32), 0))) if result.log_probs is not None: output_tensors.append( pb_utils.Tensor( "output_log_probs", np.expand_dims(np.array(result.log_probs, np.float32), 0))) if result.context_logits is not None: assert (result.context_logits.dtype is expected_logits_dtype) output_tensors.append( pb_utils.Tensor( "context_logits", np.expand_dims( np.array( result.context_logits, torch_to_numpy_dtype_dict[ result.context_logits.dtype]), 0))) if result.generation_logits is not None: assert (result.generation_logits.dtype is expected_logits_dtype) output_tensors.append( pb_utils.Tensor( "generation_logits", np.expand_dims( np.array( result.generation_logits, torch_to_numpy_dtype_dict[ result.generation_logits.dtype]), 0))) if batch_size > 1: output_tensors.append( pb_utils.Tensor( "batch_index", np.expand_dims(np.array([batch_index], np.int32), 0))) if num_return_sequences > 1: output_tensors.append( pb_utils.Tensor( "sequence_index", np.expand_dims(np.array([result.sequence_index], np.int32), 0))) if result.request_perf_metrics is not None: kv_cache_metrics = result.request_perf_metrics.kv_cache_metrics output_tensors.append( pb_utils.Tensor( "kv_cache_alloc_new_blocks", np.expand_dims( np.array([kv_cache_metrics.num_new_allocated_blocks], np.int32), 0))) output_tensors.append( pb_utils.Tensor( "kv_cache_reused_blocks", np.expand_dims( np.array([kv_cache_metrics.num_reused_blocks], np.int32), 0))) output_tensors.append( pb_utils.Tensor( "kv_cache_alloc_total_blocks", np.expand_dims( np.array([kv_cache_metrics.num_total_allocated_blocks], np.int32), 0))) return pb_utils.InferenceResponse( output_tensors), result.is_final, output_lengths def convert_scheduler_policy(batch_scheduler_policy: str): if batch_scheduler_policy.lower() == "max_utilization": return trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION elif batch_scheduler_policy.lower() == "guaranteed_no_evict": return trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT raise pb_utils.TritonModelException( f"batch_scheduler_policy value of '{batch_scheduler_policy}' is not supported." ) def convert_batching_type(gpt_model_type: str): if gpt_model_type is None: return None if gpt_model_type.lower( ) == "inflight_fused_batching" or gpt_model_type.lower( ) == "inflight_batching": return trtllm.BatchingType.INFLIGHT elif gpt_model_type.lower() == "v1": return trtllm.BatchingType.STATIC raise pb_utils.TritonModelException( f"gpt_model_type value of '{gpt_model_type}' is not supported.") def convert_decoding_mode(decoding_mode: str): if decoding_mode is None: return None elif decoding_mode == "auto": return trtllm.DecodingMode.Auto() elif decoding_mode == "top_k": return trtllm.DecodingMode.TopK() elif decoding_mode == "top_p": return trtllm.DecodingMode.TopP() elif decoding_mode == "top_k_top_p": return trtllm.DecodingMode.TopKTopP() elif decoding_mode == "beam_search": return trtllm.DecodingMode.BeamSearch() elif decoding_mode == "medusa": return trtllm.DecodingMode.Medusa() elif decoding_mode == "redrafter": return trtllm.DecodingMode.ExplicitDraftTokens() elif decoding_mode == "lookahead": return trtllm.DecodingMode.Lookahead() elif decoding_mode == "eagle": return trtllm.DecodingMode.Eagle() raise pb_utils.TritonModelException( f"decoding_mode value of '{decoding_mode}' is not supported.") def convert_timestamp_to_seconds(timestamp: str): return int( datetime.datetime.strptime(timestamp, "%m-%d-%Y %H:%M:%S.%f").timestamp()) def triton_string_to_torch(dtype): type_map = { "TYPE_BOOL": torch.bool, "TYPE_UINT8": torch.uint8, "TYPE_INT8": torch.int8, "TYPE_INT16": torch.int16, "TYPE_INT32": torch.int32, "TYPE_INT64": torch.int64, "TYPE_FP16": torch.float16, "TYPE_FP32": torch.float32, "TYPE_FP64": torch.float64, "TYPE_BF16": torch.bfloat16 } return type_map[dtype] class TritonPythonModel: """Your Python model must use the same class name. Every Python model that is created must have "TritonPythonModel" as the class name. """ def get_scheduler_config(self, model_config): batch_scheduler_policy = get_parameter(model_config, "batch_scheduler_policy") if batch_scheduler_policy is None: return trtllm.SchedulerConfig() return trtllm.SchedulerConfig( convert_scheduler_policy(batch_scheduler_policy)) def get_kv_cache_config(self, model_config): kwargs = { "enable_block_reuse": get_parameter(model_config, "enable_kv_cache_reuse", bool), "max_tokens": get_parameter(model_config, "max_tokens_in_paged_kv_cache", int), "sink_token_length": get_parameter(model_config, "sink_token_length", int), "free_gpu_memory_fraction": get_parameter(model_config, "kv_cache_free_gpu_mem_fraction", float), "cross_kv_cache_fraction": get_parameter(model_config, "cross_kv_cache_fraction", float), "host_cache_size": get_parameter(model_config, "kv_cache_host_memory_bytes", int), "onboard_blocks": get_parameter(model_config, "kv_cache_onboard_blocks", bool), } max_attention_window_size = get_parameter(model_config, "max_attention_window_size") if max_attention_window_size: kwargs["max_attention_window"] = [ int(x) for x in max_attention_window_size.split(",") ] kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.KvCacheConfig(**kwargs) def get_parallel_config(self, model_config): kwargs = {} gpu_device_ids = get_parameter(model_config, "gpu_device_ids") if gpu_device_ids: kwargs["device_ids"] = [int(x) for x in gpu_device_ids.split(",")] self.use_orchestrator_mode = os.environ.get("TRTLLM_ORCHESTRATOR", "0") == "1" if self.use_orchestrator_mode: kwargs[ "communication_mode"] = trtllm.CommunicationMode.ORCHESTRATOR worker_path = get_parameter(model_config, "worker_path") spawn_processes = os.environ.get( "TRTLLM_ORCHESTRATOR_SPAWN_PROCESSES", "1") == "1" if not spawn_processes: raise pb_utils.TritonModelException( "Orchestrator mode with --disable-spawn-processes is not supported in the Python backend." ) is_orchestrator = (mpi_rank() == 0) if spawn_processes else True if worker_path is not None: raise pb_utils.TritonModelException( "worker_path parameter is specified, but this is no longer supported. Please specify executor_worker_path instead to specify the location of the trtllmExecutorWorker executable." ) executor_worker_path = get_parameter(model_config, "executor_worker_path") kwargs["orchestrator_config"] = trtllm.OrchestratorConfig( is_orchestrator, executor_worker_path) if len(kwargs) > 0: return trtllm.ParallelConfig(**kwargs) return None def get_peft_cache_config(self, model_config): kwargs = { "optimal_adapter_size": get_parameter(model_config, "lora_cache_optimal_adapter_size", int), "max_adapter_size": get_parameter(model_config, "lora_cache_max_adapter_size", int), "device_cache_percent": get_parameter(model_config, "lora_cache_gpu_memory_fraction", float), "host_cache_size": get_parameter(model_config, "lora_cache_host_memory_bytes", int), } kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.PeftCacheConfig(**kwargs) def get_decoding_config(self, model_config): eagle_choices = parse_eagle_choices( get_parameter(model_config, "eagle_choices")) kwargs = { "medusa_choices": parse_medusa_choices(get_parameter(model_config, "medusa_choices")), "eagle_config": None if eagle_choices is None else trtllm.EagleConfig(eagle_choices), "decoding_mode": convert_decoding_mode(get_parameter(model_config, "decoding_mode")), } print(kwargs) kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.DecodingConfig(**kwargs) def get_extended_runtime_perf_knob_config(self, model_config): kwargs = { "multi_block_mode": get_parameter(model_config, "multi_block_mode", bool), "enable_context_fmha_fp32_acc": get_parameter(model_config, "enable_context_fmha_fp32_acc", bool), "cuda_graph_mode": get_parameter(model_config, "cuda_graph_mode", bool), "cuda_graph_cache_size": get_parameter(model_config, "cuda_graph_cache_size", int), } kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.ExtendedRuntimePerfKnobConfig(**kwargs) def get_guided_decoding_config(self, model_config): guided_decoding_backend = get_parameter(model_config, "guided_decoding_backend", str) tokenizer_dir = get_parameter(model_config, "tokenizer_dir", str) if guided_decoding_backend not in ['xgrammar']: if tokenizer_dir: pb_utils.Logger.log_warn( f"Guided decoding backend has not been set but tokenizer_dir is given. Tokenizer_dir will be ignored." ) return None if guided_decoding_backend == 'xgrammar': guided_decoding_backend = trtllm.GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR if not tokenizer_dir: raise ValueError( "Guided decoding requires tokenizer's information. Please provide 'tokenizer_dir'." ) from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) pb_utils.Logger.log_info( f"Guided decoding has been set with {guided_decoding_backend} backend" ) return trtllm.GuidedDecodingConfig( backend=guided_decoding_backend, **_xgrammar_tokenizer_info(tokenizer)) def get_executor_config(self, model_config): kwargs = { "max_beam_width": get_parameter(model_config, "max_beam_width", int), "scheduler_config": self.get_scheduler_config(model_config), "kv_cache_config": self.get_kv_cache_config(model_config), "enable_chunked_context": get_parameter(model_config, "enable_chunked_context", bool), "normalize_log_probs": get_parameter(model_config, "normalize_log_probs", bool), "batching_type": convert_batching_type(get_parameter(model_config, "gpt_model_type")), "parallel_config": self.get_parallel_config(model_config), "peft_cache_config": self.get_peft_cache_config(model_config), "decoding_config": self.get_decoding_config(model_config), "max_queue_size": model_config.get( "dynamic_batching", {}, ).get( "default_queue_policy", {}, ).get("max_queue_size"), "extended_runtime_perf_knob_config": self.get_extended_runtime_perf_knob_config(model_config), "guided_decoding_config": self.get_guided_decoding_config(model_config) } kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.ExecutorConfig(**kwargs) def create_metrics(self, model: str, version: str, is_v1_model: bool): self.request_metric_family = pb_utils.MetricFamily( name="nv_trt_llm_request_metrics", description="TRT LLM request metrics", kind=pb_utils.MetricFamily.GAUGE, ) self.runtime_memory_metric_family = pb_utils.MetricFamily( name="nv_trt_llm_runtime_memory_metrics", description="TRT LLM runtime memory metrics", kind=pb_utils.MetricFamily.GAUGE, ) self.kv_cache_metric_family = pb_utils.MetricFamily( name="nv_trt_llm_kv_cache_block_metrics", description="TRT LLM KV cache block metrics", kind=pb_utils.MetricFamily.GAUGE, ) model_type = "v1" if is_v1_model else "inflight_batcher" self.model_type_metric_family = pb_utils.MetricFamily( name=f"nv_trt_llm_{model_type}_metrics", description=f"TRT LLM {model_type}-specific metrics", kind=pb_utils.MetricFamily.GAUGE, ) self.general_metric_family = pb_utils.MetricFamily( name="nv_trt_llm_general_metrics", description="General TRT LLM metrics", kind=pb_utils.MetricFamily.GAUGE, ) # Set the metric using self.general_metric_output_family.observe(string_size) self.request_tokens_metric_family = pb_utils.MetricFamily( name="nv_llm_input_token_len", description="TRT LLM response metrics", kind=pb_utils.MetricFamily.HISTOGRAM, ) self.response_tokens_metric_family = pb_utils.MetricFamily( name="nv_llm_output_token_len", description="TRT LLM response metrics", kind=pb_utils.MetricFamily.HISTOGRAM, ) common_labels = {"model": model, "version": version} self.all_metrics = { # Request metrics "num_active_requests": self.request_metric_family.Metric(labels={ "request_type": "active", **common_labels }), "max_num_active_requests": self.request_metric_family.Metric(labels={ "request_type": "max", **common_labels }), "num_scheduled_requests": self.request_metric_family.Metric(labels={ "request_type": "scheduled", **common_labels }), "num_context_requests": self.request_metric_family.Metric(labels={ "request_type": "context", **common_labels }), # Runtime metrics "cpu_mem_usage": self.runtime_memory_metric_family.Metric(labels={ "memory_type": "cpu", **common_labels }), "gpu_mem_usage": self.runtime_memory_metric_family.Metric(labels={ "memory_type": "gpu", **common_labels }), "pinned_mem_usage": self.runtime_memory_metric_family.Metric(labels={ "memory_type": "pinned", **common_labels }), # KV cache metrics "max_num_blocks": self.kv_cache_metric_family.Metric(labels={ "kv_cache_block_type": "max", **common_labels }), "free_num_blocks": self.kv_cache_metric_family.Metric(labels={ "kv_cache_block_type": "free", **common_labels }), "used_num_blocks": self.kv_cache_metric_family.Metric(labels={ "kv_cache_block_type": "used", **common_labels }), "tokens_per_block": self.kv_cache_metric_family.Metric(labels={ "kv_cache_block_type": "tokens_per", **common_labels }), # General metrics "timestamp": self.general_metric_family.Metric(labels={ "general_type": "timestamp", **common_labels }), "iter": self.general_metric_family.Metric(labels={ "general_type": "iteration_counter", **common_labels }), METRIC_TOTAL_OUTPUT_TOKENS: self.response_tokens_metric_family.Metric( labels={ "response_metric_type": METRIC_TOTAL_OUTPUT_TOKENS, **common_labels }, buckets=build_1_2_5_buckets(1000)), METRIC_TOTAL_INPUT_TOKENS: self.request_tokens_metric_family.Metric( labels={ "response_metric_type": METRIC_TOTAL_INPUT_TOKENS, **common_labels }, buckets=build_1_2_5_buckets(1000)), } if is_v1_model: self.all_metrics.update({ "num_ctx_tokens": self.model_type_metric_family.Metric(labels={ "v1_specific_metric": "total_context_tokens", **common_labels }), "num_gen_tokens": self.model_type_metric_family.Metric( labels={ "v1_specific_metric": "total_generation_tokens", **common_labels }), "empty_gen_slots": self.model_type_metric_family.Metric( labels={ "v1_specific_metric": "empty_generation_slots", **common_labels }), }) else: self.all_metrics.update({ "num_ctx_tokens": self.model_type_metric_family.Metric( labels={ "inflight_batcher_specific_metric": "total_context_tokens", **common_labels }), "num_gen_requests": self.model_type_metric_family.Metric( labels={ "inflight_batcher_specific_metric": "generation_requests", **common_labels }), "micro_batch_id": self.model_type_metric_family.Metric( labels={ "inflight_batcher_specific_metric": "micro_batch_id", **common_labels }), "num_paused_requests": self.model_type_metric_family.Metric( labels={ "inflight_batcher_specific_metric": "paused_requests", **common_labels }), }) def initialize(self, args): """`initialize` is called only once when the model is being loaded. Implementing `initialize` function is optional. This function allows the model to initialize any state associated with this model. Parameters ---------- args : dict Both keys and values are strings. The dictionary keys and values are: * model_config: A JSON string containing the model configuration * model_instance_kind: A string containing model instance kind * model_instance_device_id: A string containing model instance device ID * model_repository: Model repository path * model_version: Model version * model_name: Model name """ model_config = json.loads(args['model_config']) gpt_model_path = get_parameter(model_config, "gpt_model_path") if get_parameter(model_config, "enable_trt_overlap", bool): raise pb_utils.TritonModelException( f"enable_trt_overlap=true is not supported.") self.exclude_input_from_output = get_parameter( model_config, "exclude_input_in_output", bool) executor_config = self.get_executor_config(model_config) self.executor = trtllm.Executor(gpt_model_path, trtllm.ModelType.DECODER_ONLY, executor_config) self.decoupled = pb_utils.using_decoupled_model_transaction_policy( model_config) self.cancellation_check_period_ms = get_parameter( model_config, "cancellation_check_period_ms", int) or 100 self.stats_check_period_ms = get_parameter( model_config, "stats_check_period_ms", int) or 100 self.logits_dtype = None for output in model_config['output']: if output['name'] == 'context_logits' or output[ 'name'] == 'generation_logits': self.logits_dtype = triton_string_to_torch(output['data_type']) self.create_metrics(args["model_name"], args["model_version"], is_v1_model=executor_config.batching_type == trtllm.BatchingType.STATIC) self.triton_user_id_to_req_ids = {} self.triton_req_id_to_req_ids = {} self.req_id_to_request_data = {} self.lock = Lock() self.running = False self.awaiter_thread = Thread(target=self.awaiter_loop) self.cancellation_thread = Thread(target=self.cancellation_loop) self.metrics_thread = Thread(target=self.metrics_loop) if self.executor.can_enqueue_requests(): self.running = True self.awaiter_thread.start() self.cancellation_thread.start() self.metrics_thread.start() else: # In leader mode, worker ranks will wait here until leader is done. self.executor.shutdown() def handle_stop_request(self, triton_user_id, response_sender): if triton_user_id is None or triton_user_id == "": response_sender.send( pb_utils.InferenceResponse(error=pb_utils.TritonError( "A request id must be provided for request cancellation")), flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) return with self.lock: if triton_user_id in self.triton_user_id_to_req_ids: req_ids = self.triton_user_id_to_req_ids[triton_user_id] for req_id in req_ids: self.executor.cancel_request(req_id) response_sender.send( pb_utils.InferenceResponse(), flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) def execute(self, requests): """`execute` must be implemented in every Python model. `execute` function receives a list of pb_utils.InferenceRequest as the only argument. This function is called when an inference is requested for this model. Parameters ---------- requests : list A list of pb_utils.InferenceRequest Returns ------- list A list of pb_utils.InferenceResponse. The length of this list must be the same as `requests` """ if not self.executor.can_enqueue_requests(): return # Convert to executor requests. triton_requests = [] executor_requests = [] batch_indices = [] triton_user_ids = [] triton_req_ids = [] for request in requests: triton_user_id = request.request_id() response_sender = request.get_response_sender() stop = get_input_scalar_by_name(request, 'stop') if stop: self.handle_stop_request(triton_user_id, response_sender) else: #Unique request id used to identify each triton request triton_req_id = str(randint(0, sys.maxsize)) self.triton_req_id_to_req_ids[triton_req_id] = set() if triton_user_id is not None and triton_user_id != "": self.triton_user_id_to_req_ids[triton_user_id] = set() try: converted_reqs = convert_request( request, self.exclude_input_from_output, self.decoupled) except Exception as e: response_sender.send( pb_utils.InferenceResponse(error=pb_utils.TritonError( f"An error occurred when processing the input values for request id {request.request_id()}, the error was '{e}'" )), flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) else: for batch_index, converted_req in enumerate( converted_reqs): triton_requests.append(request) executor_requests.append(converted_req) triton_user_ids.append(triton_user_id) triton_req_ids.append(triton_req_id) batch_indices.append(batch_index) with self.lock: request_ids = self.executor.enqueue_requests(executor_requests) for req_id, triton_req_id, triton_user_id, executor_request, triton_request, batch_index in zip( request_ids, triton_req_ids, triton_user_ids, executor_requests, triton_requests, batch_indices): self.req_id_to_request_data[req_id] = RequestData( triton_req_id, triton_user_id, batch_index, len(batch_indices), executor_request.sampling_config.num_return_sequences, 0, 0, triton_request.get_response_sender()) self.triton_req_id_to_req_ids[triton_req_id].add(req_id) input_len = len( executor_request.input_token_ids ) if executor_request.input_token_ids is not None else 0 self.req_id_to_request_data[ req_id].num_input_tokens += input_len # This checks both request level and instance config level if executor_request.output_config.exclude_input_from_output == False and executor_request.streaming == False: self.req_id_to_request_data[ req_id].num_output_tokens -= self.req_id_to_request_data[ req_id].num_input_tokens * executor_request.sampling_config.beam_width if triton_user_id is not None and triton_user_id != "": self.triton_user_id_to_req_ids[triton_user_id].add(req_id) return None def awaiter_loop(self): """Gets responses from executor and returns the results.""" while self.running: for response in self.executor.await_responses( timeout=datetime.timedelta(milliseconds=1)): req_id = response.request_id request_data = None with self.lock: if req_id not in self.req_id_to_request_data: continue request_data = self.req_id_to_request_data[req_id] triton_response, is_final, output_length = convert_response( response, request_data.batch_index, request_data.batch_size, request_data.num_return_sequences, self.logits_dtype) with self.lock: self.req_id_to_request_data[ req_id].num_output_tokens += output_length triton_request_final = False if is_final: with self.lock: # Check if all executor requests part of that triton request are finished self.triton_req_id_to_req_ids[ request_data.triton_req_id].remove(req_id) if len(self.triton_req_id_to_req_ids[ request_data.triton_req_id]) == 0: pb_utils.Logger.log_info( f"DELETING Req id {req_id}, triton_req_id {request_data.triton_req_id} " ) triton_request_final = True del self.triton_req_id_to_req_ids[ request_data.triton_req_id] if request_data.triton_user_id is not None and request_data.triton_user_id != "": del self.triton_user_id_to_req_ids[ request_data.triton_user_id] self.update_metrics_per_request(req_id) del self.req_id_to_request_data[req_id] request_data.response_sender.send( triton_response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL if triton_request_final else 0) def cancellation_loop(self): """Checks if any pending requests have been cancelled.""" while self.running: time.sleep(self.cancellation_check_period_ms / 1000.0) with self.lock: for req_id, request_data in self.req_id_to_request_data.items( ): if request_data.response_sender.is_cancelled(): self.executor.cancel_request(req_id) def update_metrics_per_request(self, req_id): """Updates triton metrics after completing one request""" output_tokens = self.req_id_to_request_data[req_id].num_output_tokens input_tokens = self.req_id_to_request_data[req_id].num_input_tokens self.all_metrics[METRIC_TOTAL_OUTPUT_TOKENS].observe(output_tokens) self.all_metrics[METRIC_TOTAL_INPUT_TOKENS].observe(input_tokens) def metrics_loop(self): """Updates triton metrics using stats from the executor.""" while self.running: time.sleep(self.stats_check_period_ms / 1000.0) for stat in self.executor.get_latest_iteration_stats(): try: for key, metric in self.all_metrics.items(): # Skip processing for both histogram metrics if isinstance(key, str) and key in [ METRIC_TOTAL_OUTPUT_TOKENS, METRIC_TOTAL_INPUT_TOKENS ]: continue value = None if hasattr(stat, key): value = getattr(stat, key) elif stat.kv_cache_stats is not None and hasattr( stat.kv_cache_stats, key): value = getattr(stat.kv_cache_stats, key) elif stat.static_batching_stats is not None and hasattr( stat.static_batching_stats, key): value = getattr(stat.static_batching_stats, key) elif stat.inflight_batching_stats is not None and hasattr( stat.inflight_batching_stats, key): value = getattr(stat.inflight_batching_stats, key) if value is not None: if key == "timestamp": value = convert_timestamp_to_seconds(value) metric.set(value) else: pb_utils.Logger.log_warn( f"Metric \"{key}\" not found.") except Exception as e: pb_utils.Logger.log_warn( f"Error while processing metrics: {e}") def finalize(self): """`finalize` is called only once when the model is being unloaded. Implementing `finalize` function is optional. This function allows the model to perform any necessary clean ups before exit. """ if self.executor.can_enqueue_requests(): self.running = False self.awaiter_thread.join() self.cancellation_thread.join() self.metrics_thread.join() self.executor.shutdown()