AmmarFahmy
adding all files
105b369
from typing import Optional, List, Iterator, Dict, Any, Union
from phi.llm.base import LLM
from phi.llm.message import Message
from phi.tools.function import FunctionCall
from phi.utils.log import logger
from phi.utils.timer import Timer
from phi.utils.tools import get_function_call_for_tool_call
try:
from mistralai.client import MistralClient
from mistralai.models.chat_completion import (
ChatMessage,
DeltaMessage,
ResponseFormat as ChatCompletionResponseFormat,
ChatCompletionResponse,
ChatCompletionStreamResponse,
ToolCall as ChoiceDeltaToolCall,
)
except ImportError:
logger.error("`mistralai` not installed")
raise
class Mistral(LLM):
name: str = "Mistral"
model: str = "mistral-large-latest"
# -*- Request parameters
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_p: Optional[float] = None
random_seed: Optional[int] = None
safe_mode: bool = False
safe_prompt: bool = False
response_format: Optional[Union[Dict[str, Any], ChatCompletionResponseFormat]] = None
request_params: Optional[Dict[str, Any]] = None
# -*- Client parameters
api_key: Optional[str] = None
endpoint: Optional[str] = None
max_retries: Optional[int] = None
timeout: Optional[int] = None
client_params: Optional[Dict[str, Any]] = None
# -*- Provide the MistralClient manually
mistral_client: Optional[MistralClient] = None
@property
def client(self) -> MistralClient:
if self.mistral_client:
return self.mistral_client
_client_params: Dict[str, Any] = {}
if self.api_key:
_client_params["api_key"] = self.api_key
if self.endpoint:
_client_params["endpoint"] = self.endpoint
if self.max_retries:
_client_params["max_retries"] = self.max_retries
if self.timeout:
_client_params["timeout"] = self.timeout
if self.client_params:
_client_params.update(self.client_params)
return MistralClient(**_client_params)
@property
def api_kwargs(self) -> Dict[str, Any]:
_request_params: Dict[str, Any] = {}
if self.temperature:
_request_params["temperature"] = self.temperature
if self.max_tokens:
_request_params["max_tokens"] = self.max_tokens
if self.top_p:
_request_params["top_p"] = self.top_p
if self.random_seed:
_request_params["random_seed"] = self.random_seed
if self.safe_mode:
_request_params["safe_mode"] = self.safe_mode
if self.safe_prompt:
_request_params["safe_prompt"] = self.safe_prompt
if self.tools:
_request_params["tools"] = self.get_tools_for_api()
if self.tool_choice is None:
_request_params["tool_choice"] = "auto"
else:
_request_params["tool_choice"] = self.tool_choice
if self.request_params:
_request_params.update(self.request_params)
return _request_params
def to_dict(self) -> Dict[str, Any]:
_dict = super().to_dict()
if self.temperature:
_dict["temperature"] = self.temperature
if self.max_tokens:
_dict["max_tokens"] = self.max_tokens
if self.random_seed:
_dict["random_seed"] = self.random_seed
if self.safe_mode:
_dict["safe_mode"] = self.safe_mode
if self.safe_prompt:
_dict["safe_prompt"] = self.safe_prompt
if self.response_format:
_dict["response_format"] = self.response_format
return _dict
def invoke(self, messages: List[Message]) -> ChatCompletionResponse:
return self.client.chat(
messages=[m.to_dict() for m in messages],
model=self.model,
**self.api_kwargs,
)
def invoke_stream(self, messages: List[Message]) -> Iterator[ChatCompletionStreamResponse]:
yield from self.client.chat_stream(
messages=[m.to_dict() for m in messages],
model=self.model,
**self.api_kwargs,
) # type: ignore
def response(self, messages: List[Message]) -> str:
logger.debug("---------- Mistral Response Start ----------")
# -*- Log messages for debugging
for m in messages:
m.log()
response_timer = Timer()
response_timer.start()
response: ChatCompletionResponse = self.invoke(messages=messages)
response_timer.stop()
logger.debug(f"Time to generate response: {response_timer.elapsed:.4f}s")
# logger.debug(f"Mistral response type: {type(response)}")
# logger.debug(f"Mistral response: {response}")
# -*- Parse response
response_message: ChatMessage = response.choices[0].message
# -*- Create assistant message
assistant_message = Message(
role=response_message.role or "assistant",
content=response_message.content,
)
if response_message.tool_calls is not None and len(response_message.tool_calls) > 0:
assistant_message.tool_calls = [t.model_dump() for t in response_message.tool_calls]
# -*- Update usage metrics
# Add response time to metrics
assistant_message.metrics["time"] = response_timer.elapsed
if "response_times" not in self.metrics:
self.metrics["response_times"] = []
self.metrics["response_times"].append(response_timer.elapsed)
# Add token usage to metrics
self.metrics.update(response.usage.model_dump())
# -*- Add assistant message to messages
messages.append(assistant_message)
assistant_message.log()
# -*- Parse and run tool calls
if assistant_message.tool_calls is not None and len(assistant_message.tool_calls) > 0:
final_response = ""
function_calls_to_run: List[FunctionCall] = []
for tool_call in assistant_message.tool_calls:
_tool_call_id = tool_call.get("id")
_function_call = get_function_call_for_tool_call(tool_call, self.functions)
if _function_call is None:
messages.append(
Message(role="tool", tool_call_id=_tool_call_id, content="Could not find function to call.")
)
continue
if _function_call.error is not None:
messages.append(Message(role="tool", tool_call_id=_tool_call_id, content=_function_call.error))
continue
function_calls_to_run.append(_function_call)
if self.show_tool_calls:
if len(function_calls_to_run) == 1:
final_response += f"\n - Running: {function_calls_to_run[0].get_call_str()}\n\n"
elif len(function_calls_to_run) > 1:
final_response += "\nRunning:"
for _f in function_calls_to_run:
final_response += f"\n - {_f.get_call_str()}"
final_response += "\n\n"
function_call_results = self.run_function_calls(function_calls_to_run)
if len(function_call_results) > 0:
messages.extend(function_call_results)
# -*- Get new response using result of tool call
final_response += self.response(messages=messages)
return final_response
logger.debug("---------- Mistral Response End ----------")
# -*- Return content if no function calls are present
if assistant_message.content is not None:
return assistant_message.get_content_string()
return "Something went wrong, please try again."
def response_stream(self, messages: List[Message]) -> Iterator[str]:
logger.debug("---------- Mistral Response Start ----------")
# -*- Log messages for debugging
for m in messages:
m.log()
assistant_message_role = None
assistant_message_content = ""
assistant_message_tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
response_timer = Timer()
response_timer.start()
for response in self.invoke_stream(messages=messages):
# logger.debug(f"Mistral response type: {type(response)}")
# logger.debug(f"Mistral response: {response}")
# -*- Parse response
response_delta: DeltaMessage = response.choices[0].delta
if assistant_message_role is None and response_delta.role is not None:
assistant_message_role = response_delta.role
response_content: Optional[str] = response_delta.content
response_tool_calls: Optional[List[ChoiceDeltaToolCall]] = response_delta.tool_calls
# -*- Return content if present, otherwise get tool call
if response_content is not None:
assistant_message_content += response_content
yield response_content
# -*- Parse tool calls
if response_tool_calls is not None and len(response_tool_calls) > 0:
if assistant_message_tool_calls is None:
assistant_message_tool_calls = []
assistant_message_tool_calls.extend(response_tool_calls)
response_timer.stop()
logger.debug(f"Time to generate response: {response_timer.elapsed:.4f}s")
# -*- Create assistant message
assistant_message = Message(role=(assistant_message_role or "assistant"))
# -*- Add content to assistant message
if assistant_message_content != "":
assistant_message.content = assistant_message_content
# -*- Add tool calls to assistant message
if assistant_message_tool_calls is not None:
assistant_message.tool_calls = [t.model_dump() for t in assistant_message_tool_calls]
# -*- Update usage metrics
# Add response time to metrics
assistant_message.metrics["time"] = response_timer.elapsed
if "response_times" not in self.metrics:
self.metrics["response_times"] = []
self.metrics["response_times"].append(response_timer.elapsed)
# -*- Add assistant message to messages
messages.append(assistant_message)
assistant_message.log()
# -*- Parse and run tool calls
if assistant_message.tool_calls is not None and len(assistant_message.tool_calls) > 0:
function_calls_to_run: List[FunctionCall] = []
for tool_call in assistant_message.tool_calls:
_tool_call_id = tool_call.get("id")
_function_call = get_function_call_for_tool_call(tool_call, self.functions)
if _function_call is None:
messages.append(
Message(role="tool", tool_call_id=_tool_call_id, content="Could not find function to call.")
)
continue
if _function_call.error is not None:
messages.append(Message(role="tool", tool_call_id=_tool_call_id, content=_function_call.error))
continue
function_calls_to_run.append(_function_call)
if self.show_tool_calls:
if len(function_calls_to_run) == 1:
yield f"\n - Running: {function_calls_to_run[0].get_call_str()}\n\n"
elif len(function_calls_to_run) > 1:
yield "\nRunning:"
for _f in function_calls_to_run:
yield f"\n - {_f.get_call_str()}"
yield "\n\n"
function_call_results = self.run_function_calls(function_calls_to_run)
if len(function_call_results) > 0:
messages.extend(function_call_results)
# -*- Yield new response using results of tool calls
yield from self.response_stream(messages=messages)
logger.debug("---------- Mistral Response End ----------")