diff --git a/.dockerignore b/.dockerignore index 7f462b6f584af00ed7797e26493e15e4a559db01..88479ab6e4fa6334d6ad63c6e48d303f801b2ed7 100644 --- a/.dockerignore +++ b/.dockerignore @@ -16,4 +16,5 @@ _old uploads .ipynb_checkpoints **/*.db -_test \ No newline at end of file +_test +backend/data/* diff --git a/.env.example b/.env.example index 34375da804d819cdba45d66aec99752a9d08ca3d..e3b5cad5c3f0b5c1e211ddfbac3e62ff23ef3672 100644 --- a/.env.example +++ b/.env.example @@ -1,15 +1,13 @@ -# Ollama URL for the backend to connect -# The path '/ollama' will be redirected to the specified backend URL -OLLAMA_BASE_URL='http://localhost:11434' - -OPENAI_API_BASE_URL='' -OPENAI_API_KEY='' - -# AUTOMATIC1111_BASE_URL="http://localhost:7860" - -# DO NOT TRACK -SCARF_NO_ANALYTICS=true -DO_NOT_TRACK=true -ANONYMIZED_TELEMETRY=false - -GLOBAL_LOG_LEVEL="ERROR" \ No newline at end of file +# Ollama URL for the backend to connect +# The path '/ollama' will be redirected to the specified backend URL +OLLAMA_BASE_URL='http://localhost:11434' + +OPENAI_API_BASE_URL='' +OPENAI_API_KEY='' + +# AUTOMATIC1111_BASE_URL="http://localhost:7860" + +# DO NOT TRACK +SCARF_NO_ANALYTICS=true +DO_NOT_TRACK=true +ANONYMIZED_TELEMETRY=false \ No newline at end of file diff --git a/.gitattributes b/.gitattributes index 8f743fba7597607212e15a82fe2c95d9820326ee..16375626a0d8b4f058b9b53630072c1b138dc2fc 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,3 @@ *.sh text eol=lf *.ttf filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/deploy-to-hf-spaces.yml b/.github/workflows/deploy-to-hf-spaces.yml index f37530bb6d2237a1f5ad4b4938e1d696fe33e2aa..43474a5b4683ad90baa9917194165665af0c027b 100644 --- a/.github/workflows/deploy-to-hf-spaces.yml +++ b/.github/workflows/deploy-to-hf-spaces.yml @@ -28,6 +28,8 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v4 + with: + lfs: true - name: Remove git history run: rm -rf .git @@ -52,7 +54,9 @@ jobs: - name: Set up Git and push to Space run: | git init --initial-branch=main + git lfs install git lfs track "*.ttf" + git lfs track "*.jpg" rm demo.gif git add . git commit -m "GitHub deploy: ${{ github.sha }}" diff --git a/CHANGELOG.md b/CHANGELOG.md index 7aad55b8ff78a04c8da713a68c579ec32ee8019e..1a56d79bbacf75096b832a9c6ea5773c69023b37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,10 +5,71 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.4.1] - 2024-11-19 + +### Added + +- **🛠️ Tool Descriptions on Hover**: When enabled, tool descriptions now appear upon hovering over the tool icon in the message input, giving you more context instantly and improving workflow fluidity. + +### Fixed + +- **🚫 Graceful Handling of Deleted Users**: Resolved an issue where deleted users caused models, knowledge, prompts, or tools to fail loading in the workspace, ensuring smoother operation and fewer interruptions. +- **🔗 Proxy Fix for HTTPS Models Endpoint**: Fixed issues with proxies affecting the secure `/api/v1/models/` endpoint, ensuring stable connections and reliable access. +- **🔒 API Key Creation**: Addressed a bug that previously prevented API keys from being created. + +## [0.4.0] - 2024-11-19 + +### Added + +- **👥 User Groups**: You can now create and manage user groups, making user organization seamless. +- **🔐 Group-Based Access Control**: Set granular access to models, knowledge, prompts, and tools based on user groups, allowing for more controlled and secure environments. +- **🛠️ Group-Based User Permissions**: Easily manage workspace permissions. Grant users the ability to upload files, delete, edit, or create temporary chats, as well as define their ability to create models, knowledge, prompts, and tools. +- **🔑 LDAP Support**: Newly introduced LDAP authentication adds robust security and scalability to user management. +- **🌐 Enhanced OpenAI-Compatible Connections**: Added prefix ID support to avoid model ID clashes, with explicit model ID support for APIs lacking '/models' endpoint support, ensuring smooth operation with custom setups. +- **🔐 Ollama API Key Support**: Now manage credentials for Ollama when set behind proxies, including the option to utilize prefix ID for proper distinction across multiple Ollama instances. +- **🔄 Connection Enable/Disable Toggle**: Easily enable or disable individual OpenAI and Ollama connections as needed. +- **🎨 Redesigned Model Workspace**: Freshly redesigned to improve usability for managing models across users and groups. +- **🎨 Redesigned Prompt Workspace**: A fresh UI to conveniently organize and manage prompts. +- **🧩 Sorted Functions Workspace**: Functions are now automatically categorized by type (Action, Filter, Pipe), streamlining management. +- **💻 Redesigned Collaborative Workspace**: Enhanced support for multiple users contributing to models, knowledge, prompts, or tools, improving collaboration. +- **🔧 Auto-Selected Tools in Model Editor**: Tools enabled through the model editor are now automatically selected, whereas previously it only gave users the option to enable the tool, reducing manual steps and enhancing efficiency. +- **🔔 Web Search & Tools Indicator**: A clear indication now shows when web search or tools are active, reducing confusion. +- **🔑 Toggle API Key Auth**: Tighten security by easily enabling or disabling API key authentication option for Open WebUI. +- **🗂️ Agentic Retrieval**: Improve RAG accuracy via smart pre-processing of chat history to determine the best queries before retrieval. +- **📁 Large Text as File Option**: Optionally convert large pasted text into a file upload, keeping the chat interface cleaner. +- **🗂️ Toggle Citations for Models**: Ability to disable citations has been introduced in the model editor. +- **🔍 User Settings Search**: Quickly search for settings fields, improving ease of use and navigation. +- **🗣️ Experimental SpeechT5 TTS**: Local SpeechT5 support added for improved text-to-speech capabilities. +- **🔄 Unified Reset for Models**: A one-click option has been introduced to reset and remove all models from the Admin Settings. +- **🛠️ Initial Setup Wizard**: The setup process now explicitly informs users that they are creating an admin account during the first-time setup, ensuring clarity. Previously, users encountered the login page right away without this distinction. +- **🌐 Enhanced Translations**: Several language translations, including Ukrainian, Norwegian, and Brazilian Portuguese, were refined for better localization. + +### Fixed + +- **🎥 YouTube Video Attachments**: Fixed issues preventing proper loading and attachment of YouTube videos as files. +- **🔄 Shared Chat Update**: Corrected issues where shared chats were not updating, improving collaboration consistency. +- **🔍 DuckDuckGo Rate Limit Fix**: Addressed issues with DuckDuckGo search integration, enhancing search stability and performance when operating within rate limits. +- **🧾 Citations Relevance Fix**: Adjusted the relevance percentage calculation for citations, so that Open WebUI properly reflect the accuracy of a retrieved document in RAG, ensuring users get clearer insights into sources. +- **🔑 Jina Search API Key Requirement**: Added the option to input an API key for Jina Search, ensuring smooth functionality as keys are now mandatory. + +### Changed + +- **🛠️ Functions Moved to Admin Panel**: As Functions operate as advanced plugins, they are now accessible from the Admin Panel instead of the workspace. +- **🛠️ Manage Ollama Connections**: The "Models" section in Admin Settings has been relocated to Admin Settings > "Connections" > Ollama Connections. You can now manage Ollama instances via a dedicated "Manage Ollama" modal from "Connections", streamlining the setup and configuration of Ollama models. +- **📊 Base Models in Admin Settings**: Admins can now find all base models, both connections or functions, in the "Models" Admin setting. Global model accessibility can be enabled or disabled here. Models are private by default, requiring explicit permission assignment for user access. +- **📌 Sticky Model Selection for New Chats**: The model chosen from a previous chat now persists when creating a new chat. If you click "New Chat" again from the new chat page, it will revert to your default model. +- **🎨 Design Refactoring**: Overall design refinements across the platform have been made, providing a more cohesive and polished user experience. + +### Removed + +- **📂 Model List Reordering**: Temporarily removed and will be reintroduced in upcoming user group settings improvements. +- **⚙️ Default Model Setting**: Removed the ability to set a default model for users, will be reintroduced with user group settings in the future. + ## [0.3.35] - 2024-10-26 ### Added +- **🌐 Translation Update**: Added translation labels in the SearchInput and CreateCollection components and updated Brazilian Portuguese translation (pt-BR) - **📁 Robust File Handling**: Enhanced file input handling for chat. If the content extraction fails or is empty, users will now receive a clear warning, preventing silent failures and ensuring you always know what's happening with your uploads. - **🌍 New Language Support**: Introduced Hungarian translations and updated French translations, expanding the platform's language accessibility for a more global user base. diff --git a/Dockerfile b/Dockerfile index e8211a14537864ad3a443b90efa2afb88a15b023..91238b4d4ea6a26d1472ba6d8181cdabbcc7e235 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,177 +1,176 @@ -# syntax=docker/dockerfile:1 -# Initialize device type args -# use build args in the docker build command with --build-arg="BUILDARG=true" -ARG USE_CUDA=false -ARG USE_OLLAMA=false -# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default) -ARG USE_CUDA_VER=cu121 -# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers -# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard -# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) -# IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. -ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 -ARG USE_RERANKING_MODEL="" - -# Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken -ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base" - -ARG BUILD_HASH=dev-build -# Override at your own risk - non-root configurations are untested -ARG UID=0 -ARG GID=0 - -######## WebUI frontend ######## -FROM --platform=$BUILDPLATFORM node:22-alpine3.20 AS build -ARG BUILD_HASH - -WORKDIR /app - -COPY package.json package-lock.json ./ -RUN npm ci - -COPY . . -ENV APP_BUILD_HASH=${BUILD_HASH} -RUN npm run build - -######## WebUI backend ######## -FROM python:3.11-slim-bookworm AS base - -# Use args -ARG USE_CUDA -ARG USE_OLLAMA -ARG USE_CUDA_VER -ARG USE_EMBEDDING_MODEL -ARG USE_RERANKING_MODEL -ARG UID -ARG GID - -## Basis ## -ENV ENV=prod \ - PORT=8080 \ - # pass build args to the build - USE_OLLAMA_DOCKER=${USE_OLLAMA} \ - USE_CUDA_DOCKER=${USE_CUDA} \ - USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \ - USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \ - USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL} - -## Basis URL Config ## -ENV OLLAMA_BASE_URL="/ollama" \ - OPENAI_API_BASE_URL="" - -## API Key and Security Config ## -ENV OPENAI_API_KEY="" \ - WEBUI_SECRET_KEY="" \ - SCARF_NO_ANALYTICS=true \ - DO_NOT_TRACK=true \ - ANONYMIZED_TELEMETRY=false - -#### Other models ######################################################### -## whisper TTS model settings ## -ENV WHISPER_MODEL="base" \ - WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" - -## RAG Embedding model settings ## -ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \ - RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \ - SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" - -## Tiktoken model settings ## -ENV TIKTOKEN_ENCODING_NAME="cl100k_base" \ - TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken" - -## Hugging Face download cache ## -ENV HF_HOME="/app/backend/data/cache/embedding/models" - -## Torch Extensions ## -# ENV TORCH_EXTENSIONS_DIR="/.cache/torch_extensions" - -#### Other models ########################################################## - -WORKDIR /app/backend - -ENV HOME=/root -# Create user and group if not root -RUN if [ $UID -ne 0 ]; then \ - if [ $GID -ne 0 ]; then \ - addgroup --gid $GID app; \ - fi; \ - adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \ - fi - -RUN mkdir -p $HOME/.cache/chroma -RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry_user_id - -# Make sure the user has access to the app and root directory -RUN chown -R $UID:$GID /app $HOME - -RUN if [ "$USE_OLLAMA" = "true" ]; then \ - apt-get update && \ - # Install pandoc and netcat - apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \ - apt-get install -y --no-install-recommends gcc python3-dev && \ - # for RAG OCR - apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \ - # install helper tools - apt-get install -y --no-install-recommends curl jq && \ - # install ollama - curl -fsSL https://ollama.com/install.sh | sh && \ - # cleanup - rm -rf /var/lib/apt/lists/*; \ - else \ - apt-get update && \ - # Install pandoc, netcat and gcc - apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \ - apt-get install -y --no-install-recommends gcc python3-dev && \ - # for RAG OCR - apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \ - # cleanup - rm -rf /var/lib/apt/lists/*; \ - fi - -# install python dependencies -COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt - -RUN pip3 install uv && \ - if [ "$USE_CUDA" = "true" ]; then \ - # If you use CUDA the whisper and embedding model will be downloaded on first use - pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \ - uv pip install --system -r requirements.txt --no-cache-dir && \ - python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ - python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ - python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \ - else \ - pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ - uv pip install --system -r requirements.txt --no-cache-dir && \ - python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ - python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ - python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \ - fi; \ - chown -R $UID:$GID /app/backend/data/ - - - -# copy embedding weight from build -# RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2 -# COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx - -# copy built frontend files -COPY --chown=$UID:$GID --from=build /app/build /app/build -COPY --chown=$UID:$GID --from=build /app/CHANGELOG.md /app/CHANGELOG.md -COPY --chown=$UID:$GID --from=build /app/package.json /app/package.json - -# copy backend files -COPY --chown=$UID:$GID ./backend . - -EXPOSE 8080 - -HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1 - -USER $UID:$GID - -ARG BUILD_HASH -ENV WEBUI_BUILD_VERSION=${BUILD_HASH} -ENV DOCKER=true -ENV GLOBAL_LOG_LEVEL="ERROR" - -CMD [ "bash", "start.sh"] +# syntax=docker/dockerfile:1 +# Initialize device type args +# use build args in the docker build command with --build-arg="BUILDARG=true" +ARG USE_CUDA=false +ARG USE_OLLAMA=false +# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default) +ARG USE_CUDA_VER=cu121 +# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers +# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard +# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) +# IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. +ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 +ARG USE_RERANKING_MODEL="" + +# Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken +ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base" + +ARG BUILD_HASH=dev-build +# Override at your own risk - non-root configurations are untested +ARG UID=0 +ARG GID=0 + +######## WebUI frontend ######## +FROM --platform=$BUILDPLATFORM node:22-alpine3.20 AS build +ARG BUILD_HASH + +WORKDIR /app + +COPY package.json package-lock.json ./ +RUN npm ci + +COPY . . +ENV APP_BUILD_HASH=${BUILD_HASH} +RUN npm run build + +######## WebUI backend ######## +FROM python:3.11-slim-bookworm AS base + +# Use args +ARG USE_CUDA +ARG USE_OLLAMA +ARG USE_CUDA_VER +ARG USE_EMBEDDING_MODEL +ARG USE_RERANKING_MODEL +ARG UID +ARG GID + +## Basis ## +ENV ENV=prod \ + PORT=8080 \ + # pass build args to the build + USE_OLLAMA_DOCKER=${USE_OLLAMA} \ + USE_CUDA_DOCKER=${USE_CUDA} \ + USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \ + USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \ + USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL} + +## Basis URL Config ## +ENV OLLAMA_BASE_URL="/ollama" \ + OPENAI_API_BASE_URL="" + +## API Key and Security Config ## +ENV OPENAI_API_KEY="" \ + WEBUI_SECRET_KEY="" \ + SCARF_NO_ANALYTICS=true \ + DO_NOT_TRACK=true \ + ANONYMIZED_TELEMETRY=false + +#### Other models ######################################################### +## whisper TTS model settings ## +ENV WHISPER_MODEL="base" \ + WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" + +## RAG Embedding model settings ## +ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \ + RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \ + SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" + +## Tiktoken model settings ## +ENV TIKTOKEN_ENCODING_NAME="cl100k_base" \ + TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken" + +## Hugging Face download cache ## +ENV HF_HOME="/app/backend/data/cache/embedding/models" + +## Torch Extensions ## +# ENV TORCH_EXTENSIONS_DIR="/.cache/torch_extensions" + +#### Other models ########################################################## + +WORKDIR /app/backend + +ENV HOME=/root +# Create user and group if not root +RUN if [ $UID -ne 0 ]; then \ + if [ $GID -ne 0 ]; then \ + addgroup --gid $GID app; \ + fi; \ + adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \ + fi + +RUN mkdir -p $HOME/.cache/chroma +RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry_user_id + +# Make sure the user has access to the app and root directory +RUN chown -R $UID:$GID /app $HOME + +RUN if [ "$USE_OLLAMA" = "true" ]; then \ + apt-get update && \ + # Install pandoc and netcat + apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \ + apt-get install -y --no-install-recommends gcc python3-dev && \ + # for RAG OCR + apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \ + # install helper tools + apt-get install -y --no-install-recommends curl jq && \ + # install ollama + curl -fsSL https://ollama.com/install.sh | sh && \ + # cleanup + rm -rf /var/lib/apt/lists/*; \ + else \ + apt-get update && \ + # Install pandoc, netcat and gcc + apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \ + apt-get install -y --no-install-recommends gcc python3-dev && \ + # for RAG OCR + apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \ + # cleanup + rm -rf /var/lib/apt/lists/*; \ + fi + +# install python dependencies +COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt + +RUN pip3 install uv && \ + if [ "$USE_CUDA" = "true" ]; then \ + # If you use CUDA the whisper and embedding model will be downloaded on first use + pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \ + uv pip install --system -r requirements.txt --no-cache-dir && \ + python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ + python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ + python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \ + else \ + pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ + uv pip install --system -r requirements.txt --no-cache-dir && \ + python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ + python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ + python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \ + fi; \ + chown -R $UID:$GID /app/backend/data/ + + + +# copy embedding weight from build +# RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2 +# COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx + +# copy built frontend files +COPY --chown=$UID:$GID --from=build /app/build /app/build +COPY --chown=$UID:$GID --from=build /app/CHANGELOG.md /app/CHANGELOG.md +COPY --chown=$UID:$GID --from=build /app/package.json /app/package.json + +# copy backend files +COPY --chown=$UID:$GID ./backend . + +EXPOSE 8080 + +HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1 + +USER $UID:$GID + +ARG BUILD_HASH +ENV WEBUI_BUILD_VERSION=${BUILD_HASH} +ENV DOCKER=true + +CMD [ "bash", "start.sh"] diff --git a/README.md b/README.md index efc5efa9a7c20893b5f4979b6217fbb6e8aaaf82..2d5162ec77890cf21478cbbaa58f1075f1e5f052 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature- - 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query. -- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch` and `SearchApi` and inject the results directly into your chat experience. +- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch`, `SearchApi` and `Bing` and inject the results directly into your chat experience. - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions. @@ -195,18 +195,6 @@ docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --a Discover upcoming features on our roadmap in the [Open WebUI Documentation](https://docs.openwebui.com/roadmap/). -## Supporters ✨ - -A big shoutout to our amazing supporters who's helping to make this project possible! 🙏 - -### Platinum Sponsors 🤍 - -- We're looking for Sponsors! - -### Acknowledgments - -Special thanks to [Prof. Lawrence Kim](https://www.lhkim.com/) and [Prof. Nick Vincent](https://www.nickmvincent.com/) for their invaluable support and guidance in shaping this project into a research endeavor. Grateful for your mentorship throughout the journey! 🙌 - ## License 📜 This project is licensed under the [MIT License](LICENSE) - see the [LICENSE](LICENSE) file for details. 📄 diff --git a/backend/open_webui/apps/audio/main.py b/backend/open_webui/apps/audio/main.py index a1cc75b43b2490421bb6d342e0c1cc75d23cc00f..72f04896fcb310e864e3edcbb11d3bf425930762 100644 --- a/backend/open_webui/apps/audio/main.py +++ b/backend/open_webui/apps/audio/main.py @@ -32,7 +32,13 @@ from open_webui.config import ( ) from open_webui.constants import ERROR_MESSAGES -from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE +from open_webui.env import ( + ENV, + SRC_LOG_LEVELS, + DEVICE_TYPE, + ENABLE_FORWARD_USER_INFO_HEADERS, +) + from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse @@ -47,7 +53,12 @@ MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["AUDIO"]) -app = FastAPI() +app = FastAPI( + docs_url="/docs" if ENV == "dev" else None, + openapi_url="/openapi.json" if ENV == "dev" else None, + redoc_url=None, +) + app.add_middleware( CORSMiddleware, allow_origins=CORS_ALLOW_ORIGIN, @@ -74,6 +85,10 @@ app.state.config.TTS_VOICE = AUDIO_TTS_VOICE app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON + +app.state.speech_synthesiser = None +app.state.speech_speaker_embeddings_dataset = None + app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT @@ -231,6 +246,21 @@ async def update_audio_config( } +def load_speech_pipeline(): + from transformers import pipeline + from datasets import load_dataset + + if app.state.speech_synthesiser is None: + app.state.speech_synthesiser = pipeline( + "text-to-speech", "microsoft/speecht5_tts" + ) + + if app.state.speech_speaker_embeddings_dataset is None: + app.state.speech_speaker_embeddings_dataset = load_dataset( + "Matthijs/cmu-arctic-xvectors", split="validation" + ) + + @app.post("/speech") async def speech(request: Request, user=Depends(get_verified_user)): body = await request.body() @@ -248,6 +278,12 @@ async def speech(request: Request, user=Depends(get_verified_user)): headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}" headers["Content-Type"] = "application/json" + if ENABLE_FORWARD_USER_INFO_HEADERS: + headers["X-OpenWebUI-User-Name"] = user.name + headers["X-OpenWebUI-User-Id"] = user.id + headers["X-OpenWebUI-User-Email"] = user.email + headers["X-OpenWebUI-User-Role"] = user.role + try: body = body.decode("utf-8") body = json.loads(body) @@ -391,6 +427,43 @@ async def speech(request: Request, user=Depends(get_verified_user)): raise HTTPException( status_code=500, detail=f"Error synthesizing speech - {response.reason}" ) + elif app.state.config.TTS_ENGINE == "transformers": + payload = None + try: + payload = json.loads(body.decode("utf-8")) + except Exception as e: + log.exception(e) + raise HTTPException(status_code=400, detail="Invalid JSON payload") + + import torch + import soundfile as sf + + load_speech_pipeline() + + embeddings_dataset = app.state.speech_speaker_embeddings_dataset + + speaker_index = 6799 + try: + speaker_index = embeddings_dataset["filename"].index( + app.state.config.TTS_MODEL + ) + except Exception: + pass + + speaker_embedding = torch.tensor( + embeddings_dataset[speaker_index]["xvector"] + ).unsqueeze(0) + + speech = app.state.speech_synthesiser( + payload["input"], + forward_params={"speaker_embeddings": speaker_embedding}, + ) + + sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"]) + with open(file_body_path, "w") as f: + json.dump(json.loads(body.decode("utf-8")), f) + + return FileResponse(file_path) def transcribe(file_path): diff --git a/backend/open_webui/apps/images/main.py b/backend/open_webui/apps/images/main.py index 2849e264014a70b6d2c8da6b494b8d03b2d7da31..c4bbaec174481257d7b15ef252e6bd03640ab396 100644 --- a/backend/open_webui/apps/images/main.py +++ b/backend/open_webui/apps/images/main.py @@ -35,7 +35,8 @@ from open_webui.config import ( AppConfig, ) from open_webui.constants import ERROR_MESSAGES -from open_webui.env import SRC_LOG_LEVELS +from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS + from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel @@ -47,7 +48,12 @@ log.setLevel(SRC_LOG_LEVELS["IMAGES"]) IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) -app = FastAPI() +app = FastAPI( + docs_url="/docs" if ENV == "dev" else None, + openapi_url="/openapi.json" if ENV == "dev" else None, + redoc_url=None, +) + app.add_middleware( CORSMiddleware, allow_origins=CORS_ALLOW_ORIGIN, @@ -456,6 +462,12 @@ async def image_generations( headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" headers["Content-Type"] = "application/json" + if ENABLE_FORWARD_USER_INFO_HEADERS: + headers["X-OpenWebUI-User-Name"] = user.name + headers["X-OpenWebUI-User-Id"] = user.id + headers["X-OpenWebUI-User-Email"] = user.email + headers["X-OpenWebUI-User-Role"] = user.role + data = { "model": ( app.state.config.MODEL diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index f0f8877f49aba97f2f8fceaa155b59e5921e9118..7b09acbf9be9595e23f2ce9d6b2fbf1937043236 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -13,18 +13,20 @@ import requests from open_webui.apps.webui.models.models import Models from open_webui.config import ( CORS_ALLOW_ORIGIN, - ENABLE_MODEL_FILTER, ENABLE_OLLAMA_API, - MODEL_FILTER_LIST, OLLAMA_BASE_URLS, + OLLAMA_API_CONFIGS, UPLOAD_DIR, AppConfig, ) -from open_webui.env import AIOHTTP_CLIENT_TIMEOUT +from open_webui.env import ( + AIOHTTP_CLIENT_TIMEOUT, + AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, +) from open_webui.constants import ERROR_MESSAGES -from open_webui.env import SRC_LOG_LEVELS +from open_webui.env import ENV, SRC_LOG_LEVELS from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse @@ -41,11 +43,18 @@ from open_webui.utils.payload import ( apply_model_system_prompt_to_body, ) from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) -app = FastAPI() + +app = FastAPI( + docs_url="/docs" if ENV == "dev" else None, + openapi_url="/openapi.json" if ENV == "dev" else None, + redoc_url=None, +) + app.add_middleware( CORSMiddleware, allow_origins=CORS_ALLOW_ORIGIN, @@ -56,12 +65,9 @@ app.add_middleware( app.state.config = AppConfig() -app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST - app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS -app.state.MODELS = {} +app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. @@ -69,60 +75,98 @@ app.state.MODELS = {} # least connections, or least response time for better resource utilization and performance optimization. -@app.middleware("http") -async def check_url(request: Request, call_next): - if len(app.state.MODELS) == 0: - await get_all_models() - else: - pass - - response = await call_next(request) - return response - - @app.head("/") @app.get("/") async def get_status(): return {"status": True} +class ConnectionVerificationForm(BaseModel): + url: str + key: Optional[str] = None + + +@app.post("/verify") +async def verify_connection( + form_data: ConnectionVerificationForm, user=Depends(get_admin_user) +): + url = form_data.url + key = form_data.key + + headers = {} + if key: + headers["Authorization"] = f"Bearer {key}" + + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + async with aiohttp.ClientSession(timeout=timeout) as session: + try: + async with session.get(f"{url}/api/version", headers=headers) as r: + if r.status != 200: + # Extract response error details if available + error_detail = f"HTTP Error: {r.status}" + res = await r.json() + if "error" in res: + error_detail = f"External Error: {res['error']}" + raise Exception(error_detail) + + response_data = await r.json() + return response_data + + except aiohttp.ClientError as e: + # ClientError covers all aiohttp requests issues + log.exception(f"Client error: {str(e)}") + # Handle aiohttp-specific connection issues, timeout etc. + raise HTTPException( + status_code=500, detail="Open WebUI: Server Connection Error" + ) + except Exception as e: + log.exception(f"Unexpected error: {e}") + # Generic error handler in case parsing JSON or other steps fail + error_detail = f"Unexpected error: {str(e)}" + raise HTTPException(status_code=500, detail=error_detail) + + @app.get("/config") async def get_config(user=Depends(get_admin_user)): - return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API} + return { + "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, + "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, + "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS, + } class OllamaConfigForm(BaseModel): - enable_ollama_api: Optional[bool] = None + ENABLE_OLLAMA_API: Optional[bool] = None + OLLAMA_BASE_URLS: list[str] + OLLAMA_API_CONFIGS: dict @app.post("/config/update") async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)): - app.state.config.ENABLE_OLLAMA_API = form_data.enable_ollama_api - return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API} - - -@app.get("/urls") -async def get_ollama_api_urls(user=Depends(get_admin_user)): - return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} - + app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API + app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS -class UrlUpdateForm(BaseModel): - urls: list[str] + app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS + # Remove any extra configs + config_urls = app.state.config.OLLAMA_API_CONFIGS.keys() + for url in list(app.state.config.OLLAMA_BASE_URLS): + if url not in config_urls: + app.state.config.OLLAMA_API_CONFIGS.pop(url, None) -@app.post("/urls/update") -async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): - app.state.config.OLLAMA_BASE_URLS = form_data.urls + return { + "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, + "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, + "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS, + } - log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}") - return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} - -async def fetch_url(url): - timeout = aiohttp.ClientTimeout(total=3) +async def aiohttp_get(url, key=None): + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: + headers = {"Authorization": f"Bearer {key}"} if key else {} async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.get(url) as response: + async with session.get(url, headers=headers) as response: return await response.json() except Exception as e: # Handle connection error here @@ -148,10 +192,18 @@ async def post_streaming_url( session = aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + key = api_config.get("key", None) + + headers = {"Content-Type": "application/json"} + if key: + headers["Authorization"] = f"Bearer {key}" + r = await session.post( url, data=payload, - headers={"Content-Type": "application/json"}, + headers=headers, ) r.raise_for_status() @@ -194,29 +246,62 @@ def merge_models_lists(model_lists): for idx, model_list in enumerate(model_lists): if model_list is not None: for model in model_list: - digest = model["digest"] - if digest not in merged_models: + id = model["model"] + if id not in merged_models: model["urls"] = [idx] - merged_models[digest] = model + merged_models[id] = model else: - merged_models[digest]["urls"].append(idx) + merged_models[id]["urls"].append(idx) return list(merged_models.values()) async def get_all_models(): log.info("get_all_models()") - if app.state.config.ENABLE_OLLAMA_API: - tasks = [ - fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS - ] + tasks = [] + for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS): + if url not in app.state.config.OLLAMA_API_CONFIGS: + tasks.append(aiohttp_get(f"{url}/api/tags")) + else: + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + enable = api_config.get("enable", True) + key = api_config.get("key", None) + + if enable: + tasks.append(aiohttp_get(f"{url}/api/tags", key)) + else: + tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) + responses = await asyncio.gather(*tasks) + for idx, response in enumerate(responses): + if response: + url = app.state.config.OLLAMA_BASE_URLS[idx] + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + + prefix_id = api_config.get("prefix_id", None) + model_ids = api_config.get("model_ids", []) + + if len(model_ids) != 0 and "models" in response: + response["models"] = list( + filter( + lambda model: model["model"] in model_ids, + response["models"], + ) + ) + + if prefix_id: + for model in response.get("models", []): + model["model"] = f"{prefix_id}.{model['model']}" + + print(responses) + models = { "models": merge_models_lists( map( - lambda response: response["models"] if response else None, responses + lambda response: response.get("models", []) if response else None, + responses, ) ) } @@ -224,8 +309,6 @@ async def get_all_models(): else: models = {"models": []} - app.state.MODELS = {model["model"]: model for model in models["models"]} - return models @@ -234,29 +317,25 @@ async def get_all_models(): async def get_ollama_tags( url_idx: Optional[int] = None, user=Depends(get_verified_user) ): + models = [] if url_idx is None: models = await get_all_models() - - if app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user": - models["models"] = list( - filter( - lambda model: model["name"] - in app.state.config.MODEL_FILTER_LIST, - models["models"], - ) - ) - return models - return models else: url = app.state.config.OLLAMA_BASE_URLS[url_idx] + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + key = api_config.get("key", None) + + headers = {} + if key: + headers["Authorization"] = f"Bearer {key}" + r = None try: - r = requests.request(method="GET", url=f"{url}/api/tags") + r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers) r.raise_for_status() - return r.json() + models = r.json() except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" @@ -273,6 +352,20 @@ async def get_ollama_tags( detail=error_detail, ) + if user.role == "user": + # Filter models based on user access control + filtered_models = [] + for model in models.get("models", []): + model_info = Models.get_model_by_id(model["model"]) + if model_info: + if user.id == model_info.user_id or has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + models["models"] = filtered_models + + return models + @app.get("/api/version") @app.get("/api/version/{url_idx}") @@ -281,7 +374,10 @@ async def get_ollama_versions(url_idx: Optional[int] = None): if url_idx is None: # returns lowest version tasks = [ - fetch_url(f"{url}/api/version") + aiohttp_get( + f"{url}/api/version", + app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None), + ) for url in app.state.config.OLLAMA_BASE_URLS ] responses = await asyncio.gather(*tasks) @@ -361,8 +457,11 @@ async def push_model( user=Depends(get_admin_user), ): if url_idx is None: - if form_data.name in app.state.MODELS: - url_idx = app.state.MODELS[form_data.name]["urls"][0] + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + + if form_data.name in models: + url_idx = models[form_data.name]["urls"][0] else: raise HTTPException( status_code=400, @@ -411,8 +510,11 @@ async def copy_model( user=Depends(get_admin_user), ): if url_idx is None: - if form_data.source in app.state.MODELS: - url_idx = app.state.MODELS[form_data.source]["urls"][0] + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + + if form_data.source in models: + url_idx = models[form_data.source]["urls"][0] else: raise HTTPException( status_code=400, @@ -421,10 +523,18 @@ async def copy_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + key = api_config.get("key", None) + + headers = {"Content-Type": "application/json"} + if key: + headers["Authorization"] = f"Bearer {key}" + r = requests.request( method="POST", url=f"{url}/api/copy", - headers={"Content-Type": "application/json"}, + headers=headers, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -459,8 +569,11 @@ async def delete_model( user=Depends(get_admin_user), ): if url_idx is None: - if form_data.name in app.state.MODELS: - url_idx = app.state.MODELS[form_data.name]["urls"][0] + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + + if form_data.name in models: + url_idx = models[form_data.name]["urls"][0] else: raise HTTPException( status_code=400, @@ -470,11 +583,18 @@ async def delete_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + key = api_config.get("key", None) + + headers = {"Content-Type": "application/json"} + if key: + headers["Authorization"] = f"Bearer {key}" + r = requests.request( method="DELETE", url=f"{url}/api/delete", - headers={"Content-Type": "application/json"}, data=form_data.model_dump_json(exclude_none=True).encode(), + headers=headers, ) try: r.raise_for_status() @@ -501,20 +621,30 @@ async def delete_model( @app.post("/api/show") async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)): - if form_data.name not in app.state.MODELS: + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + + if form_data.name not in models: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) + url_idx = random.choice(models[form_data.name]["urls"]) url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + key = api_config.get("key", None) + + headers = {"Content-Type": "application/json"} + if key: + headers["Authorization"] = f"Bearer {key}" + r = requests.request( method="POST", url=f"{url}/api/show", - headers={"Content-Type": "application/json"}, + headers=headers, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: @@ -570,23 +700,26 @@ async def generate_embeddings( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx) + return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx) -def generate_ollama_embeddings( +async def generate_ollama_embeddings( form_data: GenerateEmbeddingsForm, url_idx: Optional[int] = None, ): log.info(f"generate_ollama_embeddings {form_data}") if url_idx is None: + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + model = form_data.model if ":" not in model: model = f"{model}:latest" - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) + if model in models: + url_idx = random.choice(models[model]["urls"]) else: raise HTTPException( status_code=400, @@ -596,10 +729,17 @@ def generate_ollama_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + key = api_config.get("key", None) + + headers = {"Content-Type": "application/json"} + if key: + headers["Authorization"] = f"Bearer {key}" + r = requests.request( method="POST", url=f"{url}/api/embeddings", - headers={"Content-Type": "application/json"}, + headers=headers, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: @@ -630,20 +770,23 @@ def generate_ollama_embeddings( ) -def generate_ollama_batch_embeddings( +async def generate_ollama_batch_embeddings( form_data: GenerateEmbedForm, url_idx: Optional[int] = None, ): log.info(f"generate_ollama_batch_embeddings {form_data}") if url_idx is None: + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + model = form_data.model if ":" not in model: model = f"{model}:latest" - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) + if model in models: + url_idx = random.choice(models[model]["urls"]) else: raise HTTPException( status_code=400, @@ -653,10 +796,17 @@ def generate_ollama_batch_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + key = api_config.get("key", None) + + headers = {"Content-Type": "application/json"} + if key: + headers["Authorization"] = f"Bearer {key}" + r = requests.request( method="POST", url=f"{url}/api/embed", - headers={"Content-Type": "application/json"}, + headers=headers, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: @@ -706,13 +856,16 @@ async def generate_completion( user=Depends(get_verified_user), ): if url_idx is None: + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + model = form_data.model if ":" not in model: model = f"{model}:latest" - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) + if model in models: + url_idx = random.choice(models[model]["urls"]) else: raise HTTPException( status_code=400, @@ -720,6 +873,10 @@ async def generate_completion( ) url = app.state.config.OLLAMA_BASE_URLS[url_idx] + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + prefix_id = api_config.get("prefix_id", None) + if prefix_id: + form_data.model = form_data.model.replace(f"{prefix_id}.", "") log.info(f"url: {url}") return await post_streaming_url( @@ -743,14 +900,17 @@ class GenerateChatCompletionForm(BaseModel): keep_alive: Optional[Union[int, str]] = None -def get_ollama_url(url_idx: Optional[int], model: str): +async def get_ollama_url(url_idx: Optional[int], model: str): if url_idx is None: - if model not in app.state.MODELS: + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + + if model not in models: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), ) - url_idx = random.choice(app.state.MODELS[model]["urls"]) + url_idx = random.choice(models[model]["urls"]) url = app.state.config.OLLAMA_BASE_URLS[url_idx] return url @@ -768,15 +928,7 @@ async def generate_chat_completion( if "metadata" in payload: del payload["metadata"] - model_id = form_data.model - - if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: - raise HTTPException( - status_code=403, - detail="Model not found", - ) - + model_id = payload["model"] model_info = Models.get_model_by_id(model_id) if model_info: @@ -794,13 +946,37 @@ async def generate_chat_completion( ) payload = apply_model_system_prompt_to_body(params, payload, user) + # Check if user has access to the model + if not bypass_filter and user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + elif not bypass_filter: + if user.role != "admin": + raise HTTPException( + status_code=403, + detail="Model not found", + ) + if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = get_ollama_url(url_idx, payload["model"]) + url = await get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") log.debug(f"generate_chat_completion() - 2.payload = {payload}") + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + prefix_id = api_config.get("prefix_id", None) + if prefix_id: + payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + return await post_streaming_url( f"{url}/api/chat", json.dumps(payload), @@ -817,7 +993,7 @@ class OpenAIChatMessageContent(BaseModel): class OpenAIChatMessage(BaseModel): role: str - content: Union[str, OpenAIChatMessageContent] + content: Union[str, list[OpenAIChatMessageContent]] model_config = ConfigDict(extra="allow") @@ -836,22 +1012,24 @@ async def generate_openai_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - completion_form = OpenAIChatCompletionForm(**form_data) + try: + completion_form = OpenAIChatCompletionForm(**form_data) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=400, + detail=str(e), + ) + payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])} if "metadata" in payload: del payload["metadata"] model_id = completion_form.model - - if app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: - raise HTTPException( - status_code=403, - detail="Model not found", - ) + if ":" not in model_id: + model_id = f"{model_id}:latest" model_info = Models.get_model_by_id(model_id) - if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id @@ -862,12 +1040,36 @@ async def generate_openai_chat_completion( payload = apply_model_params_to_body_openai(params, payload) payload = apply_model_system_prompt_to_body(params, payload, user) + # Check if user has access to the model + if user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + else: + if user.role != "admin": + raise HTTPException( + status_code=403, + detail="Model not found", + ) + if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = get_ollama_url(url_idx, payload["model"]) + url = await get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + prefix_id = api_config.get("prefix_id", None) + if prefix_id: + payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + return await post_streaming_url( f"{url}/v1/chat/completions", json.dumps(payload), @@ -881,21 +1083,29 @@ async def get_openai_models( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): + + models = [] if url_idx is None: - models = await get_all_models() + model_list = await get_all_models() + models = [ + { + "id": model["model"], + "object": "model", + "created": int(time.time()), + "owned_by": "openai", + } + for model in model_list["models"] + ] - if app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user": - models["models"] = list( - filter( - lambda model: model["name"] - in app.state.config.MODEL_FILTER_LIST, - models["models"], - ) - ) + else: + url = app.state.config.OLLAMA_BASE_URLS[url_idx] + try: + r = requests.request(method="GET", url=f"{url}/api/tags") + r.raise_for_status() + + model_list = r.json() - return { - "data": [ + models = [ { "id": model["model"], "object": "model", @@ -903,31 +1113,7 @@ async def get_openai_models( "owned_by": "openai", } for model in models["models"] - ], - "object": "list", - } - - else: - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - try: - r = requests.request(method="GET", url=f"{url}/api/tags") - r.raise_for_status() - - models = r.json() - - return { - "data": [ - { - "id": model["model"], - "object": "model", - "created": int(time.time()), - "owned_by": "openai", - } - for model in models["models"] - ], - "object": "list", - } - + ] except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" @@ -944,6 +1130,23 @@ async def get_openai_models( detail=error_detail, ) + if user.role == "user": + # Filter models based on user access control + filtered_models = [] + for model in models: + model_info = Models.get_model_by_id(model["id"]) + if model_info: + if user.id == model_info.user_id or has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + models = filtered_models + + return { + "data": models, + "object": "list", + } + class UrlForm(BaseModel): url: str diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py index 8e7f9cc88ec97d5ff1f3730af20eb3ed77aa69cf..2e2da944f10514ecefc308d4278654d829c916b6 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/apps/openai/main.py @@ -11,20 +11,20 @@ from open_webui.apps.webui.models.models import Models from open_webui.config import ( CACHE_DIR, CORS_ALLOW_ORIGIN, - ENABLE_MODEL_FILTER, ENABLE_OPENAI_API, - MODEL_FILTER_LIST, OPENAI_API_BASE_URLS, OPENAI_API_KEYS, + OPENAI_API_CONFIGS, AppConfig, ) from open_webui.env import ( AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, + ENABLE_FORWARD_USER_INFO_HEADERS, ) from open_webui.constants import ERROR_MESSAGES -from open_webui.env import SRC_LOG_LEVELS +from open_webui.env import ENV, SRC_LOG_LEVELS from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, StreamingResponse @@ -37,11 +37,20 @@ from open_webui.utils.payload import ( ) from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OPENAI"]) -app = FastAPI() + +app = FastAPI( + docs_url="/docs" if ENV == "dev" else None, + openapi_url="/openapi.json" if ENV == "dev" else None, + redoc_url=None, +) + + app.add_middleware( CORSMiddleware, allow_origins=CORS_ALLOW_ORIGIN, @@ -52,69 +61,66 @@ app.add_middleware( app.state.config = AppConfig() -app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST - app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS - -app.state.MODELS = {} - - -@app.middleware("http") -async def check_url(request: Request, call_next): - if len(app.state.MODELS) == 0: - await get_all_models() - - response = await call_next(request) - return response +app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS @app.get("/config") async def get_config(user=Depends(get_admin_user)): - return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API} + return { + "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API, + "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS, + "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS, + "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS, + } class OpenAIConfigForm(BaseModel): - enable_openai_api: Optional[bool] = None + ENABLE_OPENAI_API: Optional[bool] = None + OPENAI_API_BASE_URLS: list[str] + OPENAI_API_KEYS: list[str] + OPENAI_API_CONFIGS: dict @app.post("/config/update") async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)): - app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api - return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API} - - -class UrlsUpdateForm(BaseModel): - urls: list[str] + app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API + app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS + app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS -class KeysUpdateForm(BaseModel): - keys: list[str] - - -@app.get("/urls") -async def get_openai_urls(user=Depends(get_admin_user)): - return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS} - - -@app.post("/urls/update") -async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)): - await get_all_models() - app.state.config.OPENAI_API_BASE_URLS = form_data.urls - return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS} - + # Check if API KEYS length is same than API URLS length + if len(app.state.config.OPENAI_API_KEYS) != len( + app.state.config.OPENAI_API_BASE_URLS + ): + if len(app.state.config.OPENAI_API_KEYS) > len( + app.state.config.OPENAI_API_BASE_URLS + ): + app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[ + : len(app.state.config.OPENAI_API_BASE_URLS) + ] + else: + app.state.config.OPENAI_API_KEYS += [""] * ( + len(app.state.config.OPENAI_API_BASE_URLS) + - len(app.state.config.OPENAI_API_KEYS) + ) -@app.get("/keys") -async def get_openai_keys(user=Depends(get_admin_user)): - return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS} + app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS + # Remove any extra configs + config_urls = app.state.config.OPENAI_API_CONFIGS.keys() + for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS): + if url not in config_urls: + app.state.config.OPENAI_API_CONFIGS.pop(url, None) -@app.post("/keys/update") -async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)): - app.state.config.OPENAI_API_KEYS = form_data.keys - return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS} + return { + "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API, + "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS, + "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS, + "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS, + } @app.post("/audio/speech") @@ -140,6 +146,11 @@ async def speech(request: Request, user=Depends(get_verified_user)): if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]: headers["HTTP-Referer"] = "https://openwebui.com/" headers["X-Title"] = "Open WebUI" + if ENABLE_FORWARD_USER_INFO_HEADERS: + headers["X-OpenWebUI-User-Name"] = user.name + headers["X-OpenWebUI-User-Id"] = user.id + headers["X-OpenWebUI-User-Email"] = user.email + headers["X-OpenWebUI-User-Role"] = user.role r = None try: r = requests.post( @@ -181,10 +192,10 @@ async def speech(request: Request, user=Depends(get_verified_user)): raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) -async def fetch_url(url, key): +async def aiohttp_get(url, key=None): timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: - headers = {"Authorization": f"Bearer {key}"} + headers = {"Authorization": f"Bearer {key}"} if key else {} async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.get(url, headers=headers) as response: return await response.json() @@ -239,12 +250,8 @@ def merge_models_lists(model_lists): return merged_list -def is_openai_api_disabled(): - return not app.state.config.ENABLE_OPENAI_API - - -async def get_all_models_raw() -> list: - if is_openai_api_disabled(): +async def get_all_models_responses() -> list: + if not app.state.config.ENABLE_OPENAI_API: return [] # Check if API KEYS length is same than API URLS length @@ -260,33 +267,67 @@ async def get_all_models_raw() -> list: else: app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys) - tasks = [ - fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]) - for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS) - ] + tasks = [] + for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS): + if url not in app.state.config.OPENAI_API_CONFIGS: + tasks.append( + aiohttp_get(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]) + ) + else: + api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {}) + + enable = api_config.get("enable", True) + model_ids = api_config.get("model_ids", []) + + if enable: + if len(model_ids) == 0: + tasks.append( + aiohttp_get( + f"{url}/models", app.state.config.OPENAI_API_KEYS[idx] + ) + ) + else: + model_list = { + "object": "list", + "data": [ + { + "id": model_id, + "name": model_id, + "owned_by": "openai", + "openai": {"id": model_id}, + "urlIdx": idx, + } + for model_id in model_ids + ], + } + + tasks.append(asyncio.ensure_future(asyncio.sleep(0, model_list))) responses = await asyncio.gather(*tasks) - log.debug(f"get_all_models:responses() {responses}") - return responses + for idx, response in enumerate(responses): + if response: + url = app.state.config.OPENAI_API_BASE_URLS[idx] + api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {}) + prefix_id = api_config.get("prefix_id", None) -@overload -async def get_all_models(raw: Literal[True]) -> list: ... + if prefix_id: + for model in response["data"]: + model["id"] = f"{prefix_id}.{model['id']}" + log.debug(f"get_all_models:responses() {responses}") -@overload -async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ... + return responses -async def get_all_models(raw=False) -> dict[str, list] | list: +async def get_all_models() -> dict[str, list]: log.info("get_all_models()") - if is_openai_api_disabled(): - return [] if raw else {"data": []} - responses = await get_all_models_raw() - if raw: - return responses + if not app.state.config.ENABLE_OPENAI_API: + return {"data": []} + + responses = await get_all_models_responses() def extract_data(response): if response and "data" in response: @@ -296,9 +337,7 @@ async def get_all_models(raw=False) -> dict[str, list] | list: return None models = {"data": merge_models_lists(map(extract_data, responses))} - log.debug(f"models: {models}") - app.state.MODELS = {model["id"]: model for model in models["data"]} return models @@ -306,18 +345,12 @@ async def get_all_models(raw=False) -> dict[str, list] | list: @app.get("/models") @app.get("/models/{url_idx}") async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)): + models = { + "data": [], + } + if url_idx is None: models = await get_all_models() - if app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user": - models["data"] = list( - filter( - lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST, - models["data"], - ) - ) - return models - return models else: url = app.state.config.OPENAI_API_BASE_URLS[url_idx] key = app.state.config.OPENAI_API_KEYS[url_idx] @@ -326,56 +359,126 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us headers["Authorization"] = f"Bearer {key}" headers["Content-Type"] = "application/json" + if ENABLE_FORWARD_USER_INFO_HEADERS: + headers["X-OpenWebUI-User-Name"] = user.name + headers["X-OpenWebUI-User-Id"] = user.id + headers["X-OpenWebUI-User-Email"] = user.email + headers["X-OpenWebUI-User-Role"] = user.role + r = None - try: - r = requests.request(method="GET", url=f"{url}/models", headers=headers) - r.raise_for_status() + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + async with aiohttp.ClientSession(timeout=timeout) as session: + try: + async with session.get(f"{url}/models", headers=headers) as r: + if r.status != 200: + # Extract response error details if available + error_detail = f"HTTP Error: {r.status}" + res = await r.json() + if "error" in res: + error_detail = f"External Error: {res['error']}" + raise Exception(error_detail) + + response_data = await r.json() + + # Check if we're calling OpenAI API based on the URL + if "api.openai.com" in url: + # Filter models according to the specified conditions + response_data["data"] = [ + model + for model in response_data.get("data", []) + if not any( + name in model["id"] + for name in [ + "babbage", + "dall-e", + "davinci", + "embedding", + "tts", + "whisper", + ] + ) + ] - response_data = r.json() + models = response_data + except aiohttp.ClientError as e: + # ClientError covers all aiohttp requests issues + log.exception(f"Client error: {str(e)}") + # Handle aiohttp-specific connection issues, timeout etc. + raise HTTPException( + status_code=500, detail="Open WebUI: Server Connection Error" + ) + except Exception as e: + log.exception(f"Unexpected error: {e}") + # Generic error handler in case parsing JSON or other steps fail + error_detail = f"Unexpected error: {str(e)}" + raise HTTPException(status_code=500, detail=error_detail) + + if user.role == "user": + # Filter models based on user access control + filtered_models = [] + for model in models.get("data", []): + model_info = Models.get_model_by_id(model["id"]) + if model_info: + if user.id == model_info.user_id or has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + models["data"] = filtered_models - if "api.openai.com" in url: - # Filter the response data - response_data["data"] = [ - model - for model in response_data["data"] - if not any( - name in model["id"] - for name in [ - "babbage", - "dall-e", - "davinci", - "embedding", - "tts", - "whisper", - ] - ) - ] + return models - return response_data - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() + +class ConnectionVerificationForm(BaseModel): + url: str + key: str + + +@app.post("/verify") +async def verify_connection( + form_data: ConnectionVerificationForm, user=Depends(get_admin_user) +): + url = form_data.url + key = form_data.key + + headers = {} + headers["Authorization"] = f"Bearer {key}" + headers["Content-Type"] = "application/json" + + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + async with aiohttp.ClientSession(timeout=timeout) as session: + try: + async with session.get(f"{url}/models", headers=headers) as r: + if r.status != 200: + # Extract response error details if available + error_detail = f"HTTP Error: {r.status}" + res = await r.json() if "error" in res: - error_detail = f"External: {res['error']}" - except Exception: - error_detail = f"External: {e}" + error_detail = f"External Error: {res['error']}" + raise Exception(error_detail) + + response_data = await r.json() + return response_data + except aiohttp.ClientError as e: + # ClientError covers all aiohttp requests issues + log.exception(f"Client error: {str(e)}") + # Handle aiohttp-specific connection issues, timeout etc. raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, + status_code=500, detail="Open WebUI: Server Connection Error" ) + except Exception as e: + log.exception(f"Unexpected error: {e}") + # Generic error handler in case parsing JSON or other steps fail + error_detail = f"Unexpected error: {str(e)}" + raise HTTPException(status_code=500, detail=error_detail) @app.post("/chat/completions") -@app.post("/chat/completions/{url_idx}") async def generate_chat_completion( form_data: dict, - url_idx: Optional[int] = None, user=Depends(get_verified_user), + bypass_filter: Optional[bool] = False, ): idx = 0 payload = {**form_data} @@ -386,6 +489,7 @@ async def generate_chat_completion( model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) + # Check model info and override the payload if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id @@ -394,9 +498,52 @@ async def generate_chat_completion( payload = apply_model_params_to_body_openai(params, payload) payload = apply_model_system_prompt_to_body(params, payload, user) - model = app.state.MODELS[payload.get("model")] - idx = model["urlIdx"] + # Check if user has access to the model + if not bypass_filter and user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + elif not bypass_filter: + if user.role != "admin": + raise HTTPException( + status_code=403, + detail="Model not found", + ) + + # Attemp to get urlIdx from the model + models = await get_all_models() + # Find the model from the list + model = next( + (model for model in models["data"] if model["id"] == payload.get("model")), + None, + ) + + if model: + idx = model["urlIdx"] + else: + raise HTTPException( + status_code=404, + detail="Model not found", + ) + + # Get the API config for the model + api_config = app.state.config.OPENAI_API_CONFIGS.get( + app.state.config.OPENAI_API_BASE_URLS[idx], {} + ) + prefix_id = api_config.get("prefix_id", None) + + if prefix_id: + payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + + # Add user info to the payload if the model is a pipeline if "pipeline" in model and model.get("pipeline"): payload["user"] = { "name": user.name, @@ -407,8 +554,9 @@ async def generate_chat_completion( url = app.state.config.OPENAI_API_BASE_URLS[idx] key = app.state.config.OPENAI_API_KEYS[idx] - is_o1 = payload["model"].lower().startswith("o1-") + # Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens" + is_o1 = payload["model"].lower().startswith("o1-") # Change max_completion_tokens to max_tokens (Backward compatible) if "api.openai.com" not in url and not is_o1: if "max_completion_tokens" in payload: @@ -437,6 +585,11 @@ async def generate_chat_completion( if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]: headers["HTTP-Referer"] = "https://openwebui.com/" headers["X-Title"] = "Open WebUI" + if ENABLE_FORWARD_USER_INFO_HEADERS: + headers["X-OpenWebUI-User-Name"] = user.name + headers["X-OpenWebUI-User-Id"] = user.id + headers["X-OpenWebUI-User-Email"] = user.email + headers["X-OpenWebUI-User-Role"] = user.role r = None session = None @@ -505,6 +658,11 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): headers = {} headers["Authorization"] = f"Bearer {key}" headers["Content-Type"] = "application/json" + if ENABLE_FORWARD_USER_INFO_HEADERS: + headers["X-OpenWebUI-User-Name"] = user.name + headers["X-OpenWebUI-User-Id"] = user.id + headers["X-OpenWebUI-User-Email"] = user.email + headers["X-OpenWebUI-User-Role"] = user.role r = None session = None diff --git a/backend/open_webui/apps/retrieval/loaders/main.py b/backend/open_webui/apps/retrieval/loaders/main.py index ceb868a828088153b1bc26af135540a4440741ef..e01e756bd7cc4f9d7f6dff33d3a8197790291951 100644 --- a/backend/open_webui/apps/retrieval/loaders/main.py +++ b/backend/open_webui/apps/retrieval/loaders/main.py @@ -159,7 +159,7 @@ class Loader: elif file_ext in ["htm", "html"]: loader = BSHTMLLoader(file_path, open_encoding="unicode_escape") elif file_ext == "md": - loader = UnstructuredMarkdownLoader(file_path) + loader = TextLoader(file_path, autodetect_encoding=True) elif file_content_type == "application/epub+zip": loader = UnstructuredEPubLoader(file_path) elif ( diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index 49772d5dc9ab4405eecc9200ce524ed22a251b7b..753239bc46bbfe10550c2cd8f646643145df5f35 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -37,6 +37,7 @@ from open_webui.apps.retrieval.web.serper import search_serper from open_webui.apps.retrieval.web.serply import search_serply from open_webui.apps.retrieval.web.serpstack import search_serpstack from open_webui.apps.retrieval.web.tavily import search_tavily +from open_webui.apps.retrieval.web.bing import search_bing from open_webui.apps.retrieval.utils import ( @@ -74,6 +75,8 @@ from open_webui.config import ( RAG_FILE_MAX_SIZE, RAG_OPENAI_API_BASE_URL, RAG_OPENAI_API_KEY, + RAG_OLLAMA_BASE_URL, + RAG_OLLAMA_API_KEY, RAG_RELEVANCE_THRESHOLD, RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, @@ -85,6 +88,7 @@ from open_webui.config import ( RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, RAG_WEB_SEARCH_ENGINE, RAG_WEB_SEARCH_RESULT_COUNT, + JINA_API_KEY, SEARCHAPI_API_KEY, SEARCHAPI_ENGINE, SEARXNG_QUERY_URL, @@ -93,13 +97,20 @@ from open_webui.config import ( SERPSTACK_API_KEY, SERPSTACK_HTTPS, TAVILY_API_KEY, + BING_SEARCH_V7_ENDPOINT, + BING_SEARCH_V7_SUBSCRIPTION_KEY, TIKA_SERVER_URL, UPLOAD_DIR, YOUTUBE_LOADER_LANGUAGE, + DEFAULT_LOCALE, AppConfig, ) from open_webui.constants import ERROR_MESSAGES -from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE, DOCKER +from open_webui.env import ( + SRC_LOG_LEVELS, + DEVICE_TYPE, + DOCKER, +) from open_webui.utils.misc import ( calculate_sha256, calculate_sha256_string, @@ -118,7 +129,11 @@ from langchain_core.documents import Document log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -app = FastAPI() +app = FastAPI( + docs_url="/docs" if ENV == "dev" else None, + openapi_url="/openapi.json" if ENV == "dev" else None, + redoc_url=None, +) app.state.config = AppConfig() @@ -150,6 +165,9 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY +app.state.config.OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL +app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY + app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE @@ -171,6 +189,10 @@ app.state.config.SERPLY_API_KEY = SERPLY_API_KEY app.state.config.TAVILY_API_KEY = TAVILY_API_KEY app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE +app.state.config.JINA_API_KEY = JINA_API_KEY +app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT +app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY + app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS @@ -182,11 +204,15 @@ def update_embedding_model( if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": from sentence_transformers import SentenceTransformer - app.state.sentence_transformer_ef = SentenceTransformer( - get_model_path(embedding_model, auto_update), - device=DEVICE_TYPE, - trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, - ) + try: + app.state.sentence_transformer_ef = SentenceTransformer( + get_model_path(embedding_model, auto_update), + device=DEVICE_TYPE, + trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + ) + except Exception as e: + log.debug(f"Error loading SentenceTransformer: {e}") + app.state.sentence_transformer_ef = None else: app.state.sentence_transformer_ef = None @@ -240,8 +266,16 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - app.state.config.OPENAI_API_KEY, - app.state.config.OPENAI_API_BASE_URL, + ( + app.state.config.OPENAI_API_BASE_URL + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_BASE_URL + ), + ( + app.state.config.OPENAI_API_KEY + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_API_KEY + ), app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) @@ -291,6 +325,10 @@ async def get_embedding_config(user=Depends(get_admin_user)): "url": app.state.config.OPENAI_API_BASE_URL, "key": app.state.config.OPENAI_API_KEY, }, + "ollama_config": { + "url": app.state.config.OLLAMA_BASE_URL, + "key": app.state.config.OLLAMA_API_KEY, + }, } @@ -307,8 +345,14 @@ class OpenAIConfigForm(BaseModel): key: str +class OllamaConfigForm(BaseModel): + url: str + key: str + + class EmbeddingModelUpdateForm(BaseModel): openai_config: Optional[OpenAIConfigForm] = None + ollama_config: Optional[OllamaConfigForm] = None embedding_engine: str embedding_model: str embedding_batch_size: Optional[int] = 1 @@ -329,6 +373,11 @@ async def update_embedding_config( if form_data.openai_config is not None: app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url app.state.config.OPENAI_API_KEY = form_data.openai_config.key + + if form_data.ollama_config is not None: + app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url + app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key + app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL) @@ -337,8 +386,16 @@ async def update_embedding_config( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - app.state.config.OPENAI_API_KEY, - app.state.config.OPENAI_API_BASE_URL, + ( + app.state.config.OPENAI_API_BASE_URL + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_BASE_URL + ), + ( + app.state.config.OPENAI_API_KEY + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_API_KEY + ), app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) @@ -351,6 +408,10 @@ async def update_embedding_config( "url": app.state.config.OPENAI_API_BASE_URL, "key": app.state.config.OPENAI_API_KEY, }, + "ollama_config": { + "url": app.state.config.OLLAMA_BASE_URL, + "key": app.state.config.OLLAMA_API_KEY, + }, } except Exception as e: log.exception(f"Problem updating embedding model: {e}") @@ -411,7 +472,7 @@ async def get_rag_config(user=Depends(get_admin_user)): "translation": app.state.YOUTUBE_LOADER_TRANSLATION, }, "web": { - "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "search": { "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH, "engine": app.state.config.RAG_WEB_SEARCH_ENGINE, @@ -426,6 +487,9 @@ async def get_rag_config(user=Depends(get_admin_user)): "tavily_api_key": app.state.config.TAVILY_API_KEY, "searchapi_api_key": app.state.config.SEARCHAPI_API_KEY, "seaarchapi_engine": app.state.config.SEARCHAPI_ENGINE, + "jina_api_key": app.state.config.JINA_API_KEY, + "bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT, + "bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, @@ -468,6 +532,9 @@ class WebSearchConfig(BaseModel): tavily_api_key: Optional[str] = None searchapi_api_key: Optional[str] = None searchapi_engine: Optional[str] = None + jina_api_key: Optional[str] = None + bing_search_v7_endpoint: Optional[str] = None + bing_search_v7_subscription_key: Optional[str] = None result_count: Optional[int] = None concurrent_requests: Optional[int] = None @@ -514,6 +581,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ if form_data.web is not None: app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( + # Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False form_data.web.web_loader_ssl_verification ) @@ -534,6 +602,15 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine + + app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key + app.state.config.BING_SEARCH_V7_ENDPOINT = ( + form_data.web.search.bing_search_v7_endpoint + ) + app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = ( + form_data.web.search.bing_search_v7_subscription_key + ) + app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( form_data.web.search.concurrent_requests @@ -560,7 +637,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ "translation": app.state.YOUTUBE_LOADER_TRANSLATION, }, "web": { - "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "search": { "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH, "engine": app.state.config.RAG_WEB_SEARCH_ENGINE, @@ -575,6 +652,9 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ "serachapi_api_key": app.state.config.SEARCHAPI_API_KEY, "searchapi_engine": app.state.config.SEARCHAPI_ENGINE, "tavily_api_key": app.state.config.TAVILY_API_KEY, + "jina_api_key": app.state.config.JINA_API_KEY, + "bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT, + "bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, @@ -636,6 +716,23 @@ async def update_query_settings( #################################### +def _get_docs_info(docs: list[Document]) -> str: + docs_info = set() + + # Trying to select relevant metadata identifying the document. + for doc in docs: + metadata = getattr(doc, "metadata", {}) + doc_name = metadata.get("name", "") + if not doc_name: + doc_name = metadata.get("title", "") + if not doc_name: + doc_name = metadata.get("source", "") + if doc_name: + docs_info.add(doc_name) + + return ", ".join(docs_info) + + def save_docs_to_vector_db( docs, collection_name, @@ -644,7 +741,9 @@ def save_docs_to_vector_db( split: bool = True, add: bool = False, ) -> bool: - log.info(f"save_docs_to_vector_db {docs} {collection_name}") + log.info( + f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}" + ) # Check if entries with the same hash (metadata.hash) already exist if metadata and "hash" in metadata: @@ -726,8 +825,16 @@ def save_docs_to_vector_db( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - app.state.config.OPENAI_API_KEY, - app.state.config.OPENAI_API_BASE_URL, + ( + app.state.config.OPENAI_API_BASE_URL + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_BASE_URL + ), + ( + app.state.config.OPENAI_API_KEY + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_API_KEY + ), app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) @@ -954,7 +1061,7 @@ def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_u loader = YoutubeLoader.from_youtube_url( form_data.url, - add_video_info=True, + add_video_info=False, language=app.state.config.YOUTUBE_LOADER_LANGUAGE, translation=app.state.YOUTUBE_LOADER_TRANSLATION, ) @@ -1132,7 +1239,20 @@ def search_web(engine: str, query: str) -> list[SearchResult]: else: raise Exception("No SEARCHAPI_API_KEY found in environment variables") elif engine == "jina": - return search_jina(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT) + return search_jina( + app.state.config.JINA_API_KEY, + query, + app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + ) + elif engine == "bing": + return search_bing( + app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, + app.state.config.BING_SEARCH_V7_ENDPOINT, + str(DEFAULT_LOCALE), + query, + app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + ) else: raise Exception("No search engine API key found in environment variables") @@ -1162,8 +1282,12 @@ def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)): urls = [result.link for result in web_results] - loader = get_web_loader(urls) - docs = loader.load() + loader = get_web_loader( + urls, + verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + ) + docs = loader.aload() save_docs_to_vector_db(docs, collection_name, overwrite=True) diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py index 77903de0a7196ab7c7f8c8fb88f5eeb2a45e7f38..d992f1b34ba1341b83449606c47da93386ae9f7b 100644 --- a/backend/open_webui/apps/retrieval/utils.py +++ b/backend/open_webui/apps/retrieval/utils.py @@ -3,6 +3,7 @@ import os import uuid from typing import Optional, Union +import asyncio import requests from huggingface_hub import snapshot_download @@ -10,11 +11,6 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev from langchain_community.retrievers import BM25Retriever from langchain_core.documents import Document - -from open_webui.apps.ollama.main import ( - GenerateEmbedForm, - generate_ollama_batch_embeddings, -) from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message @@ -76,7 +72,7 @@ def query_doc( limit=k, ) - log.info(f"query_doc:result {result}") + log.info(f"query_doc:result {result.ids} {result.metadatas}") return result except Exception as e: print(e) @@ -127,7 +123,10 @@ def query_doc_with_hybrid_search( "metadatas": [[d.metadata for d in result]], } - log.info(f"query_doc_with_hybrid_search:result {result}") + log.info( + "query_doc_with_hybrid_search:result " + + f'{result["metadatas"]} {result["distances"]}' + ) return result except Exception as e: raise e @@ -178,35 +177,34 @@ def merge_and_sort_query_results( def query_collection( collection_names: list[str], - query: str, + queries: list[str], embedding_function, k: int, ) -> dict: - results = [] - query_embedding = embedding_function(query) - - for collection_name in collection_names: - if collection_name: - try: - result = query_doc( - collection_name=collection_name, - k=k, - query_embedding=query_embedding, - ) - if result is not None: - results.append(result.model_dump()) - except Exception as e: - log.exception(f"Error when querying the collection: {e}") - else: - pass + for query in queries: + query_embedding = embedding_function(query) + for collection_name in collection_names: + if collection_name: + try: + result = query_doc( + collection_name=collection_name, + k=k, + query_embedding=query_embedding, + ) + if result is not None: + results.append(result.model_dump()) + except Exception as e: + log.exception(f"Error when querying the collection: {e}") + else: + pass return merge_and_sort_query_results(results, k=k) def query_collection_with_hybrid_search( collection_names: list[str], - query: str, + queries: list[str], embedding_function, k: int, reranking_function, @@ -216,15 +214,16 @@ def query_collection_with_hybrid_search( error = False for collection_name in collection_names: try: - result = query_doc_with_hybrid_search( - collection_name=collection_name, - query=query, - embedding_function=embedding_function, - k=k, - reranking_function=reranking_function, - r=r, - ) - results.append(result) + for query in queries: + result = query_doc_with_hybrid_search( + collection_name=collection_name, + query=query, + embedding_function=embedding_function, + k=k, + reranking_function=reranking_function, + r=r, + ) + results.append(result) except Exception as e: log.exception( "Error when querying the collection with " f"hybrid_search: {e}" @@ -281,8 +280,8 @@ def get_embedding_function( embedding_engine, embedding_model, embedding_function, - openai_key, - openai_url, + url, + key, embedding_batch_size, ): if embedding_engine == "": @@ -292,8 +291,8 @@ def get_embedding_function( engine=embedding_engine, model=embedding_model, text=query, - key=openai_key if embedding_engine == "openai" else "", - url=openai_url if embedding_engine == "openai" else "", + url=url, + key=key, ) def generate_multiple(query, func): @@ -310,15 +309,14 @@ def get_embedding_function( def get_rag_context( files, - messages, + queries, embedding_function, k, reranking_function, r, hybrid_search, ): - log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}") - query = get_last_user_message(messages) + log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}") extracted_collections = [] relevant_contexts = [] @@ -360,7 +358,7 @@ def get_rag_context( try: context = query_collection_with_hybrid_search( collection_names=collection_names, - query=query, + queries=queries, embedding_function=embedding_function, k=k, reranking_function=reranking_function, @@ -375,7 +373,7 @@ def get_rag_context( if (not hybrid_search) or (context is None): context = query_collection( collection_names=collection_names, - query=query, + queries=queries, embedding_function=embedding_function, k=k, ) @@ -467,7 +465,7 @@ def get_model_path(model: str, update_model: bool = False): def generate_openai_batch_embeddings( - model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1" + model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = "" ) -> Optional[list[list[float]]]: try: r = requests.post( @@ -489,29 +487,50 @@ def generate_openai_batch_embeddings( return None +def generate_ollama_batch_embeddings( + model: str, texts: list[str], url: str, key: str +) -> Optional[list[list[float]]]: + try: + r = requests.post( + f"{url}/api/embed", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {key}", + }, + json={"input": texts, "model": model}, + ) + r.raise_for_status() + data = r.json() + + print(data) + if "embeddings" in data: + return data["embeddings"] + else: + raise "Something went wrong :/" + except Exception as e: + print(e) + return None + + def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs): + url = kwargs.get("url", "") + key = kwargs.get("key", "") + if engine == "ollama": if isinstance(text, list): embeddings = generate_ollama_batch_embeddings( - GenerateEmbedForm(**{"model": model, "input": text}) + **{"model": model, "texts": text, "url": url, "key": key} ) else: embeddings = generate_ollama_batch_embeddings( - GenerateEmbedForm(**{"model": model, "input": [text]}) + **{"model": model, "texts": [text], "url": url, "key": key} ) - return ( - embeddings["embeddings"][0] - if isinstance(text, str) - else embeddings["embeddings"] - ) + return embeddings[0] if isinstance(text, str) else embeddings elif engine == "openai": - key = kwargs.get("key", "") - url = kwargs.get("url", "https://api.openai.com/v1") - if isinstance(text, list): - embeddings = generate_openai_batch_embeddings(model, text, key, url) + embeddings = generate_openai_batch_embeddings(model, text, url, key) else: - embeddings = generate_openai_batch_embeddings(model, [text], key, url) + embeddings = generate_openai_batch_embeddings(model, [text], url, key) return embeddings[0] if isinstance(text, str) else embeddings diff --git a/backend/open_webui/apps/retrieval/vector/connector.py b/backend/open_webui/apps/retrieval/vector/connector.py index c9d8c2d787f1dca19dd0ed83e9434e704652008a..acfc526a9e20781959da915f02c1ace889c9f765 100644 --- a/backend/open_webui/apps/retrieval/vector/connector.py +++ b/backend/open_webui/apps/retrieval/vector/connector.py @@ -8,6 +8,14 @@ elif VECTOR_DB == "qdrant": from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient VECTOR_DB_CLIENT = QdrantClient() +elif VECTOR_DB == "opensearch": + from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient + + VECTOR_DB_CLIENT = OpenSearchClient() +elif VECTOR_DB == "pgvector": + from open_webui.apps.retrieval.vector.dbs.pgvector import PgvectorClient + + VECTOR_DB_CLIENT = PgvectorClient() else: from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient diff --git a/backend/open_webui/apps/retrieval/vector/dbs/chroma.py b/backend/open_webui/apps/retrieval/vector/dbs/chroma.py index cb4d6283f6052c28548a2968857b6173254dbb26..d61ef0cff2b493b637d6daa7d0078e8147d789e0 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/apps/retrieval/vector/dbs/chroma.py @@ -27,7 +27,9 @@ class ChromaClient: if CHROMA_CLIENT_AUTH_PROVIDER is not None: settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER if CHROMA_CLIENT_AUTH_CREDENTIALS is not None: - settings_dict["chroma_client_auth_credentials"] = CHROMA_CLIENT_AUTH_CREDENTIALS + settings_dict["chroma_client_auth_credentials"] = ( + CHROMA_CLIENT_AUTH_CREDENTIALS + ) if CHROMA_HTTP_HOST != "": self.client = chromadb.HttpClient( diff --git a/backend/open_webui/apps/retrieval/vector/dbs/opensearch.py b/backend/open_webui/apps/retrieval/vector/dbs/opensearch.py new file mode 100644 index 0000000000000000000000000000000000000000..57d5636443ce2fa694f9e31bc9f3faa2dbe4ce0e --- /dev/null +++ b/backend/open_webui/apps/retrieval/vector/dbs/opensearch.py @@ -0,0 +1,178 @@ +from opensearchpy import OpenSearch +from typing import Optional + +from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult +from open_webui.config import ( + OPENSEARCH_URI, + OPENSEARCH_SSL, + OPENSEARCH_CERT_VERIFY, + OPENSEARCH_USERNAME, + OPENSEARCH_PASSWORD, +) + + +class OpenSearchClient: + def __init__(self): + self.index_prefix = "open_webui" + self.client = OpenSearch( + hosts=[OPENSEARCH_URI], + use_ssl=OPENSEARCH_SSL, + verify_certs=OPENSEARCH_CERT_VERIFY, + http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD), + ) + + def _result_to_get_result(self, result) -> GetResult: + ids = [] + documents = [] + metadatas = [] + + for hit in result["hits"]["hits"]: + ids.append(hit["_id"]) + documents.append(hit["_source"].get("text")) + metadatas.append(hit["_source"].get("metadata")) + + return GetResult(ids=ids, documents=documents, metadatas=metadatas) + + def _result_to_search_result(self, result) -> SearchResult: + ids = [] + distances = [] + documents = [] + metadatas = [] + + for hit in result["hits"]["hits"]: + ids.append(hit["_id"]) + distances.append(hit["_score"]) + documents.append(hit["_source"].get("text")) + metadatas.append(hit["_source"].get("metadata")) + + return SearchResult( + ids=ids, distances=distances, documents=documents, metadatas=metadatas + ) + + def _create_index(self, index_name: str, dimension: int): + body = { + "mappings": { + "properties": { + "id": {"type": "keyword"}, + "vector": { + "type": "dense_vector", + "dims": dimension, # Adjust based on your vector dimensions + "index": true, + "similarity": "faiss", + "method": { + "name": "hnsw", + "space_type": "ip", # Use inner product to approximate cosine similarity + "engine": "faiss", + "ef_construction": 128, + "m": 16, + }, + }, + "text": {"type": "text"}, + "metadata": {"type": "object"}, + } + } + } + self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body) + + def _create_batches(self, items: list[VectorItem], batch_size=100): + for i in range(0, len(items), batch_size): + yield items[i : i + batch_size] + + def has_collection(self, index_name: str) -> bool: + # has_collection here means has index. + # We are simply adapting to the norms of the other DBs. + return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}") + + def delete_colleciton(self, index_name: str): + # delete_collection here means delete index. + # We are simply adapting to the norms of the other DBs. + self.client.indices.delete(index=f"{self.index_prefix}_{index_name}") + + def search( + self, index_name: str, vectors: list[list[float]], limit: int + ) -> Optional[SearchResult]: + query = { + "size": limit, + "_source": ["text", "metadata"], + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": "cosineSimilarity(params.vector, 'vector') + 1.0", + "params": { + "vector": vectors[0] + }, # Assuming single query vector + }, + } + }, + } + + result = self.client.search( + index=f"{self.index_prefix}_{index_name}", body=query + ) + + return self._result_to_search_result(result) + + def get_or_create_index(self, index_name: str, dimension: int): + if not self.has_index(index_name): + self._create_index(index_name, dimension) + + def get(self, index_name: str) -> Optional[GetResult]: + query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]} + + result = self.client.search( + index=f"{self.index_prefix}_{index_name}", body=query + ) + return self._result_to_get_result(result) + + def insert(self, index_name: str, items: list[VectorItem]): + if not self.has_index(index_name): + self._create_index(index_name, dimension=len(items[0]["vector"])) + + for batch in self._create_batches(items): + actions = [ + { + "index": { + "_id": item["id"], + "_source": { + "vector": item["vector"], + "text": item["text"], + "metadata": item["metadata"], + }, + } + } + for item in batch + ] + self.client.bulk(actions) + + def upsert(self, index_name: str, items: list[VectorItem]): + if not self.has_index(index_name): + self._create_index(index_name, dimension=len(items[0]["vector"])) + + for batch in self._create_batches(items): + actions = [ + { + "index": { + "_id": item["id"], + "_source": { + "vector": item["vector"], + "text": item["text"], + "metadata": item["metadata"], + }, + } + } + for item in batch + ] + self.client.bulk(actions) + + def delete(self, index_name: str, ids: list[str]): + actions = [ + {"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}} + for id in ids + ] + self.client.bulk(body=actions) + + def reset(self): + indices = self.client.indices.get(index=f"{self.index_prefix}_*") + for index in indices: + self.client.indices.delete(index=index) diff --git a/backend/open_webui/apps/retrieval/vector/dbs/pgvector.py b/backend/open_webui/apps/retrieval/vector/dbs/pgvector.py new file mode 100644 index 0000000000000000000000000000000000000000..11fcab05cf6ffaa88af32ce0807f8ef3dadb8c8d --- /dev/null +++ b/backend/open_webui/apps/retrieval/vector/dbs/pgvector.py @@ -0,0 +1,354 @@ +from typing import Optional, List, Dict, Any +from sqlalchemy import ( + cast, + column, + create_engine, + Column, + Integer, + select, + text, + Text, + values, +) +from sqlalchemy.sql import true +from sqlalchemy.pool import NullPool + +from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker +from sqlalchemy.dialects.postgresql import JSONB, array +from pgvector.sqlalchemy import Vector +from sqlalchemy.ext.mutable import MutableDict + +from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.config import PGVECTOR_DB_URL + +VECTOR_LENGTH = 1536 +Base = declarative_base() + + +class DocumentChunk(Base): + __tablename__ = "document_chunk" + + id = Column(Text, primary_key=True) + vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True) + collection_name = Column(Text, nullable=False) + text = Column(Text, nullable=True) + vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True) + + +class PgvectorClient: + def __init__(self) -> None: + + # if no pgvector uri, use the existing database connection + if not PGVECTOR_DB_URL: + from open_webui.apps.webui.internal.db import Session + + self.session = Session + else: + engine = create_engine( + PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool + ) + SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=engine, expire_on_commit=False + ) + self.session = scoped_session(SessionLocal) + + try: + # Ensure the pgvector extension is available + self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) + + # Create the tables if they do not exist + # Base.metadata.create_all requires a bind (engine or connection) + # Get the connection from the session + connection = self.session.connection() + Base.metadata.create_all(bind=connection) + + # Create an index on the vector column if it doesn't exist + self.session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector " + "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);" + ) + ) + self.session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name " + "ON document_chunk (collection_name);" + ) + ) + self.session.commit() + print("Initialization complete.") + except Exception as e: + self.session.rollback() + print(f"Error during initialization: {e}") + raise + + def adjust_vector_length(self, vector: List[float]) -> List[float]: + # Adjust vector to have length VECTOR_LENGTH + current_length = len(vector) + if current_length < VECTOR_LENGTH: + # Pad the vector with zeros + vector += [0.0] * (VECTOR_LENGTH - current_length) + elif current_length > VECTOR_LENGTH: + raise Exception( + f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}" + ) + return vector + + def insert(self, collection_name: str, items: List[VectorItem]) -> None: + try: + new_items = [] + for item in items: + vector = self.adjust_vector_length(item["vector"]) + new_chunk = DocumentChunk( + id=item["id"], + vector=vector, + collection_name=collection_name, + text=item["text"], + vmetadata=item["metadata"], + ) + new_items.append(new_chunk) + self.session.bulk_save_objects(new_items) + self.session.commit() + print( + f"Inserted {len(new_items)} items into collection '{collection_name}'." + ) + except Exception as e: + self.session.rollback() + print(f"Error during insert: {e}") + raise + + def upsert(self, collection_name: str, items: List[VectorItem]) -> None: + try: + for item in items: + vector = self.adjust_vector_length(item["vector"]) + existing = ( + self.session.query(DocumentChunk) + .filter(DocumentChunk.id == item["id"]) + .first() + ) + if existing: + existing.vector = vector + existing.text = item["text"] + existing.vmetadata = item["metadata"] + existing.collection_name = ( + collection_name # Update collection_name if necessary + ) + else: + new_chunk = DocumentChunk( + id=item["id"], + vector=vector, + collection_name=collection_name, + text=item["text"], + vmetadata=item["metadata"], + ) + self.session.add(new_chunk) + self.session.commit() + print(f"Upserted {len(items)} items into collection '{collection_name}'.") + except Exception as e: + self.session.rollback() + print(f"Error during upsert: {e}") + raise + + def search( + self, + collection_name: str, + vectors: List[List[float]], + limit: Optional[int] = None, + ) -> Optional[SearchResult]: + try: + if not vectors: + return None + + # Adjust query vectors to VECTOR_LENGTH + vectors = [self.adjust_vector_length(vector) for vector in vectors] + num_queries = len(vectors) + + def vector_expr(vector): + return cast(array(vector), Vector(VECTOR_LENGTH)) + + # Create the values for query vectors + qid_col = column("qid", Integer) + q_vector_col = column("q_vector", Vector(VECTOR_LENGTH)) + query_vectors = ( + values(qid_col, q_vector_col) + .data( + [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)] + ) + .alias("query_vectors") + ) + + # Build the lateral subquery for each query vector + subq = ( + select( + DocumentChunk.id, + DocumentChunk.text, + DocumentChunk.vmetadata, + ( + DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector) + ).label("distance"), + ) + .where(DocumentChunk.collection_name == collection_name) + .order_by( + (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)) + ) + ) + if limit is not None: + subq = subq.limit(limit) + subq = subq.lateral("result") + + # Build the main query by joining query_vectors and the lateral subquery + stmt = ( + select( + query_vectors.c.qid, + subq.c.id, + subq.c.text, + subq.c.vmetadata, + subq.c.distance, + ) + .select_from(query_vectors) + .join(subq, true()) + .order_by(query_vectors.c.qid, subq.c.distance) + ) + + result_proxy = self.session.execute(stmt) + results = result_proxy.all() + + ids = [[] for _ in range(num_queries)] + distances = [[] for _ in range(num_queries)] + documents = [[] for _ in range(num_queries)] + metadatas = [[] for _ in range(num_queries)] + + if not results: + return SearchResult( + ids=ids, + distances=distances, + documents=documents, + metadatas=metadatas, + ) + + for row in results: + qid = int(row.qid) + ids[qid].append(row.id) + distances[qid].append(row.distance) + documents[qid].append(row.text) + metadatas[qid].append(row.vmetadata) + + return SearchResult( + ids=ids, distances=distances, documents=documents, metadatas=metadatas + ) + except Exception as e: + print(f"Error during search: {e}") + return None + + def query( + self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None + ) -> Optional[GetResult]: + try: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + + for key, value in filter.items(): + query = query.filter(DocumentChunk.vmetadata[key].astext == str(value)) + + if limit is not None: + query = query.limit(limit) + + results = query.all() + + if not results: + return None + + ids = [[result.id for result in results]] + documents = [[result.text for result in results]] + metadatas = [[result.vmetadata for result in results]] + + return GetResult( + ids=ids, + documents=documents, + metadatas=metadatas, + ) + except Exception as e: + print(f"Error during query: {e}") + return None + + def get( + self, collection_name: str, limit: Optional[int] = None + ) -> Optional[GetResult]: + try: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + if limit is not None: + query = query.limit(limit) + + results = query.all() + + if not results: + return None + + ids = [[result.id for result in results]] + documents = [[result.text for result in results]] + metadatas = [[result.vmetadata for result in results]] + + return GetResult(ids=ids, documents=documents, metadatas=metadatas) + except Exception as e: + print(f"Error during get: {e}") + return None + + def delete( + self, + collection_name: str, + ids: Optional[List[str]] = None, + filter: Optional[Dict[str, Any]] = None, + ) -> None: + try: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + if ids: + query = query.filter(DocumentChunk.id.in_(ids)) + if filter: + for key, value in filter.items(): + query = query.filter( + DocumentChunk.vmetadata[key].astext == str(value) + ) + deleted = query.delete(synchronize_session=False) + self.session.commit() + print(f"Deleted {deleted} items from collection '{collection_name}'.") + except Exception as e: + self.session.rollback() + print(f"Error during delete: {e}") + raise + + def reset(self) -> None: + try: + deleted = self.session.query(DocumentChunk).delete() + self.session.commit() + print( + f"Reset complete. Deleted {deleted} items from 'document_chunk' table." + ) + except Exception as e: + self.session.rollback() + print(f"Error during reset: {e}") + raise + + def close(self) -> None: + pass + + def has_collection(self, collection_name: str) -> bool: + try: + exists = ( + self.session.query(DocumentChunk) + .filter(DocumentChunk.collection_name == collection_name) + .first() + is not None + ) + return exists + except Exception as e: + print(f"Error checking collection existence: {e}") + return False + + def delete_collection(self, collection_name: str) -> None: + self.delete(collection_name) + print(f"Collection '{collection_name}' deleted.") diff --git a/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py b/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py index 09669e96abdf950c424f3f9f004fb6c3c0a2d601..822f0de29bab9d0d77389cdf4bc1d46a2f4ea4c2 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py @@ -5,7 +5,7 @@ from qdrant_client.http.models import PointStruct from qdrant_client.models import models from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult -from open_webui.config import QDRANT_URI +from open_webui.config import QDRANT_URI, QDRANT_API_KEY NO_LIMIT = 999999999 @@ -14,7 +14,12 @@ class QdrantClient: def __init__(self): self.collection_prefix = "open-webui" self.QDRANT_URI = QDRANT_URI - self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None + self.QDRANT_API_KEY = QDRANT_API_KEY + self.client = ( + Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY) + if self.QDRANT_URI + else None + ) def _result_to_get_result(self, points) -> GetResult: ids = [] diff --git a/backend/open_webui/apps/retrieval/web/bing.py b/backend/open_webui/apps/retrieval/web/bing.py new file mode 100644 index 0000000000000000000000000000000000000000..c675187f469031df399c0325feca2fa7e7249583 --- /dev/null +++ b/backend/open_webui/apps/retrieval/web/bing.py @@ -0,0 +1,73 @@ +import logging +import os +from pprint import pprint +from typing import Optional +import requests +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.env import SRC_LOG_LEVELS +import argparse + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) +""" +Documentation: https://docs.microsoft.com/en-us/bing/search-apis/bing-web-search/overview +""" + + +def search_bing( + subscription_key: str, + endpoint: str, + locale: str, + query: str, + count: int, + filter_list: Optional[list[str]] = None, +) -> list[SearchResult]: + mkt = locale + params = {"q": query, "mkt": mkt, "answerCount": count} + headers = {"Ocp-Apim-Subscription-Key": subscription_key} + + try: + response = requests.get(endpoint, headers=headers, params=params) + response.raise_for_status() + json_response = response.json() + results = json_response.get("webPages", {}).get("value", []) + if filter_list: + results = get_filtered_results(results, filter_list) + return [ + SearchResult( + link=result["url"], + title=result.get("name"), + snippet=result.get("snippet"), + ) + for result in results + ] + except Exception as ex: + log.error(f"Error: {ex}") + raise ex + + +def main(): + parser = argparse.ArgumentParser(description="Search Bing from the command line.") + parser.add_argument( + "query", + type=str, + default="Top 10 international news today", + help="The search query.", + ) + parser.add_argument( + "--count", type=int, default=10, help="Number of search results to return." + ) + parser.add_argument( + "--filter", nargs="*", help="List of filters to apply to the search results." + ) + parser.add_argument( + "--locale", + type=str, + default="en-US", + help="The locale to use for the search, maps to market in api", + ) + + args = parser.parse_args() + + results = search_bing(args.locale, args.query, args.count, args.filter) + pprint(results) diff --git a/backend/open_webui/apps/retrieval/web/jina_search.py b/backend/open_webui/apps/retrieval/web/jina_search.py index 03288dd4ee01ef50c0aee998d8ffbbb79502e329..baed89934787ca8e9a6c2a7a43f75d477f612c88 100644 --- a/backend/open_webui/apps/retrieval/web/jina_search.py +++ b/backend/open_webui/apps/retrieval/web/jina_search.py @@ -9,7 +9,7 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -def search_jina(query: str, count: int) -> list[SearchResult]: +def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]: """ Search using Jina's Search API and return the results as a list of SearchResult objects. Args: @@ -20,9 +20,7 @@ def search_jina(query: str, count: int) -> list[SearchResult]: list[SearchResult]: A list of search results """ jina_search_endpoint = "https://s.jina.ai/" - headers = { - "Accept": "application/json", - } + headers = {"Accept": "application/json", "Authorization": f"Bearer {api_key}"} url = str(URL(jina_search_endpoint + query)) response = requests.get(url, headers=headers) response.raise_for_status() diff --git a/backend/open_webui/apps/retrieval/web/testdata/bing.json b/backend/open_webui/apps/retrieval/web/testdata/bing.json new file mode 100644 index 0000000000000000000000000000000000000000..01fe0973b2968f6ce0b01fc174d05f33dd2f9bd9 --- /dev/null +++ b/backend/open_webui/apps/retrieval/web/testdata/bing.json @@ -0,0 +1,58 @@ +{ + "_type": "SearchResponse", + "queryContext": { + "originalQuery": "Top 10 international results" + }, + "webPages": { + "webSearchUrl": "https://www.bing.com/search?q=Top+10+international+results", + "totalEstimatedMatches": 687, + "value": [ + { + "id": "https://api.bing.microsoft.com/api/v7/#WebPages.0", + "name": "2024 Mexican Grand Prix - F1 results and latest standings ... - PlanetF1", + "url": "https://www.planetf1.com/news/f1-results-2024-mexican-grand-prix-race-standings", + "datePublished": "2024-10-27T00:00:00.0000000", + "datePublishedFreshnessText": "1 day ago", + "isFamilyFriendly": true, + "displayUrl": "https://www.planetf1.com/news/f1-results-2024-mexican-grand-prix-race-standings", + "snippet": "Nico Hulkenberg and Pierre Gasly completed the top 10. A full report of the Mexican Grand Prix is available at the bottom of this article. F1 results – 2024 Mexican Grand Prix", + "dateLastCrawled": "2024-10-28T07:15:00.0000000Z", + "cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=916492551782&mkt=en-US&setlang=en-US&w=zBsfaAPyF2tUrHFHr_vFFdUm8sng4g34", + "language": "en", + "isNavigational": false, + "noCache": false + }, + { + "id": "https://api.bing.microsoft.com/api/v7/#WebPages.1", + "name": "F1 Results Today: HUGE Verstappen penalties cause major title change", + "url": "https://www.gpfans.com/en/f1-news/1033512/f1-results-today-mexican-grand-prix-huge-max-verstappen-penalties-cause-major-title-change/", + "datePublished": "2024-10-27T00:00:00.0000000", + "datePublishedFreshnessText": "1 day ago", + "isFamilyFriendly": true, + "displayUrl": "https://www.gpfans.com/en/f1-news/1033512/f1-results-today-mexican-grand-prix-huge-max...", + "snippet": "Elsewhere, Mercedes duo Lewis Hamilton and George Russell came home in P4 and P5 respectively. Meanwhile, the surprise package of the day were Haas, with both Kevin Magnussen and Nico Hulkenberg finishing inside the points.. READ MORE: RB star issues apology after red flag CRASH at Mexican GP Mexican Grand Prix 2024 results. 1. Carlos Sainz [Ferrari] 2. Lando Norris [McLaren] - +4.705", + "dateLastCrawled": "2024-10-28T06:06:00.0000000Z", + "cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=2840656522642&mkt=en-US&setlang=en-US&w=-Tbkwxnq52jZCvG7l3CtgcwT1vwAjIUD", + "language": "en", + "isNavigational": false, + "noCache": false + }, + { + "id": "https://api.bing.microsoft.com/api/v7/#WebPages.2", + "name": "International Power Rankings: England flying, Kangaroos cruising, Fiji rise", + "url": "https://www.loverugbyleague.com/post/international-power-rankings-england-flying-kangaroos-cruising-fiji-rise", + "datePublished": "2024-10-28T00:00:00.0000000", + "datePublishedFreshnessText": "7 hours ago", + "isFamilyFriendly": true, + "displayUrl": "https://www.loverugbyleague.com/post/international-power-rankings-england-flying...", + "snippet": "LRL RECOMMENDS: England player ratings from first Test against Samoa as omnificent George Williams scores perfect 10. 2. Australia (Men) – SAME. The Kangaroos remain 2nd in our Power Rankings after their 22-10 win against New Zealand in Christchurch on Sunday. As was the case in their win against Tonga last week, Mal Meninga’s side weren ...", + "dateLastCrawled": "2024-10-28T07:09:00.0000000Z", + "cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=1535008462672&mkt=en-US&setlang=en-US&w=82ujhH4Kp0iuhCS7wh1xLUFYUeetaVVm", + "language": "en", + "isNavigational": false, + "noCache": false + } + ], + "someResultsRemoved": true + } +} diff --git a/backend/open_webui/apps/socket/main.py b/backend/open_webui/apps/socket/main.py index c05e682f6b2cdb44f8838e337017a607e911a115..22273e8946aa12dfcf3e74f9634216095c956c7a 100644 --- a/backend/open_webui/apps/socket/main.py +++ b/backend/open_webui/apps/socket/main.py @@ -1,3 +1,5 @@ +# TODO: move socket to webui app + import asyncio import socketio import logging diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index 730abcc400c0401e54ad9f72b2c06d7a373eaa73..e579d8ee83af8a98e5e7f4f2ad9b750ed9776e40 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -12,6 +12,7 @@ from open_webui.apps.webui.routers import ( chats, folders, configs, + groups, files, functions, memories, @@ -34,6 +35,7 @@ from open_webui.config import ( ENABLE_LOGIN_FORM, ENABLE_MESSAGE_RATING, ENABLE_SIGNUP, + ENABLE_API_KEY, ENABLE_EVALUATION_ARENA_MODELS, EVALUATION_ARENA_MODELS, DEFAULT_ARENA_MODEL, @@ -50,9 +52,22 @@ from open_webui.config import ( WEBHOOK_URL, WEBUI_AUTH, WEBUI_BANNERS, + ENABLE_LDAP, + LDAP_SERVER_LABEL, + LDAP_SERVER_HOST, + LDAP_SERVER_PORT, + LDAP_ATTRIBUTE_FOR_USERNAME, + LDAP_SEARCH_FILTERS, + LDAP_SEARCH_BASE, + LDAP_APP_DN, + LDAP_APP_PASSWORD, + LDAP_USE_TLS, + LDAP_CA_CERT_FILE, + LDAP_CIPHERS, AppConfig, ) from open_webui.env import ( + ENV, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER, ) @@ -72,7 +87,11 @@ from open_webui.utils.payload import ( from open_webui.utils.tools import get_tools -app = FastAPI() +app = FastAPI( + docs_url="/docs" if ENV == "dev" else None, + openapi_url="/openapi.json" if ENV == "dev" else None, + redoc_url=None, +) log = logging.getLogger(__name__) @@ -80,6 +99,8 @@ app.state.config = AppConfig() app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM +app.state.config.ENABLE_API_KEY = ENABLE_API_KEY + app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER @@ -92,6 +113,8 @@ app.state.config.ADMIN_EMAIL = ADMIN_EMAIL app.state.config.DEFAULT_MODELS = DEFAULT_MODELS app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE + + app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.BANNERS = WEBUI_BANNERS @@ -111,7 +134,19 @@ app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES -app.state.MODELS = {} +app.state.config.ENABLE_LDAP = ENABLE_LDAP +app.state.config.LDAP_SERVER_LABEL = LDAP_SERVER_LABEL +app.state.config.LDAP_SERVER_HOST = LDAP_SERVER_HOST +app.state.config.LDAP_SERVER_PORT = LDAP_SERVER_PORT +app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME +app.state.config.LDAP_APP_DN = LDAP_APP_DN +app.state.config.LDAP_APP_PASSWORD = LDAP_APP_PASSWORD +app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE +app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS +app.state.config.LDAP_USE_TLS = LDAP_USE_TLS +app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE +app.state.config.LDAP_CIPHERS = LDAP_CIPHERS + app.state.TOOLS = {} app.state.FUNCTIONS = {} @@ -135,13 +170,15 @@ app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(tools.router, prefix="/tools", tags=["tools"]) -app.include_router(functions.router, prefix="/functions", tags=["functions"]) app.include_router(memories.router, prefix="/memories", tags=["memories"]) -app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"]) - app.include_router(folders.router, prefix="/folders", tags=["folders"]) + +app.include_router(groups.router, prefix="/groups", tags=["groups"]) app.include_router(files.router, prefix="/files", tags=["files"]) +app.include_router(functions.router, prefix="/functions", tags=["functions"]) +app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"]) + app.include_router(utils.router, prefix="/utils", tags=["utils"]) @@ -336,7 +373,7 @@ def get_function_params(function_module, form_data, user, extra_params=None): return params -async def generate_function_chat_completion(form_data, user): +async def generate_function_chat_completion(form_data, user, models: dict = {}): model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) @@ -372,6 +409,7 @@ async def generate_function_chat_completion(form_data, user): "name": user.name, "role": user.role, }, + "__metadata__": metadata, } extra_params["__tools__"] = get_tools( app, @@ -379,7 +417,7 @@ async def generate_function_chat_completion(form_data, user): user, { **extra_params, - "__model__": app.state.MODELS[form_data["model"]], + "__model__": models.get(form_data["model"], None), "__messages__": form_data["messages"], "__files__": files, }, diff --git a/backend/open_webui/apps/webui/models/auths.py b/backend/open_webui/apps/webui/models/auths.py index c88b2b5f22ec795bf586579627a81db937e1a882..77ffbbeedbc548b95a97df732d936ebbf018d7fc 100644 --- a/backend/open_webui/apps/webui/models/auths.py +++ b/backend/open_webui/apps/webui/models/auths.py @@ -64,6 +64,11 @@ class SigninForm(BaseModel): password: str +class LdapForm(BaseModel): + user: str + password: str + + class ProfileImageUrlForm(BaseModel): profile_image_url: str diff --git a/backend/open_webui/apps/webui/models/chats.py b/backend/open_webui/apps/webui/models/chats.py index 231842b0b62dc04ce95fcfdb3331747bacacaaa5..c86417fa1c787bf4966e77d1b3b0cbc602890a99 100644 --- a/backend/open_webui/apps/webui/models/chats.py +++ b/backend/open_webui/apps/webui/models/chats.py @@ -203,15 +203,22 @@ class ChatTable: def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: try: with get_db() as db: - print("update_shared_chat_by_id") chat = db.get(Chat, chat_id) - print(chat) - chat.title = chat.title - chat.chat = chat.chat + shared_chat = ( + db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first() + ) + + if shared_chat is None: + return self.insert_shared_chat_by_chat_id(chat_id) + + shared_chat.title = chat.title + shared_chat.chat = chat.chat + + shared_chat.updated_at = int(time.time()) db.commit() - db.refresh(chat) + db.refresh(shared_chat) - return self.get_chat_by_id(chat.share_id) + return ChatModel.model_validate(shared_chat) except Exception: return None diff --git a/backend/open_webui/apps/webui/models/groups.py b/backend/open_webui/apps/webui/models/groups.py new file mode 100644 index 0000000000000000000000000000000000000000..963c2e162003beaae20c1822900c28a86251fd5d --- /dev/null +++ b/backend/open_webui/apps/webui/models/groups.py @@ -0,0 +1,186 @@ +import json +import logging +import time +from typing import Optional +import uuid + +from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.env import SRC_LOG_LEVELS + +from open_webui.apps.webui.models.files import FileMetadataResponse + + +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, String, Text, JSON, func + + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + +#################### +# UserGroup DB Schema +#################### + + +class Group(Base): + __tablename__ = "group" + + id = Column(Text, unique=True, primary_key=True) + user_id = Column(Text) + + name = Column(Text) + description = Column(Text) + + data = Column(JSON, nullable=True) + meta = Column(JSON, nullable=True) + + permissions = Column(JSON, nullable=True) + user_ids = Column(JSON, nullable=True) + + created_at = Column(BigInteger) + updated_at = Column(BigInteger) + + +class GroupModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str + user_id: str + + name: str + description: str + + data: Optional[dict] = None + meta: Optional[dict] = None + + permissions: Optional[dict] = None + user_ids: list[str] = [] + + created_at: int # timestamp in epoch + updated_at: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class GroupResponse(BaseModel): + id: str + user_id: str + name: str + description: str + permissions: Optional[dict] = None + data: Optional[dict] = None + meta: Optional[dict] = None + user_ids: list[str] = [] + created_at: int # timestamp in epoch + updated_at: int # timestamp in epoch + + +class GroupForm(BaseModel): + name: str + description: str + + +class GroupUpdateForm(GroupForm): + permissions: Optional[dict] = None + user_ids: Optional[list[str]] = None + admin_ids: Optional[list[str]] = None + + +class GroupTable: + def insert_new_group( + self, user_id: str, form_data: GroupForm + ) -> Optional[GroupModel]: + with get_db() as db: + group = GroupModel( + **{ + **form_data.model_dump(), + "id": str(uuid.uuid4()), + "user_id": user_id, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + + try: + result = Group(**group.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return GroupModel.model_validate(result) + else: + return None + + except Exception: + return None + + def get_groups(self) -> list[GroupModel]: + with get_db() as db: + return [ + GroupModel.model_validate(group) + for group in db.query(Group).order_by(Group.updated_at.desc()).all() + ] + + def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]: + with get_db() as db: + return [ + GroupModel.model_validate(group) + for group in db.query(Group) + .filter( + func.json_array_length(Group.user_ids) > 0 + ) # Ensure array exists + .filter( + Group.user_ids.cast(String).like(f'%"{user_id}"%') + ) # String-based check + .order_by(Group.updated_at.desc()) + .all() + ] + + def get_group_by_id(self, id: str) -> Optional[GroupModel]: + try: + with get_db() as db: + group = db.query(Group).filter_by(id=id).first() + return GroupModel.model_validate(group) if group else None + except Exception: + return None + + def update_group_by_id( + self, id: str, form_data: GroupUpdateForm, overwrite: bool = False + ) -> Optional[GroupModel]: + try: + with get_db() as db: + db.query(Group).filter_by(id=id).update( + { + **form_data.model_dump(exclude_none=True), + "updated_at": int(time.time()), + } + ) + db.commit() + return self.get_group_by_id(id=id) + except Exception as e: + log.exception(e) + return None + + def delete_group_by_id(self, id: str) -> bool: + try: + with get_db() as db: + db.query(Group).filter_by(id=id).delete() + db.commit() + return True + except Exception: + return False + + def delete_all_groups(self) -> bool: + with get_db() as db: + try: + db.query(Group).delete() + db.commit() + + return True + except Exception: + return False + + +Groups = GroupTable() diff --git a/backend/open_webui/apps/webui/models/knowledge.py b/backend/open_webui/apps/webui/models/knowledge.py index 9bc07ae4445c41506e96883d24889d6b60186d10..4973199cb70ed91c413460809910cfbf4ec76965 100644 --- a/backend/open_webui/apps/webui/models/knowledge.py +++ b/backend/open_webui/apps/webui/models/knowledge.py @@ -8,11 +8,13 @@ from open_webui.apps.webui.internal.db import Base, get_db from open_webui.env import SRC_LOG_LEVELS from open_webui.apps.webui.models.files import FileMetadataResponse +from open_webui.apps.webui.models.users import Users, UserResponse from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON +from open_webui.utils.access_control import has_access log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -34,6 +36,23 @@ class Knowledge(Base): data = Column(JSON, nullable=True) meta = Column(JSON, nullable=True) + access_control = Column(JSON, nullable=True) # Controls data access levels. + # Defines access control rules for this entry. + # - `None`: Public access, available to all users with the "user" role. + # - `{}`: Private access, restricted exclusively to the owner. + # - Custom permissions: Specific access control for reading and writing; + # Can specify group or user-level restrictions: + # { + # "read": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # }, + # "write": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # } + # } + created_at = Column(BigInteger) updated_at = Column(BigInteger) @@ -50,6 +69,8 @@ class KnowledgeModel(BaseModel): data: Optional[dict] = None meta: Optional[dict] = None + access_control: Optional[dict] = None + created_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -59,15 +80,15 @@ class KnowledgeModel(BaseModel): #################### -class KnowledgeResponse(BaseModel): - id: str - name: str - description: str - data: Optional[dict] = None - meta: Optional[dict] = None - created_at: int # timestamp in epoch - updated_at: int # timestamp in epoch +class KnowledgeUserModel(KnowledgeModel): + user: Optional[UserResponse] = None + + +class KnowledgeResponse(KnowledgeModel): + files: Optional[list[FileMetadataResponse | dict]] = None + +class KnowledgeUserResponse(KnowledgeUserModel): files: Optional[list[FileMetadataResponse | dict]] = None @@ -75,12 +96,7 @@ class KnowledgeForm(BaseModel): name: str description: str data: Optional[dict] = None - - -class KnowledgeUpdateForm(BaseModel): - name: Optional[str] = None - description: Optional[str] = None - data: Optional[dict] = None + access_control: Optional[dict] = None class KnowledgeTable: @@ -110,14 +126,33 @@ class KnowledgeTable: except Exception: return None - def get_knowledge_items(self) -> list[KnowledgeModel]: + def get_knowledge_bases(self) -> list[KnowledgeUserModel]: with get_db() as db: - return [ - KnowledgeModel.model_validate(knowledge) - for knowledge in db.query(Knowledge) - .order_by(Knowledge.updated_at.desc()) - .all() - ] + knowledge_bases = [] + for knowledge in ( + db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all() + ): + user = Users.get_user_by_id(knowledge.user_id) + knowledge_bases.append( + KnowledgeUserModel.model_validate( + { + **KnowledgeModel.model_validate(knowledge).model_dump(), + "user": user.model_dump() if user else None, + } + ) + ) + return knowledge_bases + + def get_knowledge_bases_by_user_id( + self, user_id: str, permission: str = "write" + ) -> list[KnowledgeUserModel]: + knowledge_bases = self.get_knowledge_bases() + return [ + knowledge_base + for knowledge_base in knowledge_bases + if knowledge_base.user_id == user_id + or has_access(user_id, permission, knowledge_base.access_control) + ] def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: try: @@ -128,14 +163,32 @@ class KnowledgeTable: return None def update_knowledge_by_id( - self, id: str, form_data: KnowledgeUpdateForm, overwrite: bool = False + self, id: str, form_data: KnowledgeForm, overwrite: bool = False + ) -> Optional[KnowledgeModel]: + try: + with get_db() as db: + knowledge = self.get_knowledge_by_id(id=id) + db.query(Knowledge).filter_by(id=id).update( + { + **form_data.model_dump(), + "updated_at": int(time.time()), + } + ) + db.commit() + return self.get_knowledge_by_id(id=id) + except Exception as e: + log.exception(e) + return None + + def update_knowledge_data_by_id( + self, id: str, data: dict ) -> Optional[KnowledgeModel]: try: with get_db() as db: knowledge = self.get_knowledge_by_id(id=id) db.query(Knowledge).filter_by(id=id).update( { - **form_data.model_dump(exclude_none=True), + "data": data, "updated_at": int(time.time()), } ) diff --git a/backend/open_webui/apps/webui/models/models.py b/backend/open_webui/apps/webui/models/models.py index ea72c10453384e9b59dff39aa2c5ccdb12f8a479..e3a208f522ab542b4919b5eea84b4da205ca1c52 100644 --- a/backend/open_webui/apps/webui/models/models.py +++ b/backend/open_webui/apps/webui/models/models.py @@ -4,8 +4,19 @@ from typing import Optional from open_webui.apps.webui.internal.db import Base, JSONField, get_db from open_webui.env import SRC_LOG_LEVELS + +from open_webui.apps.webui.models.users import Users, UserResponse + + from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, Text + +from sqlalchemy import or_, and_, func +from sqlalchemy.dialects import postgresql, sqlite +from sqlalchemy import BigInteger, Column, Text, JSON, Boolean + + +from open_webui.utils.access_control import has_access + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -67,6 +78,25 @@ class Model(Base): Holds a JSON encoded blob of metadata, see `ModelMeta`. """ + access_control = Column(JSON, nullable=True) # Controls data access levels. + # Defines access control rules for this entry. + # - `None`: Public access, available to all users with the "user" role. + # - `{}`: Private access, restricted exclusively to the owner. + # - Custom permissions: Specific access control for reading and writing; + # Can specify group or user-level restrictions: + # { + # "read": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # }, + # "write": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # } + # } + + is_active = Column(Boolean, default=True) + updated_at = Column(BigInteger) created_at = Column(BigInteger) @@ -80,6 +110,9 @@ class ModelModel(BaseModel): params: ModelParams meta: ModelMeta + access_control: Optional[dict] = None + + is_active: bool updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -91,12 +124,12 @@ class ModelModel(BaseModel): #################### -class ModelResponse(BaseModel): - id: str - name: str - meta: ModelMeta - updated_at: int # timestamp in epoch - created_at: int # timestamp in epoch +class ModelUserResponse(ModelModel): + user: Optional[UserResponse] = None + + +class ModelResponse(ModelModel): + pass class ModelForm(BaseModel): @@ -105,6 +138,8 @@ class ModelForm(BaseModel): name: str meta: ModelMeta params: ModelParams + access_control: Optional[dict] = None + is_active: bool = True class ModelsTable: @@ -138,6 +173,39 @@ class ModelsTable: with get_db() as db: return [ModelModel.model_validate(model) for model in db.query(Model).all()] + def get_models(self) -> list[ModelUserResponse]: + with get_db() as db: + models = [] + for model in db.query(Model).filter(Model.base_model_id != None).all(): + user = Users.get_user_by_id(model.user_id) + models.append( + ModelUserResponse.model_validate( + { + **ModelModel.model_validate(model).model_dump(), + "user": user.model_dump() if user else None, + } + ) + ) + return models + + def get_base_models(self) -> list[ModelModel]: + with get_db() as db: + return [ + ModelModel.model_validate(model) + for model in db.query(Model).filter(Model.base_model_id == None).all() + ] + + def get_models_by_user_id( + self, user_id: str, permission: str = "write" + ) -> list[ModelUserResponse]: + models = self.get_models() + return [ + model + for model in models + if model.user_id == user_id + or has_access(user_id, permission, model.access_control) + ] + def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: with get_db() as db: @@ -146,6 +214,23 @@ class ModelsTable: except Exception: return None + def toggle_model_by_id(self, id: str) -> Optional[ModelModel]: + with get_db() as db: + try: + is_active = db.query(Model).filter_by(id=id).first().is_active + + db.query(Model).filter_by(id=id).update( + { + "is_active": not is_active, + "updated_at": int(time.time()), + } + ) + db.commit() + + return self.get_model_by_id(id) + except Exception: + return None + def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: try: with get_db() as db: @@ -153,7 +238,7 @@ class ModelsTable: result = ( db.query(Model) .filter_by(id=id) - .update(model.model_dump(exclude={"id"}, exclude_none=True)) + .update(model.model_dump(exclude={"id"})) ) db.commit() @@ -175,5 +260,15 @@ class ModelsTable: except Exception: return False + def delete_all_models(self) -> bool: + try: + with get_db() as db: + db.query(Model).delete() + db.commit() + + return True + except Exception: + return False + Models = ModelsTable() diff --git a/backend/open_webui/apps/webui/models/prompts.py b/backend/open_webui/apps/webui/models/prompts.py index f3f17270ed326acc780c5a78308b0fcab1f85442..3a82007111b5935f149f2648bdbaa0eefc2a9811 100644 --- a/backend/open_webui/apps/webui/models/prompts.py +++ b/backend/open_webui/apps/webui/models/prompts.py @@ -2,8 +2,12 @@ import time from typing import Optional from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.apps.webui.models.users import Users, UserResponse + from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text +from sqlalchemy import BigInteger, Column, String, Text, JSON + +from open_webui.utils.access_control import has_access #################### # Prompts DB Schema @@ -19,6 +23,23 @@ class Prompt(Base): content = Column(Text) timestamp = Column(BigInteger) + access_control = Column(JSON, nullable=True) # Controls data access levels. + # Defines access control rules for this entry. + # - `None`: Public access, available to all users with the "user" role. + # - `{}`: Private access, restricted exclusively to the owner. + # - Custom permissions: Specific access control for reading and writing; + # Can specify group or user-level restrictions: + # { + # "read": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # }, + # "write": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # } + # } + class PromptModel(BaseModel): command: str @@ -27,6 +48,7 @@ class PromptModel(BaseModel): content: str timestamp: int # timestamp in epoch + access_control: Optional[dict] = None model_config = ConfigDict(from_attributes=True) @@ -35,10 +57,15 @@ class PromptModel(BaseModel): #################### +class PromptUserResponse(PromptModel): + user: Optional[UserResponse] = None + + class PromptForm(BaseModel): command: str title: str content: str + access_control: Optional[dict] = None class PromptsTable: @@ -48,16 +75,14 @@ class PromptsTable: prompt = PromptModel( **{ "user_id": user_id, - "command": form_data.command, - "title": form_data.title, - "content": form_data.content, + **form_data.model_dump(), "timestamp": int(time.time()), } ) try: with get_db() as db: - result = Prompt(**prompt.dict()) + result = Prompt(**prompt.model_dump()) db.add(result) db.commit() db.refresh(result) @@ -76,11 +101,34 @@ class PromptsTable: except Exception: return None - def get_prompts(self) -> list[PromptModel]: + def get_prompts(self) -> list[PromptUserResponse]: with get_db() as db: - return [ - PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all() - ] + prompts = [] + + for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all(): + user = Users.get_user_by_id(prompt.user_id) + prompts.append( + PromptUserResponse.model_validate( + { + **PromptModel.model_validate(prompt).model_dump(), + "user": user.model_dump() if user else None, + } + ) + ) + + return prompts + + def get_prompts_by_user_id( + self, user_id: str, permission: str = "write" + ) -> list[PromptUserResponse]: + prompts = self.get_prompts() + + return [ + prompt + for prompt in prompts + if prompt.user_id == user_id + or has_access(user_id, permission, prompt.access_control) + ] def update_prompt_by_command( self, command: str, form_data: PromptForm @@ -90,6 +138,7 @@ class PromptsTable: prompt = db.query(Prompt).filter_by(command=command).first() prompt.title = form_data.title prompt.content = form_data.content + prompt.access_control = form_data.access_control prompt.timestamp = int(time.time()) db.commit() return PromptModel.model_validate(prompt) diff --git a/backend/open_webui/apps/webui/models/tools.py b/backend/open_webui/apps/webui/models/tools.py index bef1a3596d0af2a02cff36856ec2aa3a9ad61063..76fbe43faa3b599f79fbb47a36b597db4a795574 100644 --- a/backend/open_webui/apps/webui/models/tools.py +++ b/backend/open_webui/apps/webui/models/tools.py @@ -3,10 +3,13 @@ import time from typing import Optional from open_webui.apps.webui.internal.db import Base, JSONField, get_db -from open_webui.apps.webui.models.users import Users +from open_webui.apps.webui.models.users import Users, UserResponse from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text +from sqlalchemy import BigInteger, Column, String, Text, JSON + +from open_webui.utils.access_control import has_access + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -26,6 +29,24 @@ class Tool(Base): specs = Column(JSONField) meta = Column(JSONField) valves = Column(JSONField) + + access_control = Column(JSON, nullable=True) # Controls data access levels. + # Defines access control rules for this entry. + # - `None`: Public access, available to all users with the "user" role. + # - `{}`: Private access, restricted exclusively to the owner. + # - Custom permissions: Specific access control for reading and writing; + # Can specify group or user-level restrictions: + # { + # "read": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # }, + # "write": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # } + # } + updated_at = Column(BigInteger) created_at = Column(BigInteger) @@ -42,6 +63,8 @@ class ToolModel(BaseModel): content: str specs: list[dict] meta: ToolMeta + access_control: Optional[dict] = None + updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -58,15 +81,21 @@ class ToolResponse(BaseModel): user_id: str name: str meta: ToolMeta + access_control: Optional[dict] = None updated_at: int # timestamp in epoch created_at: int # timestamp in epoch +class ToolUserResponse(ToolResponse): + user: Optional[UserResponse] = None + + class ToolForm(BaseModel): id: str name: str content: str meta: ToolMeta + access_control: Optional[dict] = None class ToolValves(BaseModel): @@ -109,9 +138,32 @@ class ToolsTable: except Exception: return None - def get_tools(self) -> list[ToolModel]: + def get_tools(self) -> list[ToolUserResponse]: with get_db() as db: - return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] + tools = [] + for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all(): + user = Users.get_user_by_id(tool.user_id) + tools.append( + ToolUserResponse.model_validate( + { + **ToolModel.model_validate(tool).model_dump(), + "user": user.model_dump() if user else None, + } + ) + ) + return tools + + def get_tools_by_user_id( + self, user_id: str, permission: str = "write" + ) -> list[ToolUserResponse]: + tools = self.get_tools() + + return [ + tool + for tool in tools + if tool.user_id == user_id + or has_access(user_id, permission, tool.access_control) + ] def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: diff --git a/backend/open_webui/apps/webui/models/users.py b/backend/open_webui/apps/webui/models/users.py index 3523ad1f9e8be46d0475105bd2d4db2023dce0a4..f0a04a3efca470b67e2eaf1737f26b60e9f10849 100644 --- a/backend/open_webui/apps/webui/models/users.py +++ b/backend/open_webui/apps/webui/models/users.py @@ -62,6 +62,14 @@ class UserModel(BaseModel): #################### +class UserResponse(BaseModel): + id: str + name: str + email: str + role: str + profile_image_url: str + + class UserRoleUpdateForm(BaseModel): id: str role: str diff --git a/backend/open_webui/apps/webui/routers/auths.py b/backend/open_webui/apps/webui/routers/auths.py index e8ff87a22f00f7dd9fe6505f7f6c5216e0f28134..e777f09ad6d25f4993020aa1c784db8938e18490 100644 --- a/backend/open_webui/apps/webui/routers/auths.py +++ b/backend/open_webui/apps/webui/routers/auths.py @@ -2,12 +2,14 @@ import re import uuid import time import datetime +import logging from open_webui.apps.webui.models.auths import ( AddUserForm, ApiKey, Auths, Token, + LdapForm, SigninForm, SigninResponse, SignupForm, @@ -16,13 +18,15 @@ from open_webui.apps.webui.models.auths import ( UserResponse, ) from open_webui.apps.webui.models.users import Users -from open_webui.config import WEBUI_AUTH + from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from open_webui.env import ( + WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER, WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE, + SRC_LOG_LEVELS, ) from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import Response @@ -37,10 +41,19 @@ from open_webui.utils.utils import ( get_password_hash, ) from open_webui.utils.webhook import post_webhook -from typing import Optional +from open_webui.utils.access_control import get_permissions + +from typing import Optional, List + +from ssl import CERT_REQUIRED, PROTOCOL_TLS +from ldap3 import Server, Connection, ALL, Tls +from ldap3.utils.conv import escape_filter_chars router = APIRouter() +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + ############################ # GetSessionUser ############################ @@ -48,6 +61,7 @@ router = APIRouter() class SessionUserResponse(Token, UserResponse): expires_at: Optional[int] = None + permissions: Optional[dict] = None @router.get("/", response_model=SessionUserResponse) @@ -80,6 +94,10 @@ async def get_session_user( secure=WEBUI_SESSION_COOKIE_SECURE, ) + user_permissions = get_permissions( + user.id, request.app.state.config.USER_PERMISSIONS + ) + return { "token": token, "token_type": "Bearer", @@ -89,6 +107,7 @@ async def get_session_user( "name": user.name, "role": user.role, "profile_image_url": user.profile_image_url, + "permissions": user_permissions, } @@ -137,6 +156,140 @@ async def update_password( raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) +############################ +# LDAP Authentication +############################ +@router.post("/ldap", response_model=SigninResponse) +async def ldap_auth(request: Request, response: Response, form_data: LdapForm): + ENABLE_LDAP = request.app.state.config.ENABLE_LDAP + LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL + LDAP_SERVER_HOST = request.app.state.config.LDAP_SERVER_HOST + LDAP_SERVER_PORT = request.app.state.config.LDAP_SERVER_PORT + LDAP_ATTRIBUTE_FOR_USERNAME = request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME + LDAP_SEARCH_BASE = request.app.state.config.LDAP_SEARCH_BASE + LDAP_SEARCH_FILTERS = request.app.state.config.LDAP_SEARCH_FILTERS + LDAP_APP_DN = request.app.state.config.LDAP_APP_DN + LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD + LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS + LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE + LDAP_CIPHERS = ( + request.app.state.config.LDAP_CIPHERS + if request.app.state.config.LDAP_CIPHERS + else "ALL" + ) + + if not ENABLE_LDAP: + raise HTTPException(400, detail="LDAP authentication is not enabled") + + try: + tls = Tls( + validate=CERT_REQUIRED, + version=PROTOCOL_TLS, + ca_certs_file=LDAP_CA_CERT_FILE, + ciphers=LDAP_CIPHERS, + ) + except Exception as e: + log.error(f"An error occurred on TLS: {str(e)}") + raise HTTPException(400, detail=str(e)) + + try: + server = Server( + host=LDAP_SERVER_HOST, + port=LDAP_SERVER_PORT, + get_info=ALL, + use_ssl=LDAP_USE_TLS, + tls=tls, + ) + connection_app = Connection( + server, + LDAP_APP_DN, + LDAP_APP_PASSWORD, + auto_bind="NONE", + authentication="SIMPLE", + ) + if not connection_app.bind(): + raise HTTPException(400, detail="Application account bind failed") + + search_success = connection_app.search( + search_base=LDAP_SEARCH_BASE, + search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})", + attributes=[f"{LDAP_ATTRIBUTE_FOR_USERNAME}", "mail", "cn"], + ) + + if not search_success: + raise HTTPException(400, detail="User not found in the LDAP server") + + entry = connection_app.entries[0] + username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower() + mail = str(entry["mail"]) + cn = str(entry["cn"]) + user_dn = entry.entry_dn + + if username == form_data.user.lower(): + connection_user = Connection( + server, + user_dn, + form_data.password, + auto_bind="NONE", + authentication="SIMPLE", + ) + if not connection_user.bind(): + raise HTTPException(400, f"Authentication failed for {form_data.user}") + + user = Users.get_user_by_email(mail) + if not user: + + try: + hashed = get_password_hash(form_data.password) + user = Auths.insert_new_auth(mail, hashed, cn) + + if not user: + raise HTTPException( + 500, detail=ERROR_MESSAGES.CREATE_USER_ERROR + ) + + except HTTPException: + raise + except Exception as err: + raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) + + user = Auths.authenticate_user(mail, password=str(form_data.password)) + + if user: + token = create_token( + data={"id": user.id}, + expires_delta=parse_duration( + request.app.state.config.JWT_EXPIRES_IN + ), + ) + + # Set the cookie token + response.set_cookie( + key="token", + value=token, + httponly=True, # Ensures the cookie is not accessible via JavaScript + ) + + return { + "token": token, + "token_type": "Bearer", + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + "profile_image_url": user.profile_image_url, + } + else: + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + else: + raise HTTPException( + 400, + f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}", + ) + except Exception as e: + raise HTTPException(400, detail=str(e)) + + ############################ # SignIn ############################ @@ -211,6 +364,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm): secure=WEBUI_SESSION_COOKIE_SECURE, ) + user_permissions = get_permissions( + user.id, request.app.state.config.USER_PERMISSIONS + ) + return { "token": token, "token_type": "Bearer", @@ -220,6 +377,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm): "name": user.name, "role": user.role, "profile_image_url": user.profile_image_url, + "permissions": user_permissions, } else: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) @@ -260,6 +418,11 @@ async def signup(request: Request, response: Response, form_data: SignupForm): if Users.get_num_users() == 0 else request.app.state.config.DEFAULT_USER_ROLE ) + + if Users.get_num_users() == 0: + # Disable signup after the first user is created + request.app.state.config.ENABLE_SIGNUP = False + hashed = get_password_hash(form_data.password) user = Auths.insert_new_auth( form_data.email.lower(), @@ -307,6 +470,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm): }, ) + user_permissions = get_permissions( + user.id, request.app.state.config.USER_PERMISSIONS + ) + return { "token": token, "token_type": "Bearer", @@ -316,6 +483,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): "name": user.name, "role": user.role, "profile_image_url": user.profile_image_url, + "permissions": user_permissions, } else: raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) @@ -413,6 +581,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)): return { "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS, "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP, + "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY, "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, @@ -423,6 +592,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)): class AdminConfig(BaseModel): SHOW_ADMIN_DETAILS: bool ENABLE_SIGNUP: bool + ENABLE_API_KEY: bool DEFAULT_USER_ROLE: str JWT_EXPIRES_IN: str ENABLE_COMMUNITY_SHARING: bool @@ -435,6 +605,7 @@ async def update_admin_config( ): request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP + request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]: request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE @@ -453,6 +624,7 @@ async def update_admin_config( return { "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS, "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP, + "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY, "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, @@ -460,6 +632,105 @@ async def update_admin_config( } +class LdapServerConfig(BaseModel): + label: str + host: str + port: Optional[int] = None + attribute_for_username: str = "uid" + app_dn: str + app_dn_password: str + search_base: str + search_filters: str = "" + use_tls: bool = True + certificate_path: Optional[str] = None + ciphers: Optional[str] = "ALL" + + +@router.get("/admin/config/ldap/server", response_model=LdapServerConfig) +async def get_ldap_server(request: Request, user=Depends(get_admin_user)): + return { + "label": request.app.state.config.LDAP_SERVER_LABEL, + "host": request.app.state.config.LDAP_SERVER_HOST, + "port": request.app.state.config.LDAP_SERVER_PORT, + "attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME, + "app_dn": request.app.state.config.LDAP_APP_DN, + "app_dn_password": request.app.state.config.LDAP_APP_PASSWORD, + "search_base": request.app.state.config.LDAP_SEARCH_BASE, + "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS, + "use_tls": request.app.state.config.LDAP_USE_TLS, + "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE, + "ciphers": request.app.state.config.LDAP_CIPHERS, + } + + +@router.post("/admin/config/ldap/server") +async def update_ldap_server( + request: Request, form_data: LdapServerConfig, user=Depends(get_admin_user) +): + required_fields = [ + "label", + "host", + "attribute_for_username", + "app_dn", + "app_dn_password", + "search_base", + ] + for key in required_fields: + value = getattr(form_data, key) + if not value: + raise HTTPException(400, detail=f"Required field {key} is empty") + + if form_data.use_tls and not form_data.certificate_path: + raise HTTPException( + 400, detail="TLS is enabled but certificate file path is missing" + ) + + request.app.state.config.LDAP_SERVER_LABEL = form_data.label + request.app.state.config.LDAP_SERVER_HOST = form_data.host + request.app.state.config.LDAP_SERVER_PORT = form_data.port + request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = ( + form_data.attribute_for_username + ) + request.app.state.config.LDAP_APP_DN = form_data.app_dn + request.app.state.config.LDAP_APP_PASSWORD = form_data.app_dn_password + request.app.state.config.LDAP_SEARCH_BASE = form_data.search_base + request.app.state.config.LDAP_SEARCH_FILTERS = form_data.search_filters + request.app.state.config.LDAP_USE_TLS = form_data.use_tls + request.app.state.config.LDAP_CA_CERT_FILE = form_data.certificate_path + request.app.state.config.LDAP_CIPHERS = form_data.ciphers + + return { + "label": request.app.state.config.LDAP_SERVER_LABEL, + "host": request.app.state.config.LDAP_SERVER_HOST, + "port": request.app.state.config.LDAP_SERVER_PORT, + "attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME, + "app_dn": request.app.state.config.LDAP_APP_DN, + "app_dn_password": request.app.state.config.LDAP_APP_PASSWORD, + "search_base": request.app.state.config.LDAP_SEARCH_BASE, + "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS, + "use_tls": request.app.state.config.LDAP_USE_TLS, + "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE, + "ciphers": request.app.state.config.LDAP_CIPHERS, + } + + +@router.get("/admin/config/ldap") +async def get_ldap_config(request: Request, user=Depends(get_admin_user)): + return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP} + + +class LdapConfigForm(BaseModel): + enable_ldap: Optional[bool] = None + + +@router.post("/admin/config/ldap") +async def update_ldap_config( + request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.ENABLE_LDAP = form_data.enable_ldap + return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP} + + ############################ # API Key ############################ @@ -467,9 +738,16 @@ async def update_admin_config( # create api key @router.post("/api_key", response_model=ApiKey) -async def create_api_key_(user=Depends(get_current_user)): +async def generate_api_key(request: Request, user=Depends(get_current_user)): + if not request.app.state.config.ENABLE_API_KEY: + raise HTTPException( + status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.API_KEY_CREATION_NOT_ALLOWED, + ) + api_key = create_api_key() success = Users.update_user_api_key_by_id(user.id, api_key) + if success: return { "api_key": api_key, diff --git a/backend/open_webui/apps/webui/routers/chats.py b/backend/open_webui/apps/webui/routers/chats.py index 53a7386452bc2a416a851ef04c5a3e9df1a17e58..b3dbd8f763a384011abea4372036798029dcbb63 100644 --- a/backend/open_webui/apps/webui/routers/chats.py +++ b/backend/open_webui/apps/webui/routers/chats.py @@ -17,7 +17,10 @@ from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel + + from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_permission log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -50,9 +53,10 @@ async def get_session_user_chat_list( @router.delete("/", response_model=bool) async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)): - if user.role == "user" and not request.app.state.config.USER_PERMISSIONS.get( - "chat", {} - ).get("deletion", {}): + + if user.role == "user" and not has_permission( + user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS + ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -385,8 +389,8 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified return result else: - if not request.app.state.config.USER_PERMISSIONS.get("chat", {}).get( - "deletion", {} + if not has_permission( + user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/backend/open_webui/apps/webui/routers/groups.py b/backend/open_webui/apps/webui/routers/groups.py new file mode 100644 index 0000000000000000000000000000000000000000..fce6f62dc099e737c3c4d3d02d14fd7dbdca082e --- /dev/null +++ b/backend/open_webui/apps/webui/routers/groups.py @@ -0,0 +1,120 @@ +import os +from pathlib import Path +from typing import Optional + +from open_webui.apps.webui.models.groups import ( + Groups, + GroupForm, + GroupUpdateForm, + GroupResponse, +) + +from open_webui.config import CACHE_DIR +from open_webui.constants import ERROR_MESSAGES +from fastapi import APIRouter, Depends, HTTPException, Request, status +from open_webui.utils.utils import get_admin_user, get_verified_user + +router = APIRouter() + +############################ +# GetFunctions +############################ + + +@router.get("/", response_model=list[GroupResponse]) +async def get_groups(user=Depends(get_verified_user)): + if user.role == "admin": + return Groups.get_groups() + else: + return Groups.get_groups_by_member_id(user.id) + + +############################ +# CreateNewGroup +############################ + + +@router.post("/create", response_model=Optional[GroupResponse]) +async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)): + try: + group = Groups.insert_new_group(user.id, form_data) + if group: + return group + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error creating group"), + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +############################ +# GetGroupById +############################ + + +@router.get("/id/{id}", response_model=Optional[GroupResponse]) +async def get_group_by_id(id: str, user=Depends(get_admin_user)): + group = Groups.get_group_by_id(id) + if group: + return group + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateGroupById +############################ + + +@router.post("/id/{id}/update", response_model=Optional[GroupResponse]) +async def update_group_by_id( + id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user) +): + try: + group = Groups.update_group_by_id(id, form_data) + if group: + return group + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating group"), + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +############################ +# DeleteGroupById +############################ + + +@router.delete("/id/{id}/delete", response_model=bool) +async def delete_group_by_id(id: str, user=Depends(get_admin_user)): + try: + result = Groups.delete_group_by_id(id) + if result: + return result + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error deleting group"), + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) diff --git a/backend/open_webui/apps/webui/routers/knowledge.py b/backend/open_webui/apps/webui/routers/knowledge.py index c07ccdffd11d333dd280984108535590a81ffd4e..22f0b34c0f1d734b1b86c3b67765942888d1fedb 100644 --- a/backend/open_webui/apps/webui/routers/knowledge.py +++ b/backend/open_webui/apps/webui/routers/knowledge.py @@ -1,14 +1,14 @@ import json from typing import Optional, Union from pydantic import BaseModel -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, status, Request import logging from open_webui.apps.webui.models.knowledge import ( Knowledges, - KnowledgeUpdateForm, KnowledgeForm, KnowledgeResponse, + KnowledgeUserResponse, ) from open_webui.apps.webui.models.files import Files, FileModel from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT @@ -17,6 +17,9 @@ from open_webui.apps.retrieval.main import process_file, ProcessFileForm from open_webui.constants import ERROR_MESSAGES from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access, has_permission + + from open_webui.env import SRC_LOG_LEVELS @@ -26,64 +29,98 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) router = APIRouter() ############################ -# GetKnowledgeItems +# getKnowledgeBases ############################ -@router.get( - "/", response_model=Optional[Union[list[KnowledgeResponse], KnowledgeResponse]] -) -async def get_knowledge_items( - id: Optional[str] = None, user=Depends(get_verified_user) -): - if id: - knowledge = Knowledges.get_knowledge_by_id(id=id) +@router.get("/", response_model=list[KnowledgeUserResponse]) +async def get_knowledge(user=Depends(get_verified_user)): + knowledge_bases = [] - if knowledge: - return knowledge - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) + if user.role == "admin": + knowledge_bases = Knowledges.get_knowledge_bases() else: - knowledge_bases = [] - - for knowledge in Knowledges.get_knowledge_items(): + knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read") + + # Get files for each knowledge base + for knowledge_base in knowledge_bases: + files = [] + if knowledge_base.data: + files = Files.get_file_metadatas_by_ids( + knowledge_base.data.get("file_ids", []) + ) - files = [] - if knowledge.data: - files = Files.get_file_metadatas_by_ids( - knowledge.data.get("file_ids", []) + # Check if all files exist + if len(files) != len(knowledge_base.data.get("file_ids", [])): + missing_files = list( + set(knowledge_base.data.get("file_ids", [])) + - set([file.id for file in files]) ) + if missing_files: + data = knowledge_base.data or {} + file_ids = data.get("file_ids", []) + + for missing_file in missing_files: + file_ids.remove(missing_file) - # Check if all files exist - if len(files) != len(knowledge.data.get("file_ids", [])): - missing_files = list( - set(knowledge.data.get("file_ids", [])) - - set([file.id for file in files]) + data["file_ids"] = file_ids + Knowledges.update_knowledge_data_by_id( + id=knowledge_base.id, data=data ) - if missing_files: - data = knowledge.data or {} - file_ids = data.get("file_ids", []) - for missing_file in missing_files: - file_ids.remove(missing_file) + files = Files.get_file_metadatas_by_ids(file_ids) + + knowledge_base = KnowledgeResponse( + **knowledge_base.model_dump(), + files=files, + ) - data["file_ids"] = file_ids - Knowledges.update_knowledge_by_id( - id=knowledge.id, form_data=KnowledgeUpdateForm(data=data) - ) + return knowledge_bases - files = Files.get_file_metadatas_by_ids(file_ids) - knowledge_bases.append( - KnowledgeResponse( - **knowledge.model_dump(), - files=files, - ) +@router.get("/list", response_model=list[KnowledgeUserResponse]) +async def get_knowledge_list(user=Depends(get_verified_user)): + knowledge_bases = [] + + if user.role == "admin": + knowledge_bases = Knowledges.get_knowledge_bases() + else: + knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "write") + + # Get files for each knowledge base + for knowledge_base in knowledge_bases: + files = [] + if knowledge_base.data: + files = Files.get_file_metadatas_by_ids( + knowledge_base.data.get("file_ids", []) ) - return knowledge_bases + + # Check if all files exist + if len(files) != len(knowledge_base.data.get("file_ids", [])): + missing_files = list( + set(knowledge_base.data.get("file_ids", [])) + - set([file.id for file in files]) + ) + if missing_files: + data = knowledge_base.data or {} + file_ids = data.get("file_ids", []) + + for missing_file in missing_files: + file_ids.remove(missing_file) + + data["file_ids"] = file_ids + Knowledges.update_knowledge_data_by_id( + id=knowledge_base.id, data=data + ) + + files = Files.get_file_metadatas_by_ids(file_ids) + + knowledge_base = KnowledgeResponse( + **knowledge_base.model_dump(), + files=files, + ) + + return knowledge_bases ############################ @@ -92,7 +129,17 @@ async def get_knowledge_items( @router.post("/create", response_model=Optional[KnowledgeResponse]) -async def create_new_knowledge(form_data: KnowledgeForm, user=Depends(get_admin_user)): +async def create_new_knowledge( + request: Request, form_data: KnowledgeForm, user=Depends(get_verified_user) +): + if user.role != "admin" and not has_permission( + user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + knowledge = Knowledges.insert_new_knowledge(user.id, form_data) if knowledge: @@ -118,13 +165,20 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)): knowledge = Knowledges.get_knowledge_by_id(id=id) if knowledge: - file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] - files = Files.get_files_by_ids(file_ids) - return KnowledgeFilesResponse( - **knowledge.model_dump(), - files=files, - ) + if ( + user.role == "admin" + or knowledge.user_id == user.id + or has_access(user.id, "read", knowledge.access_control) + ): + + file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] + files = Files.get_files_by_ids(file_ids) + + return KnowledgeFilesResponse( + **knowledge.model_dump(), + files=files, + ) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -140,11 +194,23 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/update", response_model=Optional[KnowledgeFilesResponse]) async def update_knowledge_by_id( id: str, - form_data: KnowledgeUpdateForm, - user=Depends(get_admin_user), + form_data: KnowledgeForm, + user=Depends(get_verified_user), ): - knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) + knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) if knowledge: file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] files = Files.get_files_by_ids(file_ids) @@ -173,9 +239,22 @@ class KnowledgeFileIdForm(BaseModel): def add_file_to_knowledge_by_id( id: str, form_data: KnowledgeFileIdForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): knowledge = Knowledges.get_knowledge_by_id(id=id) + + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + file = Files.get_file_by_id(form_data.file_id) if not file: raise HTTPException( @@ -206,9 +285,7 @@ def add_file_to_knowledge_by_id( file_ids.append(form_data.file_id) data["file_ids"] = file_ids - knowledge = Knowledges.update_knowledge_by_id( - id=id, form_data=KnowledgeUpdateForm(data=data) - ) + knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data) if knowledge: files = Files.get_files_by_ids(file_ids) @@ -238,9 +315,21 @@ def add_file_to_knowledge_by_id( def update_file_from_knowledge_by_id( id: str, form_data: KnowledgeFileIdForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + file = Files.get_file_by_id(form_data.file_id) if not file: raise HTTPException( @@ -288,9 +377,21 @@ def update_file_from_knowledge_by_id( def remove_file_from_knowledge_by_id( id: str, form_data: KnowledgeFileIdForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + file = Files.get_file_by_id(form_data.file_id) if not file: raise HTTPException( @@ -318,9 +419,7 @@ def remove_file_from_knowledge_by_id( file_ids.remove(form_data.file_id) data["file_ids"] = file_ids - knowledge = Knowledges.update_knowledge_by_id( - id=id, form_data=KnowledgeUpdateForm(data=data) - ) + knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data) if knowledge: files = Files.get_files_by_ids(file_ids) @@ -347,35 +446,60 @@ def remove_file_from_knowledge_by_id( ############################ -# ResetKnowledgeById +# DeleteKnowledgeById ############################ -@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse]) -async def reset_knowledge_by_id(id: str, user=Depends(get_admin_user)): +@router.delete("/{id}/delete", response_model=bool) +async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): + knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + try: VECTOR_DB_CLIENT.delete_collection(collection_name=id) except Exception as e: log.debug(e) pass - - knowledge = Knowledges.update_knowledge_by_id( - id=id, form_data=KnowledgeUpdateForm(data={"file_ids": []}) - ) - return knowledge + result = Knowledges.delete_knowledge_by_id(id=id) + return result ############################ -# DeleteKnowledgeById +# ResetKnowledgeById ############################ -@router.delete("/{id}/delete", response_model=bool) -async def delete_knowledge_by_id(id: str, user=Depends(get_admin_user)): +@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse]) +async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)): + knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + try: VECTOR_DB_CLIENT.delete_collection(collection_name=id) except Exception as e: log.debug(e) pass - result = Knowledges.delete_knowledge_by_id(id=id) - return result + + knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []}) + + return knowledge diff --git a/backend/open_webui/apps/webui/routers/models.py b/backend/open_webui/apps/webui/routers/models.py index 64bc20b78a05d58ee651eeeb3c59e566cfa36ba5..3ed1d686d4832d250f7c86f63bd8b64c33154876 100644 --- a/backend/open_webui/apps/webui/routers/models.py +++ b/backend/open_webui/apps/webui/routers/models.py @@ -4,53 +4,71 @@ from open_webui.apps.webui.models.models import ( ModelForm, ModelModel, ModelResponse, + ModelUserResponse, Models, ) from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status + + from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access, has_permission + router = APIRouter() + ########################### -# getModels +# GetModels ########################### -@router.get("/", response_model=list[ModelResponse]) +@router.get("/", response_model=list[ModelUserResponse]) async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)): - if id: - model = Models.get_model_by_id(id) - if model: - return [model] - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) + if user.role == "admin": + return Models.get_models() else: - return Models.get_all_models() + return Models.get_models_by_user_id(user.id) + + +########################### +# GetBaseModels +########################### + + +@router.get("/base", response_model=list[ModelResponse]) +async def get_base_models(user=Depends(get_admin_user)): + return Models.get_base_models() ############################ -# AddNewModel +# CreateNewModel ############################ -@router.post("/add", response_model=Optional[ModelModel]) -async def add_new_model( +@router.post("/create", response_model=Optional[ModelModel]) +async def create_new_model( request: Request, form_data: ModelForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): - if form_data.id in request.app.state.MODELS: + if user.role != "admin" and not has_permission( + user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + + model = Models.get_model_by_id(form_data.id) + if model: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.MODEL_ID_TAKEN, ) + else: model = Models.insert_new_model(form_data, user.id) - if model: return model else: @@ -60,37 +78,84 @@ async def add_new_model( ) +########################### +# GetModelById +########################### + + +@router.get("/id/{id}", response_model=Optional[ModelResponse]) +async def get_model_by_id(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) + if model: + if ( + user.role == "admin" + or model.user_id == user.id + or has_access(user.id, "read", model.access_control) + ): + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ -# UpdateModelById +# ToggelModelById ############################ -@router.post("/update", response_model=Optional[ModelModel]) -async def update_model_by_id( - request: Request, - id: str, - form_data: ModelForm, - user=Depends(get_admin_user), -): +@router.post("/id/{id}/toggle", response_model=Optional[ModelResponse]) +async def toggle_model_by_id(id: str, user=Depends(get_verified_user)): model = Models.get_model_by_id(id) if model: - model = Models.update_model_by_id(id, form_data) - return model - else: - if form_data.id in request.app.state.MODELS: - model = Models.insert_new_model(form_data, user.id) + if ( + user.role == "admin" + or model.user_id == user.id + or has_access(user.id, "write", model.access_control) + ): + model = Models.toggle_model_by_id(id) + if model: return model else: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.DEFAULT(), + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating function"), ) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.DEFAULT(), + detail=ERROR_MESSAGES.UNAUTHORIZED, ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateModelById +############################ + + +@router.post("/id/{id}/update", response_model=Optional[ModelModel]) +async def update_model_by_id( + id: str, + form_data: ModelForm, + user=Depends(get_verified_user), +): + model = Models.get_model_by_id(id) + + if not model: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + model = Models.update_model_by_id(id, form_data) + return model ############################ @@ -98,7 +163,26 @@ async def update_model_by_id( ############################ -@router.delete("/delete", response_model=bool) -async def delete_model_by_id(id: str, user=Depends(get_admin_user)): +@router.delete("/id/{id}/delete", response_model=bool) +async def delete_model_by_id(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) + if not model: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if model.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + result = Models.delete_model_by_id(id) return result + + +@router.delete("/delete/all", response_model=bool) +async def delete_all_models(user=Depends(get_admin_user)): + result = Models.delete_all_models() + return result diff --git a/backend/open_webui/apps/webui/routers/prompts.py b/backend/open_webui/apps/webui/routers/prompts.py index 6692ca6be1b1db3b3c21a24c5163705f1b66030c..4bab3754f012aa5da531883142f4d7bef14251f5 100644 --- a/backend/open_webui/apps/webui/routers/prompts.py +++ b/backend/open_webui/apps/webui/routers/prompts.py @@ -1,9 +1,15 @@ from typing import Optional -from open_webui.apps.webui.models.prompts import PromptForm, PromptModel, Prompts +from open_webui.apps.webui.models.prompts import ( + PromptForm, + PromptUserResponse, + PromptModel, + Prompts, +) from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, status, Request from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access, has_permission router = APIRouter() @@ -14,7 +20,22 @@ router = APIRouter() @router.get("/", response_model=list[PromptModel]) async def get_prompts(user=Depends(get_verified_user)): - return Prompts.get_prompts() + if user.role == "admin": + prompts = Prompts.get_prompts() + else: + prompts = Prompts.get_prompts_by_user_id(user.id, "read") + + return prompts + + +@router.get("/list", response_model=list[PromptUserResponse]) +async def get_prompt_list(user=Depends(get_verified_user)): + if user.role == "admin": + prompts = Prompts.get_prompts() + else: + prompts = Prompts.get_prompts_by_user_id(user.id, "write") + + return prompts ############################ @@ -23,7 +44,17 @@ async def get_prompts(user=Depends(get_verified_user)): @router.post("/create", response_model=Optional[PromptModel]) -async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)): +async def create_new_prompt( + request: Request, form_data: PromptForm, user=Depends(get_verified_user) +): + if user.role != "admin" and not has_permission( + user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + prompt = Prompts.get_prompt_by_command(form_data.command) if prompt is None: prompt = Prompts.insert_new_prompt(user.id, form_data) @@ -50,7 +81,12 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)): prompt = Prompts.get_prompt_by_command(f"/{command}") if prompt: - return prompt + if ( + user.role == "admin" + or prompt.user_id == user.id + or has_access(user.id, "read", prompt.access_control) + ): + return prompt else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -67,8 +103,21 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)): async def update_prompt_by_command( command: str, form_data: PromptForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): + prompt = Prompts.get_prompt_by_command(f"/{command}") + if not prompt: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if prompt.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) if prompt: return prompt @@ -85,6 +134,19 @@ async def update_prompt_by_command( @router.delete("/command/{command}/delete", response_model=bool) -async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)): +async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)): + prompt = Prompts.get_prompt_by_command(f"/{command}") + if not prompt: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if prompt.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + result = Prompts.delete_prompt_by_command(f"/{command}") return result diff --git a/backend/open_webui/apps/webui/routers/tools.py b/backend/open_webui/apps/webui/routers/tools.py index 897d1dc6214d786d68a975851baf945f7a302d20..c0479907af51e2c2bd2ad1ba0fcd4d4b07d8b8b4 100644 --- a/backend/open_webui/apps/webui/routers/tools.py +++ b/backend/open_webui/apps/webui/routers/tools.py @@ -2,50 +2,82 @@ import os from pathlib import Path from typing import Optional -from open_webui.apps.webui.models.tools import ToolForm, ToolModel, ToolResponse, Tools -from open_webui.apps.webui.utils import load_toolkit_module_by_id, replace_imports +from open_webui.apps.webui.models.tools import ( + ToolForm, + ToolModel, + ToolResponse, + ToolUserResponse, + Tools, +) +from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports from open_webui.config import CACHE_DIR, DATA_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.tools import get_tools_specs from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access, has_permission router = APIRouter() ############################ -# GetToolkits +# GetTools ############################ -@router.get("/", response_model=list[ToolResponse]) -async def get_toolkits(user=Depends(get_verified_user)): - toolkits = [toolkit for toolkit in Tools.get_tools()] - return toolkits +@router.get("/", response_model=list[ToolUserResponse]) +async def get_tools(user=Depends(get_verified_user)): + if user.role == "admin": + tools = Tools.get_tools() + else: + tools = Tools.get_tools_by_user_id(user.id, "read") + return tools ############################ -# ExportToolKits +# GetToolList +############################ + + +@router.get("/list", response_model=list[ToolUserResponse]) +async def get_tool_list(user=Depends(get_verified_user)): + if user.role == "admin": + tools = Tools.get_tools() + else: + tools = Tools.get_tools_by_user_id(user.id, "write") + return tools + + +############################ +# ExportTools ############################ @router.get("/export", response_model=list[ToolModel]) -async def get_toolkits(user=Depends(get_admin_user)): - toolkits = [toolkit for toolkit in Tools.get_tools()] - return toolkits +async def export_tools(user=Depends(get_admin_user)): + tools = Tools.get_tools() + return tools ############################ -# CreateNewToolKit +# CreateNewTools ############################ @router.post("/create", response_model=Optional[ToolResponse]) -async def create_new_toolkit( +async def create_new_tools( request: Request, form_data: ToolForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): + if user.role != "admin" and not has_permission( + user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + if not form_data.id.isidentifier(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -54,30 +86,30 @@ async def create_new_toolkit( form_data.id = form_data.id.lower() - toolkit = Tools.get_tool_by_id(form_data.id) - if toolkit is None: + tools = Tools.get_tool_by_id(form_data.id) + if tools is None: try: form_data.content = replace_imports(form_data.content) - toolkit_module, frontmatter = load_toolkit_module_by_id( + tools_module, frontmatter = load_tools_module_by_id( form_data.id, content=form_data.content ) form_data.meta.manifest = frontmatter TOOLS = request.app.state.TOOLS - TOOLS[form_data.id] = toolkit_module + TOOLS[form_data.id] = tools_module specs = get_tools_specs(TOOLS[form_data.id]) - toolkit = Tools.insert_new_tool(user.id, form_data, specs) + tools = Tools.insert_new_tool(user.id, form_data, specs) tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id tool_cache_dir.mkdir(parents=True, exist_ok=True) - if toolkit: - return toolkit + if tools: + return tools else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"), + detail=ERROR_MESSAGES.DEFAULT("Error creating tools"), ) except Exception as e: print(e) @@ -93,16 +125,21 @@ async def create_new_toolkit( ############################ -# GetToolkitById +# GetToolsById ############################ @router.get("/id/{id}", response_model=Optional[ToolModel]) -async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): - toolkit = Tools.get_tool_by_id(id) - - if toolkit: - return toolkit +async def get_tools_by_id(id: str, user=Depends(get_verified_user)): + tools = Tools.get_tool_by_id(id) + + if tools: + if ( + user.role == "admin" + or tools.user_id == user.id + or has_access(user.id, "read", tools.access_control) + ): + return tools else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -111,26 +148,39 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): ############################ -# UpdateToolkitById +# UpdateToolsById ############################ @router.post("/id/{id}/update", response_model=Optional[ToolModel]) -async def update_toolkit_by_id( +async def update_tools_by_id( request: Request, id: str, form_data: ToolForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): + tools = Tools.get_tool_by_id(id) + if not tools: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if tools.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + try: form_data.content = replace_imports(form_data.content) - toolkit_module, frontmatter = load_toolkit_module_by_id( + tools_module, frontmatter = load_tools_module_by_id( id, content=form_data.content ) form_data.meta.manifest = frontmatter TOOLS = request.app.state.TOOLS - TOOLS[id] = toolkit_module + TOOLS[id] = tools_module specs = get_tools_specs(TOOLS[id]) @@ -140,14 +190,14 @@ async def update_toolkit_by_id( } print(updated) - toolkit = Tools.update_tool_by_id(id, updated) + tools = Tools.update_tool_by_id(id, updated) - if toolkit: - return toolkit + if tools: + return tools else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"), + detail=ERROR_MESSAGES.DEFAULT("Error updating tools"), ) except Exception as e: @@ -158,14 +208,28 @@ async def update_toolkit_by_id( ############################ -# DeleteToolkitById +# DeleteToolsById ############################ @router.delete("/id/{id}/delete", response_model=bool) -async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)): - result = Tools.delete_tool_by_id(id) +async def delete_tools_by_id( + request: Request, id: str, user=Depends(get_verified_user) +): + tools = Tools.get_tool_by_id(id) + if not tools: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + if tools.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + + result = Tools.delete_tool_by_id(id) if result: TOOLS = request.app.state.TOOLS if id in TOOLS: @@ -180,9 +244,9 @@ async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin @router.get("/id/{id}/valves", response_model=Optional[dict]) -async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)): - toolkit = Tools.get_tool_by_id(id) - if toolkit: +async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)): + tools = Tools.get_tool_by_id(id) + if tools: try: valves = Tools.get_tool_valves_by_id(id) return valves @@ -204,19 +268,19 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)): @router.get("/id/{id}/valves/spec", response_model=Optional[dict]) -async def get_toolkit_valves_spec_by_id( - request: Request, id: str, user=Depends(get_admin_user) +async def get_tools_valves_spec_by_id( + request: Request, id: str, user=Depends(get_verified_user) ): - toolkit = Tools.get_tool_by_id(id) - if toolkit: + tools = Tools.get_tool_by_id(id) + if tools: if id in request.app.state.TOOLS: - toolkit_module = request.app.state.TOOLS[id] + tools_module = request.app.state.TOOLS[id] else: - toolkit_module, _ = load_toolkit_module_by_id(id) - request.app.state.TOOLS[id] = toolkit_module + tools_module, _ = load_tools_module_by_id(id) + request.app.state.TOOLS[id] = tools_module - if hasattr(toolkit_module, "Valves"): - Valves = toolkit_module.Valves + if hasattr(tools_module, "Valves"): + Valves = tools_module.Valves return Valves.schema() return None else: @@ -232,19 +296,19 @@ async def get_toolkit_valves_spec_by_id( @router.post("/id/{id}/valves/update", response_model=Optional[dict]) -async def update_toolkit_valves_by_id( - request: Request, id: str, form_data: dict, user=Depends(get_admin_user) +async def update_tools_valves_by_id( + request: Request, id: str, form_data: dict, user=Depends(get_verified_user) ): - toolkit = Tools.get_tool_by_id(id) - if toolkit: + tools = Tools.get_tool_by_id(id) + if tools: if id in request.app.state.TOOLS: - toolkit_module = request.app.state.TOOLS[id] + tools_module = request.app.state.TOOLS[id] else: - toolkit_module, _ = load_toolkit_module_by_id(id) - request.app.state.TOOLS[id] = toolkit_module + tools_module, _ = load_tools_module_by_id(id) + request.app.state.TOOLS[id] = tools_module - if hasattr(toolkit_module, "Valves"): - Valves = toolkit_module.Valves + if hasattr(tools_module, "Valves"): + Valves = tools_module.Valves try: form_data = {k: v for k, v in form_data.items() if v is not None} @@ -276,9 +340,9 @@ async def update_toolkit_valves_by_id( @router.get("/id/{id}/valves/user", response_model=Optional[dict]) -async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)): - toolkit = Tools.get_tool_by_id(id) - if toolkit: +async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)): + tools = Tools.get_tool_by_id(id) + if tools: try: user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id) return user_valves @@ -295,19 +359,19 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user) @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict]) -async def get_toolkit_user_valves_spec_by_id( +async def get_tools_user_valves_spec_by_id( request: Request, id: str, user=Depends(get_verified_user) ): - toolkit = Tools.get_tool_by_id(id) - if toolkit: + tools = Tools.get_tool_by_id(id) + if tools: if id in request.app.state.TOOLS: - toolkit_module = request.app.state.TOOLS[id] + tools_module = request.app.state.TOOLS[id] else: - toolkit_module, _ = load_toolkit_module_by_id(id) - request.app.state.TOOLS[id] = toolkit_module + tools_module, _ = load_tools_module_by_id(id) + request.app.state.TOOLS[id] = tools_module - if hasattr(toolkit_module, "UserValves"): - UserValves = toolkit_module.UserValves + if hasattr(tools_module, "UserValves"): + UserValves = tools_module.UserValves return UserValves.schema() return None else: @@ -318,20 +382,20 @@ async def get_toolkit_user_valves_spec_by_id( @router.post("/id/{id}/valves/user/update", response_model=Optional[dict]) -async def update_toolkit_user_valves_by_id( +async def update_tools_user_valves_by_id( request: Request, id: str, form_data: dict, user=Depends(get_verified_user) ): - toolkit = Tools.get_tool_by_id(id) + tools = Tools.get_tool_by_id(id) - if toolkit: + if tools: if id in request.app.state.TOOLS: - toolkit_module = request.app.state.TOOLS[id] + tools_module = request.app.state.TOOLS[id] else: - toolkit_module, _ = load_toolkit_module_by_id(id) - request.app.state.TOOLS[id] = toolkit_module + tools_module, _ = load_tools_module_by_id(id) + request.app.state.TOOLS[id] = tools_module - if hasattr(toolkit_module, "UserValves"): - UserValves = toolkit_module.UserValves + if hasattr(tools_module, "UserValves"): + UserValves = tools_module.UserValves try: form_data = {k: v for k, v in form_data.items() if v is not None} diff --git a/backend/open_webui/apps/webui/routers/users.py b/backend/open_webui/apps/webui/routers/users.py index 4485f7f2bd324e37fa1a88d14b19172c5adf330c..59fe67a84367dd2f7b4279a55bcd9401129fc2b2 100644 --- a/backend/open_webui/apps/webui/routers/users.py +++ b/backend/open_webui/apps/webui/routers/users.py @@ -31,21 +31,58 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user) return Users.get_users(skip, limit) +############################ +# User Groups +############################ + + +@router.get("/groups") +async def get_user_groups(user=Depends(get_verified_user)): + return Users.get_user_groups(user.id) + + ############################ # User Permissions ############################ -@router.get("/permissions/user") +@router.get("/permissions") +async def get_user_permissisions(user=Depends(get_verified_user)): + return Users.get_user_groups(user.id) + + +############################ +# User Default Permissions +############################ +class WorkspacePermissions(BaseModel): + models: bool + knowledge: bool + prompts: bool + tools: bool + + +class ChatPermissions(BaseModel): + file_upload: bool + delete: bool + edit: bool + temporary: bool + + +class UserPermissions(BaseModel): + workspace: WorkspacePermissions + chat: ChatPermissions + + +@router.get("/default/permissions") async def get_user_permissions(request: Request, user=Depends(get_admin_user)): return request.app.state.config.USER_PERMISSIONS -@router.post("/permissions/user") +@router.post("/default/permissions") async def update_user_permissions( - request: Request, form_data: dict, user=Depends(get_admin_user) + request: Request, form_data: UserPermissions, user=Depends(get_admin_user) ): - request.app.state.config.USER_PERMISSIONS = form_data + request.app.state.config.USER_PERMISSIONS = form_data.model_dump() return request.app.state.config.USER_PERMISSIONS diff --git a/backend/open_webui/apps/webui/utils.py b/backend/open_webui/apps/webui/utils.py index e8968facd6ab2af52de367c057dc104975919d3f..465d36745c19c61ece2c9a84dbef745f27cf0df1 100644 --- a/backend/open_webui/apps/webui/utils.py +++ b/backend/open_webui/apps/webui/utils.py @@ -63,7 +63,7 @@ def replace_imports(content): return content -def load_toolkit_module_by_id(toolkit_id, content=None): +def load_tools_module_by_id(toolkit_id, content=None): if content is None: tool = Tools.get_tool_by_id(toolkit_id) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 6348c61c11a7cde5e8df0c5475ad3c8025cd16f4..874cd2e9d9401b99d665fec3b08b5adef6af7400 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -20,6 +20,7 @@ from open_webui.env import ( WEBUI_FAVICON_URL, WEBUI_NAME, log, + DATABASE_URL, ) from pydantic import BaseModel from sqlalchemy import JSON, Column, DateTime, Integer, func @@ -264,6 +265,13 @@ class AppConfig: # WEBUI_AUTH (Required for security) #################################### +ENABLE_API_KEY = PersistentConfig( + "ENABLE_API_KEY", + "auth.api_key.enable", + os.environ.get("ENABLE_API_KEY", "True").lower() == "true", +) + + JWT_EXPIRES_IN = PersistentConfig( "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") ) @@ -606,6 +614,12 @@ OLLAMA_BASE_URLS = PersistentConfig( "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS ) +OLLAMA_API_CONFIGS = PersistentConfig( + "OLLAMA_API_CONFIGS", + "ollama.api_configs", + {}, +) + #################################### # OPENAI_API #################################### @@ -646,15 +660,20 @@ OPENAI_API_BASE_URLS = PersistentConfig( "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS ) -OPENAI_API_KEY = "" +OPENAI_API_CONFIGS = PersistentConfig( + "OPENAI_API_CONFIGS", + "openai.api_configs", + {}, +) +# Get the actual OpenAI API key based on the base URL +OPENAI_API_KEY = "" try: OPENAI_API_KEY = OPENAI_API_KEYS.value[ OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1") ] except Exception: pass - OPENAI_API_BASE_URL = "https://api.openai.com/v1" #################################### @@ -727,12 +746,36 @@ DEFAULT_USER_ROLE = PersistentConfig( os.getenv("DEFAULT_USER_ROLE", "pending"), ) -USER_PERMISSIONS_CHAT_DELETION = ( - os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true" + +USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower() + == "true" +) + +USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS", "False").lower() + == "true" +) + +USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS", "False").lower() + == "true" +) + +USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true" +) + +USER_PERMISSIONS_CHAT_FILE_UPLOAD = ( + os.environ.get("USER_PERMISSIONS_CHAT_FILE_UPLOAD", "True").lower() == "true" +) + +USER_PERMISSIONS_CHAT_DELETE = ( + os.environ.get("USER_PERMISSIONS_CHAT_DELETE", "True").lower() == "true" ) -USER_PERMISSIONS_CHAT_EDITING = ( - os.environ.get("USER_PERMISSIONS_CHAT_EDITING", "True").lower() == "true" +USER_PERMISSIONS_CHAT_EDIT = ( + os.environ.get("USER_PERMISSIONS_CHAT_EDIT", "True").lower() == "true" ) USER_PERMISSIONS_CHAT_TEMPORARY = ( @@ -741,13 +784,20 @@ USER_PERMISSIONS_CHAT_TEMPORARY = ( USER_PERMISSIONS = PersistentConfig( "USER_PERMISSIONS", - "ui.user_permissions", + "user.permissions", { + "workspace": { + "models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS, + "knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS, + "prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS, + "tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS, + }, "chat": { - "deletion": USER_PERMISSIONS_CHAT_DELETION, - "editing": USER_PERMISSIONS_CHAT_EDITING, + "file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD, + "delete": USER_PERMISSIONS_CHAT_DELETE, + "edit": USER_PERMISSIONS_CHAT_EDIT, "temporary": USER_PERMISSIONS_CHAT_TEMPORARY, - } + }, }, ) @@ -773,18 +823,6 @@ DEFAULT_ARENA_MODEL = { }, } -ENABLE_MODEL_FILTER = PersistentConfig( - "ENABLE_MODEL_FILTER", - "model_filter.enable", - os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true", -) -MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") -MODEL_FILTER_LIST = PersistentConfig( - "MODEL_FILTER_LIST", - "model_filter.list", - [model.strip() for model in MODEL_FILTER_LIST.split(";")], -) - WEBHOOK_URL = PersistentConfig( "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "") ) @@ -904,19 +942,55 @@ TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig( os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""), ) -ENABLE_SEARCH_QUERY = PersistentConfig( - "ENABLE_SEARCH_QUERY", - "task.search.enable", - os.environ.get("ENABLE_SEARCH_QUERY", "True").lower() == "true", +ENABLE_TAGS_GENERATION = PersistentConfig( + "ENABLE_TAGS_GENERATION", + "task.tags.enable", + os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true", +) + + +ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig( + "ENABLE_SEARCH_QUERY_GENERATION", + "task.query.search.enable", + os.environ.get("ENABLE_SEARCH_QUERY_GENERATION", "True").lower() == "true", +) + +ENABLE_RETRIEVAL_QUERY_GENERATION = PersistentConfig( + "ENABLE_RETRIEVAL_QUERY_GENERATION", + "task.query.retrieval.enable", + os.environ.get("ENABLE_RETRIEVAL_QUERY_GENERATION", "True").lower() == "true", ) -SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig( - "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", - "task.search.prompt_template", - os.environ.get("SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", ""), +QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "QUERY_GENERATION_PROMPT_TEMPLATE", + "task.query.prompt_template", + os.environ.get("QUERY_GENERATION_PROMPT_TEMPLATE", ""), ) +DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE = """### Task: +Based on the chat history, determine whether a search is necessary, and if so, generate a 1-3 broad search queries to retrieve comprehensive and updated information. If no search is required, return an empty list. + +### Guidelines: +- Respond exclusively with a JSON object. +- If a search query is needed, return an object like: { "queries": ["query1", "query2"] } where each query is distinct and concise. +- If no search query is necessary, output should be: { "queries": [] } +- Default to suggesting a search query to ensure accurate and updated information, unless it is definitively clear no search is required. +- Be concise, focusing strictly on composing search queries with no additional commentary or text. +- When in doubt, prefer to suggest a search for comprehensiveness. +- Today's date is: {{CURRENT_DATE}} + +### Output: +JSON format: { + "queries": ["query1", "query2"] +} + +### Chat History: + +{{MESSAGES:END:6}} + +""" + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", @@ -956,6 +1030,21 @@ MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") # Qdrant QDRANT_URI = os.environ.get("QDRANT_URI", None) +QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None) + +# OpenSearch +OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200") +OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", True) +OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False) +OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None) +OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None) + +# Pgvector +PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", DATABASE_URL) +if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"): + raise ValueError( + "Pgvector requires setting PGVECTOR_DB_URL or using Postgres with vector extension as the primary database." + ) #################################### # Information Retrieval (RAG) @@ -1035,11 +1124,11 @@ RAG_EMBEDDING_MODEL = PersistentConfig( log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}") RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( - os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" + os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true" ) RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( - os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" + os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true" ) RAG_EMBEDDING_BATCH_SIZE = PersistentConfig( @@ -1060,11 +1149,11 @@ if RAG_RERANKING_MODEL.value != "": log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}") RAG_RERANKING_MODEL_AUTO_UPDATE = ( - os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true" + os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true" ) RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( - os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" + os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true" ) @@ -1129,6 +1218,19 @@ RAG_OPENAI_API_KEY = PersistentConfig( os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY), ) +RAG_OLLAMA_BASE_URL = PersistentConfig( + "RAG_OLLAMA_BASE_URL", + "rag.ollama.url", + os.getenv("RAG_OLLAMA_BASE_URL", OLLAMA_BASE_URL), +) + +RAG_OLLAMA_API_KEY = PersistentConfig( + "RAG_OLLAMA_API_KEY", + "rag.ollama.key", + os.getenv("RAG_OLLAMA_API_KEY", ""), +) + + ENABLE_RAG_LOCAL_WEB_FETCH = ( os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true" ) @@ -1218,6 +1320,12 @@ TAVILY_API_KEY = PersistentConfig( os.getenv("TAVILY_API_KEY", ""), ) +JINA_API_KEY = PersistentConfig( + "JINA_API_KEY", + "rag.web.search.jina_api_key", + os.getenv("JINA_API_KEY", ""), +) + SEARCHAPI_API_KEY = PersistentConfig( "SEARCHAPI_API_KEY", "rag.web.search.searchapi_api_key", @@ -1230,6 +1338,21 @@ SEARCHAPI_ENGINE = PersistentConfig( os.getenv("SEARCHAPI_ENGINE", ""), ) +BING_SEARCH_V7_ENDPOINT = PersistentConfig( + "BING_SEARCH_V7_ENDPOINT", + "rag.web.search.bing_search_v7_endpoint", + os.environ.get( + "BING_SEARCH_V7_ENDPOINT", "https://api.bing.microsoft.com/v7.0/search" + ), +) + +BING_SEARCH_V7_SUBSCRIPTION_KEY = PersistentConfig( + "BING_SEARCH_V7_SUBSCRIPTION_KEY", + "rag.web.search.bing_search_v7_subscription_key", + os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""), +) + + RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig( "RAG_WEB_SEARCH_RESULT_COUNT", "rag.web.search.result_count", @@ -1281,7 +1404,7 @@ AUTOMATIC1111_CFG_SCALE = PersistentConfig( AUTOMATIC1111_SAMPLER = PersistentConfig( - "AUTOMATIC1111_SAMPLERE", + "AUTOMATIC1111_SAMPLER", "image_generation.automatic1111.sampler", ( os.environ.get("AUTOMATIC1111_SAMPLER") @@ -1550,3 +1673,74 @@ AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT = PersistentConfig( "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", "audio-24khz-160kbitrate-mono-mp3" ), ) + + +#################################### +# LDAP +#################################### + +ENABLE_LDAP = PersistentConfig( + "ENABLE_LDAP", + "ldap.enable", + os.environ.get("ENABLE_LDAP", "false").lower() == "true", +) + +LDAP_SERVER_LABEL = PersistentConfig( + "LDAP_SERVER_LABEL", + "ldap.server.label", + os.environ.get("LDAP_SERVER_LABEL", "LDAP Server"), +) + +LDAP_SERVER_HOST = PersistentConfig( + "LDAP_SERVER_HOST", + "ldap.server.host", + os.environ.get("LDAP_SERVER_HOST", "localhost"), +) + +LDAP_SERVER_PORT = PersistentConfig( + "LDAP_SERVER_PORT", + "ldap.server.port", + int(os.environ.get("LDAP_SERVER_PORT", "389")), +) + +LDAP_ATTRIBUTE_FOR_USERNAME = PersistentConfig( + "LDAP_ATTRIBUTE_FOR_USERNAME", + "ldap.server.attribute_for_username", + os.environ.get("LDAP_ATTRIBUTE_FOR_USERNAME", "uid"), +) + +LDAP_APP_DN = PersistentConfig( + "LDAP_APP_DN", "ldap.server.app_dn", os.environ.get("LDAP_APP_DN", "") +) + +LDAP_APP_PASSWORD = PersistentConfig( + "LDAP_APP_PASSWORD", + "ldap.server.app_password", + os.environ.get("LDAP_APP_PASSWORD", ""), +) + +LDAP_SEARCH_BASE = PersistentConfig( + "LDAP_SEARCH_BASE", "ldap.server.users_dn", os.environ.get("LDAP_SEARCH_BASE", "") +) + +LDAP_SEARCH_FILTERS = PersistentConfig( + "LDAP_SEARCH_FILTER", + "ldap.server.search_filter", + os.environ.get("LDAP_SEARCH_FILTER", ""), +) + +LDAP_USE_TLS = PersistentConfig( + "LDAP_USE_TLS", + "ldap.server.use_tls", + os.environ.get("LDAP_USE_TLS", "True").lower() == "true", +) + +LDAP_CA_CERT_FILE = PersistentConfig( + "LDAP_CA_CERT_FILE", + "ldap.server.ca_cert_file", + os.environ.get("LDAP_CA_CERT_FILE", ""), +) + +LDAP_CIPHERS = PersistentConfig( + "LDAP_CIPHERS", "ldap.server.ciphers", os.environ.get("LDAP_CIPHERS", "ALL") +) diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index 62ae8293160662430726fd660cfd45874b6b8d1f..863ad3e34b2d5a270853644e2fe854424cb437a0 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -62,6 +62,7 @@ class ERROR_MESSAGES(str, Enum): NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/" API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." + API_KEY_NOT_ALLOWED = "Use of API key is not enabled in the environment." MALICIOUS = "Unusual activities detected, please try again in a few minutes." @@ -75,6 +76,7 @@ class ERROR_MESSAGES(str, Enum): OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found" OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama" CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance." + API_KEY_CREATION_NOT_ALLOWED = "API key creation is not allowed in the environment." EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding." diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 2b2fecf5ae5495253d77fdb2b9319ce1887f9a06..4485c713eb1334864b3b1c26c1c8bc9d826eda08 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -1,384 +1,393 @@ -import importlib.metadata -import json -import logging -import os -import pkgutil -import sys -import shutil -from pathlib import Path - -import markdown -from bs4 import BeautifulSoup -from open_webui.constants import ERROR_MESSAGES - -#################################### -# Load .env file -#################################### - -OPEN_WEBUI_DIR = Path(__file__).parent # the path containing this file -print(OPEN_WEBUI_DIR) - -BACKEND_DIR = OPEN_WEBUI_DIR.parent # the path containing this file -BASE_DIR = BACKEND_DIR.parent # the path containing the backend/ - -print(BACKEND_DIR) -print(BASE_DIR) - -try: - from dotenv import find_dotenv, load_dotenv - - load_dotenv(find_dotenv(str(BASE_DIR / ".env"))) -except ImportError: - print("dotenv not installed, skipping...") - -DOCKER = os.environ.get("DOCKER", "False").lower() == "true" - -# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance -USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") - -if USE_CUDA.lower() == "true": - try: - import torch - - assert torch.cuda.is_available(), "CUDA not available" - DEVICE_TYPE = "cuda" - except Exception as e: - cuda_error = ( - "Error when testing CUDA but USE_CUDA_DOCKER is true. " - f"Resetting USE_CUDA_DOCKER to false: {e}" - ) - os.environ["USE_CUDA_DOCKER"] = "false" - USE_CUDA = "false" - DEVICE_TYPE = "cpu" -else: - DEVICE_TYPE = "cpu" - - -#################################### -# LOGGING -#################################### - -log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"] - -GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "ERROR").upper() -if GLOBAL_LOG_LEVEL in log_levels: - logging.basicConfig(stream=sys.stdout, level="ERROR", force=False) -else: - GLOBAL_LOG_LEVEL = "ERROR" - -log = logging.getLogger(__name__) -log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}") - -if "cuda_error" in locals(): - log.exception(cuda_error) - -log_sources = [ - "AUDIO", - "COMFYUI", - "CONFIG", - "DB", - "IMAGES", - "MAIN", - "MODELS", - "OLLAMA", - "OPENAI", - "RAG", - "WEBHOOK", - "SOCKET", -] - -SRC_LOG_LEVELS = {} - -for source in log_sources: - log_env_var = source + "_LOG_LEVEL" - SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper() - if SRC_LOG_LEVELS[source] not in log_levels: - SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL - log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}") - -log.setLevel(SRC_LOG_LEVELS["CONFIG"]) - - -WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") -if WEBUI_NAME != "Open WebUI": - WEBUI_NAME += " (Open WebUI)" - -WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000") - -WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" - - -#################################### -# ENV (dev,test,prod) -#################################### - -ENV = os.environ.get("ENV", "dev") - -FROM_INIT_PY = os.environ.get("FROM_INIT_PY", "False").lower() == "true" - -if FROM_INIT_PY: - PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")} -else: - try: - PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text()) - except Exception: - PACKAGE_DATA = {"version": "0.0.0"} - - -VERSION = PACKAGE_DATA["version"] - - -# Function to parse each section -def parse_section(section): - items = [] - for li in section.find_all("li"): - # Extract raw HTML string - raw_html = str(li) - - # Extract text without HTML tags - text = li.get_text(separator=" ", strip=True) - - # Split into title and content - parts = text.split(": ", 1) - title = parts[0].strip() if len(parts) > 1 else "" - content = parts[1].strip() if len(parts) > 1 else text - - items.append({"title": title, "content": content, "raw": raw_html}) - return items - - -try: - changelog_path = BASE_DIR / "CHANGELOG.md" - with open(str(changelog_path.absolute()), "r", encoding="utf8") as file: - changelog_content = file.read() - -except Exception: - changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode() - - -# Convert markdown content to HTML -html_content = markdown.markdown(changelog_content) - -# Parse the HTML content -soup = BeautifulSoup(html_content, "html.parser") - -# Initialize JSON structure -changelog_json = {} - -# Iterate over each version -for version in soup.find_all("h2"): - version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets - date = version.get_text().strip().split(" - ")[1] - - version_data = {"date": date} - - # Find the next sibling that is a h3 tag (section title) - current = version.find_next_sibling() - - while current and current.name != "h2": - if current.name == "h3": - section_title = current.get_text().lower() # e.g., "added", "fixed" - section_items = parse_section(current.find_next_sibling("ul")) - version_data[section_title] = section_items - - # Move to the next element - current = current.find_next_sibling() - - changelog_json[version_number] = version_data - - -CHANGELOG = changelog_json - -#################################### -# SAFE_MODE -#################################### - -SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true" - -#################################### -# WEBUI_BUILD_HASH -#################################### - -WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build") - -#################################### -# DATA/FRONTEND BUILD DIR -#################################### - -DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve() - -if FROM_INIT_PY: - NEW_DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")).resolve() - NEW_DATA_DIR.mkdir(parents=True, exist_ok=True) - - # Check if the data directory exists in the package directory - if DATA_DIR.exists() and DATA_DIR != NEW_DATA_DIR: - log.info(f"Moving {DATA_DIR} to {NEW_DATA_DIR}") - for item in DATA_DIR.iterdir(): - dest = NEW_DATA_DIR / item.name - if item.is_dir(): - shutil.copytree(item, dest, dirs_exist_ok=True) - else: - shutil.copy2(item, dest) - - # Zip the data directory - shutil.make_archive(DATA_DIR.parent / "open_webui_data", "zip", DATA_DIR) - - # Remove the old data directory - shutil.rmtree(DATA_DIR) - - DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")) - - -STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")) - -FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts")) - -FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve() - -if FROM_INIT_PY: - FRONTEND_BUILD_DIR = Path( - os.getenv("FRONTEND_BUILD_DIR", OPEN_WEBUI_DIR / "frontend") - ).resolve() - - -#################################### -# Database -#################################### - -# Check if the file exists -if os.path.exists(f"{DATA_DIR}/ollama.db"): - # Rename the file - os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db") - log.info("Database migrated from Ollama-WebUI successfully.") -else: - pass - -DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db") - -# Replace the postgres:// with postgresql:// -if "postgres://" in DATABASE_URL: - DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://") - -DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", 0) - -if DATABASE_POOL_SIZE == "": - DATABASE_POOL_SIZE = 0 -else: - try: - DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE) - except Exception: - DATABASE_POOL_SIZE = 0 - -DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0) - -if DATABASE_POOL_MAX_OVERFLOW == "": - DATABASE_POOL_MAX_OVERFLOW = 0 -else: - try: - DATABASE_POOL_MAX_OVERFLOW = int(DATABASE_POOL_MAX_OVERFLOW) - except Exception: - DATABASE_POOL_MAX_OVERFLOW = 0 - -DATABASE_POOL_TIMEOUT = os.environ.get("DATABASE_POOL_TIMEOUT", 30) - -if DATABASE_POOL_TIMEOUT == "": - DATABASE_POOL_TIMEOUT = 30 -else: - try: - DATABASE_POOL_TIMEOUT = int(DATABASE_POOL_TIMEOUT) - except Exception: - DATABASE_POOL_TIMEOUT = 30 - -DATABASE_POOL_RECYCLE = os.environ.get("DATABASE_POOL_RECYCLE", 3600) - -if DATABASE_POOL_RECYCLE == "": - DATABASE_POOL_RECYCLE = 3600 -else: - try: - DATABASE_POOL_RECYCLE = int(DATABASE_POOL_RECYCLE) - except Exception: - DATABASE_POOL_RECYCLE = 3600 - -RESET_CONFIG_ON_START = ( - os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true" -) - -#################################### -# REDIS -#################################### - -REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0") - -#################################### -# WEBUI_AUTH (Required for security) -#################################### - -WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" -WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( - "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None -) -WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None) - - -#################################### -# WEBUI_SECRET_KEY -#################################### - -WEBUI_SECRET_KEY = os.environ.get( - "WEBUI_SECRET_KEY", - os.environ.get( - "WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t" - ), # DEPRECATED: remove at next major version -) - -WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get( - "WEBUI_SESSION_COOKIE_SAME_SITE", - os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"), -) - -WEBUI_SESSION_COOKIE_SECURE = os.environ.get( - "WEBUI_SESSION_COOKIE_SECURE", - os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true", -) - -if WEBUI_AUTH and WEBUI_SECRET_KEY == "": - raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) - -ENABLE_WEBSOCKET_SUPPORT = ( - os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true" -) - -WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") - -WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL) - -AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") - -if AIOHTTP_CLIENT_TIMEOUT == "": - AIOHTTP_CLIENT_TIMEOUT = None -else: - try: - AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT) - except Exception: - AIOHTTP_CLIENT_TIMEOUT = 300 - -AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get( - "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "3" -) - -if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "": - AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = None -else: - try: - AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = int( - AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST - ) - except Exception: - AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 3 - -#################################### -# OFFLINE_MODE -#################################### - -OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true" +import importlib.metadata +import json +import logging +import os +import pkgutil +import sys +import shutil +from pathlib import Path + +import markdown +from bs4 import BeautifulSoup +from open_webui.constants import ERROR_MESSAGES + +#################################### +# Load .env file +#################################### + +OPEN_WEBUI_DIR = Path(__file__).parent # the path containing this file +print(OPEN_WEBUI_DIR) + +BACKEND_DIR = OPEN_WEBUI_DIR.parent # the path containing this file +BASE_DIR = BACKEND_DIR.parent # the path containing the backend/ + +print(BACKEND_DIR) +print(BASE_DIR) + +try: + from dotenv import find_dotenv, load_dotenv + + load_dotenv(find_dotenv(str(BASE_DIR / ".env"))) +except ImportError: + print("dotenv not installed, skipping...") + +DOCKER = os.environ.get("DOCKER", "False").lower() == "true" + +# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance +USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") + +if USE_CUDA.lower() == "true": + try: + import torch + + assert torch.cuda.is_available(), "CUDA not available" + DEVICE_TYPE = "cuda" + except Exception as e: + cuda_error = ( + "Error when testing CUDA but USE_CUDA_DOCKER is true. " + f"Resetting USE_CUDA_DOCKER to false: {e}" + ) + os.environ["USE_CUDA_DOCKER"] = "false" + USE_CUDA = "false" + DEVICE_TYPE = "cpu" +else: + DEVICE_TYPE = "cpu" + + +#################################### +# LOGGING +#################################### + +log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"] + +GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper() +if GLOBAL_LOG_LEVEL in log_levels: + logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True) +else: + GLOBAL_LOG_LEVEL = "INFO" + +log = logging.getLogger(__name__) +log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}") + +if "cuda_error" in locals(): + log.exception(cuda_error) + +log_sources = [ + "AUDIO", + "COMFYUI", + "CONFIG", + "DB", + "IMAGES", + "MAIN", + "MODELS", + "OLLAMA", + "OPENAI", + "RAG", + "WEBHOOK", + "SOCKET", +] + +SRC_LOG_LEVELS = {} + +for source in log_sources: + log_env_var = source + "_LOG_LEVEL" + SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper() + if SRC_LOG_LEVELS[source] not in log_levels: + SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL + log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}") + +log.setLevel(SRC_LOG_LEVELS["CONFIG"]) + + +WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") +if WEBUI_NAME != "Open WebUI": + WEBUI_NAME += " (Open WebUI)" + +WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000") + +WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" + + +#################################### +# ENV (dev,test,prod) +#################################### + +ENV = os.environ.get("ENV", "dev") + +FROM_INIT_PY = os.environ.get("FROM_INIT_PY", "False").lower() == "true" + +if FROM_INIT_PY: + PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")} +else: + try: + PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text()) + except Exception: + PACKAGE_DATA = {"version": "0.0.0"} + + +VERSION = PACKAGE_DATA["version"] + + +# Function to parse each section +def parse_section(section): + items = [] + for li in section.find_all("li"): + # Extract raw HTML string + raw_html = str(li) + + # Extract text without HTML tags + text = li.get_text(separator=" ", strip=True) + + # Split into title and content + parts = text.split(": ", 1) + title = parts[0].strip() if len(parts) > 1 else "" + content = parts[1].strip() if len(parts) > 1 else text + + items.append({"title": title, "content": content, "raw": raw_html}) + return items + + +try: + changelog_path = BASE_DIR / "CHANGELOG.md" + with open(str(changelog_path.absolute()), "r", encoding="utf8") as file: + changelog_content = file.read() + +except Exception: + changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode() + + +# Convert markdown content to HTML +html_content = markdown.markdown(changelog_content) + +# Parse the HTML content +soup = BeautifulSoup(html_content, "html.parser") + +# Initialize JSON structure +changelog_json = {} + +# Iterate over each version +for version in soup.find_all("h2"): + version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets + date = version.get_text().strip().split(" - ")[1] + + version_data = {"date": date} + + # Find the next sibling that is a h3 tag (section title) + current = version.find_next_sibling() + + while current and current.name != "h2": + if current.name == "h3": + section_title = current.get_text().lower() # e.g., "added", "fixed" + section_items = parse_section(current.find_next_sibling("ul")) + version_data[section_title] = section_items + + # Move to the next element + current = current.find_next_sibling() + + changelog_json[version_number] = version_data + + +CHANGELOG = changelog_json + +#################################### +# SAFE_MODE +#################################### + +SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true" + +#################################### +# ENABLE_FORWARD_USER_INFO_HEADERS +#################################### + +ENABLE_FORWARD_USER_INFO_HEADERS = ( + os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true" +) + + +#################################### +# WEBUI_BUILD_HASH +#################################### + +WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build") + +#################################### +# DATA/FRONTEND BUILD DIR +#################################### + +DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve() + +if FROM_INIT_PY: + NEW_DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")).resolve() + NEW_DATA_DIR.mkdir(parents=True, exist_ok=True) + + # Check if the data directory exists in the package directory + if DATA_DIR.exists() and DATA_DIR != NEW_DATA_DIR: + log.info(f"Moving {DATA_DIR} to {NEW_DATA_DIR}") + for item in DATA_DIR.iterdir(): + dest = NEW_DATA_DIR / item.name + if item.is_dir(): + shutil.copytree(item, dest, dirs_exist_ok=True) + else: + shutil.copy2(item, dest) + + # Zip the data directory + shutil.make_archive(DATA_DIR.parent / "open_webui_data", "zip", DATA_DIR) + + # Remove the old data directory + shutil.rmtree(DATA_DIR) + + DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")) + + +STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")) + +FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts")) + +FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve() + +if FROM_INIT_PY: + FRONTEND_BUILD_DIR = Path( + os.getenv("FRONTEND_BUILD_DIR", OPEN_WEBUI_DIR / "frontend") + ).resolve() + + +#################################### +# Database +#################################### + +# Check if the file exists +if os.path.exists(f"{DATA_DIR}/ollama.db"): + # Rename the file + os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db") + log.info("Database migrated from Ollama-WebUI successfully.") +else: + pass + +DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db") + +# Replace the postgres:// with postgresql:// +if "postgres://" in DATABASE_URL: + DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://") + +DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", 0) + +if DATABASE_POOL_SIZE == "": + DATABASE_POOL_SIZE = 0 +else: + try: + DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE) + except Exception: + DATABASE_POOL_SIZE = 0 + +DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0) + +if DATABASE_POOL_MAX_OVERFLOW == "": + DATABASE_POOL_MAX_OVERFLOW = 0 +else: + try: + DATABASE_POOL_MAX_OVERFLOW = int(DATABASE_POOL_MAX_OVERFLOW) + except Exception: + DATABASE_POOL_MAX_OVERFLOW = 0 + +DATABASE_POOL_TIMEOUT = os.environ.get("DATABASE_POOL_TIMEOUT", 30) + +if DATABASE_POOL_TIMEOUT == "": + DATABASE_POOL_TIMEOUT = 30 +else: + try: + DATABASE_POOL_TIMEOUT = int(DATABASE_POOL_TIMEOUT) + except Exception: + DATABASE_POOL_TIMEOUT = 30 + +DATABASE_POOL_RECYCLE = os.environ.get("DATABASE_POOL_RECYCLE", 3600) + +if DATABASE_POOL_RECYCLE == "": + DATABASE_POOL_RECYCLE = 3600 +else: + try: + DATABASE_POOL_RECYCLE = int(DATABASE_POOL_RECYCLE) + except Exception: + DATABASE_POOL_RECYCLE = 3600 + +RESET_CONFIG_ON_START = ( + os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true" +) + +#################################### +# REDIS +#################################### + +REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0") + +#################################### +# WEBUI_AUTH (Required for security) +#################################### + +WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" +WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( + "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None +) +WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None) + + +#################################### +# WEBUI_SECRET_KEY +#################################### + +WEBUI_SECRET_KEY = os.environ.get( + "WEBUI_SECRET_KEY", + os.environ.get( + "WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t" + ), # DEPRECATED: remove at next major version +) + +WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get( + "WEBUI_SESSION_COOKIE_SAME_SITE", + os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"), +) + +WEBUI_SESSION_COOKIE_SECURE = os.environ.get( + "WEBUI_SESSION_COOKIE_SECURE", + os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true", +) + +if WEBUI_AUTH and WEBUI_SECRET_KEY == "": + raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) + +ENABLE_WEBSOCKET_SUPPORT = ( + os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true" +) + +WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") + +WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL) + +AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") + +if AIOHTTP_CLIENT_TIMEOUT == "": + AIOHTTP_CLIENT_TIMEOUT = None +else: + try: + AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT) + except Exception: + AIOHTTP_CLIENT_TIMEOUT = 300 + +AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get( + "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "3" +) + +if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "": + AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = None +else: + try: + AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = int( + AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST + ) + except Exception: + AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 3 + +#################################### +# OFFLINE_MODE +#################################### + +OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true" diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 5703f06cbc239a7b10a0559c47e0095b7fa10bc0..2ac63682aa1e349a9b5db52b666b9c6df6c504dd 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1,2455 +1,2627 @@ -import asyncio -import inspect -import json -import logging -import mimetypes -import os -import shutil -import sys -import time -import random -from contextlib import asynccontextmanager -from typing import Optional - -import aiohttp -import requests -from fastapi import ( - Depends, - FastAPI, - File, - Form, - HTTPException, - Request, - UploadFile, - status, -) -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, RedirectResponse -from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel -from sqlalchemy import text -from starlette.exceptions import HTTPException as StarletteHTTPException -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.middleware.sessions import SessionMiddleware -from starlette.responses import Response, StreamingResponse - -from open_webui.apps.audio.main import app as audio_app -from open_webui.apps.images.main import app as images_app -from open_webui.apps.ollama.main import ( - app as ollama_app, - get_all_models as get_ollama_models, - generate_chat_completion as generate_ollama_chat_completion, - GenerateChatCompletionForm, -) -from open_webui.apps.openai.main import ( - app as openai_app, - generate_chat_completion as generate_openai_chat_completion, - get_all_models as get_openai_models, -) -from open_webui.apps.retrieval.main import app as retrieval_app -from open_webui.apps.retrieval.utils import get_rag_context, rag_template -from open_webui.apps.socket.main import ( - app as socket_app, - periodic_usage_pool_cleanup, - get_event_call, - get_event_emitter, -) -from open_webui.apps.webui.internal.db import Session -from open_webui.apps.webui.main import ( - app as webui_app, - generate_function_chat_completion, - get_all_models as get_open_webui_models, -) -from open_webui.apps.webui.models.functions import Functions -from open_webui.apps.webui.models.models import Models -from open_webui.apps.webui.models.users import UserModel, Users -from open_webui.apps.webui.utils import load_function_module_by_id -from open_webui.config import ( - CACHE_DIR, - CORS_ALLOW_ORIGIN, - DEFAULT_LOCALE, - ENABLE_ADMIN_CHAT_ACCESS, - ENABLE_ADMIN_EXPORT, - ENABLE_MODEL_FILTER, - ENABLE_OLLAMA_API, - ENABLE_OPENAI_API, - ENV, - FRONTEND_BUILD_DIR, - MODEL_FILTER_LIST, - OAUTH_PROVIDERS, - ENABLE_SEARCH_QUERY, - SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, - STATIC_DIR, - TASK_MODEL, - TASK_MODEL_EXTERNAL, - TITLE_GENERATION_PROMPT_TEMPLATE, - TAGS_GENERATION_PROMPT_TEMPLATE, - TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - WEBHOOK_URL, - WEBUI_AUTH, - WEBUI_NAME, - AppConfig, - reset_config, -) -from open_webui.constants import TASKS -from open_webui.env import ( - CHANGELOG, - GLOBAL_LOG_LEVEL, - SAFE_MODE, - SRC_LOG_LEVELS, - VERSION, - WEBUI_BUILD_HASH, - WEBUI_SECRET_KEY, - WEBUI_SESSION_COOKIE_SAME_SITE, - WEBUI_SESSION_COOKIE_SECURE, - WEBUI_URL, - RESET_CONFIG_ON_START, - OFFLINE_MODE, -) -from open_webui.utils.misc import ( - add_or_update_system_message, - get_last_user_message, - prepend_to_first_user_message_content, -) -from open_webui.utils.oauth import oauth_manager -from open_webui.utils.payload import convert_payload_openai_to_ollama -from open_webui.utils.response import ( - convert_response_ollama_to_openai, - convert_streaming_response_ollama_to_openai, -) -from open_webui.utils.security_headers import SecurityHeadersMiddleware -from open_webui.utils.task import ( - moa_response_generation_template, - tags_generation_template, - search_query_generation_template, - emoji_generation_template, - title_generation_template, - tools_function_calling_generation_template, -) -from open_webui.utils.tools import get_tools -from open_webui.utils.utils import ( - decode_token, - get_admin_user, - get_current_user, - get_http_authorization_cred, - get_verified_user, -) - -if SAFE_MODE: - print("SAFE MODE ENABLED") - Functions.deactivate_all_functions() - -logging.basicConfig(stream=sys.stdout, level=logging.ERROR) -log = logging.getLogger(__name__) -log.setLevel(SRC_LOG_LEVELS["MAIN"]) - - -class SPAStaticFiles(StaticFiles): - async def get_response(self, path: str, scope): - try: - return await super().get_response(path, scope) - except (HTTPException, StarletteHTTPException) as ex: - if ex.status_code == 404: - return await super().get_response("index.html", scope) - else: - raise ex - - -print( - rf""" - ___ __ __ _ _ _ ___ - / _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _| -| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || | -| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || | - \___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___| - |_| - - -v{VERSION} - building the best open-source AI user interface. -{f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""} -https://github.com/open-webui/open-webui -""" -) - - -@asynccontextmanager -async def lifespan(app: FastAPI): - if RESET_CONFIG_ON_START: - reset_config() - - asyncio.create_task(periodic_usage_pool_cleanup()) - yield - - -app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, redoc_url=None, lifespan=lifespan -) - -app.state.config = AppConfig() - -app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API -app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API - -app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST - -app.state.config.WEBHOOK_URL = WEBHOOK_URL - -app.state.config.TASK_MODEL = TASK_MODEL -app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL -app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE -app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE -app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( - SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE -) -app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY -app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( - TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE -) - -app.state.MODELS = {} - - -################################## -# -# ChatCompletion Middleware -# -################################## - - -def get_task_model_id(default_model_id): - # Set the task model - task_model_id = default_model_id - # Check if the user has a custom task model and use that model - if app.state.MODELS[task_model_id]["owned_by"] == "ollama": - if ( - app.state.config.TASK_MODEL - and app.state.config.TASK_MODEL in app.state.MODELS - ): - task_model_id = app.state.config.TASK_MODEL - else: - if ( - app.state.config.TASK_MODEL_EXTERNAL - and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS - ): - task_model_id = app.state.config.TASK_MODEL_EXTERNAL - - return task_model_id - - -def get_filter_function_ids(model): - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel - return (function.valves if function.valves else {}).get("priority", 0) - return 0 - - filter_ids = [function.id for function in Functions.get_global_filter_functions()] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] - - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - - filter_ids.sort(key=get_priority) - return filter_ids - - -async def chat_completion_filter_functions_handler(body, model, extra_params): - skip_files = None - - filter_ids = get_filter_function_ids(model) - for filter_id in filter_ids: - filter = Functions.get_function_by_id(filter_id) - if not filter: - continue - - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, _, _ = load_function_module_by_id(filter_id) - webui_app.state.FUNCTIONS[filter_id] = function_module - - # Check if the function has a file_handler variable - if hasattr(function_module, "file_handler"): - skip_files = function_module.file_handler - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) - - if not hasattr(function_module, "inlet"): - continue - - try: - inlet = function_module.inlet - - # Get the signature of the function - sig = inspect.signature(inlet) - params = {"body": body} | { - k: v - for k, v in { - **extra_params, - "__model__": model, - "__id__": filter_id, - }.items() - if k in sig.parameters - } - - if "__user__" in params and hasattr(function_module, "UserValves"): - try: - params["__user__"]["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, params["__user__"]["id"] - ) - ) - except Exception as e: - print(e) - - if inspect.iscoroutinefunction(inlet): - body = await inlet(**params) - else: - body = inlet(**params) - - except Exception as e: - print(f"Error: {e}") - raise e - - if skip_files and "files" in body.get("metadata", {}): - del body["metadata"]["files"] - - return body, {} - - -def get_tools_function_calling_payload(messages, task_model_id, content): - user_message = get_last_user_message(messages) - history = "\n".join( - f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" - for message in messages[::-1][:4] - ) - - prompt = f"History:\n{history}\nQuery: {user_message}" - - return { - "model": task_model_id, - "messages": [ - {"role": "system", "content": content}, - {"role": "user", "content": f"Query: {prompt}"}, - ], - "stream": False, - "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, - } - - -async def get_content_from_response(response) -> Optional[str]: - content = None - if hasattr(response, "body_iterator"): - async for chunk in response.body_iterator: - data = json.loads(chunk.decode("utf-8")) - content = data["choices"][0]["message"]["content"] - - # Cleanup any remaining background tasks if necessary - if response.background is not None: - await response.background() - else: - content = response["choices"][0]["message"]["content"] - return content - - -async def chat_completion_tools_handler( - body: dict, user: UserModel, extra_params: dict -) -> tuple[dict, dict]: - # If tool_ids field is present, call the functions - metadata = body.get("metadata", {}) - - tool_ids = metadata.get("tool_ids", None) - log.debug(f"{tool_ids=}") - if not tool_ids: - return body, {} - - skip_files = False - contexts = [] - citations = [] - - task_model_id = get_task_model_id(body["model"]) - tools = get_tools( - webui_app, - tool_ids, - user, - { - **extra_params, - "__model__": app.state.MODELS[task_model_id], - "__messages__": body["messages"], - "__files__": metadata.get("files", []), - }, - ) - log.info(f"{tools=}") - - specs = [tool["spec"] for tool in tools.values()] - tools_specs = json.dumps(specs) - - if app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "": - template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE - else: - template = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text.""" - - tools_function_calling_prompt = tools_function_calling_generation_template( - template, tools_specs - ) - log.info(f"{tools_function_calling_prompt=}") - payload = get_tools_function_calling_payload( - body["messages"], task_model_id, tools_function_calling_prompt - ) - - try: - payload = filter_pipeline(payload, user) - except Exception as e: - raise e - - try: - response = await generate_chat_completions(form_data=payload, user=user) - log.debug(f"{response=}") - content = await get_content_from_response(response) - log.debug(f"{content=}") - - if not content: - return body, {} - - try: - content = content[content.find("{") : content.rfind("}") + 1] - if not content: - raise Exception("No JSON object found in the response") - - result = json.loads(content) - - tool_function_name = result.get("name", None) - if tool_function_name not in tools: - return body, {} - - tool_function_params = result.get("parameters", {}) - - try: - required_params = ( - tools[tool_function_name] - .get("spec", {}) - .get("parameters", {}) - .get("required", []) - ) - tool_function = tools[tool_function_name]["callable"] - tool_function_params = { - k: v - for k, v in tool_function_params.items() - if k in required_params - } - tool_output = await tool_function(**tool_function_params) - - except Exception as e: - tool_output = str(e) - - if tools[tool_function_name]["citation"]: - citations.append( - { - "source": { - "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" - }, - "document": [tool_output], - "metadata": [{"source": tool_function_name}], - } - ) - if tools[tool_function_name]["file_handler"]: - skip_files = True - - if isinstance(tool_output, str): - contexts.append(tool_output) - except Exception as e: - log.exception(f"Error: {e}") - content = None - except Exception as e: - log.exception(f"Error: {e}") - content = None - - log.debug(f"tool_contexts: {contexts}") - - if skip_files and "files" in body.get("metadata", {}): - del body["metadata"]["files"] - - return body, {"contexts": contexts, "citations": citations} - - -async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]: - contexts = [] - citations = [] - - if files := body.get("metadata", {}).get("files", None): - contexts, citations = get_rag_context( - files=files, - messages=body["messages"], - embedding_function=retrieval_app.state.EMBEDDING_FUNCTION, - k=retrieval_app.state.config.TOP_K, - reranking_function=retrieval_app.state.sentence_transformer_rf, - r=retrieval_app.state.config.RELEVANCE_THRESHOLD, - hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH, - ) - - log.debug(f"rag_contexts: {contexts}, citations: {citations}") - - return body, {"contexts": contexts, "citations": citations} - - -def is_chat_completion_request(request): - return request.method == "POST" and any( - endpoint in request.url.path - for endpoint in ["/ollama/api/chat", "/chat/completions"] - ) - - -async def get_body_and_model_and_user(request): - # Read the original request body - body = await request.body() - body_str = body.decode("utf-8") - body = json.loads(body_str) if body_str else {} - - model_id = body["model"] - if model_id not in app.state.MODELS: - raise Exception("Model not found") - model = app.state.MODELS[model_id] - - user = get_current_user( - request, - get_http_authorization_cred(request.headers.get("Authorization")), - ) - - return body, model, user - - -class ChatCompletionMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - if not is_chat_completion_request(request): - return await call_next(request) - log.debug(f"request.url.path: {request.url.path}") - - try: - body, model, user = await get_body_and_model_and_user(request) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - metadata = { - "chat_id": body.pop("chat_id", None), - "message_id": body.pop("id", None), - "session_id": body.pop("session_id", None), - "tool_ids": body.get("tool_ids", None), - "files": body.get("files", None), - } - body["metadata"] = metadata - - extra_params = { - "__event_emitter__": get_event_emitter(metadata), - "__event_call__": get_event_call(metadata), - "__user__": { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, - } - - # Initialize data_items to store additional data to be sent to the client - # Initialize contexts and citation - data_items = [] - contexts = [] - citations = [] - - try: - body, flags = await chat_completion_filter_functions_handler( - body, model, extra_params - ) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - metadata = { - **metadata, - "tool_ids": body.pop("tool_ids", None), - "files": body.pop("files", None), - } - body["metadata"] = metadata - - try: - body, flags = await chat_completion_tools_handler(body, user, extra_params) - contexts.extend(flags.get("contexts", [])) - citations.extend(flags.get("citations", [])) - except Exception as e: - log.exception(e) - - try: - body, flags = await chat_completion_files_handler(body) - contexts.extend(flags.get("contexts", [])) - citations.extend(flags.get("citations", [])) - except Exception as e: - log.exception(e) - - # If context is not empty, insert it into the messages - if len(contexts) > 0: - context_string = "/n".join(contexts).strip() - prompt = get_last_user_message(body["messages"]) - - if prompt is None: - raise Exception("No user message found") - if ( - retrieval_app.state.config.RELEVANCE_THRESHOLD == 0 - and context_string.strip() == "" - ): - log.debug( - f"With a 0 relevancy threshold for RAG, the context cannot be empty" - ) - - # Workaround for Ollama 2.0+ system prompt issue - # TODO: replace with add_or_update_system_message - if model["owned_by"] == "ollama": - body["messages"] = prepend_to_first_user_message_content( - rag_template( - retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) - else: - body["messages"] = add_or_update_system_message( - rag_template( - retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) - - # If there are citations, add them to the data_items - if len(citations) > 0: - data_items.append({"citations": citations}) - - modified_body_bytes = json.dumps(body).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], - ] - - response = await call_next(request) - if not isinstance(response, StreamingResponse): - return response - - content_type = response.headers["Content-Type"] - is_openai = "text/event-stream" in content_type - is_ollama = "application/x-ndjson" in content_type - if not is_openai and not is_ollama: - return response - - def wrap_item(item): - return f"data: {item}\n\n" if is_openai else f"{item}\n" - - async def stream_wrapper(original_generator, data_items): - for item in data_items: - yield wrap_item(json.dumps(item)) - - async for data in original_generator: - yield data - - return StreamingResponse( - stream_wrapper(response.body_iterator, data_items), - headers=dict(response.headers), - ) - - async def _receive(self, body: bytes): - return {"type": "http.request", "body": body, "more_body": False} - - -app.add_middleware(ChatCompletionMiddleware) - - -################################## -# -# Pipeline Middleware -# -################################## - - -def get_sorted_filters(model_id): - filters = [ - model - for model in app.state.MODELS.values() - if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" - and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) - ) - ] - sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) - return sorted_filters - - -def filter_pipeline(payload, user): - user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} - model_id = payload["model"] - sorted_filters = get_sorted_filters(model_id) - - model = app.state.MODELS[model_id] - - if "pipeline" in model: - sorted_filters.append(model) - - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - if key == "": - continue - - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/inlet", - headers=headers, - json={ - "user": user, - "body": payload, - }, - ) - - r.raise_for_status() - payload = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - res = r.json() - if "detail" in res: - raise Exception(r.status_code, res["detail"]) - - return payload - - -class PipelineMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - if not is_chat_completion_request(request): - return await call_next(request) - - log.debug(f"request.url.path: {request.url.path}") - - # Read the original request body - body = await request.body() - # Decode body to string - body_str = body.decode("utf-8") - # Parse string to JSON - data = json.loads(body_str) if body_str else {} - - try: - user = get_current_user( - request, - get_http_authorization_cred(request.headers["Authorization"]), - ) - except KeyError as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={"detail": "Not authenticated"}, - ) - - try: - data = filter_pipeline(data, user) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - modified_body_bytes = json.dumps(data).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], - ] - - response = await call_next(request) - return response - - async def _receive(self, body: bytes): - return {"type": "http.request", "body": body, "more_body": False} - - -app.add_middleware(PipelineMiddleware) - - -from urllib.parse import urlencode, parse_qs, urlparse - - -class RedirectMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - # Check if the request is a GET request - if request.method == "GET": - path = request.url.path - query_params = dict(parse_qs(urlparse(str(request.url)).query)) - - # Check for the specific watch path and the presence of 'v' parameter - if path.endswith("/watch") and "v" in query_params: - video_id = query_params["v"][0] # Extract the first 'v' parameter - encoded_video_id = urlencode({"youtube": video_id}) - redirect_url = f"/?{encoded_video_id}" - return RedirectResponse(url=redirect_url) - - # Proceed with the normal flow of other requests - response = await call_next(request) - return response - - -# Add the middleware to the app -app.add_middleware(RedirectMiddleware) - - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -app.add_middleware(SecurityHeadersMiddleware) - - -@app.middleware("http") -async def commit_session_after_request(request: Request, call_next): - response = await call_next(request) - log.debug("Commit session after request") - Session.commit() - return response - - -@app.middleware("http") -async def check_url(request: Request, call_next): - if len(app.state.MODELS) == 0: - await get_all_models() - else: - pass - - start_time = int(time.time()) - response = await call_next(request) - process_time = int(time.time()) - start_time - response.headers["X-Process-Time"] = str(process_time) - - return response - - -@app.middleware("http") -async def update_embedding_function(request: Request, call_next): - response = await call_next(request) - if "/embedding/update" in request.url.path: - webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION - return response - - -@app.middleware("http") -async def inspect_websocket(request: Request, call_next): - if ( - "/ws/socket.io" in request.url.path - and request.query_params.get("transport") == "websocket" - ): - upgrade = (request.headers.get("Upgrade") or "").lower() - connection = (request.headers.get("Connection") or "").lower().split(",") - # Check that there's the correct headers for an upgrade, else reject the connection - # This is to work around this upstream issue: https://github.com/miguelgrinberg/python-engineio/issues/367 - if upgrade != "websocket" or "upgrade" not in connection: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": "Invalid WebSocket upgrade request"}, - ) - return await call_next(request) - - -app.mount("/ws", socket_app) -app.mount("/ollama", ollama_app) -app.mount("/openai", openai_app) - -app.mount("/images/api/v1", images_app) -app.mount("/audio/api/v1", audio_app) -app.mount("/retrieval/api/v1", retrieval_app) - -app.mount("/api/v1", webui_app) - - -webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION - - -async def get_all_models(): - # TODO: Optimize this function - open_webui_models = [] - openai_models = [] - ollama_models = [] - - if app.state.config.ENABLE_OPENAI_API: - openai_models = await get_openai_models() - openai_models = openai_models["data"] - - if app.state.config.ENABLE_OLLAMA_API: - ollama_models = await get_ollama_models() - ollama_models = [ - { - "id": model["model"], - "name": model["name"], - "object": "model", - "created": int(time.time()), - "owned_by": "ollama", - "ollama": model, - } - for model in ollama_models["models"] - ] - - open_webui_models = await get_open_webui_models() - - models = open_webui_models + openai_models + ollama_models - - # If there are no models, return an empty list - if len([model for model in models if model["owned_by"] != "arena"]) == 0: - return [] - - global_action_ids = [ - function.id for function in Functions.get_global_action_functions() - ] - enabled_action_ids = [ - function.id - for function in Functions.get_functions_by_type("action", active_only=True) - ] - - custom_models = Models.get_all_models() - for custom_model in custom_models: - if custom_model.base_model_id is None: - for model in models: - if ( - custom_model.id == model["id"] - or custom_model.id == model["id"].split(":")[0] - ): - model["name"] = custom_model.name - model["info"] = custom_model.model_dump() - - action_ids = [] - if "info" in model and "meta" in model["info"]: - action_ids.extend(model["info"]["meta"].get("actionIds", [])) - - model["action_ids"] = action_ids - else: - owned_by = "openai" - pipe = None - action_ids = [] - - for model in models: - if ( - custom_model.base_model_id == model["id"] - or custom_model.base_model_id == model["id"].split(":")[0] - ): - owned_by = model["owned_by"] - if "pipe" in model: - pipe = model["pipe"] - break - - if custom_model.meta: - meta = custom_model.meta.model_dump() - if "actionIds" in meta: - action_ids.extend(meta["actionIds"]) - - models.append( - { - "id": custom_model.id, - "name": custom_model.name, - "object": "model", - "created": custom_model.created_at, - "owned_by": owned_by, - "info": custom_model.model_dump(), - "preset": True, - **({"pipe": pipe} if pipe is not None else {}), - "action_ids": action_ids, - } - ) - - for model in models: - action_ids = [] - if "action_ids" in model: - action_ids = model["action_ids"] - del model["action_ids"] - - action_ids = action_ids + global_action_ids - action_ids = list(set(action_ids)) - action_ids = [ - action_id for action_id in action_ids if action_id in enabled_action_ids - ] - - model["actions"] = [] - for action_id in action_ids: - action = Functions.get_function_by_id(action_id) - if action is None: - raise Exception(f"Action not found: {action_id}") - - if action_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[action_id] - else: - function_module, _, _ = load_function_module_by_id(action_id) - webui_app.state.FUNCTIONS[action_id] = function_module - - __webui__ = False - if hasattr(function_module, "__webui__"): - __webui__ = function_module.__webui__ - - if hasattr(function_module, "actions"): - actions = function_module.actions - model["actions"].extend( - [ - { - "id": f"{action_id}.{_action['id']}", - "name": _action.get( - "name", f"{action.name} ({_action['id']})" - ), - "description": action.meta.description, - "icon_url": _action.get( - "icon_url", action.meta.manifest.get("icon_url", None) - ), - **({"__webui__": __webui__} if __webui__ else {}), - } - for _action in actions - ] - ) - else: - model["actions"].append( - { - "id": action_id, - "name": action.name, - "description": action.meta.description, - "icon_url": action.meta.manifest.get("icon_url", None), - **({"__webui__": __webui__} if __webui__ else {}), - } - ) - - app.state.MODELS = {model["id"]: model for model in models} - webui_app.state.MODELS = app.state.MODELS - - return models - - -@app.get("/api/models") -async def get_models(user=Depends(get_verified_user)): - models = await get_all_models() - - # Filter out filter pipelines - models = [ - model - for model in models - if "pipeline" not in model or model["pipeline"].get("type", None) != "filter" - ] - - if app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user": - models = list( - filter( - lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST, - models, - ) - ) - return {"data": models} - - return {"data": models} - - -@app.post("/api/chat/completions") -async def generate_chat_completions( - form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False -): - model_id = form_data["model"] - - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Model not found", - ) - - model = app.state.MODELS[model_id] - - if model["owned_by"] == "arena": - model_ids = model.get("info", {}).get("meta", {}).get("model_ids") - filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") - if model_ids and filter_mode == "exclude": - model_ids = [ - model["id"] - for model in await get_all_models() - if model.get("owned_by") != "arena" - and not model.get("info", {}).get("meta", {}).get("hidden", False) - and model["id"] not in model_ids - ] - - selected_model_id = None - if isinstance(model_ids, list) and model_ids: - selected_model_id = random.choice(model_ids) - else: - model_ids = [ - model["id"] - for model in await get_all_models() - if model.get("owned_by") != "arena" - and not model.get("info", {}).get("meta", {}).get("hidden", False) - ] - selected_model_id = random.choice(model_ids) - - form_data["model"] = selected_model_id - - if form_data.get("stream") == True: - - async def stream_wrapper(stream): - yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" - async for chunk in stream: - yield chunk - - response = await generate_chat_completions( - form_data, user, bypass_filter=True - ) - return StreamingResponse( - stream_wrapper(response.body_iterator), media_type="text/event-stream" - ) - else: - return { - **( - await generate_chat_completions(form_data, user, bypass_filter=True) - ), - "selected_model_id": selected_model_id, - } - if model.get("pipe"): - return await generate_function_chat_completion(form_data, user=user) - if model["owned_by"] == "ollama": - # Using /ollama/api/chat endpoint - form_data = convert_payload_openai_to_ollama(form_data) - form_data = GenerateChatCompletionForm(**form_data) - response = await generate_ollama_chat_completion( - form_data=form_data, user=user, bypass_filter=True - ) - if form_data.stream: - response.headers["content-type"] = "text/event-stream" - return StreamingResponse( - convert_streaming_response_ollama_to_openai(response), - headers=dict(response.headers), - ) - else: - return convert_response_ollama_to_openai(response) - else: - return await generate_openai_chat_completion(form_data, user=user) - - -@app.post("/api/chat/completed") -async def chat_completed(form_data: dict, user=Depends(get_verified_user)): - data = form_data - model_id = data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - model = app.state.MODELS[model_id] - - sorted_filters = get_sorted_filters(model_id) - if "pipeline" in model: - sorted_filters = [model] + sorted_filters - - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/outlet", - headers=headers, - json={ - "user": { - "id": user.id, - "name": user.name, - "email": user.email, - "role": user.role, - }, - "body": data, - }, - ) - - r.raise_for_status() - data = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - try: - res = r.json() - if "detail" in res: - return JSONResponse( - status_code=r.status_code, - content=res, - ) - except Exception: - pass - - else: - pass - - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel to include vavles - return (function.valves if function.valves else {}).get("priority", 0) - return 0 - - filter_ids = [function.id for function in Functions.get_global_filter_functions()] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - - # Sort filter_ids by priority, using the get_priority function - filter_ids.sort(key=get_priority) - - for filter_id in filter_ids: - filter = Functions.get_function_by_id(filter_id) - if not filter: - continue - - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, _, _ = load_function_module_by_id(filter_id) - webui_app.state.FUNCTIONS[filter_id] = function_module - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) - - if not hasattr(function_module, "outlet"): - continue - try: - outlet = function_module.outlet - - # Get the signature of the function - sig = inspect.signature(outlet) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": filter_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(outlet): - data = await outlet(**params) - else: - data = outlet(**params) - - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - return data - - -@app.post("/api/chat/actions/{action_id}") -async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)): - if "." in action_id: - action_id, sub_action_id = action_id.split(".") - else: - sub_action_id = None - - action = Functions.get_function_by_id(action_id) - if not action: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Action not found", - ) - - data = form_data - model_id = data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - model = app.state.MODELS[model_id] - - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - if action_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[action_id] - else: - function_module, _, _ = load_function_module_by_id(action_id) - webui_app.state.FUNCTIONS[action_id] = function_module - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(action_id) - function_module.valves = function_module.Valves(**(valves if valves else {})) - - if hasattr(function_module, "action"): - try: - action = function_module.action - - # Get the signature of the function - sig = inspect.signature(action) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": sub_action_id if sub_action_id is not None else action_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - action_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(action): - data = await action(**params) - else: - data = action(**params) - - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - return data - - -################################## -# -# Task Endpoints -# -################################## - - -# TODO: Refactor task API endpoints below into a separate file - - -@app.get("/api/task/config") -async def get_task_config(user=Depends(get_verified_user)): - return { - "TASK_MODEL": app.state.config.TASK_MODEL, - "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, - "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, - "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, - "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY, - "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, - "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - } - - -class TaskConfigForm(BaseModel): - TASK_MODEL: Optional[str] - TASK_MODEL_EXTERNAL: Optional[str] - TITLE_GENERATION_PROMPT_TEMPLATE: str - TAGS_GENERATION_PROMPT_TEMPLATE: str - SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str - ENABLE_SEARCH_QUERY: bool - TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str - - -@app.post("/api/task/config/update") -async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)): - app.state.config.TASK_MODEL = form_data.TASK_MODEL - app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL - app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( - form_data.TITLE_GENERATION_PROMPT_TEMPLATE - ) - app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = ( - form_data.TAGS_GENERATION_PROMPT_TEMPLATE - ) - - app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( - form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE - ) - app.state.config.ENABLE_SEARCH_QUERY = form_data.ENABLE_SEARCH_QUERY - app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( - form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE - ) - - return { - "TASK_MODEL": app.state.config.TASK_MODEL, - "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, - "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, - "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, - "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, - "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY, - "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - } - - -@app.post("/api/task/title/completions") -async def generate_title(form_data: dict, user=Depends(get_verified_user)): - print("generate_title") - - model_id = form_data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) - print(task_model_id) - - model = app.state.MODELS[task_model_id] - - if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": - template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE - else: - template = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. - -Examples of titles: -📉 Stock Market Trends -🍪 Perfect Chocolate Chip Recipe -Evolution of Music Streaming -Remote Work Productivity Tips -Artificial Intelligence in Healthcare -🎮 Video Game Development Insights - - -{{MESSAGES:END:2}} -""" - - content = title_generation_template( - template, - form_data["messages"], - { - "name": user.name, - "location": user.info.get("location") if user.info else None, - }, - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - **( - {"max_tokens": 50} - if app.state.MODELS[task_model_id]["owned_by"] == "ollama" - else { - "max_completion_tokens": 50, - } - ), - "chat_id": form_data.get("chat_id", None), - "metadata": {"task": str(TASKS.TITLE_GENERATION), "task_body": form_data}, - } - log.debug(payload) - - # Handle pipeline filters - try: - payload = filter_pipeline(payload, user) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/tags/completions") -async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)): - print("generate_chat_tags") - model_id = form_data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) - print(task_model_id) - - if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": - template = app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE - else: - template = """### Task: -Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags. - -### Guidelines: -- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education) -- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation -- If content is too short (less than 3 messages) or too diverse, use only ["General"] -- Use the chat's primary language; default to English if multilingual -- Prioritize accuracy over specificity - -### Output: -JSON format: { "tags": ["tag1", "tag2", "tag3"] } - -### Chat History: - -{{MESSAGES:END:6}} -""" - - content = tags_generation_template( - template, form_data["messages"], {"name": user.name} - ) - - print("content", content) - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - "metadata": {"task": str(TASKS.TAGS_GENERATION), "task_body": form_data}, - } - log.debug(payload) - - # Handle pipeline filters - try: - payload = filter_pipeline(payload, user) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/query/completions") -async def generate_search_query(form_data: dict, user=Depends(get_verified_user)): - print("generate_search_query") - if not app.state.config.ENABLE_SEARCH_QUERY: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Search query generation is disabled", - ) - - model_id = form_data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) - print(task_model_id) - - model = app.state.MODELS[task_model_id] - - if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "": - template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE - else: - template = """Given the user's message and interaction history, decide if a web search is necessary. You must be concise and exclusively provide a search query if one is necessary. Refrain from verbose responses or any additional commentary. Prefer suggesting a search if uncertain to provide comprehensive or updated information. If a search isn't needed at all, respond with an empty string. Default to a search query when in doubt. Today's date is {{CURRENT_DATE}}. - -User Message: -{{prompt:end:4000}} - -Interaction History: -{{MESSAGES:END:6}} - -Search Query:""" - - content = search_query_generation_template( - template, form_data["messages"], {"name": user.name} - ) - - print("content", content) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - **( - {"max_tokens": 30} - if app.state.MODELS[task_model_id]["owned_by"] == "ollama" - else { - "max_completion_tokens": 30, - } - ), - "metadata": {"task": str(TASKS.QUERY_GENERATION), "task_body": form_data}, - } - log.debug(payload) - - # Handle pipeline filters - try: - payload = filter_pipeline(payload, user) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/emoji/completions") -async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): - print("generate_emoji") - - model_id = form_data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) - print(task_model_id) - - model = app.state.MODELS[task_model_id] - - template = ''' -Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). - -Message: """{{prompt}}""" -''' - content = emoji_generation_template( - template, - form_data["prompt"], - { - "name": user.name, - "location": user.info.get("location") if user.info else None, - }, - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - **( - {"max_tokens": 4} - if app.state.MODELS[task_model_id]["owned_by"] == "ollama" - else { - "max_completion_tokens": 4, - } - ), - "chat_id": form_data.get("chat_id", None), - "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data}, - } - log.debug(payload) - - # Handle pipeline filters - try: - payload = filter_pipeline(payload, user) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/moa/completions") -async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)): - print("generate_moa_response") - - model_id = form_data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) - print(task_model_id) - - model = app.state.MODELS[task_model_id] - - template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" - -Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. - -Responses from models: {{responses}}""" - - content = moa_response_generation_template( - template, - form_data["prompt"], - form_data["responses"], - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": form_data.get("stream", False), - "chat_id": form_data.get("chat_id", None), - "metadata": { - "task": str(TASKS.MOA_RESPONSE_GENERATION), - "task_body": form_data, - }, - } - log.debug(payload) - - try: - payload = filter_pipeline(payload, user) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -################################## -# -# Pipelines Endpoints -# -################################## - - -# TODO: Refactor pipelines API endpoints below into a separate file - - -@app.get("/api/pipelines/list") -async def get_pipelines_list(user=Depends(get_admin_user)): - responses = await get_openai_models(raw=True) - - print(responses) - urlIdxs = [ - idx - for idx, response in enumerate(responses) - if response is not None and "pipelines" in response - ] - - return { - "data": [ - { - "url": openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx], - "idx": urlIdx, - } - for urlIdx in urlIdxs - ] - } - - -@app.post("/api/pipelines/upload") -async def upload_pipeline( - urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) -): - print("upload_pipeline", urlIdx, file.filename) - # Check if the uploaded file is a python file - if not (file.filename and file.filename.endswith(".py")): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Only Python (.py) files are allowed.", - ) - - upload_folder = f"{CACHE_DIR}/pipelines" - os.makedirs(upload_folder, exist_ok=True) - file_path = os.path.join(upload_folder, file.filename) - - r = None - try: - # Save the uploaded file - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - - with open(file_path, "rb") as f: - files = {"file": f} - r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - status_code = status.HTTP_404_NOT_FOUND - if r is not None: - status_code = r.status_code - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=status_code, - detail=detail, - ) - finally: - # Ensure the file is deleted after the upload is completed or on failure - if os.path.exists(file_path): - os.remove(file_path) - - -class AddPipelineForm(BaseModel): - url: str - urlIdx: int - - -@app.post("/api/pipelines/add") -async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)): - r = None - try: - urlIdx = form_data.urlIdx - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/pipelines/add", headers=headers, json={"url": form_data.url} - ) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -class DeletePipelineForm(BaseModel): - id: str - urlIdx: int - - -@app.delete("/api/pipelines/delete") -async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)): - r = None - try: - urlIdx = form_data.urlIdx - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.delete( - f"{url}/pipelines/delete", headers=headers, json={"id": form_data.id} - ) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -@app.get("/api/pipelines") -async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)): - r = None - try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.get(f"{url}/pipelines", headers=headers) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -@app.get("/api/pipelines/{pipeline_id}/valves") -async def get_pipeline_valves( - urlIdx: Optional[int], - pipeline_id: str, - user=Depends(get_admin_user), -): - r = None - try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.get(f"{url}/{pipeline_id}/valves", headers=headers) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -@app.get("/api/pipelines/{pipeline_id}/valves/spec") -async def get_pipeline_valves_spec( - urlIdx: Optional[int], - pipeline_id: str, - user=Depends(get_admin_user), -): - r = None - try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.get(f"{url}/{pipeline_id}/valves/spec", headers=headers) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -@app.post("/api/pipelines/{pipeline_id}/valves/update") -async def update_pipeline_valves( - urlIdx: Optional[int], - pipeline_id: str, - form_data: dict, - user=Depends(get_admin_user), -): - r = None - try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{pipeline_id}/valves/update", - headers=headers, - json={**form_data}, - ) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -################################## -# -# Config Endpoints -# -################################## - - -@app.get("/api/config") -async def get_app_config(request: Request): - user = None - if "token" in request.cookies: - token = request.cookies.get("token") - data = decode_token(token) - if data is not None and "id" in data: - user = Users.get_user_by_id(data["id"]) - - return { - "status": True, - "name": WEBUI_NAME, - "version": VERSION, - "default_locale": str(DEFAULT_LOCALE), - "oauth": { - "providers": { - name: config.get("name", name) - for name, config in OAUTH_PROVIDERS.items() - } - }, - "features": { - "auth": WEBUI_AUTH, - "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), - "enable_signup": webui_app.state.config.ENABLE_SIGNUP, - "enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM, - **( - { - "enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH, - "enable_image_generation": images_app.state.config.ENABLED, - "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, - "enable_message_rating": webui_app.state.config.ENABLE_MESSAGE_RATING, - "enable_admin_export": ENABLE_ADMIN_EXPORT, - "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, - } - if user is not None - else {} - ), - }, - **( - { - "default_models": webui_app.state.config.DEFAULT_MODELS, - "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, - "audio": { - "tts": { - "engine": audio_app.state.config.TTS_ENGINE, - "voice": audio_app.state.config.TTS_VOICE, - "split_on": audio_app.state.config.TTS_SPLIT_ON, - }, - "stt": { - "engine": audio_app.state.config.STT_ENGINE, - }, - }, - "file": { - "max_size": retrieval_app.state.config.FILE_MAX_SIZE, - "max_count": retrieval_app.state.config.FILE_MAX_COUNT, - }, - "permissions": {**webui_app.state.config.USER_PERMISSIONS}, - } - if user is not None - else {} - ), - } - - -@app.get("/api/config/model/filter") -async def get_model_filter_config(user=Depends(get_admin_user)): - return { - "enabled": app.state.config.ENABLE_MODEL_FILTER, - "models": app.state.config.MODEL_FILTER_LIST, - } - - -class ModelFilterConfigForm(BaseModel): - enabled: bool - models: list[str] - - -@app.post("/api/config/model/filter") -async def update_model_filter_config( - form_data: ModelFilterConfigForm, user=Depends(get_admin_user) -): - app.state.config.ENABLE_MODEL_FILTER = form_data.enabled - app.state.config.MODEL_FILTER_LIST = form_data.models - - return { - "enabled": app.state.config.ENABLE_MODEL_FILTER, - "models": app.state.config.MODEL_FILTER_LIST, - } - - -# TODO: webhook endpoint should be under config endpoints - - -@app.get("/api/webhook") -async def get_webhook_url(user=Depends(get_admin_user)): - return { - "url": app.state.config.WEBHOOK_URL, - } - - -class UrlForm(BaseModel): - url: str - - -@app.post("/api/webhook") -async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): - app.state.config.WEBHOOK_URL = form_data.url - webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL - return {"url": app.state.config.WEBHOOK_URL} - - -@app.get("/api/version") -async def get_app_version(): - return { - "version": VERSION, - } - - -@app.get("/api/changelog") -async def get_app_changelog(): - return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5} - - -@app.get("/api/version/updates") -async def get_app_latest_release_version(): - if OFFLINE_MODE: - log.debug( - f"Offline mode is enabled, returning current version as latest version" - ) - return {"current": VERSION, "latest": VERSION} - try: - timeout = aiohttp.ClientTimeout(total=1) - async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.get( - "https://api.github.com/repos/open-webui/open-webui/releases/latest" - ) as response: - response.raise_for_status() - data = await response.json() - latest_version = data["tag_name"] - - return {"current": VERSION, "latest": latest_version[1:]} - except Exception as e: - log.debug(e) - return {"current": VERSION, "latest": VERSION} - - -############################ -# OAuth Login & Callback -############################ - -# SessionMiddleware is used by authlib for oauth -if len(OAUTH_PROVIDERS) > 0: - app.add_middleware( - SessionMiddleware, - secret_key=WEBUI_SECRET_KEY, - session_cookie="oui-session", - same_site=WEBUI_SESSION_COOKIE_SAME_SITE, - https_only=WEBUI_SESSION_COOKIE_SECURE, - ) - - -@app.get("/oauth/{provider}/login") -async def oauth_login(provider: str, request: Request): - return await oauth_manager.handle_login(provider, request) - - -# OAuth login logic is as follows: -# 1. Attempt to find a user with matching subject ID, tied to the provider -# 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth -# - This is considered insecure in general, as OAuth providers do not always verify email addresses -# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user -# - Email addresses are considered unique, so we fail registration if the email address is already taken -@app.get("/oauth/{provider}/callback") -async def oauth_callback(provider: str, request: Request, response: Response): - return await oauth_manager.handle_callback(provider, request, response) - - -@app.get("/manifest.json") -async def get_manifest_json(): - return { - "name": WEBUI_NAME, - "short_name": WEBUI_NAME, - "description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.", - "start_url": "/", - "display": "standalone", - "background_color": "#343541", - "orientation": "any", - "icons": [ - { - "src": "/static/logo.png", - "type": "image/png", - "sizes": "500x500", - "purpose": "any", - }, - { - "src": "/static/logo.png", - "type": "image/png", - "sizes": "500x500", - "purpose": "maskable", - }, - ], - } - - -@app.get("/opensearch.xml") -async def get_opensearch_xml(): - xml_content = rf""" - - {WEBUI_NAME} - Search {WEBUI_NAME} - UTF-8 - {WEBUI_URL}/static/favicon.png - - {WEBUI_URL} - - """ - return Response(content=xml_content, media_type="application/xml") - - -@app.get("/health") -async def healthcheck(): - return {"status": True} - - -@app.get("/health/db") -async def healthcheck_with_db(): - Session.execute(text("SELECT 1;")).all() - return {"status": True} - - -app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") -app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache") - - -if os.path.exists(FRONTEND_BUILD_DIR): - mimetypes.add_type("text/javascript", ".js") - app.mount( - "/", - SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True), - name="spa-static-files", - ) -else: - log.warning( - f"Frontend build directory not found at '{FRONTEND_BUILD_DIR}'. Serving API only." - ) +import asyncio +import inspect +import json +import logging +import mimetypes +import os +import shutil +import sys +import time +import random +from contextlib import asynccontextmanager +from typing import Optional + +from aiocache import cached +import aiohttp +import requests +from fastapi import ( + Depends, + FastAPI, + File, + Form, + HTTPException, + Request, + UploadFile, + status, +) +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, RedirectResponse +from fastapi.staticfiles import StaticFiles +from pydantic import BaseModel +from sqlalchemy import text +from starlette.exceptions import HTTPException as StarletteHTTPException +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.sessions import SessionMiddleware +from starlette.responses import Response, StreamingResponse + +from open_webui.apps.audio.main import app as audio_app +from open_webui.apps.images.main import app as images_app +from open_webui.apps.ollama.main import ( + app as ollama_app, + get_all_models as get_ollama_models, + generate_chat_completion as generate_ollama_chat_completion, + GenerateChatCompletionForm, +) +from open_webui.apps.openai.main import ( + app as openai_app, + generate_chat_completion as generate_openai_chat_completion, + get_all_models as get_openai_models, + get_all_models_responses as get_openai_models_responses, +) +from open_webui.apps.retrieval.main import app as retrieval_app +from open_webui.apps.retrieval.utils import get_rag_context, rag_template +from open_webui.apps.socket.main import ( + app as socket_app, + periodic_usage_pool_cleanup, + get_event_call, + get_event_emitter, +) +from open_webui.apps.webui.internal.db import Session +from open_webui.apps.webui.main import ( + app as webui_app, + generate_function_chat_completion, + get_all_models as get_open_webui_models, +) +from open_webui.apps.webui.models.functions import Functions +from open_webui.apps.webui.models.models import Models +from open_webui.apps.webui.models.users import UserModel, Users +from open_webui.apps.webui.utils import load_function_module_by_id +from open_webui.config import ( + CACHE_DIR, + CORS_ALLOW_ORIGIN, + DEFAULT_LOCALE, + ENABLE_ADMIN_CHAT_ACCESS, + ENABLE_ADMIN_EXPORT, + ENABLE_OLLAMA_API, + ENABLE_OPENAI_API, + ENABLE_TAGS_GENERATION, + ENV, + FRONTEND_BUILD_DIR, + OAUTH_PROVIDERS, + STATIC_DIR, + TASK_MODEL, + TASK_MODEL_EXTERNAL, + ENABLE_SEARCH_QUERY_GENERATION, + ENABLE_RETRIEVAL_QUERY_GENERATION, + QUERY_GENERATION_PROMPT_TEMPLATE, + DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, + TITLE_GENERATION_PROMPT_TEMPLATE, + TAGS_GENERATION_PROMPT_TEMPLATE, + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + WEBHOOK_URL, + WEBUI_AUTH, + WEBUI_NAME, + AppConfig, + reset_config, +) +from open_webui.constants import TASKS +from open_webui.env import ( + CHANGELOG, + GLOBAL_LOG_LEVEL, + SAFE_MODE, + SRC_LOG_LEVELS, + VERSION, + WEBUI_BUILD_HASH, + WEBUI_SECRET_KEY, + WEBUI_SESSION_COOKIE_SAME_SITE, + WEBUI_SESSION_COOKIE_SECURE, + WEBUI_URL, + RESET_CONFIG_ON_START, + OFFLINE_MODE, +) +from open_webui.utils.misc import ( + add_or_update_system_message, + get_last_user_message, + prepend_to_first_user_message_content, +) +from open_webui.utils.oauth import oauth_manager +from open_webui.utils.payload import convert_payload_openai_to_ollama +from open_webui.utils.response import ( + convert_response_ollama_to_openai, + convert_streaming_response_ollama_to_openai, +) +from open_webui.utils.security_headers import SecurityHeadersMiddleware +from open_webui.utils.task import ( + moa_response_generation_template, + tags_generation_template, + query_generation_template, + emoji_generation_template, + title_generation_template, + tools_function_calling_generation_template, +) +from open_webui.utils.tools import get_tools +from open_webui.utils.utils import ( + decode_token, + get_admin_user, + get_current_user, + get_http_authorization_cred, + get_verified_user, +) +from open_webui.utils.access_control import has_access + +if SAFE_MODE: + print("SAFE MODE ENABLED") + Functions.deactivate_all_functions() + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +class SPAStaticFiles(StaticFiles): + async def get_response(self, path: str, scope): + try: + return await super().get_response(path, scope) + except (HTTPException, StarletteHTTPException) as ex: + if ex.status_code == 404: + return await super().get_response("index.html", scope) + else: + raise ex + + +print( + rf""" + ___ __ __ _ _ _ ___ + / _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _| +| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || | +| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || | + \___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___| + |_| + + +v{VERSION} - building the best open-source AI user interface. +{f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""} +https://github.com/open-webui/open-webui +""" +) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + if RESET_CONFIG_ON_START: + reset_config() + + asyncio.create_task(periodic_usage_pool_cleanup()) + yield + + +app = FastAPI( + docs_url="/docs" if ENV == "dev" else None, + openapi_url="/openapi.json" if ENV == "dev" else None, + redoc_url=None, + lifespan=lifespan, +) + +app.state.config = AppConfig() + +app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API +app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API + +app.state.config.WEBHOOK_URL = WEBHOOK_URL + +app.state.config.TASK_MODEL = TASK_MODEL +app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL + +app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE + +app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION +app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE + + +app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION +app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION +app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE + +app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE +) + +################################## +# +# ChatCompletion Middleware +# +################################## + + +def get_filter_function_ids(model): + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel + return (function.valves if function.valves else {}).get("priority", 0) + return 0 + + filter_ids = [function.id for function in Functions.get_global_filter_functions()] + if "info" in model and "meta" in model["info"]: + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) + + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] + + filter_ids = [ + filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids + ] + + filter_ids.sort(key=get_priority) + return filter_ids + + +async def chat_completion_filter_functions_handler(body, model, extra_params): + skip_files = None + + filter_ids = get_filter_function_ids(model) + for filter_id in filter_ids: + filter = Functions.get_function_by_id(filter_id) + if not filter: + continue + + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, _, _ = load_function_module_by_id(filter_id) + webui_app.state.FUNCTIONS[filter_id] = function_module + + # Check if the function has a file_handler variable + if hasattr(function_module, "file_handler"): + skip_files = function_module.file_handler + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(filter_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) + + if not hasattr(function_module, "inlet"): + continue + + try: + inlet = function_module.inlet + + # Get the signature of the function + sig = inspect.signature(inlet) + params = {"body": body} | { + k: v + for k, v in { + **extra_params, + "__model__": model, + "__id__": filter_id, + }.items() + if k in sig.parameters + } + + if "__user__" in params and hasattr(function_module, "UserValves"): + try: + params["__user__"]["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, params["__user__"]["id"] + ) + ) + except Exception as e: + print(e) + + if inspect.iscoroutinefunction(inlet): + body = await inlet(**params) + else: + body = inlet(**params) + + except Exception as e: + print(f"Error: {e}") + raise e + + if skip_files and "files" in body.get("metadata", {}): + del body["metadata"]["files"] + + return body, {} + + +def get_tools_function_calling_payload(messages, task_model_id, content): + user_message = get_last_user_message(messages) + history = "\n".join( + f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" + for message in messages[::-1][:4] + ) + + prompt = f"History:\n{history}\nQuery: {user_message}" + + return { + "model": task_model_id, + "messages": [ + {"role": "system", "content": content}, + {"role": "user", "content": f"Query: {prompt}"}, + ], + "stream": False, + "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, + } + + +async def get_content_from_response(response) -> Optional[str]: + content = None + if hasattr(response, "body_iterator"): + async for chunk in response.body_iterator: + data = json.loads(chunk.decode("utf-8")) + content = data["choices"][0]["message"]["content"] + + # Cleanup any remaining background tasks if necessary + if response.background is not None: + await response.background() + else: + content = response["choices"][0]["message"]["content"] + return content + + +def get_task_model_id( + default_model_id: str, task_model: str, task_model_external: str, models +) -> str: + # Set the task model + task_model_id = default_model_id + # Check if the user has a custom task model and use that model + if models[task_model_id]["owned_by"] == "ollama": + if task_model and task_model in models: + task_model_id = task_model + else: + if task_model_external and task_model_external in models: + task_model_id = task_model_external + + return task_model_id + + +async def chat_completion_tools_handler( + body: dict, user: UserModel, models, extra_params: dict +) -> tuple[dict, dict]: + # If tool_ids field is present, call the functions + metadata = body.get("metadata", {}) + + tool_ids = metadata.get("tool_ids", None) + log.debug(f"{tool_ids=}") + if not tool_ids: + return body, {} + + skip_files = False + contexts = [] + citations = [] + + task_model_id = get_task_model_id( + body["model"], + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + tools = get_tools( + webui_app, + tool_ids, + user, + { + **extra_params, + "__model__": models[task_model_id], + "__messages__": body["messages"], + "__files__": metadata.get("files", []), + }, + ) + log.info(f"{tools=}") + + specs = [tool["spec"] for tool in tools.values()] + tools_specs = json.dumps(specs) + + if app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "": + template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + else: + template = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text.""" + + tools_function_calling_prompt = tools_function_calling_generation_template( + template, tools_specs + ) + log.info(f"{tools_function_calling_prompt=}") + payload = get_tools_function_calling_payload( + body["messages"], task_model_id, tools_function_calling_prompt + ) + + try: + payload = filter_pipeline(payload, user, models) + except Exception as e: + raise e + + try: + response = await generate_chat_completions(form_data=payload, user=user) + log.debug(f"{response=}") + content = await get_content_from_response(response) + log.debug(f"{content=}") + + if not content: + return body, {} + + try: + content = content[content.find("{") : content.rfind("}") + 1] + if not content: + raise Exception("No JSON object found in the response") + + result = json.loads(content) + + tool_function_name = result.get("name", None) + if tool_function_name not in tools: + return body, {} + + tool_function_params = result.get("parameters", {}) + + try: + required_params = ( + tools[tool_function_name] + .get("spec", {}) + .get("parameters", {}) + .get("required", []) + ) + tool_function = tools[tool_function_name]["callable"] + tool_function_params = { + k: v + for k, v in tool_function_params.items() + if k in required_params + } + tool_output = await tool_function(**tool_function_params) + + except Exception as e: + tool_output = str(e) + + if tools[tool_function_name]["citation"]: + citations.append( + { + "source": { + "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + }, + "document": [tool_output], + "metadata": [{"source": tool_function_name}], + } + ) + if tools[tool_function_name]["file_handler"]: + skip_files = True + + if isinstance(tool_output, str): + contexts.append(tool_output) + except Exception as e: + log.exception(f"Error: {e}") + content = None + except Exception as e: + log.exception(f"Error: {e}") + content = None + + log.debug(f"tool_contexts: {contexts}") + + if skip_files and "files" in body.get("metadata", {}): + del body["metadata"]["files"] + + return body, {"contexts": contexts, "citations": citations} + + +async def chat_completion_files_handler( + body: dict, user: UserModel +) -> tuple[dict, dict[str, list]]: + contexts = [] + citations = [] + + try: + queries_response = await generate_queries( + { + "model": body["model"], + "messages": body["messages"], + "type": "retrieval", + }, + user, + ) + queries_response = queries_response["choices"][0]["message"]["content"] + + try: + queries_response = json.loads(queries_response) + except Exception as e: + queries_response = {"queries": []} + + queries = queries_response.get("queries", []) + except Exception as e: + queries = [] + + if len(queries) == 0: + queries = [get_last_user_message(body["messages"])] + + print(f"{queries=}") + + if files := body.get("metadata", {}).get("files", None): + contexts, citations = get_rag_context( + files=files, + queries=queries, + embedding_function=retrieval_app.state.EMBEDDING_FUNCTION, + k=retrieval_app.state.config.TOP_K, + reranking_function=retrieval_app.state.sentence_transformer_rf, + r=retrieval_app.state.config.RELEVANCE_THRESHOLD, + hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH, + ) + + log.debug(f"rag_contexts: {contexts}, citations: {citations}") + + return body, {"contexts": contexts, "citations": citations} + + +def is_chat_completion_request(request): + return request.method == "POST" and any( + endpoint in request.url.path + for endpoint in ["/ollama/api/chat", "/chat/completions"] + ) + + +async def get_body_and_model_and_user(request, models): + # Read the original request body + body = await request.body() + body_str = body.decode("utf-8") + body = json.loads(body_str) if body_str else {} + + model_id = body["model"] + if model_id not in models: + raise Exception("Model not found") + model = models[model_id] + + user = get_current_user( + request, + get_http_authorization_cred(request.headers.get("Authorization")), + ) + + return body, model, user + + +class ChatCompletionMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + if not is_chat_completion_request(request): + return await call_next(request) + log.debug(f"request.url.path: {request.url.path}") + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + try: + body, model, user = await get_body_and_model_and_user(request, models) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + model_info = Models.get_model_by_id(model["id"]) + if user.role == "user": + if model.get("arena"): + if not has_access( + user.id, + type="read", + access_control=model.get("info", {}) + .get("meta", {}) + .get("access_control", {}), + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + else: + if not model_info: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={"detail": "Model not found"}, + ) + elif not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"detail": "User does not have access to the model"}, + ) + + metadata = { + "chat_id": body.pop("chat_id", None), + "message_id": body.pop("id", None), + "session_id": body.pop("session_id", None), + "tool_ids": body.get("tool_ids", None), + "files": body.get("files", None), + } + body["metadata"] = metadata + + extra_params = { + "__event_emitter__": get_event_emitter(metadata), + "__event_call__": get_event_call(metadata), + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + "__metadata__": metadata, + } + + # Initialize data_items to store additional data to be sent to the client + # Initialize contexts and citation + data_items = [] + contexts = [] + citations = [] + + try: + body, flags = await chat_completion_filter_functions_handler( + body, model, extra_params + ) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + tool_ids = body.pop("tool_ids", None) + files = body.pop("files", None) + + metadata = { + **metadata, + "tool_ids": tool_ids, + "files": files, + } + body["metadata"] = metadata + + try: + body, flags = await chat_completion_tools_handler( + body, user, models, extra_params + ) + contexts.extend(flags.get("contexts", [])) + citations.extend(flags.get("citations", [])) + except Exception as e: + log.exception(e) + + try: + body, flags = await chat_completion_files_handler(body, user) + contexts.extend(flags.get("contexts", [])) + citations.extend(flags.get("citations", [])) + except Exception as e: + log.exception(e) + + # If context is not empty, insert it into the messages + if len(contexts) > 0: + context_string = "/n".join(contexts).strip() + prompt = get_last_user_message(body["messages"]) + + if prompt is None: + raise Exception("No user message found") + if ( + retrieval_app.state.config.RELEVANCE_THRESHOLD == 0 + and context_string.strip() == "" + ): + log.debug( + f"With a 0 relevancy threshold for RAG, the context cannot be empty" + ) + + # Workaround for Ollama 2.0+ system prompt issue + # TODO: replace with add_or_update_system_message + if model["owned_by"] == "ollama": + body["messages"] = prepend_to_first_user_message_content( + rag_template( + retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt + ), + body["messages"], + ) + else: + body["messages"] = add_or_update_system_message( + rag_template( + retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt + ), + body["messages"], + ) + + # If there are citations, add them to the data_items + if len(citations) > 0: + data_items.append({"citations": citations}) + + modified_body_bytes = json.dumps(body).encode("utf-8") + # Replace the request body with the modified one + request._body = modified_body_bytes + # Set custom header to ensure content-length matches new body length + request.headers.__dict__["_list"] = [ + (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), + *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], + ] + + response = await call_next(request) + if not isinstance(response, StreamingResponse): + return response + + content_type = response.headers["Content-Type"] + is_openai = "text/event-stream" in content_type + is_ollama = "application/x-ndjson" in content_type + if not is_openai and not is_ollama: + return response + + def wrap_item(item): + return f"data: {item}\n\n" if is_openai else f"{item}\n" + + async def stream_wrapper(original_generator, data_items): + for item in data_items: + yield wrap_item(json.dumps(item)) + + async for data in original_generator: + yield data + + return StreamingResponse( + stream_wrapper(response.body_iterator, data_items), + headers=dict(response.headers), + ) + + async def _receive(self, body: bytes): + return {"type": "http.request", "body": body, "more_body": False} + + +app.add_middleware(ChatCompletionMiddleware) + + +################################## +# +# Pipeline Middleware +# +################################## + + +def get_sorted_filters(model_id, models): + filters = [ + model + for model in models.values() + if "pipeline" in model + and "type" in model["pipeline"] + and model["pipeline"]["type"] == "filter" + and ( + model["pipeline"]["pipelines"] == ["*"] + or any( + model_id == target_model_id + for target_model_id in model["pipeline"]["pipelines"] + ) + ) + ] + sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) + return sorted_filters + + +def filter_pipeline(payload, user, models): + user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} + model_id = payload["model"] + + sorted_filters = get_sorted_filters(model_id, models) + model = models[model_id] + + if "pipeline" in model: + sorted_filters.append(model) + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + if key == "": + continue + + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/inlet", + headers=headers, + json={ + "user": user, + "body": payload, + }, + ) + + r.raise_for_status() + payload = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + res = r.json() + if "detail" in res: + raise Exception(r.status_code, res["detail"]) + + return payload + + +class PipelineMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + if not is_chat_completion_request(request): + return await call_next(request) + + log.debug(f"request.url.path: {request.url.path}") + + # Read the original request body + body = await request.body() + # Decode body to string + body_str = body.decode("utf-8") + # Parse string to JSON + data = json.loads(body_str) if body_str else {} + + try: + user = get_current_user( + request, + get_http_authorization_cred(request.headers["Authorization"]), + ) + except KeyError as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Not authenticated"}, + ) + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + try: + data = filter_pipeline(data, user, models) + except Exception as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + modified_body_bytes = json.dumps(data).encode("utf-8") + # Replace the request body with the modified one + request._body = modified_body_bytes + # Set custom header to ensure content-length matches new body length + request.headers.__dict__["_list"] = [ + (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), + *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], + ] + + response = await call_next(request) + return response + + async def _receive(self, body: bytes): + return {"type": "http.request", "body": body, "more_body": False} + + +app.add_middleware(PipelineMiddleware) + + +from urllib.parse import urlencode, parse_qs, urlparse + + +class RedirectMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + # Check if the request is a GET request + if request.method == "GET": + path = request.url.path + query_params = dict(parse_qs(urlparse(str(request.url)).query)) + + # Check for the specific watch path and the presence of 'v' parameter + if path.endswith("/watch") and "v" in query_params: + video_id = query_params["v"][0] # Extract the first 'v' parameter + encoded_video_id = urlencode({"youtube": video_id}) + redirect_url = f"/?{encoded_video_id}" + return RedirectResponse(url=redirect_url) + + # Proceed with the normal flow of other requests + response = await call_next(request) + return response + + +# Add the middleware to the app +app.add_middleware(RedirectMiddleware) + + +app.add_middleware( + CORSMiddleware, + allow_origins=CORS_ALLOW_ORIGIN, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.add_middleware(SecurityHeadersMiddleware) + + +@app.middleware("http") +async def commit_session_after_request(request: Request, call_next): + response = await call_next(request) + log.debug("Commit session after request") + Session.commit() + return response + + +@app.middleware("http") +async def check_url(request: Request, call_next): + start_time = int(time.time()) + request.state.enable_api_key = webui_app.state.config.ENABLE_API_KEY + response = await call_next(request) + process_time = int(time.time()) - start_time + response.headers["X-Process-Time"] = str(process_time) + return response + + +@app.middleware("http") +async def update_embedding_function(request: Request, call_next): + response = await call_next(request) + if "/embedding/update" in request.url.path: + webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION + return response + + +@app.middleware("http") +async def inspect_websocket(request: Request, call_next): + if ( + "/ws/socket.io" in request.url.path + and request.query_params.get("transport") == "websocket" + ): + upgrade = (request.headers.get("Upgrade") or "").lower() + connection = (request.headers.get("Connection") or "").lower().split(",") + # Check that there's the correct headers for an upgrade, else reject the connection + # This is to work around this upstream issue: https://github.com/miguelgrinberg/python-engineio/issues/367 + if upgrade != "websocket" or "upgrade" not in connection: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": "Invalid WebSocket upgrade request"}, + ) + return await call_next(request) + + +app.mount("/ws", socket_app) +app.mount("/ollama", ollama_app) +app.mount("/openai", openai_app) + +app.mount("/images/api/v1", images_app) +app.mount("/audio/api/v1", audio_app) +app.mount("/retrieval/api/v1", retrieval_app) + +app.mount("/api/v1", webui_app) + +webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION + + +async def get_all_base_models(): + open_webui_models = [] + openai_models = [] + ollama_models = [] + + if app.state.config.ENABLE_OPENAI_API: + openai_models = await get_openai_models() + openai_models = openai_models["data"] + + if app.state.config.ENABLE_OLLAMA_API: + ollama_models = await get_ollama_models() + ollama_models = [ + { + "id": model["model"], + "name": model["name"], + "object": "model", + "created": int(time.time()), + "owned_by": "ollama", + "ollama": model, + } + for model in ollama_models["models"] + ] + + open_webui_models = await get_open_webui_models() + + models = open_webui_models + openai_models + ollama_models + return models + + +@cached(ttl=1) +async def get_all_models(): + models = await get_all_base_models() + + # If there are no models, return an empty list + if len([model for model in models if not model.get("arena", False)]) == 0: + return [] + + global_action_ids = [ + function.id for function in Functions.get_global_action_functions() + ] + enabled_action_ids = [ + function.id + for function in Functions.get_functions_by_type("action", active_only=True) + ] + + custom_models = Models.get_all_models() + for custom_model in custom_models: + if custom_model.base_model_id is None: + for model in models: + if ( + custom_model.id == model["id"] + or custom_model.id == model["id"].split(":")[0] + ): + if custom_model.is_active: + model["name"] = custom_model.name + model["info"] = custom_model.model_dump() + + action_ids = [] + if "info" in model and "meta" in model["info"]: + action_ids.extend( + model["info"]["meta"].get("actionIds", []) + ) + + model["action_ids"] = action_ids + else: + models.remove(model) + + elif custom_model.is_active and ( + custom_model.id not in [model["id"] for model in models] + ): + owned_by = "openai" + pipe = None + action_ids = [] + + for model in models: + if ( + custom_model.base_model_id == model["id"] + or custom_model.base_model_id == model["id"].split(":")[0] + ): + owned_by = model["owned_by"] + if "pipe" in model: + pipe = model["pipe"] + break + + if custom_model.meta: + meta = custom_model.meta.model_dump() + if "actionIds" in meta: + action_ids.extend(meta["actionIds"]) + + models.append( + { + "id": f"{custom_model.id}", + "name": custom_model.name, + "object": "model", + "created": custom_model.created_at, + "owned_by": owned_by, + "info": custom_model.model_dump(), + "preset": True, + **({"pipe": pipe} if pipe is not None else {}), + "action_ids": action_ids, + } + ) + + # Process action_ids to get the actions + def get_action_items_from_module(function, module): + actions = [] + if hasattr(module, "actions"): + actions = module.actions + return [ + { + "id": f"{function.id}.{action['id']}", + "name": action.get("name", f"{function.name} ({action['id']})"), + "description": function.meta.description, + "icon_url": action.get( + "icon_url", function.meta.manifest.get("icon_url", None) + ), + } + for action in actions + ] + else: + return [ + { + "id": function.id, + "name": function.name, + "description": function.meta.description, + "icon_url": function.meta.manifest.get("icon_url", None), + } + ] + + def get_function_module_by_id(function_id): + if function_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[function_id] + else: + function_module, _, _ = load_function_module_by_id(function_id) + webui_app.state.FUNCTIONS[function_id] = function_module + + for model in models: + action_ids = [ + action_id + for action_id in list(set(model.pop("action_ids", []) + global_action_ids)) + if action_id in enabled_action_ids + ] + + model["actions"] = [] + for action_id in action_ids: + action_function = Functions.get_function_by_id(action_id) + if action_function is None: + raise Exception(f"Action not found: {action_id}") + + function_module = get_function_module_by_id(action_id) + model["actions"].extend( + get_action_items_from_module(action_function, function_module) + ) + return models + + +@app.get("/api/models") +async def get_models(user=Depends(get_verified_user)): + models = await get_all_models() + + # Filter out filter pipelines + models = [ + model + for model in models + if "pipeline" not in model or model["pipeline"].get("type", None) != "filter" + ] + + # Filter out models that the user does not have access to + if user.role == "user": + filtered_models = [] + for model in models: + if model.get("arena"): + if has_access( + user.id, + type="read", + access_control=model.get("info", {}) + .get("meta", {}) + .get("access_control", {}), + ): + filtered_models.append(model) + continue + + model_info = Models.get_model_by_id(model["id"]) + if model_info: + if user.id == model_info.user_id or has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + models = filtered_models + + return {"data": models} + + +@app.get("/api/models/base") +async def get_base_models(user=Depends(get_admin_user)): + models = await get_all_base_models() + + # Filter out arena models + models = [model for model in models if not model.get("arena", False)] + return {"data": models} + + +@app.post("/api/chat/completions") +async def generate_chat_completions( + form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False +): + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + model = models[model_id] + + # Check if user has access to the model + if not bypass_filter and user.role == "user": + if model.get("arena"): + if not has_access( + user.id, + type="read", + access_control=model.get("info", {}) + .get("meta", {}) + .get("access_control", {}), + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + else: + model_info = Models.get_model_by_id(model_id) + if not model_info: + raise HTTPException( + status_code=404, + detail="Model not found", + ) + elif not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + + if model["owned_by"] == "arena": + model_ids = model.get("info", {}).get("meta", {}).get("model_ids") + filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") + if model_ids and filter_mode == "exclude": + model_ids = [ + model["id"] + for model in await get_all_models() + if model.get("owned_by") != "arena" and model["id"] not in model_ids + ] + + selected_model_id = None + if isinstance(model_ids, list) and model_ids: + selected_model_id = random.choice(model_ids) + else: + model_ids = [ + model["id"] + for model in await get_all_models() + if model.get("owned_by") != "arena" + ] + selected_model_id = random.choice(model_ids) + + form_data["model"] = selected_model_id + + if form_data.get("stream") == True: + + async def stream_wrapper(stream): + yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" + async for chunk in stream: + yield chunk + + response = await generate_chat_completions( + form_data, user, bypass_filter=True + ) + return StreamingResponse( + stream_wrapper(response.body_iterator), media_type="text/event-stream" + ) + else: + return { + **( + await generate_chat_completions(form_data, user, bypass_filter=True) + ), + "selected_model_id": selected_model_id, + } + + if model.get("pipe"): + # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter + return await generate_function_chat_completion( + form_data, user=user, models=models + ) + if model["owned_by"] == "ollama": + # Using /ollama/api/chat endpoint + form_data = convert_payload_openai_to_ollama(form_data) + form_data = GenerateChatCompletionForm(**form_data) + response = await generate_ollama_chat_completion( + form_data=form_data, user=user, bypass_filter=bypass_filter + ) + if form_data.stream: + response.headers["content-type"] = "text/event-stream" + return StreamingResponse( + convert_streaming_response_ollama_to_openai(response), + headers=dict(response.headers), + ) + else: + return convert_response_ollama_to_openai(response) + else: + return await generate_openai_chat_completion( + form_data, user=user, bypass_filter=bypass_filter + ) + + +@app.post("/api/chat/completed") +async def chat_completed(form_data: dict, user=Depends(get_verified_user)): + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + data = form_data + model_id = data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + model = models[model_id] + sorted_filters = get_sorted_filters(model_id, models) + if "pipeline" in model: + sorted_filters = [model] + sorted_filters + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + if key != "": + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/outlet", + headers=headers, + json={ + "user": { + "id": user.id, + "name": user.name, + "email": user.email, + "role": user.role, + }, + "body": data, + }, + ) + + r.raise_for_status() + data = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + try: + res = r.json() + if "detail" in res: + return JSONResponse( + status_code=r.status_code, + content=res, + ) + except Exception: + pass + + else: + pass + + __event_emitter__ = get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + __event_call__ = get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel to include vavles + return (function.valves if function.valves else {}).get("priority", 0) + return 0 + + filter_ids = [function.id for function in Functions.get_global_filter_functions()] + if "info" in model and "meta" in model["info"]: + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) + + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] + filter_ids = [ + filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids + ] + + # Sort filter_ids by priority, using the get_priority function + filter_ids.sort(key=get_priority) + + for filter_id in filter_ids: + filter = Functions.get_function_by_id(filter_id) + if not filter: + continue + + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, _, _ = load_function_module_by_id(filter_id) + webui_app.state.FUNCTIONS[filter_id] = function_module + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(filter_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) + + if not hasattr(function_module, "outlet"): + continue + try: + outlet = function_module.outlet + + # Get the signature of the function + sig = inspect.signature(outlet) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": filter_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, user.id + ) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(outlet): + data = await outlet(**params) + else: + data = outlet(**params) + + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + return data + + +@app.post("/api/chat/actions/{action_id}") +async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)): + if "." in action_id: + action_id, sub_action_id = action_id.split(".") + else: + sub_action_id = None + + action = Functions.get_function_by_id(action_id) + if not action: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Action not found", + ) + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + data = form_data + model_id = data["model"] + + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + model = models[model_id] + + __event_emitter__ = get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + __event_call__ = get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + if action_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[action_id] + else: + function_module, _, _ = load_function_module_by_id(action_id) + webui_app.state.FUNCTIONS[action_id] = function_module + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(action_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + + if hasattr(function_module, "action"): + try: + action = function_module.action + + # Get the signature of the function + sig = inspect.signature(action) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": sub_action_id if sub_action_id is not None else action_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + action_id, user.id + ) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(action): + data = await action(**params) + else: + data = action(**params) + + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + return data + + +################################## +# +# Task Endpoints +# +################################## + + +# TODO: Refactor task API endpoints below into a separate file + + +@app.get("/api/task/config") +async def get_task_config(user=Depends(get_verified_user)): + return { + "TASK_MODEL": app.state.config.TASK_MODEL, + "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, + "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, + "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION, + "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION, + "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, + "QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + } + + +class TaskConfigForm(BaseModel): + TASK_MODEL: Optional[str] + TASK_MODEL_EXTERNAL: Optional[str] + TITLE_GENERATION_PROMPT_TEMPLATE: str + TAGS_GENERATION_PROMPT_TEMPLATE: str + ENABLE_TAGS_GENERATION: bool + ENABLE_SEARCH_QUERY_GENERATION: bool + ENABLE_RETRIEVAL_QUERY_GENERATION: bool + QUERY_GENERATION_PROMPT_TEMPLATE: str + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str + + +@app.post("/api/task/config/update") +async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)): + app.state.config.TASK_MODEL = form_data.TASK_MODEL + app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL + app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( + form_data.TITLE_GENERATION_PROMPT_TEMPLATE + ) + app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = ( + form_data.TAGS_GENERATION_PROMPT_TEMPLATE + ) + app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION + app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ( + form_data.ENABLE_SEARCH_QUERY_GENERATION + ) + app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ( + form_data.ENABLE_RETRIEVAL_QUERY_GENERATION + ) + + app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = ( + form_data.QUERY_GENERATION_PROMPT_TEMPLATE + ) + app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( + form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + ) + + return { + "TASK_MODEL": app.state.config.TASK_MODEL, + "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, + "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, + "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION, + "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION, + "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, + "QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + } + + +@app.post("/api/task/title/completions") +async def generate_title(form_data: dict, user=Depends(get_verified_user)): + print("generate_title") + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + print(task_model_id) + + model = models[task_model_id] + + if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": + template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE + else: + template = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. + +Examples of titles: +📉 Stock Market Trends +🍪 Perfect Chocolate Chip Recipe +Evolution of Music Streaming +Remote Work Productivity Tips +Artificial Intelligence in Healthcare +🎮 Video Game Development Insights + + +{{MESSAGES:END:2}} +""" + + content = title_generation_template( + template, + form_data["messages"], + { + "name": user.name, + "location": user.info.get("location") if user.info else None, + }, + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + **( + {"max_tokens": 50} + if models[task_model_id]["owned_by"] == "ollama" + else { + "max_completion_tokens": 50, + } + ), + "chat_id": form_data.get("chat_id", None), + "metadata": {"task": str(TASKS.TITLE_GENERATION), "task_body": form_data}, + } + log.debug(payload) + + # Handle pipeline filters + try: + payload = filter_pipeline(payload, user, models) + except Exception as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + + +@app.post("/api/task/tags/completions") +async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)): + print("generate_chat_tags") + if not app.state.config.ENABLE_TAGS_GENERATION: + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"detail": "Tags generation is disabled"}, + ) + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + print(task_model_id) + + if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": + template = app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE + else: + template = """### Task: +Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags. + +### Guidelines: +- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education) +- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation +- If content is too short (less than 3 messages) or too diverse, use only ["General"] +- Use the chat's primary language; default to English if multilingual +- Prioritize accuracy over specificity + +### Output: +JSON format: { "tags": ["tag1", "tag2", "tag3"] } + +### Chat History: + +{{MESSAGES:END:6}} +""" + + content = tags_generation_template( + template, form_data["messages"], {"name": user.name} + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": {"task": str(TASKS.TAGS_GENERATION), "task_body": form_data}, + } + log.debug(payload) + + # Handle pipeline filters + try: + payload = filter_pipeline(payload, user, models) + except Exception as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + + +@app.post("/api/task/queries/completions") +async def generate_queries(form_data: dict, user=Depends(get_verified_user)): + print("generate_queries") + type = form_data.get("type") + if type == "web_search": + if not app.state.config.ENABLE_SEARCH_QUERY_GENERATION: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Search query generation is disabled", + ) + elif type == "retrieval": + if not app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Query generation is disabled", + ) + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + print(task_model_id) + + model = models[task_model_id] + + if app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE != "": + template = app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE + + content = query_generation_template( + template, form_data["messages"], {"name": user.name} + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": {"task": str(TASKS.QUERY_GENERATION), "task_body": form_data}, + } + log.debug(payload) + + # Handle pipeline filters + try: + payload = filter_pipeline(payload, user, models) + except Exception as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + + +@app.post("/api/task/emoji/completions") +async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): + print("generate_emoji") + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + print(task_model_id) + + model = models[task_model_id] + + template = ''' +Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). + +Message: """{{prompt}}""" +''' + content = emoji_generation_template( + template, + form_data["prompt"], + { + "name": user.name, + "location": user.info.get("location") if user.info else None, + }, + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + **( + {"max_tokens": 4} + if models[task_model_id]["owned_by"] == "ollama" + else { + "max_completion_tokens": 4, + } + ), + "chat_id": form_data.get("chat_id", None), + "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data}, + } + log.debug(payload) + + # Handle pipeline filters + try: + payload = filter_pipeline(payload, user, models) + except Exception as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + + +@app.post("/api/task/moa/completions") +async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)): + print("generate_moa_response") + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + print(task_model_id) + + model = models[task_model_id] + + template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" + +Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. + +Responses from models: {{responses}}""" + + content = moa_response_generation_template( + template, + form_data["prompt"], + form_data["responses"], + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": form_data.get("stream", False), + "chat_id": form_data.get("chat_id", None), + "metadata": { + "task": str(TASKS.MOA_RESPONSE_GENERATION), + "task_body": form_data, + }, + } + log.debug(payload) + + try: + payload = filter_pipeline(payload, user, models) + except Exception as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + + +################################## +# +# Pipelines Endpoints +# +################################## + + +# TODO: Refactor pipelines API endpoints below into a separate file + + +@app.get("/api/pipelines/list") +async def get_pipelines_list(user=Depends(get_admin_user)): + responses = await get_openai_models_responses() + + print(responses) + urlIdxs = [ + idx + for idx, response in enumerate(responses) + if response is not None and "pipelines" in response + ] + + return { + "data": [ + { + "url": openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx], + "idx": urlIdx, + } + for urlIdx in urlIdxs + ] + } + + +@app.post("/api/pipelines/upload") +async def upload_pipeline( + urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) +): + print("upload_pipeline", urlIdx, file.filename) + # Check if the uploaded file is a python file + if not (file.filename and file.filename.endswith(".py")): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Only Python (.py) files are allowed.", + ) + + upload_folder = f"{CACHE_DIR}/pipelines" + os.makedirs(upload_folder, exist_ok=True) + file_path = os.path.join(upload_folder, file.filename) + + r = None + try: + # Save the uploaded file + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + headers = {"Authorization": f"Bearer {key}"} + + with open(file_path, "rb") as f: + files = {"file": f} + r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = "Pipeline not found" + status_code = status.HTTP_404_NOT_FOUND + if r is not None: + status_code = r.status_code + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=status_code, + detail=detail, + ) + finally: + # Ensure the file is deleted after the upload is completed or on failure + if os.path.exists(file_path): + os.remove(file_path) + + +class AddPipelineForm(BaseModel): + url: str + urlIdx: int + + +@app.post("/api/pipelines/add") +async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)): + r = None + try: + urlIdx = form_data.urlIdx + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/pipelines/add", headers=headers, json={"url": form_data.url} + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = "Pipeline not found" + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail, + ) + + +class DeletePipelineForm(BaseModel): + id: str + urlIdx: int + + +@app.delete("/api/pipelines/delete") +async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)): + r = None + try: + urlIdx = form_data.urlIdx + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + headers = {"Authorization": f"Bearer {key}"} + r = requests.delete( + f"{url}/pipelines/delete", headers=headers, json={"id": form_data.id} + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = "Pipeline not found" + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail, + ) + + +@app.get("/api/pipelines") +async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)): + r = None + try: + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + headers = {"Authorization": f"Bearer {key}"} + r = requests.get(f"{url}/pipelines", headers=headers) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = "Pipeline not found" + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail, + ) + + +@app.get("/api/pipelines/{pipeline_id}/valves") +async def get_pipeline_valves( + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), +): + r = None + try: + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + headers = {"Authorization": f"Bearer {key}"} + r = requests.get(f"{url}/{pipeline_id}/valves", headers=headers) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = "Pipeline not found" + + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail, + ) + + +@app.get("/api/pipelines/{pipeline_id}/valves/spec") +async def get_pipeline_valves_spec( + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), +): + r = None + try: + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + headers = {"Authorization": f"Bearer {key}"} + r = requests.get(f"{url}/{pipeline_id}/valves/spec", headers=headers) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = "Pipeline not found" + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail, + ) + + +@app.post("/api/pipelines/{pipeline_id}/valves/update") +async def update_pipeline_valves( + urlIdx: Optional[int], + pipeline_id: str, + form_data: dict, + user=Depends(get_admin_user), +): + r = None + try: + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{pipeline_id}/valves/update", + headers=headers, + json={**form_data}, + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = "Pipeline not found" + + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail, + ) + + +################################## +# +# Config Endpoints +# +################################## + + +@app.get("/api/config") +async def get_app_config(request: Request): + user = None + if "token" in request.cookies: + token = request.cookies.get("token") + try: + data = decode_token(token) + except Exception as e: + log.debug(e) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token", + ) + if data is not None and "id" in data: + user = Users.get_user_by_id(data["id"]) + + onboarding = False + if user is None: + user_count = Users.get_num_users() + onboarding = user_count == 0 + + return { + **({"onboarding": True} if onboarding else {}), + "status": True, + "name": WEBUI_NAME, + "version": VERSION, + "default_locale": str(DEFAULT_LOCALE), + "oauth": { + "providers": { + name: config.get("name", name) + for name, config in OAUTH_PROVIDERS.items() + } + }, + "features": { + "auth": WEBUI_AUTH, + "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), + "enable_ldap": webui_app.state.config.ENABLE_LDAP, + "enable_api_key": webui_app.state.config.ENABLE_API_KEY, + "enable_signup": webui_app.state.config.ENABLE_SIGNUP, + "enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM, + **( + { + "enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH, + "enable_image_generation": images_app.state.config.ENABLED, + "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, + "enable_message_rating": webui_app.state.config.ENABLE_MESSAGE_RATING, + "enable_admin_export": ENABLE_ADMIN_EXPORT, + "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, + } + if user is not None + else {} + ), + }, + **( + { + "default_models": webui_app.state.config.DEFAULT_MODELS, + "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, + "audio": { + "tts": { + "engine": audio_app.state.config.TTS_ENGINE, + "voice": audio_app.state.config.TTS_VOICE, + "split_on": audio_app.state.config.TTS_SPLIT_ON, + }, + "stt": { + "engine": audio_app.state.config.STT_ENGINE, + }, + }, + "file": { + "max_size": retrieval_app.state.config.FILE_MAX_SIZE, + "max_count": retrieval_app.state.config.FILE_MAX_COUNT, + }, + "permissions": {**webui_app.state.config.USER_PERMISSIONS}, + } + if user is not None + else {} + ), + } + + +# TODO: webhook endpoint should be under config endpoints + + +@app.get("/api/webhook") +async def get_webhook_url(user=Depends(get_admin_user)): + return { + "url": app.state.config.WEBHOOK_URL, + } + + +class UrlForm(BaseModel): + url: str + + +@app.post("/api/webhook") +async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): + app.state.config.WEBHOOK_URL = form_data.url + webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL + return {"url": app.state.config.WEBHOOK_URL} + + +@app.get("/api/version") +async def get_app_version(): + return { + "version": VERSION, + } + + +@app.get("/api/changelog") +async def get_app_changelog(): + return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5} + + +@app.get("/api/version/updates") +async def get_app_latest_release_version(): + if OFFLINE_MODE: + log.debug( + f"Offline mode is enabled, returning current version as latest version" + ) + return {"current": VERSION, "latest": VERSION} + try: + timeout = aiohttp.ClientTimeout(total=1) + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: + async with session.get( + "https://api.github.com/repos/open-webui/open-webui/releases/latest" + ) as response: + response.raise_for_status() + data = await response.json() + latest_version = data["tag_name"] + + return {"current": VERSION, "latest": latest_version[1:]} + except Exception as e: + log.debug(e) + return {"current": VERSION, "latest": VERSION} + + +############################ +# OAuth Login & Callback +############################ + +# SessionMiddleware is used by authlib for oauth +if len(OAUTH_PROVIDERS) > 0: + app.add_middleware( + SessionMiddleware, + secret_key=WEBUI_SECRET_KEY, + session_cookie="oui-session", + same_site=WEBUI_SESSION_COOKIE_SAME_SITE, + https_only=WEBUI_SESSION_COOKIE_SECURE, + ) + + +@app.get("/oauth/{provider}/login") +async def oauth_login(provider: str, request: Request): + return await oauth_manager.handle_login(provider, request) + + +# OAuth login logic is as follows: +# 1. Attempt to find a user with matching subject ID, tied to the provider +# 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth +# - This is considered insecure in general, as OAuth providers do not always verify email addresses +# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user +# - Email addresses are considered unique, so we fail registration if the email address is already taken +@app.get("/oauth/{provider}/callback") +async def oauth_callback(provider: str, request: Request, response: Response): + return await oauth_manager.handle_callback(provider, request, response) + + +@app.get("/manifest.json") +async def get_manifest_json(): + return { + "name": WEBUI_NAME, + "short_name": WEBUI_NAME, + "description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.", + "start_url": "/", + "display": "standalone", + "background_color": "#343541", + "orientation": "natural", + "icons": [ + { + "src": "/static/logo.png", + "type": "image/png", + "sizes": "500x500", + "purpose": "any", + }, + { + "src": "/static/logo.png", + "type": "image/png", + "sizes": "500x500", + "purpose": "maskable", + }, + ], + } + + +@app.get("/opensearch.xml") +async def get_opensearch_xml(): + xml_content = rf""" + + {WEBUI_NAME} + Search {WEBUI_NAME} + UTF-8 + {WEBUI_URL}/static/favicon.png + + {WEBUI_URL} + + """ + return Response(content=xml_content, media_type="application/xml") + + +@app.get("/health") +async def healthcheck(): + return {"status": True} + + +@app.get("/health/db") +async def healthcheck_with_db(): + Session.execute(text("SELECT 1;")).all() + return {"status": True} + + +app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") +app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache") + + +if os.path.exists(FRONTEND_BUILD_DIR): + mimetypes.add_type("text/javascript", ".js") + app.mount( + "/", + SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True), + name="spa-static-files", + ) +else: + log.warning( + f"Frontend build directory not found at '{FRONTEND_BUILD_DIR}'. Serving API only." + ) diff --git a/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py b/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py new file mode 100644 index 0000000000000000000000000000000000000000..e4deb99df1ef05cfdc8fbe8fae3dc5a9530b5e83 --- /dev/null +++ b/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py @@ -0,0 +1,85 @@ +"""Add group table + +Revision ID: 922e7a387820 +Revises: 4ace53fd72c8 +Create Date: 2024-11-14 03:00:00.000000 + +""" + +from alembic import op +import sqlalchemy as sa + +revision = "922e7a387820" +down_revision = "4ace53fd72c8" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "group", + sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True), + sa.Column("user_id", sa.Text(), nullable=True), + sa.Column("name", sa.Text(), nullable=True), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("data", sa.JSON(), nullable=True), + sa.Column("meta", sa.JSON(), nullable=True), + sa.Column("permissions", sa.JSON(), nullable=True), + sa.Column("user_ids", sa.JSON(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + ) + + # Add 'access_control' column to 'model' table + op.add_column( + "model", + sa.Column("access_control", sa.JSON(), nullable=True), + ) + + # Add 'is_active' column to 'model' table + op.add_column( + "model", + sa.Column( + "is_active", + sa.Boolean(), + nullable=False, + server_default=sa.sql.expression.true(), + ), + ) + + # Add 'access_control' column to 'knowledge' table + op.add_column( + "knowledge", + sa.Column("access_control", sa.JSON(), nullable=True), + ) + + # Add 'access_control' column to 'prompt' table + op.add_column( + "prompt", + sa.Column("access_control", sa.JSON(), nullable=True), + ) + + # Add 'access_control' column to 'tools' table + op.add_column( + "tool", + sa.Column("access_control", sa.JSON(), nullable=True), + ) + + +def downgrade(): + op.drop_table("group") + + # Drop 'access_control' column from 'model' table + op.drop_column("model", "access_control") + + # Drop 'is_active' column from 'model' table + op.drop_column("model", "is_active") + + # Drop 'access_control' column from 'knowledge' table + op.drop_column("knowledge", "access_control") + + # Drop 'access_control' column from 'prompt' table + op.drop_column("prompt", "access_control") + + # Drop 'access_control' column from 'tools' table + op.drop_column("tool", "access_control") diff --git a/backend/open_webui/storage/provider.py b/backend/open_webui/storage/provider.py index 7bea0bdae5f79174fc8b869bc381562567f1da81..28379bcc95ba6b41cd963cafe504e513836d3698 100644 --- a/backend/open_webui/storage/provider.py +++ b/backend/open_webui/storage/provider.py @@ -51,7 +51,10 @@ class StorageProvider: try: self.s3_client.upload_file(file_path, self.bucket_name, filename) - return open(file_path, "rb").read(), file_path + return ( + open(file_path, "rb").read(), + "s3://" + self.bucket_name + "/" + filename, + ) except ClientError as e: raise RuntimeError(f"Error uploading file to S3: {e}") diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py new file mode 100644 index 0000000000000000000000000000000000000000..9f90b04c2c70b8bbc2184aec4202810369de6c63 --- /dev/null +++ b/backend/open_webui/utils/access_control.py @@ -0,0 +1,95 @@ +from typing import Optional, Union, List, Dict, Any +from open_webui.apps.webui.models.groups import Groups +import json + + +def get_permissions( + user_id: str, + default_permissions: Dict[str, Any], +) -> Dict[str, Any]: + """ + Get all permissions for a user by combining the permissions of all groups the user is a member of. + If a permission is defined in multiple groups, the most permissive value is used (True > False). + Permissions are nested in a dict with the permission key as the key and a boolean as the value. + """ + + def combine_permissions( + permissions: Dict[str, Any], group_permissions: Dict[str, Any] + ) -> Dict[str, Any]: + """Combine permissions from multiple groups by taking the most permissive value.""" + for key, value in group_permissions.items(): + if isinstance(value, dict): + if key not in permissions: + permissions[key] = {} + permissions[key] = combine_permissions(permissions[key], value) + else: + if key not in permissions: + permissions[key] = value + else: + permissions[key] = permissions[key] or value + return permissions + + user_groups = Groups.get_groups_by_member_id(user_id) + + # deep copy default permissions to avoid modifying the original dict + permissions = json.loads(json.dumps(default_permissions)) + + for group in user_groups: + group_permissions = group.permissions + permissions = combine_permissions(permissions, group_permissions) + + return permissions + + +def has_permission( + user_id: str, + permission_key: str, + default_permissions: Dict[str, bool] = {}, +) -> bool: + """ + Check if a user has a specific permission by checking the group permissions + and falls back to default permissions if not found in any group. + + Permission keys can be hierarchical and separated by dots ('.'). + """ + + def get_permission(permissions: Dict[str, bool], keys: List[str]) -> bool: + """Traverse permissions dict using a list of keys (from dot-split permission_key).""" + for key in keys: + if key not in permissions: + return False # If any part of the hierarchy is missing, deny access + permissions = permissions[key] # Go one level deeper + + return bool(permissions) # Return the boolean at the final level + + permission_hierarchy = permission_key.split(".") + + # Retrieve user group permissions + user_groups = Groups.get_groups_by_member_id(user_id) + + for group in user_groups: + group_permissions = group.permissions + if get_permission(group_permissions, permission_hierarchy): + return True + + # Check default permissions afterwards if the group permissions don't allow it + return get_permission(default_permissions, permission_hierarchy) + + +def has_access( + user_id: str, + type: str = "write", + access_control: Optional[dict] = None, +) -> bool: + if access_control is None: + return type == "read" + + user_groups = Groups.get_groups_by_member_id(user_id) + user_group_ids = [group.id for group in user_groups] + permission_access = access_control.get(type, {}) + permitted_group_ids = permission_access.get("group_ids", []) + permitted_user_ids = permission_access.get("user_ids", []) + + return user_id in permitted_user_ids or any( + group_id in permitted_group_ids for group_id in user_group_ids + ) diff --git a/backend/open_webui/utils/pdf_generator.py b/backend/open_webui/utils/pdf_generator.py index daaf7fe91f24ac8378d866a505568ca37f81e6b5..47f1cfa62d79b79d11fa71c15eb8566c93f8863e 100644 --- a/backend/open_webui/utils/pdf_generator.py +++ b/backend/open_webui/utils/pdf_generator.py @@ -54,18 +54,18 @@ class PDFGenerator: html_content = markdown(content, extensions=["pymdownx.extra"]) html_message = f""" -
- {date_str} -
-

- {role.title()} - {model} -

-
-
- {html_content} -
-
+
{date_str}
+
+
+

+ {role.title()} + {model} +

+
+
+                    {content}
+                
+
""" return html_message diff --git a/backend/open_webui/utils/security_headers.py b/backend/open_webui/utils/security_headers.py index a24c5131dacc94a434be37323757891f148fa1b3..0091f3efb966741449582301ae10e32ea9b37b90 100644 --- a/backend/open_webui/utils/security_headers.py +++ b/backend/open_webui/utils/security_headers.py @@ -20,6 +20,7 @@ def set_security_headers() -> Dict[str, str]: This function reads specific environment variables and uses their values to set corresponding security headers. The headers that can be set are: - cache-control + - permissions-policy - strict-transport-security - referrer-policy - x-content-type-options @@ -38,6 +39,7 @@ def set_security_headers() -> Dict[str, str]: header_setters = { "CACHE_CONTROL": set_cache_control, "HSTS": set_hsts, + "PERMISSIONS_POLICY": set_permissions_policy, "REFERRER_POLICY": set_referrer, "XCONTENT_TYPE": set_xcontent_type, "XDOWNLOAD_OPTIONS": set_xdownload_options, @@ -73,6 +75,15 @@ def set_xframe(value: str): return {"X-Frame-Options": value} +# Set Permissions-Policy response header +def set_permissions_policy(value: str): + pattern = r"^(?:(accelerometer|autoplay|camera|clipboard-read|clipboard-write|fullscreen|geolocation|gyroscope|magnetometer|microphone|midi|payment|picture-in-picture|sync-xhr|usb|xr-spatial-tracking)=\((self)?\),?)*$" + match = re.match(pattern, value, re.IGNORECASE) + if not match: + value = "none" + return {"Permissions-Policy": value} + + # Set Referrer-Policy response header def set_referrer(value: str): pattern = r"^(no-referrer|no-referrer-when-downgrade|origin|origin-when-cross-origin|same-origin|strict-origin|strict-origin-when-cross-origin|unsafe-url)$" diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 040a33d85ef23c6898341c5a759a4b2635eebebb..cd28587ccc3dfe02c0debeb922e4b388fda864df 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -163,7 +163,7 @@ def emoji_generation_template( return template -def search_query_generation_template( +def query_generation_template( template: str, messages: list[dict], user: Optional[dict] = None ) -> str: prompt = get_last_user_message(messages) diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 1bce2aefa335926b3852ddcecc8fd06c4428f262..f11dbd039b25398ee4d393dc7cb0890017788764 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -4,7 +4,7 @@ from typing import Awaitable, Callable, get_type_hints from open_webui.apps.webui.models.tools import Tools from open_webui.apps.webui.models.users import UserModel -from open_webui.apps.webui.utils import load_toolkit_module_by_id +from open_webui.apps.webui.utils import load_tools_module_by_id from open_webui.utils.schemas import json_schema_to_model log = logging.getLogger(__name__) @@ -32,15 +32,16 @@ def apply_extra_params_to_tool_function( def get_tools( webui_app, tool_ids: list[str], user: UserModel, extra_params: dict ) -> dict[str, dict]: - tools = {} + tools_dict = {} + for tool_id in tool_ids: - toolkit = Tools.get_tool_by_id(tool_id) - if toolkit is None: + tools = Tools.get_tool_by_id(tool_id) + if tools is None: continue module = webui_app.state.TOOLS.get(tool_id, None) if module is None: - module, _ = load_toolkit_module_by_id(tool_id) + module, _ = load_tools_module_by_id(tool_id) webui_app.state.TOOLS[tool_id] = module extra_params["__id__"] = tool_id @@ -53,11 +54,19 @@ def get_tools( **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) ) - for spec in toolkit.specs: + for spec in tools.specs: # TODO: Fix hack for OpenAI API for val in spec.get("parameters", {}).get("properties", {}).values(): if val["type"] == "str": val["type"] = "string" + + # Remove internal parameters + spec["parameters"]["properties"] = { + key: val + for key, val in spec["parameters"]["properties"].items() + if not key.startswith("__") + } + function_name = spec["name"] # convert to function that takes only model params and inserts custom params @@ -77,13 +86,14 @@ def get_tools( } # TODO: if collision, prepend toolkit name - if function_name in tools: - log.warning(f"Tool {function_name} already exists in another toolkit!") - log.warning(f"Collision between {toolkit} and {tool_id}.") - log.warning(f"Discarding {toolkit}.{function_name}") + if function_name in tools_dict: + log.warning(f"Tool {function_name} already exists in another tools!") + log.warning(f"Collision between {tools} and {tool_id}.") + log.warning(f"Discarding {tools}.{function_name}") else: - tools[function_name] = tool_dict - return tools + tools_dict[function_name] = tool_dict + + return tools_dict def doc_to_dict(docstring): diff --git a/backend/open_webui/utils/utils.py b/backend/open_webui/utils/utils.py index a61a2c8f1dae1286cf01cfe2d02cafc1af0d0200..9bc2e30abc7ff26e0a9c47b6bece7aa12dbc4176 100644 --- a/backend/open_webui/utils/utils.py +++ b/backend/open_webui/utils/utils.py @@ -1,12 +1,15 @@ import logging import uuid +import jwt + from datetime import UTC, datetime, timedelta -from typing import Optional, Union +from typing import Optional, Union, List, Dict -import jwt from open_webui.apps.webui.models.users import Users + from open_webui.constants import ERROR_MESSAGES from open_webui.env import WEBUI_SECRET_KEY + from fastapi import Depends, HTTPException, Request, Response, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from passlib.context import CryptContext @@ -88,10 +91,21 @@ def get_current_user( # auth by api key if token.startswith("sk-"): + if not request.state.enable_api_key: + raise HTTPException( + status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED + ) return get_current_user_by_api_key(token) # auth by jwt token - data = decode_token(token) + try: + data = decode_token(token) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token", + ) + if data is not None and "id" in data: user = Users.get_user_by_id(data["id"]) if user is None: diff --git a/backend/requirements.txt b/backend/requirements.txt index 6bb220920e8892da2ff76e1d8f2ac4f512e93eae..62587496cb3946503683831a007ff570c7b9150b 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,7 +1,7 @@ fastapi==0.111.0 uvicorn[standard]==0.30.6 pydantic==2.9.2 -python-multipart==0.0.9 +python-multipart==0.0.17 Flask==3.0.3 Flask-Cors==5.0.0 @@ -13,18 +13,20 @@ passlib[bcrypt]==1.7.4 requests==2.32.3 aiohttp==3.10.8 async-timeout +aiocache sqlalchemy==2.0.32 alembic==1.13.2 peewee==3.17.6 peewee-migrate==1.12.2 psycopg2-binary==2.9.9 +pgvector==0.3.5 PyMySQL==1.1.1 bcrypt==4.2.0 pymongo redis -boto3==1.35.0 +boto3==1.35.53 argon2-cffi==23.1.0 APScheduler==3.10.4 @@ -35,14 +37,15 @@ anthropic google-generativeai==0.7.2 tiktoken -langchain==0.2.15 -langchain-community==0.2.12 +langchain==0.3.5 +langchain-community==0.3.3 langchain-chroma==0.1.4 fake-useragent==1.5.1 chromadb==0.5.15 -pymilvus==2.4.7 +pymilvus==2.4.9 qdrant-client~=1.12.0 +opensearch-py==2.7.1 sentence-transformers==3.2.0 colbert-ai==0.2.21 @@ -51,7 +54,7 @@ einops==0.8.0 ftfy==6.2.3 pypdf==4.3.1 -xhtml2pdf==0.2.16 +fpdf2==2.7.9 pymdown-extensions==10.11.2 docx2txt==0.8 python-pptx==1.0.0 @@ -65,11 +68,11 @@ pyxlsb==1.0.10 xlrd==2.0.1 validators==0.33.0 psutil +sentencepiece +soundfile==0.12.1 opencv-python-headless==4.10.0.84 rapidocr-onnxruntime==1.3.24 - -fpdf2==2.7.9 rank-bm25==0.2.2 faster-whisper==1.0.3 @@ -84,7 +87,7 @@ pytube==15.0.0 extract_msg pydub -duckduckgo-search~=6.2.13 +duckduckgo-search~=6.3.5 ## Tests docker~=7.1.0 @@ -92,3 +95,6 @@ pytest~=8.3.2 pytest-docker~=3.1.1 googleapis-common-protos==1.63.2 + +## LDAP +ldap3==2.9.1 diff --git a/backend/start.sh b/backend/start.sh index 129dcf8a4df95e6bea553349f85b273743b0b662..a945acb62e931cf1e0643d7bdcb5a414b4d71704 100755 --- a/backend/start.sh +++ b/backend/start.sh @@ -54,6 +54,4 @@ if [ -n "$SPACE_ID" ]; then export WEBUI_URL=${SPACE_HOST} fi -export GLOBAL_LOG_LEVEL="ERROR" - -GLOBAL_LOG_LEVEL="ERROR" WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" exec uvicorn open_webui.main:app --host "$HOST" --port "$PORT" --forwarded-allow-ips '*' +WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" exec uvicorn open_webui.main:app --host "$HOST" --port "$PORT" --forwarded-allow-ips '*' diff --git a/package-lock.json b/package-lock.json index 148493d22b27aa9ac98408e2ab245264872b4ed0..7267c10c658ec91ab5020901e44cc48194251e2e 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,18 +1,19 @@ { "name": "open-webui", - "version": "0.3.35", + "version": "0.4.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.3.35", + "version": "0.4.1", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-python": "^6.1.6", "@codemirror/language-data": "^6.5.1", "@codemirror/theme-one-dark": "^6.1.2", "@huggingface/transformers": "^3.0.0", + "@mediapipe/tasks-vision": "^0.10.17", "@pyscript/core": "^0.4.32", "@sveltejs/adapter-node": "^2.0.0", "@xyflow/svelte": "^0.1.19", @@ -1749,6 +1750,11 @@ "@lezer/lr": "^1.4.0" } }, + "node_modules/@mediapipe/tasks-vision": { + "version": "0.10.17", + "resolved": "https://registry.npmjs.org/@mediapipe/tasks-vision/-/tasks-vision-0.10.17.tgz", + "integrity": "sha512-CZWV/q6TTe8ta61cZXjfnnHsfWIdFhms03M9T7Cnd5y2mdpylJM0rF1qRq+wsQVRMLz1OYPVEBU9ph2Bx8cxrg==" + }, "node_modules/@melt-ui/svelte": { "version": "0.76.0", "resolved": "https://registry.npmjs.org/@melt-ui/svelte/-/svelte-0.76.0.tgz", @@ -3993,9 +3999,10 @@ "integrity": "sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g==" }, "node_modules/cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", + "license": "MIT", "dependencies": { "path-key": "^3.1.0", "shebang-command": "^2.0.0", diff --git a/package.json b/package.json index 232e0883d82b09b1973544bbbb806a98da688da6..f8aa9d37b890a494ddddfb6f54208695c3b78c0c 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.3.35", + "version": "0.4.1", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", @@ -53,6 +53,7 @@ "@codemirror/language-data": "^6.5.1", "@codemirror/theme-one-dark": "^6.1.2", "@huggingface/transformers": "^3.0.0", + "@mediapipe/tasks-vision": "^0.10.17", "@pyscript/core": "^0.4.32", "@sveltejs/adapter-node": "^2.0.0", "@xyflow/svelte": "^0.1.19", diff --git a/pyproject.toml b/pyproject.toml index 3cf8dfd8116b7906138aa3c830c294a1a457a7a5..1ae65981239235c30e1105be914e08de9b0ae75f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ dependencies = [ "fastapi==0.111.0", "uvicorn[standard]==0.30.6", "pydantic==2.9.2", - "python-multipart==0.0.9", + "python-multipart==0.0.17", "Flask==3.0.3", "Flask-Cors==5.0.0", @@ -21,18 +21,20 @@ dependencies = [ "requests==2.32.3", "aiohttp==3.10.8", "async-timeout", + "aiocache", "sqlalchemy==2.0.32", "alembic==1.13.2", "peewee==3.17.6", "peewee-migrate==1.12.2", "psycopg2-binary==2.9.9", + "pgvector==0.3.5", "PyMySQL==1.1.1", "bcrypt==4.2.0", "pymongo", "redis", - "boto3==1.35.0", + "boto3==1.35.53", "argon2-cffi==23.1.0", "APScheduler==3.10.4", @@ -42,13 +44,15 @@ dependencies = [ "google-generativeai==0.7.2", "tiktoken", - "langchain==0.2.15", - "langchain-community==0.2.12", + "langchain==0.3.5", + "langchain-community==0.3.3", "langchain-chroma==0.1.4", "fake-useragent==1.5.1", - "chromadb==0.5.9", - "pymilvus==2.4.7", + "chromadb==0.5.15", + "pymilvus==2.4.9", + "qdrant-client~=1.12.0", + "opensearch-py==2.7.1", "sentence-transformers==3.2.0", "colbert-ai==0.2.21", @@ -56,7 +60,7 @@ dependencies = [ "ftfy==6.2.3", "pypdf==4.3.1", - "xhtml2pdf==0.2.16", + "fpdf2==2.7.9", "pymdown-extensions==10.11.2", "docx2txt==0.8", "python-pptx==1.0.0", @@ -70,11 +74,11 @@ dependencies = [ "xlrd==2.0.1", "validators==0.33.0", "psutil", + "sentencepiece", + "soundfile==0.12.1", "opencv-python-headless==4.10.0.84", "rapidocr-onnxruntime==1.3.24", - - "fpdf2==2.7.9", "rank-bm25==0.2.2", "faster-whisper==1.0.3", @@ -89,13 +93,15 @@ dependencies = [ "extract_msg", "pydub", - "duckduckgo-search~=6.2.13", + "duckduckgo-search~=6.3.5", "docker~=7.1.0", "pytest~=8.3.2", "pytest-docker~=3.1.1", - "googleapis-common-protos==1.63.2" + "googleapis-common-protos==1.63.2", + + "ldap3==2.9.1" ] readme = "README.md" requires-python = ">= 3.11, < 3.12.0a1" diff --git a/src/app.css b/src/app.css index d7f2d0e5488a9a9038c3fb3251d768c095037319..a0c94ab2f20bce077ec5eb8d89533ea6d5589d22 100644 --- a/src/app.css +++ b/src/app.css @@ -16,6 +16,12 @@ font-display: swap; } +@font-face { + font-family: 'InstrumentSerif'; + src: url('/assets/fonts/InstrumentSerif-Regular.ttf'); + font-display: swap; +} + html { word-break: break-word; } @@ -26,6 +32,10 @@ code { width: auto; } +.font-secondary { + font-family: 'InstrumentSerif', sans-serif; +} + math { margin-top: 1rem; } diff --git a/src/lib/apis/auths/index.ts b/src/lib/apis/auths/index.ts index d093528f4b68a475588ce040fe2a258cabf43c04..160248f3babbcb2e707d010e1bf1abc7b6e27d84 100644 --- a/src/lib/apis/auths/index.ts +++ b/src/lib/apis/auths/index.ts @@ -110,6 +110,150 @@ export const getSessionUser = async (token: string) => { return res; }; +export const ldapUserSignIn = async (user: string, password: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/ldap`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + credentials: 'include', + body: JSON.stringify({ + user: user, + password: password + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getLdapConfig = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/admin/config/ldap`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateLdapConfig = async (token: string = '', enable_ldap: boolean) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/admin/config/ldap`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + enable_ldap: enable_ldap + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getLdapServer = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/admin/config/ldap/server`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateLdapServer = async (token: string = '', body: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/admin/config/ldap/server`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify(body) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const userSignIn = async (email: string, password: string) => { let error = null; diff --git a/src/lib/apis/groups/index.ts b/src/lib/apis/groups/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..a9c22c22a302fd3e79b35390cb3b224f4b89828c --- /dev/null +++ b/src/lib/apis/groups/index.ts @@ -0,0 +1,162 @@ +import { WEBUI_API_BASE_URL } from '$lib/constants'; + +export const createNewGroup = async (token: string, group: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/create`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...group + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getGroups = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getGroupById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/id/${id}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateGroupById = async (token: string, id: string, group: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/id/${id}/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...group + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteGroupById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/id/${id}/delete`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index a9ea060c261c96d103b12c76f2e2bacbc7f0fd67..5432ae4880c12f5dc62045ae703beca58ac975ea 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -1,9 +1,8 @@ import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; -export const getModels = async (token: string = '') => { +export const getModels = async (token: string = '', base: boolean = false) => { let error = null; - - const res = await fetch(`${WEBUI_BASE_URL}/api/models`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/models${base ? '/base' : ''}`, { method: 'GET', headers: { Accept: 'application/json', @@ -16,8 +15,8 @@ export const getModels = async (token: string = '') => { return res.json(); }) .catch((err) => { - console.log(err); error = err; + console.log(err); return null; }); @@ -26,26 +25,10 @@ export const getModels = async (token: string = '') => { } let models = res?.data ?? []; - models = models .filter((models) => models) // Sort the models .sort((a, b) => { - // Check if models have position property - const aHasPosition = a.info?.meta?.position !== undefined; - const bHasPosition = b.info?.meta?.position !== undefined; - - // If both a and b have the position property - if (aHasPosition && bHasPosition) { - return a.info.meta.position - b.info.meta.position; - } - - // If only a has the position property, it should come first - if (aHasPosition) return -1; - - // If only b has the position property, it should come first - if (bHasPosition) return 1; - // Compare case-insensitively by name for models without position property const lowerA = a.name.toLowerCase(); const lowerB = b.name.toLowerCase(); @@ -365,15 +348,16 @@ export const generateEmoji = async ( return null; }; -export const generateSearchQuery = async ( +export const generateQueries = async ( token: string = '', model: string, messages: object[], - prompt: string + prompt: string, + type?: string = 'web_search' ) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/query/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/task/queries/completions`, { method: 'POST', headers: { Accept: 'application/json', @@ -383,7 +367,8 @@ export const generateSearchQuery = async ( body: JSON.stringify({ model: model, messages: messages, - prompt: prompt + prompt: prompt, + type: type }) }) .then(async (res) => { @@ -402,7 +387,39 @@ export const generateSearchQuery = async ( throw error; } - return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt; + try { + // Step 1: Safely extract the response string + const response = res?.choices[0]?.message?.content ?? ''; + + // Step 2: Attempt to fix common JSON format issues like single quotes + const sanitizedResponse = response.replace(/['‘’`]/g, '"'); // Convert single quotes to double quotes for valid JSON + + // Step 3: Find the relevant JSON block within the response + const jsonStartIndex = sanitizedResponse.indexOf('{'); + const jsonEndIndex = sanitizedResponse.lastIndexOf('}'); + + // Step 4: Check if we found a valid JSON block (with both `{` and `}`) + if (jsonStartIndex !== -1 && jsonEndIndex !== -1) { + const jsonResponse = sanitizedResponse.substring(jsonStartIndex, jsonEndIndex + 1); + + // Step 5: Parse the JSON block + const parsed = JSON.parse(jsonResponse); + + // Step 6: If there's a "queries" key, return the queries array; otherwise, return an empty array + if (parsed && parsed.queries) { + return Array.isArray(parsed.queries) ? parsed.queries : []; + } else { + return []; + } + } + + // If no valid JSON block found, return an empty array + return []; + } catch (e) { + // Catch and safely return empty array on any parsing errors + console.error('Failed to parse response: ', e); + return []; + } }; export const generateMoACompletion = async ( diff --git a/src/lib/apis/knowledge/index.ts b/src/lib/apis/knowledge/index.ts index 22a1de819f7856a2a8d05818fcd419020dd6453e..0840b4e329ed1917a74fa3703373d755cb635d44 100644 --- a/src/lib/apis/knowledge/index.ts +++ b/src/lib/apis/knowledge/index.ts @@ -1,6 +1,11 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; -export const createNewKnowledge = async (token: string, name: string, description: string) => { +export const createNewKnowledge = async ( + token: string, + name: string, + description: string, + accessControl: null | object +) => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/knowledge/create`, { @@ -12,7 +17,8 @@ export const createNewKnowledge = async (token: string, name: string, descriptio }, body: JSON.stringify({ name: name, - description: description + description: description, + access_control: accessControl }) }) .then(async (res) => { @@ -32,7 +38,7 @@ export const createNewKnowledge = async (token: string, name: string, descriptio return res; }; -export const getKnowledgeItems = async (token: string = '') => { +export const getKnowledgeBases = async (token: string = '') => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/knowledge/`, { @@ -63,6 +69,37 @@ export const getKnowledgeItems = async (token: string = '') => { return res; }; +export const getKnowledgeBaseList = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/knowledge/list`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getKnowledgeById = async (token: string, id: string) => { let error = null; @@ -99,6 +136,7 @@ type KnowledgeUpdateForm = { name?: string; description?: string; data?: object; + access_control?: null | object; }; export const updateKnowledgeById = async (token: string, id: string, form: KnowledgeUpdateForm) => { @@ -114,7 +152,8 @@ export const updateKnowledgeById = async (token: string, id: string, form: Knowl body: JSON.stringify({ name: form?.name ? form.name : undefined, description: form?.description ? form.description : undefined, - data: form?.data ? form.data : undefined + data: form?.data ? form.data : undefined, + access_control: form.access_control }) }) .then(async (res) => { diff --git a/src/lib/apis/models/index.ts b/src/lib/apis/models/index.ts index e36f9e55d9451a3804fc56eb022e7158cb7b7edc..8fec6829a2fd8c45551bad2d3b6eacb5717c6cc9 100644 --- a/src/lib/apis/models/index.ts +++ b/src/lib/apis/models/index.ts @@ -1,9 +1,71 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; -export const addNewModel = async (token: string, model: object) => { +export const getModels = async (token: string = '') => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/models/add`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getBaseModels = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/base`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const createNewModel = async (token: string, model: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/create`, { method: 'POST', headers: { Accept: 'application/json', @@ -29,10 +91,13 @@ export const addNewModel = async (token: string, model: object) => { return res; }; -export const getModelInfos = async (token: string = '') => { +export const getModelById = async (token: string, id: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/models`, { + const searchParams = new URLSearchParams(); + searchParams.append('id', id); + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/id/${id}`, { method: 'GET', headers: { Accept: 'application/json', @@ -49,6 +114,7 @@ export const getModelInfos = async (token: string = '') => { }) .catch((err) => { error = err; + console.log(err); return null; }); @@ -60,14 +126,14 @@ export const getModelInfos = async (token: string = '') => { return res; }; -export const getModelById = async (token: string, id: string) => { +export const toggleModelById = async (token: string, id: string) => { let error = null; const searchParams = new URLSearchParams(); searchParams.append('id', id); - const res = await fetch(`${WEBUI_API_BASE_URL}/models?${searchParams.toString()}`, { - method: 'GET', + const res = await fetch(`${WEBUI_API_BASE_URL}/models/id/${id}/toggle`, { + method: 'POST', headers: { Accept: 'application/json', 'Content-Type': 'application/json', @@ -101,7 +167,7 @@ export const updateModelById = async (token: string, id: string, model: object) const searchParams = new URLSearchParams(); searchParams.append('id', id); - const res = await fetch(`${WEBUI_API_BASE_URL}/models/update?${searchParams.toString()}`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/id/${id}/update`, { method: 'POST', headers: { Accept: 'application/json', @@ -137,7 +203,39 @@ export const deleteModelById = async (token: string, id: string) => { const searchParams = new URLSearchParams(); searchParams.append('id', id); - const res = await fetch(`${WEBUI_API_BASE_URL}/models/delete?${searchParams.toString()}`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/id/${id}/delete`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteAllModels = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/delete/all`, { method: 'DELETE', headers: { Accept: 'application/json', diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index c4083a9c71f54b7a9ed074dd7c5e3eabde8f5714..bac8e7e11bc8509a02b9145c6423ea0c68d436ee 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -1,5 +1,40 @@ import { OLLAMA_API_BASE_URL } from '$lib/constants'; +export const verifyOllamaConnection = async ( + token: string = '', + url: string = '', + key: string = '' +) => { + let error = null; + + const res = await fetch(`${OLLAMA_API_BASE_URL}/verify`, { + method: 'POST', + headers: { + Accept: 'application/json', + Authorization: `Bearer ${token}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + url, + key + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = `Ollama: ${err?.error?.message ?? 'Network Problem'}`; + return []; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getOllamaConfig = async (token: string = '') => { let error = null; @@ -32,7 +67,13 @@ export const getOllamaConfig = async (token: string = '') => { return res; }; -export const updateOllamaConfig = async (token: string = '', enable_ollama_api: boolean) => { +type OllamaConfig = { + ENABLE_OLLAMA_API: boolean; + OLLAMA_BASE_URLS: string[]; + OLLAMA_API_CONFIGS: object; +}; + +export const updateOllamaConfig = async (token: string = '', config: OllamaConfig) => { let error = null; const res = await fetch(`${OLLAMA_API_BASE_URL}/config/update`, { @@ -43,7 +84,7 @@ export const updateOllamaConfig = async (token: string = '', enable_ollama_api: ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - enable_ollama_api: enable_ollama_api + ...config }) }) .then(async (res) => { @@ -166,10 +207,10 @@ export const getOllamaVersion = async (token: string, urlIdx?: number) => { return res?.version ?? false; }; -export const getOllamaModels = async (token: string = '') => { +export const getOllamaModels = async (token: string = '', urlIdx: null | number = null) => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/api/tags`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/tags${urlIdx !== null ? `/${urlIdx}` : ''}`, { method: 'GET', headers: { Accept: 'application/json', diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index 5866f4fd846fb64a7dbf2f0dbecc337be24a5247..bad96716c6401860593770fa8d6d3bd39e03a99a 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -32,7 +32,14 @@ export const getOpenAIConfig = async (token: string = '') => { return res; }; -export const updateOpenAIConfig = async (token: string = '', enable_openai_api: boolean) => { +type OpenAIConfig = { + ENABLE_OPENAI_API: boolean; + OPENAI_API_BASE_URLS: string[]; + OPENAI_API_KEYS: string[]; + OPENAI_API_CONFIGS: object; +}; + +export const updateOpenAIConfig = async (token: string = '', config: OpenAIConfig) => { let error = null; const res = await fetch(`${OPENAI_API_BASE_URL}/config/update`, { @@ -43,7 +50,7 @@ export const updateOpenAIConfig = async (token: string = '', enable_openai_api: ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - enable_openai_api: enable_openai_api + ...config }) }) .then(async (res) => { @@ -231,41 +238,39 @@ export const getOpenAIModels = async (token: string, urlIdx?: number) => { return res; }; -export const getOpenAIModelsDirect = async ( - base_url: string = 'https://api.openai.com/v1', - api_key: string = '' +export const verifyOpenAIConnection = async ( + token: string = '', + url: string = 'https://api.openai.com/v1', + key: string = '' ) => { let error = null; - const res = await fetch(`${base_url}/models`, { - method: 'GET', + const res = await fetch(`${OPENAI_API_BASE_URL}/verify`, { + method: 'POST', headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${api_key}` - } + Accept: 'application/json', + Authorization: `Bearer ${token}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + url, + key + }) }) .then(async (res) => { if (!res.ok) throw await res.json(); return res.json(); }) .catch((err) => { - console.log(err); error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`; - return null; + return []; }); if (error) { throw error; } - const models = Array.isArray(res) ? res : (res?.data ?? null); - - return models - .map((model) => ({ id: model.id, name: model.name ?? model.id, external: true })) - .filter((model) => (base_url.includes('openai') ? model.name.includes('gpt') : true)) - .sort((a, b) => { - return a.name.localeCompare(b.name); - }); + return res; }; export const generateOpenAIChatCompletion = async ( diff --git a/src/lib/apis/prompts/index.ts b/src/lib/apis/prompts/index.ts index 20f5d9d77fbe7163d2edc50e95ce5a4f5b4d271e..992e8bc99943bd0272305992e003827eafbdd3ec 100644 --- a/src/lib/apis/prompts/index.ts +++ b/src/lib/apis/prompts/index.ts @@ -1,11 +1,13 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; -export const createNewPrompt = async ( - token: string, - command: string, - title: string, - content: string -) => { +type PromptItem = { + command: string; + title: string; + content: string; + access_control: null | object; +}; + +export const createNewPrompt = async (token: string, prompt: PromptItem) => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/create`, { @@ -16,9 +18,8 @@ export const createNewPrompt = async ( authorization: `Bearer ${token}` }, body: JSON.stringify({ - command: `/${command}`, - title: title, - content: content + ...prompt, + command: `/${prompt.command}` }) }) .then(async (res) => { @@ -69,6 +70,37 @@ export const getPrompts = async (token: string = '') => { return res; }; +export const getPromptList = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/list`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getPromptByCommand = async (token: string, command: string) => { let error = null; @@ -101,15 +133,10 @@ export const getPromptByCommand = async (token: string, command: string) => { return res; }; -export const updatePromptByCommand = async ( - token: string, - command: string, - title: string, - content: string -) => { +export const updatePromptByCommand = async (token: string, prompt: PromptItem) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/command/${command}/update`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/command/${prompt.command}/update`, { method: 'POST', headers: { Accept: 'application/json', @@ -117,9 +144,8 @@ export const updatePromptByCommand = async ( authorization: `Bearer ${token}` }, body: JSON.stringify({ - command: `/${command}`, - title: title, - content: content + ...prompt, + command: `/${prompt.command}` }) }) .then(async (res) => { diff --git a/src/lib/apis/tools/index.ts b/src/lib/apis/tools/index.ts index 877116a8643844114f70a367850d7ee6acbd410d..9bf3af672237dbce1a4c863ccdf00a3b2b0cf7fe 100644 --- a/src/lib/apis/tools/index.ts +++ b/src/lib/apis/tools/index.ts @@ -62,6 +62,37 @@ export const getTools = async (token: string = '') => { return res; }; +export const getToolList = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools/list`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const exportTools = async (token: string = '') => { let error = null; diff --git a/src/lib/apis/users/index.ts b/src/lib/apis/users/index.ts index 669a32965709af0008e565885c3fa98a8e80f0fc..c2f6155da562801f6927b7cf915c68014df50b17 100644 --- a/src/lib/apis/users/index.ts +++ b/src/lib/apis/users/index.ts @@ -1,10 +1,10 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; import { getUserPosition } from '$lib/utils'; -export const getUserPermissions = async (token: string) => { +export const getUserGroups = async (token: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/users/permissions/user`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/users/groups`, { method: 'GET', headers: { 'Content-Type': 'application/json', @@ -28,10 +28,37 @@ export const getUserPermissions = async (token: string) => { return res; }; -export const updateUserPermissions = async (token: string, permissions: object) => { +export const getUserDefaultPermissions = async (token: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/users/permissions/user`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/users/default/permissions`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateUserDefaultPermissions = async (token: string, permissions: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/users/default/permissions`, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/src/lib/components/ChangelogModal.svelte b/src/lib/components/ChangelogModal.svelte index 55ea0f8dba6fcd9918e3b4de717cf305abd504c6..3fedfb3cb2d8adfb2367a61437287311391ccb0e 100644 --- a/src/lib/components/ChangelogModal.svelte +++ b/src/lib/components/ChangelogModal.svelte @@ -22,7 +22,7 @@ }); - +
@@ -59,7 +59,7 @@
-
+
{#if changelog} {#each Object.keys(changelog) as version} @@ -111,7 +111,7 @@ await updateUserSettings(localStorage.token, { ui: $settings }); show = false; }} - class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg" + class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full" > {$i18n.t("Okay, Let's Go!")} diff --git a/src/lib/components/OnBoarding.svelte b/src/lib/components/OnBoarding.svelte new file mode 100644 index 0000000000000000000000000000000000000000..4cef95aec29b448506b726dbacda4f233ec70c48 --- /dev/null +++ b/src/lib/components/OnBoarding.svelte @@ -0,0 +1,78 @@ + + +{#if show} +
+
+
+
+ logo +
+
+
+ + + +
+ +
+ +
+
+
+ + +
{$i18n.t(`wherever you are`)}
+ + +
+
+ +
{$i18n.t(`Get started`)}
+
+
+ + + + + +{/if} diff --git a/src/lib/components/admin/Evaluations.svelte b/src/lib/components/admin/Evaluations.svelte index 13d51ab8c04d819177f15dcee55017983f4e8716..69501b2fc35492e172d2f4d69618c808dbea7bc2 100644 --- a/src/lib/components/admin/Evaluations.svelte +++ b/src/lib/components/admin/Evaluations.svelte @@ -1,677 +1,100 @@ - {#if loaded} -
-
-
- {$i18n.t('Leaderboard')} -
- -
- - {rankedModels.length} +
+
- -
- -
-
- -
- { - loadEmbeddingModel(); - }} - /> -
-
-
-
- -
- {#if loadingLeaderboard} -
-
- +
+ + +
-
- {/if} - {#if (rankedModels ?? []).length === 0} -
- {$i18n.t('No models found')} -
- {:else} - {$i18n.t('Leaderboard')} + + + - - - - - - - - - - {#each rankedModels as model, modelIdx (model.id)} - - - - - - - - - - {/each} - -
- {$i18n.t('RK')} - - {$i18n.t('Model')} - - {$i18n.t('Rating')} - - {$i18n.t('Won')} - - {$i18n.t('Lost')} -
-
- {model?.rating !== '-' ? modelIdx + 1 : '-'} -
-
-
-
- {model.name} -
- -
- {model.name} -
-
-
- {model.rating} - -
- {#if model.stats.won === '-'} - - - {:else} - - {model.stats.won} - {/if} -
-
-
- {#if model.stats.lost === '-'} - - - {:else} - - {model.stats.lost} - {/if} -
-
- {/if} -
- -
-
-
- ⓘ {$i18n.t( - 'The evaluation leaderboard is based on the Elo rating system and is updated in real-time.' - )} -
- {$i18n.t( - 'The leaderboard is currently in beta, and we may adjust the rating calculations as we refine the algorithm.' - )} -
-
- -
- -
-
- {$i18n.t('Feedback History')} - -
- - {feedbacks.length} -
- -
-
- - - -
+ + +
+
{$i18n.t('Feedbacks')}
+
-
- -
- {#if (feedbacks ?? []).length === 0} -
- {$i18n.t('No feedbacks found')} -
- {:else} - - - - - - - - - - - - - - - {#each paginatedFeedbacks as feedback (feedback.id)} - - - - - - - - - - - {/each} - -
- {$i18n.t('User')} - - {$i18n.t('Models')} - - {$i18n.t('Result')} - - {$i18n.t('Updated At')} -
-
- -
- {feedback?.user?.name} -
-
-
-
-
-
- {#if feedback.data?.sibling_model_ids} -
- {feedback.data?.model_id} -
- - -
- {#if feedback.data.sibling_model_ids.length > 2} - - {feedback.data.sibling_model_ids.slice(0, 2).join(', ')}, {$i18n.t( - 'and {{COUNT}} more', - { COUNT: feedback.data.sibling_model_ids.length - 2 } - )} - {:else} - {feedback.data.sibling_model_ids.join(', ')} - {/if} -
-
- {:else} -
- {feedback.data?.model_id} -
- {/if} -
-
-
-
- {#if feedback.data.rating.toString() === '1'} - - {:else if feedback.data.rating.toString() === '0'} - - {:else if feedback.data.rating.toString() === '-1'} - - {/if} -
-
- {dayjs(feedback.updated_at * 1000).fromNow()} - - { - deleteFeedbackHandler(feedback.id); - }} - > - - -
- {/if} -
- - {#if feedbacks.length > 0} -
-
- {$i18n.t('Help us create the best community leaderboard by sharing your feedback history!')} -
- -
- - - -
+
+ {#if selectedTab === 'leaderboard'} + + {:else if selectedTab === 'feedbacks'} + + {/if}
- {/if} - - {#if feedbacks.length > 10} - - {/if} - -
+
{/if} diff --git a/src/lib/components/admin/Evaluations/Feedbacks.svelte b/src/lib/components/admin/Evaluations/Feedbacks.svelte new file mode 100644 index 0000000000000000000000000000000000000000..d465ab97a6646a9017d650ca31ed6d55e8b37783 --- /dev/null +++ b/src/lib/components/admin/Evaluations/Feedbacks.svelte @@ -0,0 +1,283 @@ + + +
+
+ {$i18n.t('Feedback History')} + +
+ + {feedbacks.length} +
+ +
+
+ + + +
+
+
+ +
+ {#if (feedbacks ?? []).length === 0} +
+ {$i18n.t('No feedbacks found')} +
+ {:else} + + + + + + + + + + + + + + + + {#each paginatedFeedbacks as feedback (feedback.id)} + + + + + + + + + + + {/each} + +
+ {$i18n.t('User')} + + {$i18n.t('Models')} + + {$i18n.t('Result')} + + {$i18n.t('Updated At')} +
+
+ +
+ {feedback?.user?.name} +
+
+
+
+
+
+ {#if feedback.data?.sibling_model_ids} +
+ {feedback.data?.model_id} +
+ + +
+ {#if feedback.data.sibling_model_ids.length > 2} + + {feedback.data.sibling_model_ids.slice(0, 2).join(', ')}, {$i18n.t( + 'and {{COUNT}} more', + { COUNT: feedback.data.sibling_model_ids.length - 2 } + )} + {:else} + {feedback.data.sibling_model_ids.join(', ')} + {/if} +
+
+ {:else} +
+ {feedback.data?.model_id} +
+ {/if} +
+
+
+
+ {#if feedback.data.rating.toString() === '1'} + + {:else if feedback.data.rating.toString() === '0'} + + {:else if feedback.data.rating.toString() === '-1'} + + {/if} +
+
+ {dayjs(feedback.updated_at * 1000).fromNow()} + + { + deleteFeedbackHandler(feedback.id); + }} + > + + +
+ {/if} +
+ +{#if feedbacks.length > 0} +
+
+ {$i18n.t('Help us create the best community leaderboard by sharing your feedback history!')} +
+ +
+ + + +
+
+{/if} + +{#if feedbacks.length > 10} + +{/if} diff --git a/src/lib/components/admin/Evaluations/Leaderboard.svelte b/src/lib/components/admin/Evaluations/Leaderboard.svelte new file mode 100644 index 0000000000000000000000000000000000000000..42ff171d74240c8991ea94529f0c36ce397e2960 --- /dev/null +++ b/src/lib/components/admin/Evaluations/Leaderboard.svelte @@ -0,0 +1,410 @@ + + +
+
+
+ {$i18n.t('Leaderboard')} +
+ +
+ + {rankedModels.length} +
+ +
+ +
+
+ +
+ { + loadEmbeddingModel(); + }} + /> +
+
+
+
+ +
+ {#if loadingLeaderboard} +
+
+ +
+
+ {/if} + {#if (rankedModels ?? []).length === 0} +
+ {$i18n.t('No models found')} +
+ {:else} + + + + + + + + + + + + {#each rankedModels as model, modelIdx (model.id)} + + + + + + + + + + {/each} + +
+ {$i18n.t('RK')} + + {$i18n.t('Model')} + + {$i18n.t('Rating')} + + {$i18n.t('Won')} + + {$i18n.t('Lost')} +
+
+ {model?.rating !== '-' ? modelIdx + 1 : '-'} +
+
+
+
+ {model.name} +
+ +
+ {model.name} +
+
+
+ {model.rating} + +
+ {#if model.stats.won === '-'} + - + {:else} + + {model.stats.won} + {/if} +
+
+
+ {#if model.stats.lost === '-'} + - + {:else} + + {model.stats.lost} + {/if} +
+
+ {/if} +
+ +
+
+
+ ⓘ {$i18n.t( + 'The evaluation leaderboard is based on the Elo rating system and is updated in real-time.' + )} +
+ {$i18n.t( + 'The leaderboard is currently in beta, and we may adjust the rating calculations as we refine the algorithm.' + )} +
+
diff --git a/src/lib/components/admin/Functions.svelte b/src/lib/components/admin/Functions.svelte new file mode 100644 index 0000000000000000000000000000000000000000..c12fa2073230443d216472bd254e5500d25c644e --- /dev/null +++ b/src/lib/components/admin/Functions.svelte @@ -0,0 +1,542 @@ + + + + + {$i18n.t('Functions')} | {$WEBUI_NAME} + + + +
+
+
+ {$i18n.t('Functions')} +
+ {filteredItems.length} +
+
+ +
+
+
+ +
+ +
+ +
+ + + +
+
+
+ +
+ {#each filteredItems as func} +
+ +
+
+
+
+ {func.type} +
+ + {#if func?.meta?.manifest?.version} +
+ v{func?.meta?.manifest?.version ?? ''} +
+ {/if} + +
+ {func.name} +
+
+ +
+
{func.id}
+ +
+ {func.meta.description} +
+
+
+
+
+
+ {#if shiftKey} + + + + {:else} + {#if func?.meta?.manifest?.funding_url ?? false} + + + + {/if} + + + + + + { + goto(`/admin/functions/edit?id=${encodeURIComponent(func.id)}`); + }} + shareHandler={() => { + shareHandler(func); + }} + cloneHandler={() => { + cloneHandler(func); + }} + exportHandler={() => { + exportHandler(func); + }} + deleteHandler={async () => { + selectedFunction = func; + showDeleteConfirm = true; + }} + toggleGlobalHandler={() => { + if (['filter', 'action'].includes(func.type)) { + toggleGlobalHandler(func); + } + }} + onClose={() => {}} + > + + + {/if} + +
+ + { + toggleFunctionById(localStorage.token, func.id); + models.set(await getModels(localStorage.token)); + }} + /> + +
+
+
+ {/each} +
+ + + +
+
+ { + console.log(importFiles); + showConfirm = true; + }} + /> + + + + +
+
+ +{#if $config?.features.enable_community_sharing} + +{/if} + + { + deleteHandler(selectedFunction); + }} +> +
+ {$i18n.t('This will delete')} {selectedFunction.name}. +
+
+ + + { + await tick(); + models.set(await getModels(localStorage.token)); + }} +/> + + { + const reader = new FileReader(); + reader.onload = async (event) => { + const _functions = JSON.parse(event.target.result); + console.log(_functions); + + for (const func of _functions) { + const res = await createNewFunction(localStorage.token, func).catch((error) => { + toast.error(error); + return null; + }); + } + + toast.success($i18n.t('Functions imported successfully')); + functions.set(await getFunctions(localStorage.token)); + models.set(await getModels(localStorage.token)); + }; + + reader.readAsText(importFiles[0]); + }} +> +
+
+
Please carefully review the following warnings:
+ +
    +
  • {$i18n.t('Functions allow arbitrary code execution.')}
  • +
  • {$i18n.t('Do not install functions from sources you do not fully trust.')}
  • +
+
+ +
+ {$i18n.t( + 'I acknowledge that I have read and I understand the implications of my action. I am aware of the risks associated with executing arbitrary code and I have verified the trustworthiness of the source.' + )} +
+
+
diff --git a/src/lib/components/admin/Functions/FunctionEditor.svelte b/src/lib/components/admin/Functions/FunctionEditor.svelte new file mode 100644 index 0000000000000000000000000000000000000000..2fef4daec9063c7d6af658e4aad9a5e202510199 --- /dev/null +++ b/src/lib/components/admin/Functions/FunctionEditor.svelte @@ -0,0 +1,430 @@ + + +
+
+
{ + if (edit) { + submitHandler(); + } else { + showConfirm = true; + } + }} + > + +
+
+
+ + { + submitHandler(); + }} +> +
+
+
{$i18n.t('Please carefully review the following warnings:')}
+ +
    +
  • {$i18n.t('Functions allow arbitrary code execution.')}
  • +
  • {$i18n.t('Do not install functions from sources you do not fully trust.')}
  • +
+
+ +
+ {$i18n.t( + 'I acknowledge that I have read and I understand the implications of my action. I am aware of the risks associated with executing arbitrary code and I have verified the trustworthiness of the source.' + )} +
+
+
diff --git a/src/lib/components/admin/Functions/FunctionMenu.svelte b/src/lib/components/admin/Functions/FunctionMenu.svelte new file mode 100644 index 0000000000000000000000000000000000000000..0b9ba39dd293a3502f264b36f476e6c784885a8d --- /dev/null +++ b/src/lib/components/admin/Functions/FunctionMenu.svelte @@ -0,0 +1,138 @@ + + + { + if (e.detail === false) { + onClose(); + } + }} +> + + + + +
+ + {#if ['filter', 'action'].includes(func.type)} +
+
+ + +
{$i18n.t('Global')}
+
+ +
+ +
+
+ +
+ {/if} + + { + editHandler(); + }} + > + + + + +
{$i18n.t('Edit')}
+
+ + { + shareHandler(); + }} + > + +
{$i18n.t('Share')}
+
+ + { + cloneHandler(); + }} + > + + +
{$i18n.t('Clone')}
+
+ + { + exportHandler(); + }} + > + + +
{$i18n.t('Export')}
+
+ +
+ + { + deleteHandler(); + }} + > + +
{$i18n.t('Delete')}
+
+
+
+
diff --git a/src/lib/components/admin/Settings.svelte b/src/lib/components/admin/Settings.svelte index 63b15f902fcecf9b65e72d207dcbfbee39d7f148..53fe7e2c6d72c35e71cbe46f6db3f35fbe62234a 100644 --- a/src/lib/components/admin/Settings.svelte +++ b/src/lib/components/admin/Settings.svelte @@ -2,11 +2,11 @@ import { getContext, tick, onMount } from 'svelte'; import { toast } from 'svelte-sonner'; + import { config } from '$lib/stores'; + import { getBackendConfig } from '$lib/apis'; import Database from './Settings/Database.svelte'; import General from './Settings/General.svelte'; - import Users from './Settings/Users.svelte'; - import Pipelines from './Settings/Pipelines.svelte'; import Audio from './Settings/Audio.svelte'; import Images from './Settings/Images.svelte'; @@ -15,8 +15,7 @@ import Connections from './Settings/Connections.svelte'; import Documents from './Settings/Documents.svelte'; import WebSearch from './Settings/WebSearch.svelte'; - import { config } from '$lib/stores'; - import { getBackendConfig } from '$lib/apis'; + import ChartBar from '../icons/ChartBar.svelte'; import DocumentChartBar from '../icons/DocumentChartBar.svelte'; import Evaluations from './Settings/Evaluations.svelte'; @@ -39,16 +38,16 @@ }); -
+
- -
-
+
{#if selectedTab === 'general'} { @@ -361,12 +336,6 @@ await config.set(await getBackendConfig()); }} /> - {:else if selectedTab === 'users'} - { - toast.success($i18n.t('Settings saved successfully!')); - }} - /> {:else if selectedTab === 'connections'} { diff --git a/src/lib/components/admin/Settings/Audio.svelte b/src/lib/components/admin/Settings/Audio.svelte index 9042187e209d5afc6c49d5bcdbc1b4baa9936b3d..d817c83c40f4aca2e90cd8c0258c9809ac970023 100644 --- a/src/lib/components/admin/Settings/Audio.svelte +++ b/src/lib/components/admin/Settings/Audio.svelte @@ -181,7 +181,7 @@
+ @@ -333,7 +334,7 @@
+ {:else if TTS_ENGINE === 'transformers'} +
+
{$i18n.t('TTS Model')}
+
+
+ + + + +
+
+
+ {$i18n.t(`Open WebUI uses SpeechT5 and CMU Arctic speaker embeddings.`)} + + To learn more about SpeechT5, + + + {$i18n.t(`click here`, { + name: 'SpeechT5' + })}. + + To see the available CMU Arctic speaker embeddings, + + {$i18n.t(`click here`)}. + +
+
{:else if TTS_ENGINE === 'openai'}
diff --git a/src/lib/components/admin/Settings/Connections.svelte b/src/lib/components/admin/Settings/Connections.svelte index 9a690384768443ed128d3b2936d27e2072d5ec04..c2468e5fbab848b1255b8d93df7ca18f9fa81301 100644 --- a/src/lib/components/admin/Settings/Connections.svelte +++ b/src/lib/components/admin/Settings/Connections.svelte @@ -1,31 +1,23 @@ + + + +
{ updateOpenAIHandler(); - updateOllamaUrlsHandler(); + updateOllamaHandler(); dispatch('save'); }} > -
+
{#if ENABLE_OPENAI_API !== null && ENABLE_OLLAMA_API !== null} -
+
{$i18n.t('OpenAI API')}
-
- { - updateOpenAIConfig(localStorage.token, ENABLE_OPENAI_API); - }} - /> +
+
+ { + updateOpenAIHandler(); + }} + /> +
{#if ENABLE_OPENAI_API} -
- {#each OPENAI_API_BASE_URLS as url, idx} -
-
- - - {#if pipelineUrls[url]} -
- - - - - - - -
- {/if} -
- - + +
+
+
{$i18n.t('Manage OpenAI API Connections')}
+ + + + +
+ +
+ {#each OPENAI_API_BASE_URLS as url, idx} + { + updateOpenAIHandler(); + }} + onDelete={() => { + OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS.filter( + (url, urlIdx) => idx !== urlIdx + ); + OPENAI_API_KEYS = OPENAI_API_KEYS.filter((key, keyIdx) => idx !== keyIdx); + }} /> -
- {#if idx === 0} - - {:else} - - {/if} -
- -
- - - -
-
-
- {$i18n.t('WebUI will make requests to')} - '{url}/models' -
- {/each} + {/each} +
{/if}
-
+
-
-
+
+
{$i18n.t('Ollama API')}
{ - updateOllamaConfig(localStorage.token, ENABLE_OLLAMA_API); - - if (OLLAMA_BASE_URLS.length === 0) { - OLLAMA_BASE_URLS = ['']; - } + updateOllamaHandler(); }} />
+ {#if ENABLE_OLLAMA_API} -
-
- {#each OLLAMA_BASE_URLS as url, idx} -
- +
+ +
+
+
{$i18n.t('Manage Ollama API Connections')}
+ + + + +
-
- {#if idx === 0} - - {:else} - - {/if} -
- -
- - - -
-
- {/each} +
+
+ {#each OLLAMA_BASE_URLS as url, idx} + { + updateOllamaHandler(); + }} + onDelete={() => { + OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter((url, urlIdx) => idx !== urlIdx); + }} + /> + {/each} +
-
-
- {$i18n.t('Trouble accessing Ollama?')} - - {$i18n.t('Click here for help.')} - +
+ {$i18n.t('Trouble accessing Ollama?')} + + {$i18n.t('Click here for help.')} + +
{/if}
diff --git a/src/lib/components/admin/Settings/Connections/AddConnectionModal.svelte b/src/lib/components/admin/Settings/Connections/AddConnectionModal.svelte new file mode 100644 index 0000000000000000000000000000000000000000..39e8fd0348e05858565bf364c3545ff0dafe856d --- /dev/null +++ b/src/lib/components/admin/Settings/Connections/AddConnectionModal.svelte @@ -0,0 +1,365 @@ + + + +
+
+
+ {#if edit} + {$i18n.t('Edit Connection')} + {:else} + {$i18n.t('Add Connection')} + {/if} +
+ +
+ +
+
+ { + e.preventDefault(); + submitHandler(); + }} + > +
+
+
+
{$i18n.t('URL')}
+ +
+ +
+
+ + + + + +
+ + + +
+
+ +
+
+
{$i18n.t('Key')}
+ +
+ +
+
+ +
+
{$i18n.t('Prefix ID')}
+ +
+ + + +
+
+
+ +
+ +
+
+
{$i18n.t('Model IDs')}
+
+ + {#if modelIds.length > 0} +
+ {#each modelIds as modelId, modelIdx} +
+
+ {modelId} +
+
+ +
+
+ {/each} +
+ {:else} +
+ {#if ollama} + {$i18n.t('Leave empty to include all models from "{{URL}}/api/tags" endpoint', { + URL: url + })} + {:else} + {$i18n.t('Leave empty to include all models from "{{URL}}/models" endpoint', { + URL: url + })} + {/if} +
+ {/if} +
+ +
+ +
+ + +
+ +
+
+
+ +
+ {#if edit} + + {/if} + + +
+ +
+
+
+
diff --git a/src/lib/components/admin/Settings/Connections/ManageOllamaModal.svelte b/src/lib/components/admin/Settings/Connections/ManageOllamaModal.svelte new file mode 100644 index 0000000000000000000000000000000000000000..979d42cd2b844e9f1339b4f9f0765abadc94d804 --- /dev/null +++ b/src/lib/components/admin/Settings/Connections/ManageOllamaModal.svelte @@ -0,0 +1,1054 @@ + + + { + deleteModelHandler(); + }} +/> + + +
+
+
+
+ {$i18n.t('Manage Ollama')} +
+ +
+ + + +
+
+ +
+ +
+ {#if !loading} +
+
+
+ {#if updateModelId} +
+ Updating "{updateModelId}" {updateProgress ? `(${updateProgress}%)` : ''} +
+ {/if} + +
+
+ {$i18n.t('Pull a model from Ollama.com')} +
+
+
+ +
+ +
+ +
+ {$i18n.t('To access the available model names for downloading,')} + {$i18n.t('click here.')} +
+ + {#if Object.keys($MODEL_DOWNLOAD_POOL).length > 0} + {#each Object.keys($MODEL_DOWNLOAD_POOL) as model} + {#if 'pullProgress' in $MODEL_DOWNLOAD_POOL[model]} +
+
{model}
+
+
+
+
+ {$MODEL_DOWNLOAD_POOL[model].pullProgress ?? 0}% +
+
+ + + + +
+ {#if 'digest' in $MODEL_DOWNLOAD_POOL[model]} +
+ {$MODEL_DOWNLOAD_POOL[model].digest} +
+ {/if} +
+
+ {/if} + {/each} + {/if} +
+ +
+
{$i18n.t('Delete a model')}
+
+
+ +
+ +
+
+ +
+
{$i18n.t('Create a model')}
+
+
+ + +