File size: 3,262 Bytes
98caf15
1eaad00
98caf15
 
 
 
92238be
 
 
 
 
 
 
 
 
98caf15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1eaad00
 
 
 
 
 
 
 
 
 
92238be
 
 
1eaad00
 
 
 
 
 
 
98caf15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92238be
 
 
 
 
 
 
 
98caf15
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import json
import logging
from fastapi import APIRouter, Body
from typing import List, Dict
from pydantic import BaseModel

try:
    from .tool_gpu_checker import gpu_checker_get_message
    from .tool_bpy_doc import bpy_doc_get_documentation
    from .tool_find_related import find_relatedness
except:
    from tool_gpu_checker import gpu_checker_get_message
    from tool_bpy_doc import bpy_doc_get_documentation
    from tool_find_related import find_relatedness


class ToolCallFunction(BaseModel):
    name: str
    arguments: str


class ToolCallInput(BaseModel):
    id: str
    type: str
    function: ToolCallFunction


router = APIRouter()


def process_tool_call(tool_call: ToolCallInput) -> Dict:
    output = {"tool_call_id": tool_call.id, "output": ""}
    function_name = tool_call.function.name

    try:
        function_args = json.loads(tool_call.function.arguments)
        if function_name == "get_bpy_api_info":
            output["output"] = bpy_doc_get_documentation(
                function_args.get("api", ""))
        elif function_name == "check_gpu":
            output["output"] = gpu_checker_get_message(
                function_args.get("gpu", ""))
        elif function_name == "find_related":
            output["output"] = find_relatedness(
                function_args["repo"], function_args["number"])
    except json.JSONDecodeError as e:
        error_message = f"Malformed JSON encountered at position {e.pos}: {e.msg}\n {tool_call.function.arguments}"
        output["output"] = error_message

        # Logging the error for further investigation
        logging.error(f"JSONDecodeError in process_tool_call: {error_message}")

    return output


@router.post("/function_call", response_model=List[Dict])
def function_call(tool_calls: List[ToolCallInput] = Body(..., description="List of tool calls in the request body")):
    """
    Endpoint to process tool calls.
    Args:
        tool_calls (List[ToolCallInput]): List of tool calls.
    Returns:
        List[Dict]: List of tool outputs with tool_call_id and output.
    """
    tool_outputs = [process_tool_call(tool_input) for tool_input in tool_calls]
    return tool_outputs


if __name__ == "__main__":
    tool_calls_data = [
        {
            "id": "call_abc123",
            "type": "function",
            "function": {
                "name": "get_bpy_api_info",
                "arguments": "{\"api\":\"bpy.context.scene.world\"}"
            }
        },
        {
            "id": "call_abc456",
            "type": "function",
            "function": {
                "name": "check_gpu",
                "arguments": "{\"gpu\":\"Mesa Intel(R) Iris(R) Plus Graphics 640 (Kaby Lake GT3e) (KBL GT3) Intel 4.6 (Core Profile) Mesa 22.2.5\"}"
            }
        },
        {
            "id": "call_abc789",
            "type": "function",
            "function": {
                "name": "find_related",
                "arguments": "{\"repo\":\"blender\",\"number\":111434}"
            }
        }
    ]

    tool_calls = [
        ToolCallInput(id=tc['id'], type=tc['type'],
                      function=ToolCallFunction(**tc['function']))
        for tc in tool_calls_data
    ]

    test = function_call(tool_calls)
    print(test)