AmmarFahmy
adding all files
105b369
import httpx
from typing import Optional, List, Iterator, Dict, Any, Union
from phi.llm.base import LLM
from phi.llm.message import Message
from phi.tools.function import FunctionCall
from phi.utils.log import logger
from phi.utils.timer import Timer
from phi.utils.tools import get_function_call_for_tool_call
try:
from groq import Groq as GroqClient
from groq.types.chat.chat_completion import ChatCompletion, ChoiceMessage
from groq.lib.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, ChoiceDeltaToolCall
except ImportError:
logger.error("`groq` not installed")
raise
class Groq(LLM):
name: str = "Groq"
model: str = "mixtral-8x7b-32768"
# -*- Request parameters
frequency_penalty: Optional[float] = None
logit_bias: Optional[Any] = None
logprobs: Optional[bool] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
response_format: Optional[Dict[str, Any]] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
temperature: Optional[float] = None
top_logprobs: Optional[int] = None
top_p: Optional[float] = None
user: Optional[str] = None
extra_headers: Optional[Any] = None
extra_query: Optional[Any] = None
request_params: Optional[Dict[str, Any]] = None
# -*- Client parameters
api_key: Optional[str] = None
base_url: Optional[Union[str, httpx.URL]] = None
timeout: Optional[int] = None
max_retries: Optional[int] = None
default_headers: Optional[Any] = None
default_query: Optional[Any] = None
client_params: Optional[Dict[str, Any]] = None
# -*- Provide the Groq manually
groq_client: Optional[GroqClient] = None
@property
def client(self) -> GroqClient:
if self.groq_client:
return self.groq_client
_client_params: Dict[str, Any] = {}
if self.api_key:
_client_params["api_key"] = self.api_key
if self.base_url:
_client_params["base_url"] = self.base_url
if self.timeout:
_client_params["timeout"] = self.timeout
if self.max_retries:
_client_params["max_retries"] = self.max_retries
if self.default_headers:
_client_params["default_headers"] = self.default_headers
if self.default_query:
_client_params["default_query"] = self.default_query
if self.client_params:
_client_params.update(self.client_params)
return GroqClient(**_client_params)
@property
def api_kwargs(self) -> Dict[str, Any]:
_request_params: Dict[str, Any] = {}
if self.frequency_penalty:
_request_params["frequency_penalty"] = self.frequency_penalty
if self.logit_bias:
_request_params["logit_bias"] = self.logit_bias
if self.logprobs:
_request_params["logprobs"] = self.logprobs
if self.max_tokens:
_request_params["max_tokens"] = self.max_tokens
if self.presence_penalty:
_request_params["presence_penalty"] = self.presence_penalty
if self.response_format:
_request_params["response_format"] = self.response_format
if self.seed:
_request_params["seed"] = self.seed
if self.stop:
_request_params["stop"] = self.stop
if self.temperature:
_request_params["temperature"] = self.temperature
if self.top_logprobs:
_request_params["top_logprobs"] = self.top_logprobs
if self.top_p:
_request_params["top_p"] = self.top_p
if self.user:
_request_params["user"] = self.user
if self.extra_headers:
_request_params["extra_headers"] = self.extra_headers
if self.extra_query:
_request_params["extra_query"] = self.extra_query
if self.tools:
_request_params["tools"] = self.get_tools_for_api()
if self.tool_choice is None:
_request_params["tool_choice"] = "auto"
else:
_request_params["tool_choice"] = self.tool_choice
if self.request_params:
_request_params.update(self.request_params)
return _request_params
def to_dict(self) -> Dict[str, Any]:
_dict = super().to_dict()
if self.frequency_penalty:
_dict["frequency_penalty"] = self.frequency_penalty
if self.logit_bias:
_dict["logit_bias"] = self.logit_bias
if self.logprobs:
_dict["logprobs"] = self.logprobs
if self.max_tokens:
_dict["max_tokens"] = self.max_tokens
if self.presence_penalty:
_dict["presence_penalty"] = self.presence_penalty
if self.response_format:
_dict["response_format"] = self.response_format
if self.seed:
_dict["seed"] = self.seed
if self.stop:
_dict["stop"] = self.stop
if self.temperature:
_dict["temperature"] = self.temperature
if self.top_logprobs:
_dict["top_logprobs"] = self.top_logprobs
if self.top_p:
_dict["top_p"] = self.top_p
if self.user:
_dict["user"] = self.user
if self.extra_headers:
_dict["extra_headers"] = self.extra_headers
if self.extra_query:
_dict["extra_query"] = self.extra_query
if self.tools:
_dict["tools"] = self.get_tools_for_api()
if self.tool_choice is None:
_dict["tool_choice"] = "auto"
else:
_dict["tool_choice"] = self.tool_choice
return _dict
def invoke(self, messages: List[Message]) -> ChatCompletion:
return self.client.chat.completions.create(
model=self.model,
messages=[m.to_dict() for m in messages], # type: ignore
**self.api_kwargs,
)
def invoke_stream(self, messages: List[Message]) -> Iterator[ChatCompletionChunk]:
yield from self.client.chat.completions.create(
model=self.model,
messages=[m.to_dict() for m in messages], # type: ignore
stream=True,
**self.api_kwargs,
)
def response(self, messages: List[Message]) -> str:
logger.debug("---------- Groq Response Start ----------")
# -*- Log messages for debugging
for m in messages:
m.log()
response_timer = Timer()
response_timer.start()
response: ChatCompletion = self.invoke(messages=messages)
response_timer.stop()
logger.debug(f"Time to generate response: {response_timer.elapsed:.4f}s")
# logger.debug(f"Groq response type: {type(response)}")
# logger.debug(f"Groq response: {response}")
# -*- Parse response
response_message: ChoiceMessage = response.choices[0].message
# -*- Create assistant message
assistant_message = Message(
role=response_message.role or "assistant",
content=response_message.content,
)
if response_message.tool_calls is not None and len(response_message.tool_calls) > 0:
assistant_message.tool_calls = [t.model_dump() for t in response_message.tool_calls]
# -*- Update usage metrics
# Add response time to metrics
assistant_message.metrics["time"] = response_timer.elapsed
if "response_times" not in self.metrics:
self.metrics["response_times"] = []
self.metrics["response_times"].append(response_timer.elapsed)
# Add token usage to metrics
if response.usage is not None:
self.metrics.update(response.usage.model_dump())
# -*- Add assistant message to messages
messages.append(assistant_message)
assistant_message.log()
# -*- Parse and run tool calls
if assistant_message.tool_calls is not None and len(assistant_message.tool_calls) > 0:
final_response = ""
function_calls_to_run: List[FunctionCall] = []
for tool_call in assistant_message.tool_calls:
_tool_call_id = tool_call.get("id")
_function_call = get_function_call_for_tool_call(tool_call, self.functions)
if _function_call is None:
messages.append(
Message(role="tool", tool_call_id=_tool_call_id, content="Could not find function to call.")
)
continue
if _function_call.error is not None:
messages.append(Message(role="tool", tool_call_id=_tool_call_id, content=_function_call.error))
continue
function_calls_to_run.append(_function_call)
if self.show_tool_calls:
if len(function_calls_to_run) == 1:
final_response += f"\n - Running: {function_calls_to_run[0].get_call_str()}\n\n"
elif len(function_calls_to_run) > 1:
final_response += "\nRunning:"
for _f in function_calls_to_run:
final_response += f"\n - {_f.get_call_str()}"
final_response += "\n\n"
function_call_results = self.run_function_calls(function_calls_to_run)
if len(function_call_results) > 0:
messages.extend(function_call_results)
# -*- Get new response using result of tool call
final_response += self.response(messages=messages)
return final_response
logger.debug("---------- Groq Response End ----------")
# -*- Return content if no function calls are present
if assistant_message.content is not None:
return assistant_message.get_content_string()
return "Something went wrong, please try again."
def response_stream(self, messages: List[Message]) -> Iterator[str]:
logger.debug("---------- Groq Response Start ----------")
# -*- Log messages for debugging
for m in messages:
m.log()
assistant_message_role = None
assistant_message_content = ""
assistant_message_tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
response_timer = Timer()
response_timer.start()
for response in self.invoke_stream(messages=messages):
# logger.debug(f"Groq response type: {type(response)}")
# logger.debug(f"Groq response: {response}")
# -*- Parse response
response_delta: ChoiceDelta = response.choices[0].delta
if assistant_message_role is None and response_delta.role is not None:
assistant_message_role = response_delta.role
response_content: Optional[str] = response_delta.content
response_tool_calls: Optional[List[ChoiceDeltaToolCall]] = response_delta.tool_calls
# -*- Return content if present, otherwise get tool call
if response_content is not None:
assistant_message_content += response_content
yield response_content
# -*- Parse tool calls
if response_tool_calls is not None and len(response_tool_calls) > 0:
if assistant_message_tool_calls is None:
assistant_message_tool_calls = []
assistant_message_tool_calls.extend(response_tool_calls)
response_timer.stop()
logger.debug(f"Time to generate response: {response_timer.elapsed:.4f}s")
# -*- Create assistant message
assistant_message = Message(role=(assistant_message_role or "assistant"))
# -*- Add content to assistant message
if assistant_message_content != "":
assistant_message.content = assistant_message_content
# -*- Add tool calls to assistant message
if assistant_message_tool_calls is not None:
assistant_message.tool_calls = [t.model_dump() for t in assistant_message_tool_calls]
# -*- Update usage metrics
# Add response time to metrics
assistant_message.metrics["time"] = response_timer.elapsed
if "response_times" not in self.metrics:
self.metrics["response_times"] = []
self.metrics["response_times"].append(response_timer.elapsed)
# -*- Add assistant message to messages
messages.append(assistant_message)
assistant_message.log()
# -*- Parse and run tool calls
if assistant_message.tool_calls is not None and len(assistant_message.tool_calls) > 0:
function_calls_to_run: List[FunctionCall] = []
for tool_call in assistant_message.tool_calls:
_tool_call_id = tool_call.get("id")
_function_call = get_function_call_for_tool_call(tool_call, self.functions)
if _function_call is None:
messages.append(
Message(role="tool", tool_call_id=_tool_call_id, content="Could not find function to call.")
)
continue
if _function_call.error is not None:
messages.append(Message(role="tool", tool_call_id=_tool_call_id, content=_function_call.error))
continue
function_calls_to_run.append(_function_call)
if self.show_tool_calls:
if len(function_calls_to_run) == 1:
yield f"\n - Running: {function_calls_to_run[0].get_call_str()}\n\n"
elif len(function_calls_to_run) > 1:
yield "\nRunning:"
for _f in function_calls_to_run:
yield f"\n - {_f.get_call_str()}"
yield "\n\n"
function_call_results = self.run_function_calls(function_calls_to_run)
if len(function_call_results) > 0:
messages.extend(function_call_results)
# -*- Yield new response using results of tool calls
yield from self.response_stream(messages=messages)
logger.debug("---------- Groq Response End ----------")