ricklamers's picture
fix: get rid of custom cache
99e6996
raw
history blame
6.18 kB
import gradio as gr
import json
import os
import numexpr
from groq import Groq
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
MODEL = "llama3-groq-8b-8192-tool-use-preview"
client = Groq(api_key=os.environ["GROQ_API_KEY"])
def evaluate_math_expression(expression: str):
return json.dumps(numexpr.evaluate(expression).tolist())
calculator_tool: ChatCompletionToolParam = {
"type": "function",
"function": {
"name": "evaluate_math_expression",
"description": "Calculator tool: use this for evaluating numeric expressions with Python. Ensure the expression is valid Python syntax (e.g., use '**' for exponentiation, not '^').",
"parameters": {
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "The mathematical expression to evaluate. Must be valid Python syntax.",
},
},
"required": ["expression"],
},
},
}
tools = [calculator_tool]
def call_function(tool_call, available_functions):
function_name = tool_call.function.name
if function_name not in available_functions:
return {
"tool_call_id": tool_call.id,
"role": "tool",
"content": f"Function {function_name} does not exist.",
}
function_to_call = available_functions[function_name]
function_args = json.loads(tool_call.function.arguments)
function_response = function_to_call(**function_args)
return {
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": json.dumps(function_response),
}
def get_model_response(messages, inner_messages, message, system_message):
messages_for_model = []
for msg in messages:
native_messages = msg.get("metadata", {}).get("native_messages", [msg])
if isinstance(native_messages, list):
messages_for_model.extend(native_messages)
else:
messages_for_model.append(native_messages)
messages_for_model.insert(
0,
{
"role": "system",
"content": system_message,
},
)
messages_for_model.append(
{
"role": "user",
"content": message,
}
)
messages_for_model.extend(inner_messages)
try:
return client.chat.completions.create(
model=MODEL,
messages=messages_for_model,
tools=tools,
temperature=0.5,
top_p=0.65,
max_tokens=4096,
)
except Exception as e:
print(f"An error occurred while getting model response: {str(e)}")
print(messages_for_model)
return None
def respond(message, history, system_message):
inner_history = []
available_functions = {
"evaluate_math_expression": evaluate_math_expression,
}
assistant_content = ""
assistant_native_message_list = []
while True:
response_message = (
get_model_response(history, inner_history, message, system_message)
.choices[0]
.message
)
if not response_message.tool_calls and response_message.content is not None:
break
if response_message.tool_calls is not None:
assistant_native_message_list.append(response_message)
inner_history.append(response_message)
assistant_content += (
"```json\n"
+ json.dumps(
[
tool_call.model_dump()
for tool_call in response_message.tool_calls
],
indent=2,
)
+ "\n```\n"
)
assistant_message = {
"role": "assistant",
"content": assistant_content,
"metadata": {"native_messages": assistant_native_message_list},
}
yield assistant_message
for tool_call in response_message.tool_calls:
function_response = call_function(tool_call, available_functions)
assistant_content += (
"```json\n"
+ json.dumps(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
"response": json.loads(function_response["content"]),
},
indent=2,
)
+ "\n```\n"
)
native_tool_message = {
"tool_call_id": tool_call.id,
"role": "tool",
"content": function_response["content"],
}
assistant_native_message_list.append(
native_tool_message
)
tool_message = {
"role": "assistant",
"content": assistant_content,
"metadata": {"native_messages": assistant_native_message_list},
}
yield tool_message
inner_history.append(native_tool_message)
assistant_content += response_message.content
assistant_native_message_list.append(response_message)
final_message = {
"role": "assistant",
"content": assistant_content,
"metadata": {"native_messages": assistant_native_message_list},
}
yield final_message
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are a friendly Chatbot with access to a calculator. Don't mention that we are using functions defined in Python.",
label="System message",
),
],
type="messages",
title="Groq Tool Use Chat",
description="This chatbot uses the `llama3-groq-8b-8192-tool-use-preview` LLM with tool use capabilities, including a calculator function.",
)
if __name__ == "__main__":
demo.launch()