Spaces:
Running
Running
File size: 3,727 Bytes
98caf15 1eaad00 98caf15 92238be 25dbca2 92238be 5974bb1 92238be 98caf15 1eaad00 92238be 66906e4 92238be ed15883 25dbca2 1eaad00 98caf15 92238be 25dbca2 98caf15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import json
import logging
from fastapi import APIRouter, Body
from typing import List, Dict
from pydantic import BaseModel
try:
from tool_gpu_checker import gpu_checker_get_message
from tool_bpy_doc import bpy_doc_get_documentation
from tool_find_related import find_related
from tool_wiki_search import wiki_search
except:
from routers.tool_gpu_checker import gpu_checker_get_message
from routers.tool_bpy_doc import bpy_doc_get_documentation
from routers.tool_find_related import find_related
from routers.tool_wiki_search import wiki_search
class ToolCallFunction(BaseModel):
name: str
arguments: str
class ToolCallInput(BaseModel):
id: str
type: str
function: ToolCallFunction
router = APIRouter()
def process_tool_call(tool_call: ToolCallInput) -> Dict:
output = {"tool_call_id": tool_call.id, "output": ""}
function_name = tool_call.function.name
try:
function_args = json.loads(tool_call.function.arguments)
if function_name == "get_bpy_api_info":
output["output"] = bpy_doc_get_documentation(
function_args.get("api", ""))
elif function_name == "check_gpu":
output["output"] = gpu_checker_get_message(
function_args.get("gpu", ""))
elif function_name == "find_related":
output["output"] = find_related(
function_args["repo"], function_args["number"])
elif function_name == "wiki_search":
output["output"] = wiki_search(**function_args)
except json.JSONDecodeError as e:
error_message = f"Malformed JSON encountered at position {e.pos}: {e.msg}\n {tool_call.function.arguments}"
output["output"] = error_message
# Logging the error for further investigation
logging.error(f"JSONDecodeError in process_tool_call: {error_message}")
return output
@router.post("/function_call", response_model=List[Dict])
def function_call(tool_calls: List[ToolCallInput] = Body(..., description="List of tool calls in the request body")):
"""
Endpoint to process tool calls.
Args:
tool_calls (List[ToolCallInput]): List of tool calls.
Returns:
List[Dict]: List of tool outputs with tool_call_id and output.
"""
tool_outputs = [process_tool_call(tool_input) for tool_input in tool_calls]
return tool_outputs
if __name__ == "__main__":
tool_calls_data = [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "get_bpy_api_info",
"arguments": "{\"api\":\"bpy.context.scene.world\"}"
}
},
{
"id": "call_abc456",
"type": "function",
"function": {
"name": "check_gpu",
"arguments": "{\"gpu\":\"Mesa Intel(R) Iris(R) Plus Graphics 640 (Kaby Lake GT3e) (KBL GT3) Intel 4.6 (Core Profile) Mesa 22.2.5\"}"
}
},
{
"id": "call_abc789",
"type": "function",
"function": {
"name": "find_related",
"arguments": "{\"repo\":\"blender\",\"number\":111434}"
}
},
{
"id": "call_abc101112",
"type": "function",
"function": {
"name": "wiki_search",
"arguments": "{\"query\":\"Set Snap Base\",\"groups\":[\"manual\"]}"
}
}
]
tool_calls = [
ToolCallInput(id=tc['id'], type=tc['type'],
function=ToolCallFunction(**tc['function']))
for tc in tool_calls_data
]
test = function_call(tool_calls)
print(test)
|