Spaces:
Runtime error
Runtime error
import fastapi | |
import uvicorn | |
from .registry import build_tool, list_tools | |
from .retriever import Retriever | |
from typing import List | |
from pydantic import BaseModel | |
class RetrieveRequest(BaseModel): | |
query: str | |
topk: int = 3 | |
def _bind_tool_server(t : "ToolServer"): | |
""" Add property API to ToolServer. | |
t.api is a FastAPI object | |
""" | |
def health(): | |
return { | |
"status": "ok", | |
} | |
def get_tools_list(): | |
return { | |
"tools": t.list_tools(), | |
} | |
def get_loaded_tools(): | |
return { | |
"tools": list(t.loaded_tools), | |
} | |
def get_api_info(): | |
return { | |
"schema_version": "v1", | |
"name_for_human": "BMTools", | |
"name_for_model": "BMTools", | |
"description_for_human": "tools to big models", | |
"description_for_model": "tools to big models", | |
"auth": { | |
"type": "none", | |
}, | |
"api": { | |
"type": "openapi", | |
"url": "/openapi.json", | |
"is_user_authenticated": False, | |
}, | |
"logo_url": None, | |
"contact_email": "", | |
"legal_info_url": "", | |
} | |
def retrieve(request: RetrieveRequest): | |
tool_list = t.retrieve(request.query, request.topk) | |
return { | |
"tools": tool_list | |
} | |
class ToolServer: | |
""" This class host your own API backend. | |
""" | |
def __init__(self) -> None: | |
# define the root API server | |
self.api = fastapi.FastAPI( | |
title="BMTools", | |
description="Tools for bigmodels", | |
) | |
self.loaded_tools = dict() | |
self.retriever = Retriever() | |
_bind_tool_server(self) | |
def load_tool(self, name : str, config = {}): | |
if self.is_loaded(name): | |
raise ValueError(f"Tool {name} is already loaded") | |
try: | |
tool = build_tool(name, config) | |
except BaseException as e: | |
print(f"Cannot load tool {name}: {repr(e)}") | |
return | |
# tool = build_tool(name, config) | |
self.loaded_tools[name] = tool.api_info | |
self.retriever.add_tool(name, tool.api_info) | |
# mount sub API server to the root API server, thus can mount all urls of sub API server to /tools/{name} route | |
self.api.mount(f"/tools/{name}", tool, name) | |
return | |
def is_loaded(self, name : str): | |
return name in self.loaded_tools | |
def serve(self, host : str = "0.0.0.0", port : int = 8079): | |
uvicorn.run(self.api, host=host, port=port) | |
def list_tools(self) -> List[str]: | |
return list_tools() | |
def retrieve(self, query: str, topk: int = 3) -> List[str]: | |
return self.retriever.query(query, topk) |