Spaces:
Build error
Build error
import copy | |
import os | |
import time | |
import warnings | |
from functools import partial | |
from typing import Any, Callable | |
import httpx | |
from openhands.core.config import LLMConfig | |
with warnings.catch_warnings(): | |
warnings.simplefilter('ignore') | |
import litellm | |
from litellm import ChatCompletionMessageToolCall, ModelInfo, PromptTokensDetails | |
from litellm import Message as LiteLLMMessage | |
from litellm import completion as litellm_completion | |
from litellm import completion_cost as litellm_completion_cost | |
from litellm.exceptions import ( | |
RateLimitError, | |
) | |
from litellm.types.utils import CostPerToken, ModelResponse, Usage | |
from litellm.utils import create_pretrained_tokenizer | |
from openhands.core.exceptions import LLMNoResponseError | |
from openhands.core.logger import openhands_logger as logger | |
from openhands.core.message import Message | |
from openhands.llm.debug_mixin import DebugMixin | |
from openhands.llm.fn_call_converter import ( | |
STOP_WORDS, | |
convert_fncall_messages_to_non_fncall_messages, | |
convert_non_fncall_messages_to_fncall_messages, | |
) | |
from openhands.llm.metrics import Metrics | |
from openhands.llm.retry_mixin import RetryMixin | |
__all__ = ['LLM'] | |
# tuple of exceptions to retry on | |
LLM_RETRY_EXCEPTIONS: tuple[type[Exception], ...] = ( | |
RateLimitError, | |
litellm.Timeout, | |
litellm.InternalServerError, | |
LLMNoResponseError, | |
) | |
# cache prompt supporting models | |
# remove this when we gemini and deepseek are supported | |
CACHE_PROMPT_SUPPORTED_MODELS = [ | |
'claude-3-7-sonnet-20250219', | |
'claude-sonnet-3-7-latest', | |
'claude-3.7-sonnet', | |
'claude-3-5-sonnet-20241022', | |
'claude-3-5-sonnet-20240620', | |
'claude-3-5-haiku-20241022', | |
'claude-3-haiku-20240307', | |
'claude-3-opus-20240229', | |
'claude-sonnet-4-20250514', | |
'claude-opus-4-20250514', | |
] | |
# function calling supporting models | |
FUNCTION_CALLING_SUPPORTED_MODELS = [ | |
'claude-3-7-sonnet-20250219', | |
'claude-sonnet-3-7-latest', | |
'claude-3-5-sonnet', | |
'claude-3-5-sonnet-20240620', | |
'claude-3-5-sonnet-20241022', | |
'claude-3.5-haiku', | |
'claude-3-5-haiku-20241022', | |
'claude-sonnet-4-20250514', | |
'claude-opus-4-20250514', | |
'gpt-4o-mini', | |
'gpt-4o', | |
'o1-2024-12-17', | |
'o3-mini-2025-01-31', | |
'o3-mini', | |
'o3', | |
'o3-2025-04-16', | |
'o4-mini', | |
'o4-mini-2025-04-16', | |
'gemini-2.5-pro', | |
'gpt-4.1', | |
] | |
REASONING_EFFORT_SUPPORTED_MODELS = [ | |
'o1-2024-12-17', | |
'o1', | |
'o3', | |
'o3-2025-04-16', | |
'o3-mini-2025-01-31', | |
'o3-mini', | |
'o4-mini', | |
'o4-mini-2025-04-16', | |
] | |
MODELS_WITHOUT_STOP_WORDS = [ | |
'o1-mini', | |
'o1-preview', | |
'o1', | |
'o1-2024-12-17', | |
] | |
class LLM(RetryMixin, DebugMixin): | |
"""The LLM class represents a Language Model instance. | |
Attributes: | |
config: an LLMConfig object specifying the configuration of the LLM. | |
""" | |
def __init__( | |
self, | |
config: LLMConfig, | |
metrics: Metrics | None = None, | |
retry_listener: Callable[[int, int], None] | None = None, | |
) -> None: | |
"""Initializes the LLM. If LLMConfig is passed, its values will be the fallback. | |
Passing simple parameters always overrides config. | |
Args: | |
config: The LLM configuration. | |
metrics: The metrics to use. | |
""" | |
self._tried_model_info = False | |
self.metrics: Metrics = ( | |
metrics if metrics is not None else Metrics(model_name=config.model) | |
) | |
self.cost_metric_supported: bool = True | |
self.config: LLMConfig = copy.deepcopy(config) | |
self.model_info: ModelInfo | None = None | |
self.retry_listener = retry_listener | |
if self.config.log_completions: | |
if self.config.log_completions_folder is None: | |
raise RuntimeError( | |
'log_completions_folder is required when log_completions is enabled' | |
) | |
os.makedirs(self.config.log_completions_folder, exist_ok=True) | |
# call init_model_info to initialize config.max_output_tokens | |
# which is used in partial function | |
with warnings.catch_warnings(): | |
warnings.simplefilter('ignore') | |
self.init_model_info() | |
if self.vision_is_active(): | |
logger.debug('LLM: model has vision enabled') | |
if self.is_caching_prompt_active(): | |
logger.debug('LLM: caching prompt enabled') | |
if self.is_function_calling_active(): | |
logger.debug('LLM: model supports function calling') | |
# if using a custom tokenizer, make sure it's loaded and accessible in the format expected by litellm | |
if self.config.custom_tokenizer is not None: | |
self.tokenizer = create_pretrained_tokenizer(self.config.custom_tokenizer) | |
else: | |
self.tokenizer = None | |
# set up the completion function | |
kwargs: dict[str, Any] = { | |
'temperature': self.config.temperature, | |
'max_completion_tokens': self.config.max_output_tokens, | |
} | |
if self.config.top_k is not None: | |
# openai doesn't expose top_k | |
# litellm will handle it a bit differently than the openai-compatible params | |
kwargs['top_k'] = self.config.top_k | |
if ( | |
self.config.model.lower() in REASONING_EFFORT_SUPPORTED_MODELS | |
or self.config.model.split('/')[-1] in REASONING_EFFORT_SUPPORTED_MODELS | |
): | |
kwargs['reasoning_effort'] = self.config.reasoning_effort | |
kwargs.pop( | |
'temperature' | |
) # temperature is not supported for reasoning models | |
# Azure issue: https://github.com/All-Hands-AI/OpenHands/issues/6777 | |
if self.config.model.startswith('azure'): | |
kwargs['max_tokens'] = self.config.max_output_tokens | |
kwargs.pop('max_completion_tokens') | |
self._completion = partial( | |
litellm_completion, | |
model=self.config.model, | |
api_key=self.config.api_key.get_secret_value() | |
if self.config.api_key | |
else None, | |
base_url=self.config.base_url, | |
api_version=self.config.api_version, | |
custom_llm_provider=self.config.custom_llm_provider, | |
timeout=self.config.timeout, | |
top_p=self.config.top_p, | |
drop_params=self.config.drop_params, | |
seed=self.config.seed, | |
**kwargs, | |
) | |
self._completion_unwrapped = self._completion | |
def wrapper(*args: Any, **kwargs: Any) -> Any: | |
"""Wrapper for the litellm completion function. Logs the input and output of the completion function.""" | |
from openhands.io import json | |
messages_kwarg: list[dict[str, Any]] | dict[str, Any] = [] | |
mock_function_calling = not self.is_function_calling_active() | |
# some callers might send the model and messages directly | |
# litellm allows positional args, like completion(model, messages, **kwargs) | |
if len(args) > 1: | |
# ignore the first argument if it's provided (it would be the model) | |
# design wise: we don't allow overriding the configured values | |
# implementation wise: the partial function set the model as a kwarg already | |
# as well as other kwargs | |
messages_kwarg = args[1] if len(args) > 1 else args[0] | |
kwargs['messages'] = messages_kwarg | |
# remove the first args, they're sent in kwargs | |
args = args[2:] | |
elif 'messages' in kwargs: | |
messages_kwarg = kwargs['messages'] | |
# ensure we work with a list of messages | |
messages: list[dict[str, Any]] = ( | |
messages_kwarg if isinstance(messages_kwarg, list) else [messages_kwarg] | |
) | |
# handle conversion of to non-function calling messages if needed | |
original_fncall_messages = copy.deepcopy(messages) | |
mock_fncall_tools = None | |
# if the agent or caller has defined tools, and we mock via prompting, convert the messages | |
if mock_function_calling and 'tools' in kwargs: | |
add_in_context_learning_example = True | |
if ( | |
'openhands-lm' in self.config.model | |
or 'devstral' in self.config.model | |
): | |
add_in_context_learning_example = False | |
messages = convert_fncall_messages_to_non_fncall_messages( | |
messages, | |
kwargs['tools'], | |
add_in_context_learning_example=add_in_context_learning_example, | |
) | |
kwargs['messages'] = messages | |
# add stop words if the model supports it | |
if self.config.model not in MODELS_WITHOUT_STOP_WORDS: | |
kwargs['stop'] = STOP_WORDS | |
mock_fncall_tools = kwargs.pop('tools') | |
if 'openhands-lm' in self.config.model: | |
# If we don't have this, we might run into issue when serving openhands-lm | |
# using SGLang | |
# BadRequestError: litellm.BadRequestError: OpenAIException - Error code: 400 - {'object': 'error', 'message': '400', 'type': 'Failed to parse fc related info to json format!', 'param': None, 'code': 400} | |
kwargs['tool_choice'] = 'none' | |
else: | |
# tool_choice should not be specified when mocking function calling | |
kwargs.pop('tool_choice', None) | |
# if we have no messages, something went very wrong | |
if not messages: | |
raise ValueError( | |
'The messages list is empty. At least one message is required.' | |
) | |
# log the entire LLM prompt | |
self.log_prompt(messages) | |
# set litellm modify_params to the configured value | |
# True by default to allow litellm to do transformations like adding a default message, when a message is empty | |
# NOTE: this setting is global; unlike drop_params, it cannot be overridden in the litellm completion partial | |
litellm.modify_params = self.config.modify_params | |
# if we're not using litellm proxy, remove the extra_body | |
if 'litellm_proxy' not in self.config.model: | |
kwargs.pop('extra_body', None) | |
# Record start time for latency measurement | |
start_time = time.time() | |
# we don't support streaming here, thus we get a ModelResponse | |
resp: ModelResponse = self._completion_unwrapped(*args, **kwargs) | |
# Calculate and record latency | |
latency = time.time() - start_time | |
response_id = resp.get('id', 'unknown') | |
self.metrics.add_response_latency(latency, response_id) | |
non_fncall_response = copy.deepcopy(resp) | |
# if we mocked function calling, and we have tools, convert the response back to function calling format | |
if mock_function_calling and mock_fncall_tools is not None: | |
if len(resp.choices) < 1: | |
raise LLMNoResponseError( | |
'Response choices is less than 1 - This is only seen in Gemini models so far. Response: ' | |
+ str(resp) | |
) | |
non_fncall_response_message = resp.choices[0].message | |
# messages is already a list with proper typing from line 223 | |
fn_call_messages_with_response = ( | |
convert_non_fncall_messages_to_fncall_messages( | |
messages + [non_fncall_response_message], mock_fncall_tools | |
) | |
) | |
fn_call_response_message = fn_call_messages_with_response[-1] | |
if not isinstance(fn_call_response_message, LiteLLMMessage): | |
fn_call_response_message = LiteLLMMessage( | |
**fn_call_response_message | |
) | |
resp.choices[0].message = fn_call_response_message | |
# Check if resp has 'choices' key with at least one item | |
if not resp.get('choices') or len(resp['choices']) < 1: | |
raise LLMNoResponseError( | |
'Response choices is less than 1 - This is only seen in Gemini models so far. Response: ' | |
+ str(resp) | |
) | |
message_back: str = resp['choices'][0]['message']['content'] or '' | |
tool_calls: list[ChatCompletionMessageToolCall] = resp['choices'][0][ | |
'message' | |
].get('tool_calls', []) | |
if tool_calls: | |
for tool_call in tool_calls: | |
fn_name = tool_call.function.name | |
fn_args = tool_call.function.arguments | |
message_back += f'\nFunction call: {fn_name}({fn_args})' | |
# log the LLM response | |
self.log_response(message_back) | |
# post-process the response first to calculate cost | |
cost = self._post_completion(resp) | |
# log for evals or other scripts that need the raw completion | |
if self.config.log_completions: | |
assert self.config.log_completions_folder is not None | |
log_file = os.path.join( | |
self.config.log_completions_folder, | |
# use the metric model name (for draft editor) | |
f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json', | |
) | |
# set up the dict to be logged | |
_d = { | |
'messages': messages, | |
'response': resp, | |
'args': args, | |
'kwargs': { | |
k: v | |
for k, v in kwargs.items() | |
if k not in ('messages', 'client') | |
}, | |
'timestamp': time.time(), | |
'cost': cost, | |
} | |
# if non-native function calling, save messages/response separately | |
if mock_function_calling: | |
# Overwrite response as non-fncall to be consistent with messages | |
_d['response'] = non_fncall_response | |
# Save fncall_messages/response separately | |
_d['fncall_messages'] = original_fncall_messages | |
_d['fncall_response'] = resp | |
with open(log_file, 'w') as f: | |
f.write(json.dumps(_d)) | |
return resp | |
self._completion = wrapper | |
def completion(self) -> Callable: | |
"""Decorator for the litellm completion function. | |
Check the complete documentation at https://litellm.vercel.app/docs/completion | |
""" | |
return self._completion | |
def init_model_info(self) -> None: | |
if self._tried_model_info: | |
return | |
self._tried_model_info = True | |
try: | |
if self.config.model.startswith('openrouter'): | |
self.model_info = litellm.get_model_info(self.config.model) | |
except Exception as e: | |
logger.debug(f'Error getting model info: {e}') | |
if self.config.model.startswith('litellm_proxy/'): | |
# IF we are using LiteLLM proxy, get model info from LiteLLM proxy | |
# GET {base_url}/v1/model/info with litellm_model_id as path param | |
base_url = self.config.base_url.strip() if self.config.base_url else '' | |
if not base_url.startswith(('http://', 'https://')): | |
base_url = 'http://' + base_url | |
response = httpx.get( | |
f'{base_url}/v1/model/info', | |
headers={ | |
'Authorization': f'Bearer {self.config.api_key.get_secret_value() if self.config.api_key else None}' | |
}, | |
) | |
resp_json = response.json() | |
if 'data' not in resp_json: | |
logger.error( | |
f'Error getting model info from LiteLLM proxy: {resp_json}' | |
) | |
all_model_info = resp_json.get('data', []) | |
current_model_info = next( | |
( | |
info | |
for info in all_model_info | |
if info['model_name'] | |
== self.config.model.removeprefix('litellm_proxy/') | |
), | |
None, | |
) | |
if current_model_info: | |
self.model_info = current_model_info['model_info'] | |
logger.debug(f'Got model info from litellm proxy: {self.model_info}') | |
# Last two attempts to get model info from NAME | |
if not self.model_info: | |
try: | |
self.model_info = litellm.get_model_info( | |
self.config.model.split(':')[0] | |
) | |
# noinspection PyBroadException | |
except Exception: | |
pass | |
if not self.model_info: | |
try: | |
self.model_info = litellm.get_model_info( | |
self.config.model.split('/')[-1] | |
) | |
# noinspection PyBroadException | |
except Exception: | |
pass | |
from openhands.io import json | |
logger.debug( | |
f'Model info: {json.dumps({"model": self.config.model, "base_url": self.config.base_url}, indent=2)}' | |
) | |
if self.config.model.startswith('huggingface'): | |
# HF doesn't support the OpenAI default value for top_p (1) | |
logger.debug( | |
f'Setting top_p to 0.9 for Hugging Face model: {self.config.model}' | |
) | |
self.config.top_p = 0.9 if self.config.top_p == 1 else self.config.top_p | |
# Set the max tokens in an LM-specific way if not set | |
if self.config.max_input_tokens is None: | |
if ( | |
self.model_info is not None | |
and 'max_input_tokens' in self.model_info | |
and isinstance(self.model_info['max_input_tokens'], int) | |
): | |
self.config.max_input_tokens = self.model_info['max_input_tokens'] | |
else: | |
# Safe fallback for any potentially viable model | |
self.config.max_input_tokens = 4096 | |
if self.config.max_output_tokens is None: | |
# Safe default for any potentially viable model | |
self.config.max_output_tokens = 4096 | |
if self.model_info is not None: | |
# max_output_tokens has precedence over max_tokens, if either exists. | |
# litellm has models with both, one or none of these 2 parameters! | |
if 'max_output_tokens' in self.model_info and isinstance( | |
self.model_info['max_output_tokens'], int | |
): | |
self.config.max_output_tokens = self.model_info['max_output_tokens'] | |
elif 'max_tokens' in self.model_info and isinstance( | |
self.model_info['max_tokens'], int | |
): | |
self.config.max_output_tokens = self.model_info['max_tokens'] | |
if any( | |
model in self.config.model | |
for model in ['claude-3-7-sonnet', 'claude-3.7-sonnet'] | |
): | |
self.config.max_output_tokens = 64000 # litellm set max to 128k, but that requires a header to be set | |
# Initialize function calling capability | |
# Check if model name is in our supported list | |
model_name_supported = ( | |
self.config.model in FUNCTION_CALLING_SUPPORTED_MODELS | |
or self.config.model.split('/')[-1] in FUNCTION_CALLING_SUPPORTED_MODELS | |
or any(m in self.config.model for m in FUNCTION_CALLING_SUPPORTED_MODELS) | |
) | |
# Handle native_tool_calling user-defined configuration | |
if self.config.native_tool_calling is None: | |
self._function_calling_active = model_name_supported | |
else: | |
self._function_calling_active = self.config.native_tool_calling | |
def vision_is_active(self) -> bool: | |
with warnings.catch_warnings(): | |
warnings.simplefilter('ignore') | |
return not self.config.disable_vision and self._supports_vision() | |
def _supports_vision(self) -> bool: | |
"""Acquire from litellm if model is vision capable. | |
Returns: | |
bool: True if model is vision capable. Return False if model not supported by litellm. | |
""" | |
# litellm.supports_vision currently returns False for 'openai/gpt-...' or 'anthropic/claude-...' (with prefixes) | |
# but model_info will have the correct value for some reason. | |
# we can go with it, but we will need to keep an eye if model_info is correct for Vertex or other providers | |
# remove when litellm is updated to fix https://github.com/BerriAI/litellm/issues/5608 | |
# Check both the full model name and the name after proxy prefix for vision support | |
return ( | |
litellm.supports_vision(self.config.model) | |
or litellm.supports_vision(self.config.model.split('/')[-1]) | |
or ( | |
self.model_info is not None | |
and self.model_info.get('supports_vision', False) | |
) | |
) | |
def is_caching_prompt_active(self) -> bool: | |
"""Check if prompt caching is supported and enabled for current model. | |
Returns: | |
boolean: True if prompt caching is supported and enabled for the given model. | |
""" | |
return ( | |
self.config.caching_prompt is True | |
and ( | |
self.config.model in CACHE_PROMPT_SUPPORTED_MODELS | |
or self.config.model.split('/')[-1] in CACHE_PROMPT_SUPPORTED_MODELS | |
) | |
# We don't need to look-up model_info, because only Anthropic models needs the explicit caching breakpoint | |
) | |
def is_function_calling_active(self) -> bool: | |
"""Returns whether function calling is supported and enabled for this LLM instance. | |
The result is cached during initialization for performance. | |
""" | |
return self._function_calling_active | |
def _post_completion(self, response: ModelResponse) -> float: | |
"""Post-process the completion response. | |
Logs the cost and usage stats of the completion call. | |
""" | |
try: | |
cur_cost = self._completion_cost(response) | |
except Exception: | |
cur_cost = 0 | |
stats = '' | |
if self.cost_metric_supported: | |
# keep track of the cost | |
stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % ( | |
cur_cost, | |
self.metrics.accumulated_cost, | |
) | |
# Add latency to stats if available | |
if self.metrics.response_latencies: | |
latest_latency = self.metrics.response_latencies[-1] | |
stats += 'Response Latency: %.3f seconds\n' % latest_latency.latency | |
usage: Usage | None = response.get('usage') | |
response_id = response.get('id', 'unknown') | |
if usage: | |
# keep track of the input and output tokens | |
prompt_tokens = usage.get('prompt_tokens', 0) | |
completion_tokens = usage.get('completion_tokens', 0) | |
if prompt_tokens: | |
stats += 'Input tokens: ' + str(prompt_tokens) | |
if completion_tokens: | |
stats += ( | |
(' | ' if prompt_tokens else '') | |
+ 'Output tokens: ' | |
+ str(completion_tokens) | |
+ '\n' | |
) | |
# read the prompt cache hit, if any | |
prompt_tokens_details: PromptTokensDetails = usage.get( | |
'prompt_tokens_details' | |
) | |
cache_hit_tokens = ( | |
prompt_tokens_details.cached_tokens | |
if prompt_tokens_details and prompt_tokens_details.cached_tokens | |
else 0 | |
) | |
if cache_hit_tokens: | |
stats += 'Input tokens (cache hit): ' + str(cache_hit_tokens) + '\n' | |
# For Anthropic, the cache writes have a different cost than regular input tokens | |
# but litellm doesn't separate them in the usage stats | |
# we can read it from the provider-specific extra field | |
model_extra = usage.get('model_extra', {}) | |
cache_write_tokens = model_extra.get('cache_creation_input_tokens', 0) | |
if cache_write_tokens: | |
stats += 'Input tokens (cache write): ' + str(cache_write_tokens) + '\n' | |
# Get context window from model info | |
context_window = 0 | |
if self.model_info and 'max_input_tokens' in self.model_info: | |
context_window = self.model_info['max_input_tokens'] | |
logger.debug(f'Using context window: {context_window}') | |
# Record in metrics | |
# We'll treat cache_hit_tokens as "cache read" and cache_write_tokens as "cache write" | |
self.metrics.add_token_usage( | |
prompt_tokens=prompt_tokens, | |
completion_tokens=completion_tokens, | |
cache_read_tokens=cache_hit_tokens, | |
cache_write_tokens=cache_write_tokens, | |
context_window=context_window, | |
response_id=response_id, | |
) | |
# log the stats | |
if stats: | |
logger.debug(stats) | |
return cur_cost | |
def get_token_count(self, messages: list[dict] | list[Message]) -> int: | |
"""Get the number of tokens in a list of messages. Use dicts for better token counting. | |
Args: | |
messages (list): A list of messages, either as a list of dicts or as a list of Message objects. | |
Returns: | |
int: The number of tokens. | |
""" | |
# attempt to convert Message objects to dicts, litellm expects dicts | |
if ( | |
isinstance(messages, list) | |
and len(messages) > 0 | |
and isinstance(messages[0], Message) | |
): | |
logger.info( | |
'Message objects now include serialized tool calls in token counting' | |
) | |
# Assert the expected type for format_messages_for_llm | |
assert isinstance(messages, list) and all( | |
isinstance(m, Message) for m in messages | |
), 'Expected list of Message objects' | |
# We've already asserted that messages is a list of Message objects | |
# Use explicit typing to satisfy mypy | |
messages_typed: list[Message] = messages # type: ignore | |
messages = self.format_messages_for_llm(messages_typed) | |
# try to get the token count with the default litellm tokenizers | |
# or the custom tokenizer if set for this LLM configuration | |
try: | |
return int( | |
litellm.token_counter( | |
model=self.config.model, | |
messages=messages, | |
custom_tokenizer=self.tokenizer, | |
) | |
) | |
except Exception as e: | |
# limit logspam in case token count is not supported | |
logger.error( | |
f'Error getting token count for\n model {self.config.model}\n{e}' | |
+ ( | |
f'\ncustom_tokenizer: {self.config.custom_tokenizer}' | |
if self.config.custom_tokenizer is not None | |
else '' | |
) | |
) | |
return 0 | |
def _is_local(self) -> bool: | |
"""Determines if the system is using a locally running LLM. | |
Returns: | |
boolean: True if executing a local model. | |
""" | |
if self.config.base_url is not None: | |
for substring in ['localhost', '127.0.0.1', '0.0.0.0']: | |
if substring in self.config.base_url: | |
return True | |
elif self.config.model is not None: | |
if self.config.model.startswith('ollama'): | |
return True | |
return False | |
def _completion_cost(self, response: Any) -> float: | |
"""Calculate completion cost and update metrics with running total. | |
Calculate the cost of a completion response based on the model. Local models are treated as free. | |
Add the current cost into total cost in metrics. | |
Args: | |
response: A response from a model invocation. | |
Returns: | |
number: The cost of the response. | |
""" | |
if not self.cost_metric_supported: | |
return 0.0 | |
extra_kwargs = {} | |
if ( | |
self.config.input_cost_per_token is not None | |
and self.config.output_cost_per_token is not None | |
): | |
cost_per_token = CostPerToken( | |
input_cost_per_token=self.config.input_cost_per_token, | |
output_cost_per_token=self.config.output_cost_per_token, | |
) | |
logger.debug(f'Using custom cost per token: {cost_per_token}') | |
extra_kwargs['custom_cost_per_token'] = cost_per_token | |
# try directly get response_cost from response | |
_hidden_params = getattr(response, '_hidden_params', {}) | |
cost = _hidden_params.get('additional_headers', {}).get( | |
'llm_provider-x-litellm-response-cost', None | |
) | |
if cost is not None: | |
cost = float(cost) | |
logger.debug(f'Got response_cost from response: {cost}') | |
try: | |
if cost is None: | |
try: | |
cost = litellm_completion_cost( | |
completion_response=response, **extra_kwargs | |
) | |
except Exception as e: | |
logger.debug(f'Error getting cost from litellm: {e}') | |
if cost is None: | |
_model_name = '/'.join(self.config.model.split('/')[1:]) | |
cost = litellm_completion_cost( | |
completion_response=response, model=_model_name, **extra_kwargs | |
) | |
logger.debug( | |
f'Using fallback model name {_model_name} to get cost: {cost}' | |
) | |
self.metrics.add_cost(float(cost)) | |
return float(cost) | |
except Exception: | |
self.cost_metric_supported = False | |
logger.debug('Cost calculation not supported for this model.') | |
return 0.0 | |
def __str__(self) -> str: | |
if self.config.api_version: | |
return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})' | |
elif self.config.base_url: | |
return f'LLM(model={self.config.model}, base_url={self.config.base_url})' | |
return f'LLM(model={self.config.model})' | |
def __repr__(self) -> str: | |
return str(self) | |
def reset(self) -> None: | |
self.metrics.reset() | |
def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]: | |
if isinstance(messages, Message): | |
messages = [messages] | |
# set flags to know how to serialize the messages | |
for message in messages: | |
message.cache_enabled = self.is_caching_prompt_active() | |
message.vision_enabled = self.vision_is_active() | |
message.function_calling_enabled = self.is_function_calling_active() | |
if 'deepseek' in self.config.model: | |
message.force_string_serializer = True | |
# let pydantic handle the serialization | |
return [message.model_dump() for message in messages] | |