TestLLM / litellm /llms /cohere /chat /transformation.py
Raju2024's picture
Upload 1072 files
e3278e4 verified
import json
import time
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage
from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
from ..common_utils import validate_environment as cohere_validate_environment
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class CohereError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[httpx.Headers] = None,
):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/chat")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
status_code=status_code,
message=message,
headers=headers,
)
class CohereChatConfig(BaseConfig):
"""
Configuration class for Cohere's API interface.
Args:
preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one.
chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model.
generation_id (str, optional): Unique identifier for the generated reply.
response_id (str, optional): Unique identifier for the response.
conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation.
prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'.
connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply.
search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries.
documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite.
temperature (float, optional): A non-negative float that tunes the degree of randomness in generation.
max_tokens (int, optional): The maximum number of tokens the model will generate as part of the response.
k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step.
p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation.
frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
seed (int, optional): A seed to assist reproducibility of the model's response.
"""
preamble: Optional[str] = None
chat_history: Optional[list] = None
generation_id: Optional[str] = None
response_id: Optional[str] = None
conversation_id: Optional[str] = None
prompt_truncation: Optional[str] = None
connectors: Optional[list] = None
search_queries_only: Optional[bool] = None
documents: Optional[list] = None
temperature: Optional[int] = None
max_tokens: Optional[int] = None
k: Optional[int] = None
p: Optional[int] = None
frequency_penalty: Optional[int] = None
presence_penalty: Optional[int] = None
tools: Optional[list] = None
tool_results: Optional[list] = None
seed: Optional[int] = None
def __init__(
self,
preamble: Optional[str] = None,
chat_history: Optional[list] = None,
generation_id: Optional[str] = None,
response_id: Optional[str] = None,
conversation_id: Optional[str] = None,
prompt_truncation: Optional[str] = None,
connectors: Optional[list] = None,
search_queries_only: Optional[bool] = None,
documents: Optional[list] = None,
temperature: Optional[int] = None,
max_tokens: Optional[int] = None,
k: Optional[int] = None,
p: Optional[int] = None,
frequency_penalty: Optional[int] = None,
presence_penalty: Optional[int] = None,
tools: Optional[list] = None,
tool_results: Optional[list] = None,
seed: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return cohere_validate_environment(
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
api_key=api_key,
)
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
"tools",
"tool_choice",
"seed",
"extra_headers",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "stream":
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "n":
optional_params["num_generations"] = value
if param == "top_p":
optional_params["p"] = value
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "tools":
optional_params["tools"] = value
if param == "seed":
optional_params["seed"] = value
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
## Load Config
for k, v in litellm.CohereChatConfig.get_config().items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
most_recent_message, chat_history = cohere_messages_pt_v2(
messages=messages, model=model, llm_provider="cohere_chat"
)
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"])
optional_params["tools"] = cohere_tools
if isinstance(most_recent_message, dict):
optional_params["tool_results"] = [most_recent_message]
elif isinstance(most_recent_message, str):
optional_params["message"] = most_recent_message
## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails
if len(chat_history) > 0 and chat_history[-1]["role"] == "USER":
optional_params["force_single_step"] = True
return optional_params
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
raw_response_json = raw_response.json()
model_response.choices[0].message.content = raw_response_json["text"] # type: ignore
except Exception:
raise CohereError(
message=raw_response.text, status_code=raw_response.status_code
)
## ADD CITATIONS
if "citations" in raw_response_json:
setattr(model_response, "citations", raw_response_json["citations"])
## Tool calling response
cohere_tools_response = raw_response_json.get("tool_calls", None)
if cohere_tools_response is not None and cohere_tools_response != []:
# convert cohere_tools_response to OpenAI response format
tool_calls = []
for tool in cohere_tools_response:
function_name = tool.get("name", "")
generation_id = tool.get("generation_id", "")
parameters = tool.get("parameters", {})
tool_call = {
"id": f"call_{generation_id}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(parameters),
},
}
tool_calls.append(tool_call)
_message = litellm.Message(
tool_calls=tool_calls,
content=None,
)
model_response.choices[0].message = _message # type: ignore
## CALCULATING USAGE - use cohere `billed_units` for returning usage
billed_units = raw_response_json.get("meta", {}).get("billed_units", {})
prompt_tokens = billed_units.get("input_tokens", 0)
completion_tokens = billed_units.get("output_tokens", 0)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def _construct_cohere_tool(
self,
tools: Optional[list] = None,
):
if tools is None:
tools = []
cohere_tools = []
for tool in tools:
cohere_tool = self._translate_openai_tool_to_cohere(tool)
cohere_tools.append(cohere_tool)
return cohere_tools
def _translate_openai_tool_to_cohere(
self,
openai_tool: dict,
):
# cohere tools look like this
"""
{
"name": "query_daily_sales_report",
"description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
"parameter_definitions": {
"day": {
"description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
"type": "str",
"required": True
}
}
}
"""
# OpenAI tools look like this
"""
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
"""
cohere_tool = {
"name": openai_tool["function"]["name"],
"description": openai_tool["function"]["description"],
"parameter_definitions": {},
}
for param_name, param_def in openai_tool["function"]["parameters"][
"properties"
].items():
required_params = (
openai_tool.get("function", {})
.get("parameters", {})
.get("required", [])
)
cohere_param_def = {
"description": param_def.get("description", ""),
"type": param_def.get("type", ""),
"required": param_name in required_params,
}
cohere_tool["parameter_definitions"][param_name] = cohere_param_def
return cohere_tool
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return CohereModelResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CohereError(status_code=status_code, message=error_message)