File size: 3,727 Bytes
98caf15
1eaad00
98caf15
 
 
 
92238be
25dbca2
 
 
 
92238be
5974bb1
 
 
 
92238be
98caf15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1eaad00
 
 
 
 
 
 
 
 
 
92238be
66906e4
92238be
ed15883
25dbca2
1eaad00
 
 
 
 
 
 
98caf15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92238be
 
 
 
 
 
 
 
25dbca2
 
 
 
 
 
 
 
98caf15
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import json
import logging
from fastapi import APIRouter, Body
from typing import List, Dict
from pydantic import BaseModel

try:
    from tool_gpu_checker import gpu_checker_get_message
    from tool_bpy_doc import bpy_doc_get_documentation
    from tool_find_related import find_related
    from tool_wiki_search import wiki_search
except:
    from routers.tool_gpu_checker import gpu_checker_get_message
    from routers.tool_bpy_doc import bpy_doc_get_documentation
    from routers.tool_find_related import find_related
    from routers.tool_wiki_search import wiki_search


class ToolCallFunction(BaseModel):
    name: str
    arguments: str


class ToolCallInput(BaseModel):
    id: str
    type: str
    function: ToolCallFunction


router = APIRouter()


def process_tool_call(tool_call: ToolCallInput) -> Dict:
    output = {"tool_call_id": tool_call.id, "output": ""}
    function_name = tool_call.function.name

    try:
        function_args = json.loads(tool_call.function.arguments)
        if function_name == "get_bpy_api_info":
            output["output"] = bpy_doc_get_documentation(
                function_args.get("api", ""))
        elif function_name == "check_gpu":
            output["output"] = gpu_checker_get_message(
                function_args.get("gpu", ""))
        elif function_name == "find_related":
            output["output"] = find_related(
                function_args["repo"], function_args["number"])
        elif function_name == "wiki_search":
            output["output"] = wiki_search(**function_args)
    except json.JSONDecodeError as e:
        error_message = f"Malformed JSON encountered at position {e.pos}: {e.msg}\n {tool_call.function.arguments}"
        output["output"] = error_message

        # Logging the error for further investigation
        logging.error(f"JSONDecodeError in process_tool_call: {error_message}")

    return output


@router.post("/function_call", response_model=List[Dict])
def function_call(tool_calls: List[ToolCallInput] = Body(..., description="List of tool calls in the request body")):
    """
    Endpoint to process tool calls.
    Args:
        tool_calls (List[ToolCallInput]): List of tool calls.
    Returns:
        List[Dict]: List of tool outputs with tool_call_id and output.
    """
    tool_outputs = [process_tool_call(tool_input) for tool_input in tool_calls]
    return tool_outputs


if __name__ == "__main__":
    tool_calls_data = [
        {
            "id": "call_abc123",
            "type": "function",
            "function": {
                "name": "get_bpy_api_info",
                "arguments": "{\"api\":\"bpy.context.scene.world\"}"
            }
        },
        {
            "id": "call_abc456",
            "type": "function",
            "function": {
                "name": "check_gpu",
                "arguments": "{\"gpu\":\"Mesa Intel(R) Iris(R) Plus Graphics 640 (Kaby Lake GT3e) (KBL GT3) Intel 4.6 (Core Profile) Mesa 22.2.5\"}"
            }
        },
        {
            "id": "call_abc789",
            "type": "function",
            "function": {
                "name": "find_related",
                "arguments": "{\"repo\":\"blender\",\"number\":111434}"
            }
        },
        {
            "id": "call_abc101112",
            "type": "function",
            "function": {
                "name": "wiki_search",
                "arguments": "{\"query\":\"Set Snap Base\",\"groups\":[\"manual\"]}"
            }
        }
    ]

    tool_calls = [
        ToolCallInput(id=tc['id'], type=tc['type'],
                      function=ToolCallFunction(**tc['function']))
        for tc in tool_calls_data
    ]

    test = function_call(tool_calls)
    print(test)