Spaces:
Build error
Build error
from enum import Enum | |
from typing import Any, Literal | |
from litellm import ChatCompletionMessageToolCall | |
from pydantic import BaseModel, Field, model_serializer | |
class ContentType(Enum): | |
TEXT = 'text' | |
IMAGE_URL = 'image_url' | |
class Content(BaseModel): | |
type: str | |
cache_prompt: bool = False | |
def serialize_model( | |
self, | |
) -> dict[str, str | dict[str, str]] | list[dict[str, str | dict[str, str]]]: | |
raise NotImplementedError('Subclasses should implement this method.') | |
class TextContent(Content): | |
type: str = ContentType.TEXT.value | |
text: str | |
def serialize_model(self) -> dict[str, str | dict[str, str]]: | |
data: dict[str, str | dict[str, str]] = { | |
'type': self.type, | |
'text': self.text, | |
} | |
if self.cache_prompt: | |
data['cache_control'] = {'type': 'ephemeral'} | |
return data | |
class ImageContent(Content): | |
type: str = ContentType.IMAGE_URL.value | |
image_urls: list[str] | |
def serialize_model(self) -> list[dict[str, str | dict[str, str]]]: | |
images: list[dict[str, str | dict[str, str]]] = [] | |
for url in self.image_urls: | |
images.append({'type': self.type, 'image_url': {'url': url}}) | |
if self.cache_prompt and images: | |
images[-1]['cache_control'] = {'type': 'ephemeral'} | |
return images | |
class Message(BaseModel): | |
# NOTE: this is not the same as EventSource | |
# These are the roles in the LLM's APIs | |
role: Literal['user', 'system', 'assistant', 'tool'] | |
content: list[TextContent | ImageContent] = Field(default_factory=list) | |
cache_enabled: bool = False | |
vision_enabled: bool = False | |
# function calling | |
function_calling_enabled: bool = False | |
# - tool calls (from LLM) | |
tool_calls: list[ChatCompletionMessageToolCall] | None = None | |
# - tool execution result (to LLM) | |
tool_call_id: str | None = None | |
name: str | None = None # name of the tool | |
# force string serializer | |
force_string_serializer: bool = False | |
def contains_image(self) -> bool: | |
return any(isinstance(content, ImageContent) for content in self.content) | |
def serialize_model(self) -> dict[str, Any]: | |
# We need two kinds of serializations: | |
# - into a single string: for providers that don't support list of content items (e.g. no vision, no tool calls) | |
# - into a list of content items: the new APIs of providers with vision/prompt caching/tool calls | |
# NOTE: remove this when litellm or providers support the new API | |
if not self.force_string_serializer and ( | |
self.cache_enabled or self.vision_enabled or self.function_calling_enabled | |
): | |
return self._list_serializer() | |
# some providers, like HF and Groq/llama, don't support a list here, but a single string | |
return self._string_serializer() | |
def _string_serializer(self) -> dict[str, Any]: | |
# convert content to a single string | |
content = '\n'.join( | |
item.text for item in self.content if isinstance(item, TextContent) | |
) | |
message_dict: dict[str, Any] = {'content': content, 'role': self.role} | |
# add tool call keys if we have a tool call or response | |
return self._add_tool_call_keys(message_dict) | |
def _list_serializer(self) -> dict[str, Any]: | |
content: list[dict[str, Any]] = [] | |
role_tool_with_prompt_caching = False | |
for item in self.content: | |
d = item.model_dump() | |
# We have to remove cache_prompt for tool content and move it up to the message level | |
# See discussion here for details: https://github.com/BerriAI/litellm/issues/6422#issuecomment-2438765472 | |
if self.role == 'tool' and item.cache_prompt: | |
role_tool_with_prompt_caching = True | |
if isinstance(item, TextContent): | |
d.pop('cache_control', None) | |
elif isinstance(item, ImageContent): | |
# ImageContent.model_dump() always returns a list | |
# We know d is a list of dicts for ImageContent | |
if hasattr(d, '__iter__'): | |
for d_item in d: | |
if hasattr(d_item, 'pop'): | |
d_item.pop('cache_control', None) | |
if isinstance(item, TextContent): | |
content.append(d) | |
elif isinstance(item, ImageContent) and self.vision_enabled: | |
# ImageContent.model_dump() always returns a list | |
# We know d is a list for ImageContent | |
content.extend([d] if isinstance(d, dict) else d) | |
message_dict: dict[str, Any] = {'content': content, 'role': self.role} | |
if role_tool_with_prompt_caching: | |
message_dict['cache_control'] = {'type': 'ephemeral'} | |
# add tool call keys if we have a tool call or response | |
return self._add_tool_call_keys(message_dict) | |
def _add_tool_call_keys(self, message_dict: dict[str, Any]) -> dict[str, Any]: | |
"""Add tool call keys if we have a tool call or response. | |
NOTE: this is necessary for both native and non-native tool calling | |
""" | |
# an assistant message calling a tool | |
if self.tool_calls is not None: | |
message_dict['tool_calls'] = [ | |
{ | |
'id': tool_call.id, | |
'type': 'function', | |
'function': { | |
'name': tool_call.function.name, | |
'arguments': tool_call.function.arguments, | |
}, | |
} | |
for tool_call in self.tool_calls | |
] | |
# an observation message with tool response | |
if self.tool_call_id is not None: | |
assert self.name is not None, ( | |
'name is required when tool_call_id is not None' | |
) | |
message_dict['tool_call_id'] = self.tool_call_id | |
message_dict['name'] = self.name | |
return message_dict | |