import json import time from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union import httpx import litellm from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2 from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import ModelResponse, Usage from ..common_utils import ModelResponseIterator as CohereModelResponseIterator from ..common_utils import validate_environment as cohere_validate_environment if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj LiteLLMLoggingObj = _LiteLLMLoggingObj else: LiteLLMLoggingObj = Any class CohereError(BaseLLMException): def __init__( self, status_code: int, message: str, headers: Optional[httpx.Headers] = None, ): self.status_code = status_code self.message = message self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/chat") self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( status_code=status_code, message=message, headers=headers, ) class CohereChatConfig(BaseConfig): """ Configuration class for Cohere's API interface. Args: preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one. chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model. generation_id (str, optional): Unique identifier for the generated reply. response_id (str, optional): Unique identifier for the response. conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation. prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'. connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply. search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries. documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite. temperature (float, optional): A non-negative float that tunes the degree of randomness in generation. max_tokens (int, optional): The maximum number of tokens the model will generate as part of the response. k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step. p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation. frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens. presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens. tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking. tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools. seed (int, optional): A seed to assist reproducibility of the model's response. """ preamble: Optional[str] = None chat_history: Optional[list] = None generation_id: Optional[str] = None response_id: Optional[str] = None conversation_id: Optional[str] = None prompt_truncation: Optional[str] = None connectors: Optional[list] = None search_queries_only: Optional[bool] = None documents: Optional[list] = None temperature: Optional[int] = None max_tokens: Optional[int] = None k: Optional[int] = None p: Optional[int] = None frequency_penalty: Optional[int] = None presence_penalty: Optional[int] = None tools: Optional[list] = None tool_results: Optional[list] = None seed: Optional[int] = None def __init__( self, preamble: Optional[str] = None, chat_history: Optional[list] = None, generation_id: Optional[str] = None, response_id: Optional[str] = None, conversation_id: Optional[str] = None, prompt_truncation: Optional[str] = None, connectors: Optional[list] = None, search_queries_only: Optional[bool] = None, documents: Optional[list] = None, temperature: Optional[int] = None, max_tokens: Optional[int] = None, k: Optional[int] = None, p: Optional[int] = None, frequency_penalty: Optional[int] = None, presence_penalty: Optional[int] = None, tools: Optional[list] = None, tool_results: Optional[list] = None, seed: Optional[int] = None, ) -> None: locals_ = locals() for key, value in locals_.items(): if key != "self" and value is not None: setattr(self.__class__, key, value) def validate_environment( self, headers: dict, model: str, messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: return cohere_validate_environment( headers=headers, model=model, messages=messages, optional_params=optional_params, api_key=api_key, ) def get_supported_openai_params(self, model: str) -> List[str]: return [ "stream", "temperature", "max_tokens", "top_p", "frequency_penalty", "presence_penalty", "stop", "n", "tools", "tool_choice", "seed", "extra_headers", ] def map_openai_params( self, non_default_params: dict, optional_params: dict, model: str, drop_params: bool, ) -> dict: for param, value in non_default_params.items(): if param == "stream": optional_params["stream"] = value if param == "temperature": optional_params["temperature"] = value if param == "max_tokens": optional_params["max_tokens"] = value if param == "n": optional_params["num_generations"] = value if param == "top_p": optional_params["p"] = value if param == "frequency_penalty": optional_params["frequency_penalty"] = value if param == "presence_penalty": optional_params["presence_penalty"] = value if param == "stop": optional_params["stop_sequences"] = value if param == "tools": optional_params["tools"] = value if param == "seed": optional_params["seed"] = value return optional_params def transform_request( self, model: str, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, headers: dict, ) -> dict: ## Load Config for k, v in litellm.CohereChatConfig.get_config().items(): if ( k not in optional_params ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v most_recent_message, chat_history = cohere_messages_pt_v2( messages=messages, model=model, llm_provider="cohere_chat" ) ## Handle Tool Calling if "tools" in optional_params: _is_function_call = True cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"]) optional_params["tools"] = cohere_tools if isinstance(most_recent_message, dict): optional_params["tool_results"] = [most_recent_message] elif isinstance(most_recent_message, str): optional_params["message"] = most_recent_message ## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails if len(chat_history) > 0 and chat_history[-1]["role"] == "USER": optional_params["force_single_step"] = True return optional_params def transform_response( self, model: str, raw_response: httpx.Response, model_response: ModelResponse, logging_obj: LiteLLMLoggingObj, request_data: dict, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, encoding: Any, api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: try: raw_response_json = raw_response.json() model_response.choices[0].message.content = raw_response_json["text"] # type: ignore except Exception: raise CohereError( message=raw_response.text, status_code=raw_response.status_code ) ## ADD CITATIONS if "citations" in raw_response_json: setattr(model_response, "citations", raw_response_json["citations"]) ## Tool calling response cohere_tools_response = raw_response_json.get("tool_calls", None) if cohere_tools_response is not None and cohere_tools_response != []: # convert cohere_tools_response to OpenAI response format tool_calls = [] for tool in cohere_tools_response: function_name = tool.get("name", "") generation_id = tool.get("generation_id", "") parameters = tool.get("parameters", {}) tool_call = { "id": f"call_{generation_id}", "type": "function", "function": { "name": function_name, "arguments": json.dumps(parameters), }, } tool_calls.append(tool_call) _message = litellm.Message( tool_calls=tool_calls, content=None, ) model_response.choices[0].message = _message # type: ignore ## CALCULATING USAGE - use cohere `billed_units` for returning usage billed_units = raw_response_json.get("meta", {}).get("billed_units", {}) prompt_tokens = billed_units.get("input_tokens", 0) completion_tokens = billed_units.get("output_tokens", 0) model_response.created = int(time.time()) model_response.model = model usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ) setattr(model_response, "usage", usage) return model_response def _construct_cohere_tool( self, tools: Optional[list] = None, ): if tools is None: tools = [] cohere_tools = [] for tool in tools: cohere_tool = self._translate_openai_tool_to_cohere(tool) cohere_tools.append(cohere_tool) return cohere_tools def _translate_openai_tool_to_cohere( self, openai_tool: dict, ): # cohere tools look like this """ { "name": "query_daily_sales_report", "description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.", "parameter_definitions": { "day": { "description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.", "type": "str", "required": True } } } """ # OpenAI tools look like this """ { "type": "function", "function": { "name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "The city and state, e.g. San Francisco, CA", }, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, "required": ["location"], }, }, } """ cohere_tool = { "name": openai_tool["function"]["name"], "description": openai_tool["function"]["description"], "parameter_definitions": {}, } for param_name, param_def in openai_tool["function"]["parameters"][ "properties" ].items(): required_params = ( openai_tool.get("function", {}) .get("parameters", {}) .get("required", []) ) cohere_param_def = { "description": param_def.get("description", ""), "type": param_def.get("type", ""), "required": param_name in required_params, } cohere_tool["parameter_definitions"][param_name] = cohere_param_def return cohere_tool def get_model_response_iterator( self, streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], sync_stream: bool, json_mode: Optional[bool] = False, ): return CohereModelResponseIterator( streaming_response=streaming_response, sync_stream=sync_stream, json_mode=json_mode, ) def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] ) -> BaseLLMException: return CohereError(status_code=status_code, message=error_message)