github-actions[bot]
GitHub deploy: ebecf4caf2619658e728c20434b275a3a0ba2270
03f850e
raw
history blame
17.3 kB
import logging
from pathlib import Path
from typing import Optional
import time
import re
import aiohttp
from pydantic import BaseModel, HttpUrl
from open_webui.models.tools import (
ToolForm,
ToolModel,
ToolResponse,
ToolUserResponse,
Tools,
)
from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.tools import get_tool_specs
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
from open_webui.env import SRC_LOG_LEVELS
from open_webui.utils.tools import get_tool_servers_data
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter()
############################
# GetTools
############################
@router.get("/", response_model=list[ToolUserResponse])
async def get_tools(request: Request, user=Depends(get_verified_user)):
if not request.app.state.TOOL_SERVERS:
# If the tool servers are not set, we need to set them
# This is done only once when the server starts
# This is done to avoid loading the tool servers every time
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
request.app.state.config.TOOL_SERVER_CONNECTIONS
)
tools = Tools.get_tools()
for server in request.app.state.TOOL_SERVERS:
tools.append(
ToolUserResponse(
**{
"id": f"server:{server['idx']}",
"user_id": f"server:{server['idx']}",
"name": server.get("openapi", {})
.get("info", {})
.get("title", "Tool Server"),
"meta": {
"description": server.get("openapi", {})
.get("info", {})
.get("description", ""),
},
"access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
server["idx"]
]
.get("config", {})
.get("access_control", None),
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
)
if user.role != "admin":
tools = [
tool
for tool in tools
if tool.user_id == user.id
or has_access(user.id, "read", tool.access_control)
]
return tools
############################
# GetToolList
############################
@router.get("/list", response_model=list[ToolUserResponse])
async def get_tool_list(user=Depends(get_verified_user)):
if user.role == "admin":
tools = Tools.get_tools()
else:
tools = Tools.get_tools_by_user_id(user.id, "write")
return tools
############################
# LoadFunctionFromLink
############################
class LoadUrlForm(BaseModel):
url: HttpUrl
def github_url_to_raw_url(url: str) -> str:
# Handle 'tree' (folder) URLs (add main.py at the end)
m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
if m1:
org, repo, branch, path = m1.groups()
return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
# Handle 'blob' (file) URLs
m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
if m2:
org, repo, branch, path = m2.groups()
return (
f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
)
# No match; return as-is
return url
@router.post("/load/url", response_model=Optional[dict])
async def load_tool_from_url(
request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
):
# NOTE: This is NOT a SSRF vulnerability:
# This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
# and does NOT accept untrusted user input. Access is enforced by authentication.
url = str(form_data.url)
if not url:
raise HTTPException(status_code=400, detail="Please enter a valid URL")
url = github_url_to_raw_url(url)
url_parts = url.rstrip("/").split("/")
file_name = url_parts[-1]
tool_name = (
file_name[:-3]
if (
file_name.endswith(".py")
and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
)
else url_parts[-2] if len(url_parts) > 1 else "function"
)
try:
async with aiohttp.ClientSession() as session:
async with session.get(
url, headers={"Content-Type": "application/json"}
) as resp:
if resp.status != 200:
raise HTTPException(
status_code=resp.status, detail="Failed to fetch the tool"
)
data = await resp.text()
if not data:
raise HTTPException(
status_code=400, detail="No data received from the URL"
)
return {
"name": tool_name,
"content": data,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
############################
# ExportTools
############################
@router.get("/export", response_model=list[ToolModel])
async def export_tools(user=Depends(get_admin_user)):
tools = Tools.get_tools()
return tools
############################
# CreateNewTools
############################
@router.post("/create", response_model=Optional[ToolResponse])
async def create_new_tools(
request: Request,
form_data: ToolForm,
user=Depends(get_verified_user),
):
if user.role != "admin" and not has_permission(
user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
if not form_data.id.isidentifier():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only alphanumeric characters and underscores are allowed in the id",
)
form_data.id = form_data.id.lower()
tools = Tools.get_tool_by_id(form_data.id)
if tools is None:
try:
form_data.content = replace_imports(form_data.content)
tool_module, frontmatter = load_tool_module_by_id(
form_data.id, content=form_data.content
)
form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS
TOOLS[form_data.id] = tool_module
specs = get_tool_specs(TOOLS[form_data.id])
tools = Tools.insert_new_tool(user.id, form_data, specs)
tool_cache_dir = CACHE_DIR / "tools" / form_data.id
tool_cache_dir.mkdir(parents=True, exist_ok=True)
if tools:
return tools
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
)
except Exception as e:
log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ID_TAKEN,
)
############################
# GetToolsById
############################
@router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
tools = Tools.get_tool_by_id(id)
if tools:
if (
user.role == "admin"
or tools.user_id == user.id
or has_access(user.id, "read", tools.access_control)
):
return tools
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateToolsById
############################
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_tools_by_id(
request: Request,
id: str,
form_data: ToolForm,
user=Depends(get_verified_user),
):
tools = Tools.get_tool_by_id(id)
if not tools:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
# Is the user the original creator, in a group with write access, or an admin
if (
tools.user_id != user.id
and not has_access(user.id, "write", tools.access_control)
and user.role != "admin"
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
try:
form_data.content = replace_imports(form_data.content)
tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS
TOOLS[id] = tool_module
specs = get_tool_specs(TOOLS[id])
updated = {
**form_data.model_dump(exclude={"id"}),
"specs": specs,
}
log.debug(updated)
tools = Tools.update_tool_by_id(id, updated)
if tools:
return tools
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
############################
# DeleteToolsById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_tools_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
tools = Tools.get_tool_by_id(id)
if not tools:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
tools.user_id != user.id
and not has_access(user.id, "write", tools.access_control)
and user.role != "admin"
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
result = Tools.delete_tool_by_id(id)
if result:
TOOLS = request.app.state.TOOLS
if id in TOOLS:
del TOOLS[id]
return result
############################
# GetToolValves
############################
@router.get("/id/{id}/valves", response_model=Optional[dict])
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
tools = Tools.get_tool_by_id(id)
if tools:
try:
valves = Tools.get_tool_valves_by_id(id)
return valves
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# GetToolValvesSpec
############################
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
async def get_tools_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
tools = Tools.get_tool_by_id(id)
if tools:
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
tools_module, _ = load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "Valves"):
Valves = tools_module.Valves
return Valves.schema()
return None
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateToolValves
############################
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
async def update_tools_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
):
tools = Tools.get_tool_by_id(id)
if not tools:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
tools.user_id != user.id
and not has_access(user.id, "write", tools.access_control)
and user.role != "admin"
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
tools_module, _ = load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if not hasattr(tools_module, "Valves"):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
Valves = tools_module.Valves
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
valves = Valves(**form_data)
Tools.update_tool_valves_by_id(id, valves.model_dump())
return valves.model_dump()
except Exception as e:
log.exception(f"Failed to update tool valves by id {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
############################
# ToolUserValves
############################
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
tools = Tools.get_tool_by_id(id)
if tools:
try:
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
return user_valves
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
async def get_tools_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
tools = Tools.get_tool_by_id(id)
if tools:
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
tools_module, _ = load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "UserValves"):
UserValves = tools_module.UserValves
return UserValves.schema()
return None
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
async def update_tools_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
):
tools = Tools.get_tool_by_id(id)
if tools:
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
tools_module, _ = load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "UserValves"):
UserValves = tools_module.UserValves
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
user_valves = UserValves(**form_data)
Tools.update_user_valves_by_id_and_user_id(
id, user.id, user_valves.model_dump()
)
return user_valves.model_dump()
except Exception as e:
log.exception(f"Failed to update user valves by id {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)