import httpx from typing import Optional, List, Iterator, Dict, Any, Union from phi.llm.base import LLM from phi.llm.message import Message from phi.tools.function import FunctionCall from phi.utils.log import logger from phi.utils.timer import Timer from phi.utils.tools import get_function_call_for_tool_call try: from groq import Groq as GroqClient from groq.types.chat.chat_completion import ChatCompletion, ChoiceMessage from groq.lib.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, ChoiceDeltaToolCall except ImportError: logger.error("`groq` not installed") raise class Groq(LLM): name: str = "Groq" model: str = "mixtral-8x7b-32768" # -*- Request parameters frequency_penalty: Optional[float] = None logit_bias: Optional[Any] = None logprobs: Optional[bool] = None max_tokens: Optional[int] = None presence_penalty: Optional[float] = None response_format: Optional[Dict[str, Any]] = None seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = None temperature: Optional[float] = None top_logprobs: Optional[int] = None top_p: Optional[float] = None user: Optional[str] = None extra_headers: Optional[Any] = None extra_query: Optional[Any] = None request_params: Optional[Dict[str, Any]] = None # -*- Client parameters api_key: Optional[str] = None base_url: Optional[Union[str, httpx.URL]] = None timeout: Optional[int] = None max_retries: Optional[int] = None default_headers: Optional[Any] = None default_query: Optional[Any] = None client_params: Optional[Dict[str, Any]] = None # -*- Provide the Groq manually groq_client: Optional[GroqClient] = None @property def client(self) -> GroqClient: if self.groq_client: return self.groq_client _client_params: Dict[str, Any] = {} if self.api_key: _client_params["api_key"] = self.api_key if self.base_url: _client_params["base_url"] = self.base_url if self.timeout: _client_params["timeout"] = self.timeout if self.max_retries: _client_params["max_retries"] = self.max_retries if self.default_headers: _client_params["default_headers"] = self.default_headers if self.default_query: _client_params["default_query"] = self.default_query if self.client_params: _client_params.update(self.client_params) return GroqClient(**_client_params) @property def api_kwargs(self) -> Dict[str, Any]: _request_params: Dict[str, Any] = {} if self.frequency_penalty: _request_params["frequency_penalty"] = self.frequency_penalty if self.logit_bias: _request_params["logit_bias"] = self.logit_bias if self.logprobs: _request_params["logprobs"] = self.logprobs if self.max_tokens: _request_params["max_tokens"] = self.max_tokens if self.presence_penalty: _request_params["presence_penalty"] = self.presence_penalty if self.response_format: _request_params["response_format"] = self.response_format if self.seed: _request_params["seed"] = self.seed if self.stop: _request_params["stop"] = self.stop if self.temperature: _request_params["temperature"] = self.temperature if self.top_logprobs: _request_params["top_logprobs"] = self.top_logprobs if self.top_p: _request_params["top_p"] = self.top_p if self.user: _request_params["user"] = self.user if self.extra_headers: _request_params["extra_headers"] = self.extra_headers if self.extra_query: _request_params["extra_query"] = self.extra_query if self.tools: _request_params["tools"] = self.get_tools_for_api() if self.tool_choice is None: _request_params["tool_choice"] = "auto" else: _request_params["tool_choice"] = self.tool_choice if self.request_params: _request_params.update(self.request_params) return _request_params def to_dict(self) -> Dict[str, Any]: _dict = super().to_dict() if self.frequency_penalty: _dict["frequency_penalty"] = self.frequency_penalty if self.logit_bias: _dict["logit_bias"] = self.logit_bias if self.logprobs: _dict["logprobs"] = self.logprobs if self.max_tokens: _dict["max_tokens"] = self.max_tokens if self.presence_penalty: _dict["presence_penalty"] = self.presence_penalty if self.response_format: _dict["response_format"] = self.response_format if self.seed: _dict["seed"] = self.seed if self.stop: _dict["stop"] = self.stop if self.temperature: _dict["temperature"] = self.temperature if self.top_logprobs: _dict["top_logprobs"] = self.top_logprobs if self.top_p: _dict["top_p"] = self.top_p if self.user: _dict["user"] = self.user if self.extra_headers: _dict["extra_headers"] = self.extra_headers if self.extra_query: _dict["extra_query"] = self.extra_query if self.tools: _dict["tools"] = self.get_tools_for_api() if self.tool_choice is None: _dict["tool_choice"] = "auto" else: _dict["tool_choice"] = self.tool_choice return _dict def invoke(self, messages: List[Message]) -> ChatCompletion: return self.client.chat.completions.create( model=self.model, messages=[m.to_dict() for m in messages], # type: ignore **self.api_kwargs, ) def invoke_stream(self, messages: List[Message]) -> Iterator[ChatCompletionChunk]: yield from self.client.chat.completions.create( model=self.model, messages=[m.to_dict() for m in messages], # type: ignore stream=True, **self.api_kwargs, ) def response(self, messages: List[Message]) -> str: logger.debug("---------- Groq Response Start ----------") # -*- Log messages for debugging for m in messages: m.log() response_timer = Timer() response_timer.start() response: ChatCompletion = self.invoke(messages=messages) response_timer.stop() logger.debug(f"Time to generate response: {response_timer.elapsed:.4f}s") # logger.debug(f"Groq response type: {type(response)}") # logger.debug(f"Groq response: {response}") # -*- Parse response response_message: ChoiceMessage = response.choices[0].message # -*- Create assistant message assistant_message = Message( role=response_message.role or "assistant", content=response_message.content, ) if response_message.tool_calls is not None and len(response_message.tool_calls) > 0: assistant_message.tool_calls = [t.model_dump() for t in response_message.tool_calls] # -*- Update usage metrics # Add response time to metrics assistant_message.metrics["time"] = response_timer.elapsed if "response_times" not in self.metrics: self.metrics["response_times"] = [] self.metrics["response_times"].append(response_timer.elapsed) # Add token usage to metrics if response.usage is not None: self.metrics.update(response.usage.model_dump()) # -*- Add assistant message to messages messages.append(assistant_message) assistant_message.log() # -*- Parse and run tool calls if assistant_message.tool_calls is not None and len(assistant_message.tool_calls) > 0: final_response = "" function_calls_to_run: List[FunctionCall] = [] for tool_call in assistant_message.tool_calls: _tool_call_id = tool_call.get("id") _function_call = get_function_call_for_tool_call(tool_call, self.functions) if _function_call is None: messages.append( Message(role="tool", tool_call_id=_tool_call_id, content="Could not find function to call.") ) continue if _function_call.error is not None: messages.append(Message(role="tool", tool_call_id=_tool_call_id, content=_function_call.error)) continue function_calls_to_run.append(_function_call) if self.show_tool_calls: if len(function_calls_to_run) == 1: final_response += f"\n - Running: {function_calls_to_run[0].get_call_str()}\n\n" elif len(function_calls_to_run) > 1: final_response += "\nRunning:" for _f in function_calls_to_run: final_response += f"\n - {_f.get_call_str()}" final_response += "\n\n" function_call_results = self.run_function_calls(function_calls_to_run) if len(function_call_results) > 0: messages.extend(function_call_results) # -*- Get new response using result of tool call final_response += self.response(messages=messages) return final_response logger.debug("---------- Groq Response End ----------") # -*- Return content if no function calls are present if assistant_message.content is not None: return assistant_message.get_content_string() return "Something went wrong, please try again." def response_stream(self, messages: List[Message]) -> Iterator[str]: logger.debug("---------- Groq Response Start ----------") # -*- Log messages for debugging for m in messages: m.log() assistant_message_role = None assistant_message_content = "" assistant_message_tool_calls: Optional[List[ChoiceDeltaToolCall]] = None response_timer = Timer() response_timer.start() for response in self.invoke_stream(messages=messages): # logger.debug(f"Groq response type: {type(response)}") # logger.debug(f"Groq response: {response}") # -*- Parse response response_delta: ChoiceDelta = response.choices[0].delta if assistant_message_role is None and response_delta.role is not None: assistant_message_role = response_delta.role response_content: Optional[str] = response_delta.content response_tool_calls: Optional[List[ChoiceDeltaToolCall]] = response_delta.tool_calls # -*- Return content if present, otherwise get tool call if response_content is not None: assistant_message_content += response_content yield response_content # -*- Parse tool calls if response_tool_calls is not None and len(response_tool_calls) > 0: if assistant_message_tool_calls is None: assistant_message_tool_calls = [] assistant_message_tool_calls.extend(response_tool_calls) response_timer.stop() logger.debug(f"Time to generate response: {response_timer.elapsed:.4f}s") # -*- Create assistant message assistant_message = Message(role=(assistant_message_role or "assistant")) # -*- Add content to assistant message if assistant_message_content != "": assistant_message.content = assistant_message_content # -*- Add tool calls to assistant message if assistant_message_tool_calls is not None: assistant_message.tool_calls = [t.model_dump() for t in assistant_message_tool_calls] # -*- Update usage metrics # Add response time to metrics assistant_message.metrics["time"] = response_timer.elapsed if "response_times" not in self.metrics: self.metrics["response_times"] = [] self.metrics["response_times"].append(response_timer.elapsed) # -*- Add assistant message to messages messages.append(assistant_message) assistant_message.log() # -*- Parse and run tool calls if assistant_message.tool_calls is not None and len(assistant_message.tool_calls) > 0: function_calls_to_run: List[FunctionCall] = [] for tool_call in assistant_message.tool_calls: _tool_call_id = tool_call.get("id") _function_call = get_function_call_for_tool_call(tool_call, self.functions) if _function_call is None: messages.append( Message(role="tool", tool_call_id=_tool_call_id, content="Could not find function to call.") ) continue if _function_call.error is not None: messages.append(Message(role="tool", tool_call_id=_tool_call_id, content=_function_call.error)) continue function_calls_to_run.append(_function_call) if self.show_tool_calls: if len(function_calls_to_run) == 1: yield f"\n - Running: {function_calls_to_run[0].get_call_str()}\n\n" elif len(function_calls_to_run) > 1: yield "\nRunning:" for _f in function_calls_to_run: yield f"\n - {_f.get_call_str()}" yield "\n\n" function_call_results = self.run_function_calls(function_calls_to_run) if len(function_call_results) > 0: messages.extend(function_call_results) # -*- Yield new response using results of tool calls yield from self.response_stream(messages=messages) logger.debug("---------- Groq Response End ----------")