Spaces:
Runtime error
Runtime error
import httpx | |
from typing import Optional, List, Iterator, Dict, Any, Union | |
from phi.llm.base import LLM | |
from phi.llm.message import Message | |
from phi.tools.function import FunctionCall | |
from phi.utils.log import logger | |
from phi.utils.timer import Timer | |
from phi.utils.tools import get_function_call_for_tool_call | |
try: | |
from groq import Groq as GroqClient | |
from groq.types.chat.chat_completion import ChatCompletion, ChoiceMessage | |
from groq.lib.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, ChoiceDeltaToolCall | |
except ImportError: | |
logger.error("`groq` not installed") | |
raise | |
class Groq(LLM): | |
name: str = "Groq" | |
model: str = "mixtral-8x7b-32768" | |
# -*- Request parameters | |
frequency_penalty: Optional[float] = None | |
logit_bias: Optional[Any] = None | |
logprobs: Optional[bool] = None | |
max_tokens: Optional[int] = None | |
presence_penalty: Optional[float] = None | |
response_format: Optional[Dict[str, Any]] = None | |
seed: Optional[int] = None | |
stop: Optional[Union[str, List[str]]] = None | |
temperature: Optional[float] = None | |
top_logprobs: Optional[int] = None | |
top_p: Optional[float] = None | |
user: Optional[str] = None | |
extra_headers: Optional[Any] = None | |
extra_query: Optional[Any] = None | |
request_params: Optional[Dict[str, Any]] = None | |
# -*- Client parameters | |
api_key: Optional[str] = None | |
base_url: Optional[Union[str, httpx.URL]] = None | |
timeout: Optional[int] = None | |
max_retries: Optional[int] = None | |
default_headers: Optional[Any] = None | |
default_query: Optional[Any] = None | |
client_params: Optional[Dict[str, Any]] = None | |
# -*- Provide the Groq manually | |
groq_client: Optional[GroqClient] = None | |
def client(self) -> GroqClient: | |
if self.groq_client: | |
return self.groq_client | |
_client_params: Dict[str, Any] = {} | |
if self.api_key: | |
_client_params["api_key"] = self.api_key | |
if self.base_url: | |
_client_params["base_url"] = self.base_url | |
if self.timeout: | |
_client_params["timeout"] = self.timeout | |
if self.max_retries: | |
_client_params["max_retries"] = self.max_retries | |
if self.default_headers: | |
_client_params["default_headers"] = self.default_headers | |
if self.default_query: | |
_client_params["default_query"] = self.default_query | |
if self.client_params: | |
_client_params.update(self.client_params) | |
return GroqClient(**_client_params) | |
def api_kwargs(self) -> Dict[str, Any]: | |
_request_params: Dict[str, Any] = {} | |
if self.frequency_penalty: | |
_request_params["frequency_penalty"] = self.frequency_penalty | |
if self.logit_bias: | |
_request_params["logit_bias"] = self.logit_bias | |
if self.logprobs: | |
_request_params["logprobs"] = self.logprobs | |
if self.max_tokens: | |
_request_params["max_tokens"] = self.max_tokens | |
if self.presence_penalty: | |
_request_params["presence_penalty"] = self.presence_penalty | |
if self.response_format: | |
_request_params["response_format"] = self.response_format | |
if self.seed: | |
_request_params["seed"] = self.seed | |
if self.stop: | |
_request_params["stop"] = self.stop | |
if self.temperature: | |
_request_params["temperature"] = self.temperature | |
if self.top_logprobs: | |
_request_params["top_logprobs"] = self.top_logprobs | |
if self.top_p: | |
_request_params["top_p"] = self.top_p | |
if self.user: | |
_request_params["user"] = self.user | |
if self.extra_headers: | |
_request_params["extra_headers"] = self.extra_headers | |
if self.extra_query: | |
_request_params["extra_query"] = self.extra_query | |
if self.tools: | |
_request_params["tools"] = self.get_tools_for_api() | |
if self.tool_choice is None: | |
_request_params["tool_choice"] = "auto" | |
else: | |
_request_params["tool_choice"] = self.tool_choice | |
if self.request_params: | |
_request_params.update(self.request_params) | |
return _request_params | |
def to_dict(self) -> Dict[str, Any]: | |
_dict = super().to_dict() | |
if self.frequency_penalty: | |
_dict["frequency_penalty"] = self.frequency_penalty | |
if self.logit_bias: | |
_dict["logit_bias"] = self.logit_bias | |
if self.logprobs: | |
_dict["logprobs"] = self.logprobs | |
if self.max_tokens: | |
_dict["max_tokens"] = self.max_tokens | |
if self.presence_penalty: | |
_dict["presence_penalty"] = self.presence_penalty | |
if self.response_format: | |
_dict["response_format"] = self.response_format | |
if self.seed: | |
_dict["seed"] = self.seed | |
if self.stop: | |
_dict["stop"] = self.stop | |
if self.temperature: | |
_dict["temperature"] = self.temperature | |
if self.top_logprobs: | |
_dict["top_logprobs"] = self.top_logprobs | |
if self.top_p: | |
_dict["top_p"] = self.top_p | |
if self.user: | |
_dict["user"] = self.user | |
if self.extra_headers: | |
_dict["extra_headers"] = self.extra_headers | |
if self.extra_query: | |
_dict["extra_query"] = self.extra_query | |
if self.tools: | |
_dict["tools"] = self.get_tools_for_api() | |
if self.tool_choice is None: | |
_dict["tool_choice"] = "auto" | |
else: | |
_dict["tool_choice"] = self.tool_choice | |
return _dict | |
def invoke(self, messages: List[Message]) -> ChatCompletion: | |
return self.client.chat.completions.create( | |
model=self.model, | |
messages=[m.to_dict() for m in messages], # type: ignore | |
**self.api_kwargs, | |
) | |
def invoke_stream(self, messages: List[Message]) -> Iterator[ChatCompletionChunk]: | |
yield from self.client.chat.completions.create( | |
model=self.model, | |
messages=[m.to_dict() for m in messages], # type: ignore | |
stream=True, | |
**self.api_kwargs, | |
) | |
def response(self, messages: List[Message]) -> str: | |
logger.debug("---------- Groq Response Start ----------") | |
# -*- Log messages for debugging | |
for m in messages: | |
m.log() | |
response_timer = Timer() | |
response_timer.start() | |
response: ChatCompletion = self.invoke(messages=messages) | |
response_timer.stop() | |
logger.debug(f"Time to generate response: {response_timer.elapsed:.4f}s") | |
# logger.debug(f"Groq response type: {type(response)}") | |
# logger.debug(f"Groq response: {response}") | |
# -*- Parse response | |
response_message: ChoiceMessage = response.choices[0].message | |
# -*- Create assistant message | |
assistant_message = Message( | |
role=response_message.role or "assistant", | |
content=response_message.content, | |
) | |
if response_message.tool_calls is not None and len(response_message.tool_calls) > 0: | |
assistant_message.tool_calls = [t.model_dump() for t in response_message.tool_calls] | |
# -*- Update usage metrics | |
# Add response time to metrics | |
assistant_message.metrics["time"] = response_timer.elapsed | |
if "response_times" not in self.metrics: | |
self.metrics["response_times"] = [] | |
self.metrics["response_times"].append(response_timer.elapsed) | |
# Add token usage to metrics | |
if response.usage is not None: | |
self.metrics.update(response.usage.model_dump()) | |
# -*- Add assistant message to messages | |
messages.append(assistant_message) | |
assistant_message.log() | |
# -*- Parse and run tool calls | |
if assistant_message.tool_calls is not None and len(assistant_message.tool_calls) > 0: | |
final_response = "" | |
function_calls_to_run: List[FunctionCall] = [] | |
for tool_call in assistant_message.tool_calls: | |
_tool_call_id = tool_call.get("id") | |
_function_call = get_function_call_for_tool_call(tool_call, self.functions) | |
if _function_call is None: | |
messages.append( | |
Message(role="tool", tool_call_id=_tool_call_id, content="Could not find function to call.") | |
) | |
continue | |
if _function_call.error is not None: | |
messages.append(Message(role="tool", tool_call_id=_tool_call_id, content=_function_call.error)) | |
continue | |
function_calls_to_run.append(_function_call) | |
if self.show_tool_calls: | |
if len(function_calls_to_run) == 1: | |
final_response += f"\n - Running: {function_calls_to_run[0].get_call_str()}\n\n" | |
elif len(function_calls_to_run) > 1: | |
final_response += "\nRunning:" | |
for _f in function_calls_to_run: | |
final_response += f"\n - {_f.get_call_str()}" | |
final_response += "\n\n" | |
function_call_results = self.run_function_calls(function_calls_to_run) | |
if len(function_call_results) > 0: | |
messages.extend(function_call_results) | |
# -*- Get new response using result of tool call | |
final_response += self.response(messages=messages) | |
return final_response | |
logger.debug("---------- Groq Response End ----------") | |
# -*- Return content if no function calls are present | |
if assistant_message.content is not None: | |
return assistant_message.get_content_string() | |
return "Something went wrong, please try again." | |
def response_stream(self, messages: List[Message]) -> Iterator[str]: | |
logger.debug("---------- Groq Response Start ----------") | |
# -*- Log messages for debugging | |
for m in messages: | |
m.log() | |
assistant_message_role = None | |
assistant_message_content = "" | |
assistant_message_tool_calls: Optional[List[ChoiceDeltaToolCall]] = None | |
response_timer = Timer() | |
response_timer.start() | |
for response in self.invoke_stream(messages=messages): | |
# logger.debug(f"Groq response type: {type(response)}") | |
# logger.debug(f"Groq response: {response}") | |
# -*- Parse response | |
response_delta: ChoiceDelta = response.choices[0].delta | |
if assistant_message_role is None and response_delta.role is not None: | |
assistant_message_role = response_delta.role | |
response_content: Optional[str] = response_delta.content | |
response_tool_calls: Optional[List[ChoiceDeltaToolCall]] = response_delta.tool_calls | |
# -*- Return content if present, otherwise get tool call | |
if response_content is not None: | |
assistant_message_content += response_content | |
yield response_content | |
# -*- Parse tool calls | |
if response_tool_calls is not None and len(response_tool_calls) > 0: | |
if assistant_message_tool_calls is None: | |
assistant_message_tool_calls = [] | |
assistant_message_tool_calls.extend(response_tool_calls) | |
response_timer.stop() | |
logger.debug(f"Time to generate response: {response_timer.elapsed:.4f}s") | |
# -*- Create assistant message | |
assistant_message = Message(role=(assistant_message_role or "assistant")) | |
# -*- Add content to assistant message | |
if assistant_message_content != "": | |
assistant_message.content = assistant_message_content | |
# -*- Add tool calls to assistant message | |
if assistant_message_tool_calls is not None: | |
assistant_message.tool_calls = [t.model_dump() for t in assistant_message_tool_calls] | |
# -*- Update usage metrics | |
# Add response time to metrics | |
assistant_message.metrics["time"] = response_timer.elapsed | |
if "response_times" not in self.metrics: | |
self.metrics["response_times"] = [] | |
self.metrics["response_times"].append(response_timer.elapsed) | |
# -*- Add assistant message to messages | |
messages.append(assistant_message) | |
assistant_message.log() | |
# -*- Parse and run tool calls | |
if assistant_message.tool_calls is not None and len(assistant_message.tool_calls) > 0: | |
function_calls_to_run: List[FunctionCall] = [] | |
for tool_call in assistant_message.tool_calls: | |
_tool_call_id = tool_call.get("id") | |
_function_call = get_function_call_for_tool_call(tool_call, self.functions) | |
if _function_call is None: | |
messages.append( | |
Message(role="tool", tool_call_id=_tool_call_id, content="Could not find function to call.") | |
) | |
continue | |
if _function_call.error is not None: | |
messages.append(Message(role="tool", tool_call_id=_tool_call_id, content=_function_call.error)) | |
continue | |
function_calls_to_run.append(_function_call) | |
if self.show_tool_calls: | |
if len(function_calls_to_run) == 1: | |
yield f"\n - Running: {function_calls_to_run[0].get_call_str()}\n\n" | |
elif len(function_calls_to_run) > 1: | |
yield "\nRunning:" | |
for _f in function_calls_to_run: | |
yield f"\n - {_f.get_call_str()}" | |
yield "\n\n" | |
function_call_results = self.run_function_calls(function_calls_to_run) | |
if len(function_call_results) > 0: | |
messages.extend(function_call_results) | |
# -*- Yield new response using results of tool calls | |
yield from self.response_stream(messages=messages) | |
logger.debug("---------- Groq Response End ----------") | |