diff --git a/.github/workflows/deploy-to-hf-spaces.yml b/.github/workflows/deploy-to-hf-spaces.yml index 3f9aa67f36e2f00b1665d0aba2a3a09107dbac7b..aa8bbcfceec2acdb5a617498b369b8c672bf93a9 100644 --- a/.github/workflows/deploy-to-hf-spaces.yml +++ b/.github/workflows/deploy-to-hf-spaces.yml @@ -44,7 +44,7 @@ jobs: echo "---" >> temp_readme.md cat README.md >> temp_readme.md mv temp_readme.md README.md - + - name: Configure git run: | git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com" diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 41c19d32d94bc8532cf021ea9307325f308b504e..8d9dd7aebc1aaa95f0d85953d82cdefd9b2fa6af 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -15,6 +15,13 @@ jobs: name: Run Cypress Integration Tests runs-on: ubuntu-latest steps: + - name: Maximize build space + uses: AdityaGarg8/remove-unwanted-software@v4.1 + with: + remove-android: 'true' + remove-haskell: 'true' + remove-codeql: 'true' + - name: Checkout Repository uses: actions/checkout@v4 diff --git a/CHANGELOG.md b/CHANGELOG.md index d0b175fbfbb243df8a31751ef914753451255264..1b6bdd98fb3dc7b2d2d190d9be0bfa2770f81f67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,33 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.3.13] - 2024-08-14 + +### Added + +- **🎨 Enhanced Markdown Rendering**: Significant improvements in rendering markdown, ensuring smooth and reliable display of LaTeX and Mermaid charts, enhancing user experience with more robust visual content. +- **🔄 Auto-Install Tools & Functions Python Dependencies**: For 'Tools' and 'Functions', Open WebUI now automatically install extra python requirements specified in the frontmatter, streamlining setup processes and customization. +- **🌀 OAuth Email Claim Customization**: Introduced an 'OAUTH_EMAIL_CLAIM' variable to allow customization of the default "email" claim within OAuth configurations, providing greater flexibility in authentication processes. +- **📶 Websocket Reconnection**: Enhanced reliability with the capability to automatically reconnect when a websocket is closed, ensuring consistent and stable communication. +- **🤳 Haptic Feedback on Support Devices**: Android devices now support haptic feedback for an immersive tactile experience during certain interactions. + +### Fixed + +- **🛠️ ComfyUI Performance Improvement**: Addressed an issue causing FastAPI to stall when ComfyUI image generation was active; now runs in a separate thread to prevent UI unresponsiveness. +- **🔀 Session Handling**: Fixed an issue mandating session_id on client-side to ensure smoother session management and transitions. +- **🖋️ Minor Bug Fixes and Format Corrections**: Various minor fixes including typo corrections, backend formatting improvements, and test amendments enhancing overall system stability and performance. + +### Changed + +- **🚀 Migration to SvelteKit 2**: Upgraded the underlying framework to SvelteKit version 2, offering enhanced speed, better code structure, and improved deployment capabilities. +- **🧹 General Cleanup and Refactoring**: Performed broad cleanup and refactoring across the platform, improving code efficiency and maintaining high standards of code health. +- **🚧 Integration Testing Improvements**: Modified how Cypress integration tests detect chat messages and updated sharing tests for better reliability and accuracy. +- **📁 Standardized '.safetensors' File Extension**: Renamed the '.sft' file extension to '.safetensors' for ComfyUI workflows, standardizing file formats across the platform. + +### Removed + +- **🗑️ Deprecated Frontend Functions**: Removed frontend functions that were migrated to backend to declutter the codebase and reduce redundancy. + ## [0.3.12] - 2024-08-07 ### Added diff --git a/Dockerfile b/Dockerfile index 3be9875b3fe654cdd3570086b76f56610a790ea2..8078bf0eacbf433e43f25634cf0ff4315f6b6aaf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,7 +2,7 @@ # Initialize device type args # use build args in the docker build commmand with --build-arg="BUILDARG=true" ARG USE_CUDA=false -ARG USE_OLLAMA=true +ARG USE_OLLAMA=false # Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default) ARG USE_CUDA_VER=cu121 # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers diff --git a/README.md b/README.md index bdb4820fa1e5d0700cac97d5da7df57c7bbacfcd..b4848608176681076ef845e9e1a854d171252eb9 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,12 @@ --- -title: OpenOllama +title: Open WebUI emoji: 🐳 colorFrom: purple colorTo: gray sdk: docker app_port: 8080 -license: apache-2.0 --- - -# OpenOllama 👋 +# Open WebUI (Formerly Ollama WebUI) 👋   @@ -25,7 +23,7 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature-  -## Key Features of OpenOllama ⭐ +## Key Features of Open WebUI ⭐ - 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience with support for both `:ollama` and `:cuda` tagged images. @@ -210,4 +208,4 @@ If you have any questions, suggestions, or need assistance, please open an issue --- -Created by [Timothy J. Baek](https://github.com/tjbck) - Let's make Open WebUI even more amazing together! 💪 \ No newline at end of file +Created by [Timothy J. Baek](https://github.com/tjbck) - Let's make Open WebUI even more amazing together! 💪 diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 167db77bae81b01688a2c9c7abd363c7e47dee22..20519b59b168f35f8873c8fbaf8b542a16cf24e8 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -15,7 +15,7 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel -from typing import List + import uuid import requests import hashlib @@ -244,7 +244,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): res = r.json() if "error" in res: error_detail = f"External: {res['error']['message']}" - except: + except Exception: error_detail = f"External: {e}" raise HTTPException( @@ -299,7 +299,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): res = r.json() if "error" in res: error_detail = f"External: {res['error']['message']}" - except: + except Exception: error_detail = f"External: {e}" raise HTTPException( @@ -353,7 +353,7 @@ def transcribe( try: model = WhisperModel(**whisper_kwargs) - except: + except Exception: log.warning( "WhisperModel initialization failed, attempting download with local_files_only=False" ) @@ -421,7 +421,7 @@ def transcribe( res = r.json() if "error" in res: error_detail = f"External: {res['error']['message']}" - except: + except Exception: error_detail = f"External: {e}" raise HTTPException( @@ -438,7 +438,7 @@ def transcribe( ) -def get_available_models() -> List[dict]: +def get_available_models() -> list[dict]: if app.state.config.TTS_ENGINE == "openai": return [{"id": "tts-1"}, {"id": "tts-1-hd"}] elif app.state.config.TTS_ENGINE == "elevenlabs": @@ -466,7 +466,7 @@ async def get_models(user=Depends(get_verified_user)): return {"models": get_available_models()} -def get_available_voices() -> List[dict]: +def get_available_voices() -> list[dict]: if app.state.config.TTS_ENGINE == "openai": return [ {"name": "alloy", "id": "alloy"}, diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index 4239f3f457197f105277e974914fbd1ac9ae1d80..d2f5ddd5d6c4de23b0fdb991edc4698320d5c11f 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -94,7 +94,7 @@ app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP def get_automatic1111_api_auth(): - if app.state.config.AUTOMATIC1111_API_AUTH == None: + if app.state.config.AUTOMATIC1111_API_AUTH is None: return "" else: auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8") @@ -145,28 +145,30 @@ async def get_engine_url(user=Depends(get_admin_user)): async def update_engine_url( form_data: EngineUrlUpdateForm, user=Depends(get_admin_user) ): - if form_data.AUTOMATIC1111_BASE_URL == None: + if form_data.AUTOMATIC1111_BASE_URL is None: app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL else: url = form_data.AUTOMATIC1111_BASE_URL.strip("/") try: r = requests.head(url) + r.raise_for_status() app.state.config.AUTOMATIC1111_BASE_URL = url except Exception as e: - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) - if form_data.COMFYUI_BASE_URL == None: + if form_data.COMFYUI_BASE_URL is None: app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL else: url = form_data.COMFYUI_BASE_URL.strip("/") try: r = requests.head(url) + r.raise_for_status() app.state.config.COMFYUI_BASE_URL = url except Exception as e: - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) - if form_data.AUTOMATIC1111_API_AUTH == None: + if form_data.AUTOMATIC1111_API_AUTH is None: app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH else: app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH @@ -514,7 +516,7 @@ async def image_generations( data = ImageGenerationPayload(**data) - res = comfyui_generate_image( + res = await comfyui_generate_image( app.state.config.MODEL, data, user.id, diff --git a/backend/apps/images/utils/comfyui.py b/backend/apps/images/utils/comfyui.py index 6c37f0c49771b3f15c2f37532fc9f537dc9e10e1..f11dca57c5fd6ae623ec282492dfb0254ac062e8 100644 --- a/backend/apps/images/utils/comfyui.py +++ b/backend/apps/images/utils/comfyui.py @@ -1,5 +1,5 @@ +import asyncio import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) -import uuid import json import urllib.request import urllib.parse @@ -170,7 +170,7 @@ FLUX_DEFAULT_PROMPT = """ }, "10": { "inputs": { - "vae_name": "ae.sft" + "vae_name": "ae.safetensors" }, "class_type": "VAELoader" }, @@ -184,7 +184,7 @@ FLUX_DEFAULT_PROMPT = """ }, "12": { "inputs": { - "unet_name": "flux1-dev.sft", + "unet_name": "flux1-dev.safetensors", "weight_dtype": "default" }, "class_type": "UNETLoader" @@ -328,7 +328,7 @@ class ImageGenerationPayload(BaseModel): flux_fp8_clip: Optional[bool] = None -def comfyui_generate_image( +async def comfyui_generate_image( model: str, payload: ImageGenerationPayload, client_id, base_url ): ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://") @@ -397,7 +397,9 @@ def comfyui_generate_image( return None try: - images = get_images(ws, comfyui_prompt, client_id, base_url) + images = await asyncio.to_thread( + get_images, ws, comfyui_prompt, client_id, base_url + ) except Exception as e: log.exception(f"Error while receiving images: {e}") images = None diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 442d99ff26f1015d1b469da8a6ce1d13d01b5b13..03a8e198ee81bfa1b689708e373a883d893b4171 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -1,47 +1,36 @@ from fastapi import ( FastAPI, Request, - Response, HTTPException, Depends, - status, UploadFile, File, - BackgroundTasks, ) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse -from fastapi.concurrency import run_in_threadpool from pydantic import BaseModel, ConfigDict import os import re -import copy import random import requests import json -import uuid import aiohttp import asyncio import logging import time from urllib.parse import urlparse -from typing import Optional, List, Union +from typing import Optional, Union from starlette.background import BackgroundTask from apps.webui.models.models import Models -from apps.webui.models.users import Users from constants import ERROR_MESSAGES from utils.utils import ( - decode_token, - get_current_user, get_verified_user, get_admin_user, ) -from utils.task import prompt_template - from config import ( SRC_LOG_LEVELS, @@ -53,7 +42,12 @@ from config import ( UPLOAD_DIR, AppConfig, ) -from utils.misc import calculate_sha256, add_or_update_system_message +from utils.misc import ( + calculate_sha256, + apply_model_params_to_body_ollama, + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, +) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) @@ -120,7 +114,7 @@ async def get_ollama_api_urls(user=Depends(get_admin_user)): class UrlUpdateForm(BaseModel): - urls: List[str] + urls: list[str] @app.post("/urls/update") @@ -183,7 +177,7 @@ async def post_streaming_url(url: str, payload: str, stream: bool = True): res = await r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -238,7 +232,7 @@ async def get_all_models(): async def get_ollama_tags( url_idx: Optional[int] = None, user=Depends(get_verified_user) ): - if url_idx == None: + if url_idx is None: models = await get_all_models() if app.state.config.ENABLE_MODEL_FILTER: @@ -269,7 +263,7 @@ async def get_ollama_tags( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -282,8 +276,7 @@ async def get_ollama_tags( @app.get("/api/version/{url_idx}") async def get_ollama_versions(url_idx: Optional[int] = None): if app.state.config.ENABLE_OLLAMA_API: - if url_idx == None: - + if url_idx is None: # returns lowest version tasks = [ fetch_url(f"{url}/api/version") @@ -323,7 +316,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -346,8 +339,6 @@ async def pull_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - r = None - # Admin should be able to pull models from any source payload = {**form_data.model_dump(exclude_none=True), "insecure": True} @@ -367,7 +358,7 @@ async def push_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - if url_idx == None: + if url_idx is None: if form_data.name in app.state.MODELS: url_idx = app.state.MODELS[form_data.name]["urls"][0] else: @@ -417,7 +408,7 @@ async def copy_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - if url_idx == None: + if url_idx is None: if form_data.source in app.state.MODELS: url_idx = app.state.MODELS[form_data.source]["urls"][0] else: @@ -428,13 +419,13 @@ async def copy_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + r = requests.request( + method="POST", + url=f"{url}/api/copy", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) try: - r = requests.request( - method="POST", - url=f"{url}/api/copy", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) r.raise_for_status() log.debug(f"r.text: {r.text}") @@ -448,7 +439,7 @@ async def copy_model( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -464,7 +455,7 @@ async def delete_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - if url_idx == None: + if url_idx is None: if form_data.name in app.state.MODELS: url_idx = app.state.MODELS[form_data.name]["urls"][0] else: @@ -476,12 +467,12 @@ async def delete_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + r = requests.request( + method="DELETE", + url=f"{url}/api/delete", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) try: - r = requests.request( - method="DELETE", - url=f"{url}/api/delete", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) r.raise_for_status() log.debug(f"r.text: {r.text}") @@ -495,7 +486,7 @@ async def delete_model( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -516,12 +507,12 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + r = requests.request( + method="POST", + url=f"{url}/api/show", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) try: - r = requests.request( - method="POST", - url=f"{url}/api/show", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) r.raise_for_status() return r.json() @@ -533,7 +524,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -556,7 +547,7 @@ async def generate_embeddings( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - if url_idx == None: + if url_idx is None: model = form_data.model if ":" not in model: @@ -573,12 +564,12 @@ async def generate_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + r = requests.request( + method="POST", + url=f"{url}/api/embeddings", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) try: - r = requests.request( - method="POST", - url=f"{url}/api/embeddings", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) r.raise_for_status() return r.json() @@ -590,7 +581,7 @@ async def generate_embeddings( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -603,10 +594,9 @@ def generate_ollama_embeddings( form_data: GenerateEmbeddingsForm, url_idx: Optional[int] = None, ): - log.info(f"generate_ollama_embeddings {form_data}") - if url_idx == None: + if url_idx is None: model = form_data.model if ":" not in model: @@ -623,12 +613,12 @@ def generate_ollama_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + r = requests.request( + method="POST", + url=f"{url}/api/embeddings", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) try: - r = requests.request( - method="POST", - url=f"{url}/api/embeddings", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) r.raise_for_status() data = r.json() @@ -638,7 +628,7 @@ def generate_ollama_embeddings( if "embedding" in data: return data["embedding"] else: - raise "Something went wrong :/" + raise Exception("Something went wrong :/") except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" @@ -647,16 +637,16 @@ def generate_ollama_embeddings( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" - raise error_detail + raise Exception(error_detail) class GenerateCompletionForm(BaseModel): model: str prompt: str - images: Optional[List[str]] = None + images: Optional[list[str]] = None format: Optional[str] = None options: Optional[dict] = None system: Optional[str] = None @@ -674,8 +664,7 @@ async def generate_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - - if url_idx == None: + if url_idx is None: model = form_data.model if ":" not in model: @@ -700,12 +689,12 @@ async def generate_completion( class ChatMessage(BaseModel): role: str content: str - images: Optional[List[str]] = None + images: Optional[list[str]] = None class GenerateChatCompletionForm(BaseModel): model: str - messages: List[ChatMessage] + messages: list[ChatMessage] format: Optional[str] = None options: Optional[dict] = None template: Optional[str] = None @@ -713,6 +702,18 @@ class GenerateChatCompletionForm(BaseModel): keep_alive: Optional[Union[int, str]] = None +def get_ollama_url(url_idx: Optional[int], model: str): + if url_idx is None: + if model not in app.state.MODELS: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), + ) + url_idx = random.choice(app.state.MODELS[model]["urls"]) + url = app.state.config.OLLAMA_BASE_URLS[url_idx] + return url + + @app.post("/api/chat") @app.post("/api/chat/{url_idx}") async def generate_chat_completion( @@ -720,12 +721,7 @@ async def generate_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - - log.debug( - "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( - form_data.model_dump_json(exclude_none=True).encode() - ) - ) + log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=") payload = { **form_data.model_dump(exclude_none=True, exclude=["metadata"]), @@ -740,185 +736,21 @@ async def generate_chat_completion( if model_info.base_model_id: payload["model"] = model_info.base_model_id - model_info.params = model_info.params.model_dump() + params = model_info.params.model_dump() - if model_info.params: + if params: if payload.get("options") is None: payload["options"] = {} - if ( - model_info.params.get("mirostat", None) - and payload["options"].get("mirostat") is None - ): - payload["options"]["mirostat"] = model_info.params.get("mirostat", None) - - if ( - model_info.params.get("mirostat_eta", None) - and payload["options"].get("mirostat_eta") is None - ): - payload["options"]["mirostat_eta"] = model_info.params.get( - "mirostat_eta", None - ) - - if ( - model_info.params.get("mirostat_tau", None) - and payload["options"].get("mirostat_tau") is None - ): - payload["options"]["mirostat_tau"] = model_info.params.get( - "mirostat_tau", None - ) - - if ( - model_info.params.get("num_ctx", None) - and payload["options"].get("num_ctx") is None - ): - payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None) - - if ( - model_info.params.get("num_batch", None) - and payload["options"].get("num_batch") is None - ): - payload["options"]["num_batch"] = model_info.params.get( - "num_batch", None - ) - - if ( - model_info.params.get("num_keep", None) - and payload["options"].get("num_keep") is None - ): - payload["options"]["num_keep"] = model_info.params.get("num_keep", None) - - if ( - model_info.params.get("repeat_last_n", None) - and payload["options"].get("repeat_last_n") is None - ): - payload["options"]["repeat_last_n"] = model_info.params.get( - "repeat_last_n", None - ) - - if ( - model_info.params.get("frequency_penalty", None) - and payload["options"].get("frequency_penalty") is None - ): - payload["options"]["repeat_penalty"] = model_info.params.get( - "frequency_penalty", None - ) - - if ( - model_info.params.get("temperature", None) is not None - and payload["options"].get("temperature") is None - ): - payload["options"]["temperature"] = model_info.params.get( - "temperature", None - ) - - if ( - model_info.params.get("seed", None) is not None - and payload["options"].get("seed") is None - ): - payload["options"]["seed"] = model_info.params.get("seed", None) - - if ( - model_info.params.get("stop", None) - and payload["options"].get("stop") is None - ): - payload["options"]["stop"] = ( - [ - bytes(stop, "utf-8").decode("unicode_escape") - for stop in model_info.params["stop"] - ] - if model_info.params.get("stop", None) - else None - ) - - if ( - model_info.params.get("tfs_z", None) - and payload["options"].get("tfs_z") is None - ): - payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None) - - if ( - model_info.params.get("max_tokens", None) - and payload["options"].get("max_tokens") is None - ): - payload["options"]["num_predict"] = model_info.params.get( - "max_tokens", None - ) - - if ( - model_info.params.get("top_k", None) - and payload["options"].get("top_k") is None - ): - payload["options"]["top_k"] = model_info.params.get("top_k", None) - - if ( - model_info.params.get("top_p", None) - and payload["options"].get("top_p") is None - ): - payload["options"]["top_p"] = model_info.params.get("top_p", None) - - if ( - model_info.params.get("min_p", None) - and payload["options"].get("min_p") is None - ): - payload["options"]["min_p"] = model_info.params.get("min_p", None) - - if ( - model_info.params.get("use_mmap", None) - and payload["options"].get("use_mmap") is None - ): - payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None) - - if ( - model_info.params.get("use_mlock", None) - and payload["options"].get("use_mlock") is None - ): - payload["options"]["use_mlock"] = model_info.params.get( - "use_mlock", None - ) - - if ( - model_info.params.get("num_thread", None) - and payload["options"].get("num_thread") is None - ): - payload["options"]["num_thread"] = model_info.params.get( - "num_thread", None - ) - - system = model_info.params.get("system", None) - if system: - system = prompt_template( - system, - **( - { - "user_name": user.name, - "user_location": ( - user.info.get("location") if user.info else None - ), - } - if user - else {} - ), + payload["options"] = apply_model_params_to_body_ollama( + params, payload["options"] ) + payload = apply_model_system_prompt_to_body(params, payload, user) - if payload.get("messages"): - payload["messages"] = add_or_update_system_message( - system, payload["messages"] - ) - - if url_idx == None: - if ":" not in payload["model"]: - payload["model"] = f"{payload['model']}:latest" - - if payload["model"] in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") log.debug(payload) @@ -940,7 +772,7 @@ class OpenAIChatMessage(BaseModel): class OpenAIChatCompletionForm(BaseModel): model: str - messages: List[OpenAIChatMessage] + messages: list[OpenAIChatMessage] model_config = ConfigDict(extra="allow") @@ -952,83 +784,28 @@ async def generate_openai_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - form_data = OpenAIChatCompletionForm(**form_data) - payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])} - + completion_form = OpenAIChatCompletionForm(**form_data) + payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])} if "metadata" in payload: del payload["metadata"] - model_id = form_data.model + model_id = completion_form.model model_info = Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id - model_info.params = model_info.params.model_dump() + params = model_info.params.model_dump() - if model_info.params: - payload["temperature"] = model_info.params.get("temperature", None) - payload["top_p"] = model_info.params.get("top_p", None) - payload["max_tokens"] = model_info.params.get("max_tokens", None) - payload["frequency_penalty"] = model_info.params.get( - "frequency_penalty", None - ) - payload["seed"] = model_info.params.get("seed", None) - payload["stop"] = ( - [ - bytes(stop, "utf-8").decode("unicode_escape") - for stop in model_info.params["stop"] - ] - if model_info.params.get("stop", None) - else None - ) + if params: + payload = apply_model_params_to_body_openai(params, payload) + payload = apply_model_system_prompt_to_body(params, payload, user) - system = model_info.params.get("system", None) - - if system: - system = prompt_template( - system, - **( - { - "user_name": user.name, - "user_location": ( - user.info.get("location") if user.info else None - ), - } - if user - else {} - ), - ) - # Check if the payload already has a system message - # If not, add a system message to the payload - if payload.get("messages"): - for message in payload["messages"]: - if message.get("role") == "system": - message["content"] = system + message["content"] - break - else: - payload["messages"].insert( - 0, - { - "role": "system", - "content": system, - }, - ) + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" - if url_idx == None: - if ":" not in payload["model"]: - payload["model"] = f"{payload['model']}:latest" - - if payload["model"] in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") return await post_streaming_url( @@ -1044,7 +821,7 @@ async def get_openai_models( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - if url_idx == None: + if url_idx is None: models = await get_all_models() if app.state.config.ENABLE_MODEL_FILTER: @@ -1099,7 +876,7 @@ async def get_openai_models( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -1125,7 +902,6 @@ def parse_huggingface_url(hf_url): path_components = parsed_url.path.split("/") # Extract the desired output - user_repo = "/".join(path_components[1:3]) model_file = path_components[-1] return model_file @@ -1190,7 +966,6 @@ async def download_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - allowed_hosts = ["https://huggingface.co/", "https://github.com/"] if not any(form_data.url.startswith(host) for host in allowed_hosts): @@ -1199,7 +974,7 @@ async def download_model( detail="Invalid file_url. Only URLs from allowed hosts are permitted.", ) - if url_idx == None: + if url_idx is None: url_idx = 0 url = app.state.config.OLLAMA_BASE_URLS[url_idx] @@ -1222,7 +997,7 @@ def upload_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - if url_idx == None: + if url_idx is None: url_idx = 0 ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 948ee7c46ab0c7228da30701b1580d9260cd9450..d344c662225eab565481374a4d5095a52520d5c7 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -17,7 +17,10 @@ from utils.utils import ( get_verified_user, get_admin_user, ) -from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body +from utils.misc import ( + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, +) from config import ( SRC_LOG_LEVELS, @@ -30,7 +33,7 @@ from config import ( MODEL_FILTER_LIST, AppConfig, ) -from typing import List, Optional, Literal, overload +from typing import Optional, Literal, overload import hashlib @@ -86,11 +89,11 @@ async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user class UrlsUpdateForm(BaseModel): - urls: List[str] + urls: list[str] class KeysUpdateForm(BaseModel): - keys: List[str] + keys: list[str] @app.get("/urls") @@ -368,7 +371,7 @@ async def generate_chat_completion( payload["model"] = model_info.base_model_id params = model_info.params.model_dump() - payload = apply_model_params_to_body(params, payload) + payload = apply_model_params_to_body_openai(params, payload) payload = apply_model_system_prompt_to_body(params, payload, user) model = app.state.MODELS[payload.get("model")] diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index dc6b8830efec53d1d153ecf165852a1c3820b85a..f9788556bc0682e5527f36a3ab5060108f2a2139 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -13,7 +13,7 @@ import os, shutil, logging, re from datetime import datetime from pathlib import Path -from typing import List, Union, Sequence, Iterator, Any +from typing import Union, Sequence, Iterator, Any from chromadb.utils.batch_utils import create_batches from langchain_core.documents import Document @@ -376,7 +376,7 @@ async def update_reranking_config( try: app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model - update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True + update_reranking_model(app.state.config.RAG_RERANKING_MODEL, True) return { "status": True, @@ -439,7 +439,7 @@ class ChunkParamUpdateForm(BaseModel): class YoutubeLoaderConfig(BaseModel): - language: List[str] + language: list[str] translation: Optional[str] = None @@ -642,7 +642,7 @@ def query_doc_handler( class QueryCollectionsForm(BaseModel): - collection_names: List[str] + collection_names: list[str] query: str k: Optional[int] = None r: Optional[float] = None @@ -1021,7 +1021,7 @@ class TikaLoader: self.file_path = file_path self.mime_type = mime_type - def load(self) -> List[Document]: + def load(self) -> list[Document]: with open(self.file_path, "rb") as f: data = f.read() @@ -1185,7 +1185,7 @@ def store_doc( f.close() f = open(file_path, "rb") - if collection_name == None: + if collection_name is None: collection_name = calculate_sha256(f)[:63] f.close() @@ -1238,7 +1238,7 @@ def process_doc( f = open(file_path, "rb") collection_name = form_data.collection_name - if collection_name == None: + if collection_name is None: collection_name = calculate_sha256(f)[:63] f.close() @@ -1296,7 +1296,7 @@ def store_text( ): collection_name = form_data.collection_name - if collection_name == None: + if collection_name is None: collection_name = calculate_sha256_string(form_data.content) result = store_text_in_vector_db( @@ -1339,7 +1339,7 @@ def scan_docs_dir(user=Depends(get_admin_user)): sanitized_filename = sanitize_filename(filename) doc = Documents.get_doc_by_name(sanitized_filename) - if doc == None: + if doc is None: doc = Documents.insert_new_doc( user.id, DocumentForm( diff --git a/backend/apps/rag/search/brave.py b/backend/apps/rag/search/brave.py index 76ad1fb4731d2e572027cec32f65267a3f9f3675..681caa97612681053dd15abd82f85596734311aa 100644 --- a/backend/apps/rag/search/brave.py +++ b/backend/apps/rag/search/brave.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional +from typing import Optional import requests from apps.rag.search.main import SearchResult, get_filtered_results @@ -10,7 +10,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) def search_brave( - api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None + api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None ) -> list[SearchResult]: """Search using Brave's Search API and return the results as a list of SearchResult objects. diff --git a/backend/apps/rag/search/duckduckgo.py b/backend/apps/rag/search/duckduckgo.py index f0cc2a71035d5279d9c184f2c20787d41b16de52..e994ef47a9efa3fe3132665920dc8c34e733a73a 100644 --- a/backend/apps/rag/search/duckduckgo.py +++ b/backend/apps/rag/search/duckduckgo.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional +from typing import Optional from apps.rag.search.main import SearchResult, get_filtered_results from duckduckgo_search import DDGS from config import SRC_LOG_LEVELS @@ -9,7 +9,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) def search_duckduckgo( - query: str, count: int, filter_list: Optional[List[str]] = None + query: str, count: int, filter_list: Optional[list[str]] = None ) -> list[SearchResult]: """ Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects. @@ -18,7 +18,7 @@ def search_duckduckgo( count (int): The number of results to return Returns: - List[SearchResult]: A list of search results + list[SearchResult]: A list of search results """ # Use the DDGS context manager to create a DDGS object with DDGS() as ddgs: diff --git a/backend/apps/rag/search/google_pse.py b/backend/apps/rag/search/google_pse.py index 0c78512e74ef82d391dad426c42e55e3a25a148c..7fedb3dad9759ec7243bfed55f8e9cfad07a7296 100644 --- a/backend/apps/rag/search/google_pse.py +++ b/backend/apps/rag/search/google_pse.py @@ -1,6 +1,6 @@ import json import logging -from typing import List, Optional +from typing import Optional import requests from apps.rag.search.main import SearchResult, get_filtered_results @@ -15,7 +15,7 @@ def search_google_pse( search_engine_id: str, query: str, count: int, - filter_list: Optional[List[str]] = None, + filter_list: Optional[list[str]] = None, ) -> list[SearchResult]: """Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects. diff --git a/backend/apps/rag/search/jina_search.py b/backend/apps/rag/search/jina_search.py index 65f9ad68fe7984fa0568ac6b32367a1cf8632c26..8d1c582a1e79c713ddff4b0246f1426a5ab6127c 100644 --- a/backend/apps/rag/search/jina_search.py +++ b/backend/apps/rag/search/jina_search.py @@ -17,7 +17,7 @@ def search_jina(query: str, count: int) -> list[SearchResult]: count (int): The number of results to return Returns: - List[SearchResult]: A list of search results + list[SearchResult]: A list of search results """ jina_search_endpoint = "https://s.jina.ai/" headers = { diff --git a/backend/apps/rag/search/searxng.py b/backend/apps/rag/search/searxng.py index 6e545e994e8659c63a82cf0fce48b4724e042d84..94bed2857bb9df38c2ff65989e50581719a7e3c7 100644 --- a/backend/apps/rag/search/searxng.py +++ b/backend/apps/rag/search/searxng.py @@ -1,7 +1,7 @@ import logging import requests -from typing import List, Optional +from typing import Optional from apps.rag.search.main import SearchResult, get_filtered_results from config import SRC_LOG_LEVELS @@ -14,9 +14,9 @@ def search_searxng( query_url: str, query: str, count: int, - filter_list: Optional[List[str]] = None, + filter_list: Optional[list[str]] = None, **kwargs, -) -> List[SearchResult]: +) -> list[SearchResult]: """ Search a SearXNG instance for a given query and return the results as a list of SearchResult objects. @@ -31,10 +31,10 @@ def search_searxng( language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string. safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate). time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''. - categories: (Optional[List[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided. + categories: (Optional[list[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided. Returns: - List[SearchResult]: A list of SearchResults sorted by relevance score in descending order. + list[SearchResult]: A list of SearchResults sorted by relevance score in descending order. Raise: requests.exceptions.RequestException: If a request error occurs during the search process. diff --git a/backend/apps/rag/search/serper.py b/backend/apps/rag/search/serper.py index b278a4df15a1322a03dad373518043cf2c003eab..e71fbb6283f18e91006010429c4361f2de11a175 100644 --- a/backend/apps/rag/search/serper.py +++ b/backend/apps/rag/search/serper.py @@ -1,6 +1,6 @@ import json import logging -from typing import List, Optional +from typing import Optional import requests from apps.rag.search.main import SearchResult, get_filtered_results @@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) def search_serper( - api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None + api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None ) -> list[SearchResult]: """Search using serper.dev's API and return the results as a list of SearchResult objects. diff --git a/backend/apps/rag/search/serply.py b/backend/apps/rag/search/serply.py index 24b249b739425b9ad941c336f3daa76f46272297..28c15fd78854d7b440a116060afb1382b1b2d492 100644 --- a/backend/apps/rag/search/serply.py +++ b/backend/apps/rag/search/serply.py @@ -1,6 +1,6 @@ import json import logging -from typing import List, Optional +from typing import Optional import requests from urllib.parse import urlencode @@ -19,7 +19,7 @@ def search_serply( limit: int = 10, device_type: str = "desktop", proxy_location: str = "US", - filter_list: Optional[List[str]] = None, + filter_list: Optional[list[str]] = None, ) -> list[SearchResult]: """Search using serper.dev's API and return the results as a list of SearchResult objects. diff --git a/backend/apps/rag/search/serpstack.py b/backend/apps/rag/search/serpstack.py index 64b0f117d906414714d56470271ca13a0a9d4bb9..5c19bd1342043bfa0420b53d3a5fde37f1ef6f10 100644 --- a/backend/apps/rag/search/serpstack.py +++ b/backend/apps/rag/search/serpstack.py @@ -1,6 +1,6 @@ import json import logging -from typing import List, Optional +from typing import Optional import requests from apps.rag.search.main import SearchResult, get_filtered_results @@ -14,7 +14,7 @@ def search_serpstack( api_key: str, query: str, count: int, - filter_list: Optional[List[str]] = None, + filter_list: Optional[list[str]] = None, https_enabled: bool = True, ) -> list[SearchResult]: """Search using serpstack.com's and return the results as a list of SearchResult objects. diff --git a/backend/apps/rag/search/tavily.py b/backend/apps/rag/search/tavily.py index b15d6ef9d5b55aa2eee6f77212e6561f1fbaf4fa..ed4ab6e08407d4e0ddfe3ee3c06b11f6c05500d3 100644 --- a/backend/apps/rag/search/tavily.py +++ b/backend/apps/rag/search/tavily.py @@ -17,7 +17,7 @@ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]: query (str): The query to search for Returns: - List[SearchResult]: A list of search results + list[SearchResult]: A list of search results """ url = "https://api.tavily.com/search" data = {"query": query, "api_key": api_key} diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index fde89b0697414a54a5fdabf4cdfb853c12e11c70..034f71292c7ac4994bc95ec08a644f4d2c194e88 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -2,7 +2,7 @@ import os import logging import requests -from typing import List, Union +from typing import Union from apps.ollama.main import ( generate_ollama_embeddings, @@ -142,7 +142,7 @@ def merge_and_sort_query_results(query_results, k, reverse=False): def query_collection( - collection_names: List[str], + collection_names: list[str], query: str, embedding_function, k: int, @@ -157,13 +157,13 @@ def query_collection( embedding_function=embedding_function, ) results.append(result) - except: + except Exception: pass return merge_and_sort_query_results(results, k=k) def query_collection_with_hybrid_search( - collection_names: List[str], + collection_names: list[str], query: str, embedding_function, k: int, @@ -182,7 +182,7 @@ def query_collection_with_hybrid_search( r=r, ) results.append(result) - except: + except Exception: pass return merge_and_sort_query_results(results, k=k, reverse=True) @@ -411,7 +411,7 @@ class ChromaRetriever(BaseRetriever): query: str, *, run_manager: CallbackManagerForRetrieverRun, - ) -> List[Document]: + ) -> list[Document]: query_embeddings = self.embedding_function(query) results = self.collection.query( diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index a0b9f50085fd39a5743ee1dcf0ddcbcbe9b57b47..dddf3fbb2a127db8ef36e6060dfa3adf1fb0b4cf 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -22,7 +22,7 @@ from apps.webui.utils import load_function_module_by_id from utils.misc import ( openai_chat_chunk_message_template, openai_chat_completion_message_template, - apply_model_params_to_body, + apply_model_params_to_body_openai, apply_model_system_prompt_to_body, ) @@ -46,6 +46,7 @@ from config import ( AppConfig, OAUTH_USERNAME_CLAIM, OAUTH_PICTURE_CLAIM, + OAUTH_EMAIL_CLAIM, ) from apps.socket.main import get_event_call, get_event_emitter @@ -84,6 +85,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM +app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM app.state.MODELS = {} app.state.TOOLS = {} @@ -289,7 +291,7 @@ async def generate_function_chat_completion(form_data, user): form_data["model"] = model_info.base_model_id params = model_info.params.model_dump() - form_data = apply_model_params_to_body(params, form_data) + form_data = apply_model_params_to_body_openai(params, form_data) form_data = apply_model_system_prompt_to_body(params, form_data, user) pipe_id = get_pipe_id(form_data) diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index bcea4a367ae2d328fd69e770363fcc400083c866..3cbe8c887579047d992d17a092473cba492fcc58 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -140,7 +140,7 @@ class AuthsTable: return None else: return None - except: + except Exception: return None def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: @@ -152,7 +152,7 @@ class AuthsTable: try: user = Users.get_user_by_api_key(api_key) return user if user else None - except: + except Exception: return False def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: @@ -163,7 +163,7 @@ class AuthsTable: if auth: user = Users.get_user_by_id(auth.id) return user - except: + except Exception: return None def update_user_password_by_id(self, id: str, new_password: str) -> bool: @@ -174,7 +174,7 @@ class AuthsTable: ) db.commit() return True if result == 1 else False - except: + except Exception: return False def update_email_by_id(self, id: str, email: str) -> bool: @@ -183,7 +183,7 @@ class AuthsTable: result = db.query(Auth).filter_by(id=id).update({"email": email}) db.commit() return True if result == 1 else False - except: + except Exception: return False def delete_auth_by_id(self, id: str) -> bool: @@ -200,7 +200,7 @@ class AuthsTable: return True else: return False - except: + except Exception: return False diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index d504b18c3fb33416eb36fa3704f6126c9499561e..be77595ecad394ff9e7c44d081d25093affd2bbc 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, ConfigDict -from typing import List, Union, Optional +from typing import Union, Optional import json import uuid @@ -164,7 +164,7 @@ class ChatTable: db.refresh(chat) return self.get_chat_by_id(chat.share_id) - except: + except Exception: return None def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: @@ -175,7 +175,7 @@ class ChatTable: db.commit() return True - except: + except Exception: return False def update_chat_share_id_by_id( @@ -189,7 +189,7 @@ class ChatTable: db.commit() db.refresh(chat) return ChatModel.model_validate(chat) - except: + except Exception: return None def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: @@ -201,7 +201,7 @@ class ChatTable: db.commit() db.refresh(chat) return ChatModel.model_validate(chat) - except: + except Exception: return None def archive_all_chats_by_user_id(self, user_id: str) -> bool: @@ -210,12 +210,12 @@ class ChatTable: db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) db.commit() return True - except: + except Exception: return False def get_archived_chat_list_by_user_id( self, user_id: str, skip: int = 0, limit: int = 50 - ) -> List[ChatModel]: + ) -> list[ChatModel]: with get_db() as db: all_chats = ( @@ -233,7 +233,7 @@ class ChatTable: include_archived: bool = False, skip: int = 0, limit: int = 50, - ) -> List[ChatModel]: + ) -> list[ChatModel]: with get_db() as db: query = db.query(Chat).filter_by(user_id=user_id) if not include_archived: @@ -251,7 +251,7 @@ class ChatTable: include_archived: bool = False, skip: int = 0, limit: int = -1, - ) -> List[ChatTitleIdResponse]: + ) -> list[ChatTitleIdResponse]: with get_db() as db: query = db.query(Chat).filter_by(user_id=user_id) if not include_archived: @@ -279,8 +279,8 @@ class ChatTable: ] def get_chat_list_by_chat_ids( - self, chat_ids: List[str], skip: int = 0, limit: int = 50 - ) -> List[ChatModel]: + self, chat_ids: list[str], skip: int = 0, limit: int = 50 + ) -> list[ChatModel]: with get_db() as db: all_chats = ( db.query(Chat) @@ -297,7 +297,7 @@ class ChatTable: chat = db.get(Chat, id) return ChatModel.model_validate(chat) - except: + except Exception: return None def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: @@ -319,10 +319,10 @@ class ChatTable: chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() return ChatModel.model_validate(chat) - except: + except Exception: return None - def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]: + def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]: with get_db() as db: all_chats = ( @@ -332,7 +332,7 @@ class ChatTable: ) return [ChatModel.model_validate(chat) for chat in all_chats] - def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]: + def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]: with get_db() as db: all_chats = ( @@ -342,7 +342,7 @@ class ChatTable: ) return [ChatModel.model_validate(chat) for chat in all_chats] - def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]: + def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]: with get_db() as db: all_chats = ( @@ -360,7 +360,7 @@ class ChatTable: db.commit() return True and self.delete_shared_chat_by_chat_id(id) - except: + except Exception: return False def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: @@ -371,7 +371,7 @@ class ChatTable: db.commit() return True and self.delete_shared_chat_by_chat_id(id) - except: + except Exception: return False def delete_chats_by_user_id(self, user_id: str) -> bool: @@ -385,7 +385,7 @@ class ChatTable: db.commit() return True - except: + except Exception: return False def delete_shared_chats_by_user_id(self, user_id: str) -> bool: @@ -400,7 +400,7 @@ class ChatTable: db.commit() return True - except: + except Exception: return False diff --git a/backend/apps/webui/models/documents.py b/backend/apps/webui/models/documents.py index ac8655da9ce4ee2fd50299aaaf1a9a6bdba86b58..4157c2c95f8288ffee05fb13a56595ec44ea6148 100644 --- a/backend/apps/webui/models/documents.py +++ b/backend/apps/webui/models/documents.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, ConfigDict -from typing import List, Optional +from typing import Optional import time import logging @@ -93,7 +93,7 @@ class DocumentsTable: return DocumentModel.model_validate(result) else: return None - except: + except Exception: return None def get_doc_by_name(self, name: str) -> Optional[DocumentModel]: @@ -102,10 +102,10 @@ class DocumentsTable: document = db.query(Document).filter_by(name=name).first() return DocumentModel.model_validate(document) if document else None - except: + except Exception: return None - def get_docs(self) -> List[DocumentModel]: + def get_docs(self) -> list[DocumentModel]: with get_db() as db: return [ @@ -160,7 +160,7 @@ class DocumentsTable: db.query(Document).filter_by(name=name).delete() db.commit() return True - except: + except Exception: return False diff --git a/backend/apps/webui/models/files.py b/backend/apps/webui/models/files.py index 16272f24ad11f9fda4047f2c2721123cdea3a2aa..2de5c33b599dc608d2240c634afb7546dd8bd26d 100644 --- a/backend/apps/webui/models/files.py +++ b/backend/apps/webui/models/files.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, ConfigDict -from typing import List, Union, Optional +from typing import Union, Optional import time import logging @@ -90,10 +90,10 @@ class FilesTable: try: file = db.get(File, id) return FileModel.model_validate(file) - except: + except Exception: return None - def get_files(self) -> List[FileModel]: + def get_files(self) -> list[FileModel]: with get_db() as db: return [FileModel.model_validate(file) for file in db.query(File).all()] @@ -107,7 +107,7 @@ class FilesTable: db.commit() return True - except: + except Exception: return False def delete_all_files(self) -> bool: @@ -119,7 +119,7 @@ class FilesTable: db.commit() return True - except: + except Exception: return False diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index cb73da69449482a35264e688748422323880583b..3afdc1ea9185029b5ea46b9526d0e24e7efc444e 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, ConfigDict -from typing import List, Union, Optional +from typing import Union, Optional import time import logging @@ -122,10 +122,10 @@ class FunctionsTable: function = db.get(Function, id) return FunctionModel.model_validate(function) - except: + except Exception: return None - def get_functions(self, active_only=False) -> List[FunctionModel]: + def get_functions(self, active_only=False) -> list[FunctionModel]: with get_db() as db: if active_only: @@ -141,7 +141,7 @@ class FunctionsTable: def get_functions_by_type( self, type: str, active_only=False - ) -> List[FunctionModel]: + ) -> list[FunctionModel]: with get_db() as db: if active_only: @@ -157,7 +157,7 @@ class FunctionsTable: for function in db.query(Function).filter_by(type=type).all() ] - def get_global_filter_functions(self) -> List[FunctionModel]: + def get_global_filter_functions(self) -> list[FunctionModel]: with get_db() as db: return [ @@ -167,7 +167,7 @@ class FunctionsTable: .all() ] - def get_global_action_functions(self) -> List[FunctionModel]: + def get_global_action_functions(self) -> list[FunctionModel]: with get_db() as db: return [ FunctionModel.model_validate(function) @@ -198,7 +198,7 @@ class FunctionsTable: db.commit() db.refresh(function) return self.get_function_by_id(id) - except: + except Exception: return None def get_user_valves_by_id_and_user_id( @@ -256,7 +256,7 @@ class FunctionsTable: ) db.commit() return self.get_function_by_id(id) - except: + except Exception: return None def deactivate_all_functions(self) -> Optional[bool]: @@ -271,7 +271,7 @@ class FunctionsTable: ) db.commit() return True - except: + except Exception: return None def delete_function_by_id(self, id: str) -> bool: @@ -281,7 +281,7 @@ class FunctionsTable: db.commit() return True - except: + except Exception: return False diff --git a/backend/apps/webui/models/memories.py b/backend/apps/webui/models/memories.py index 02d4b6924986465612fd2d32be0499ffd3ff3f93..41bb11ccf47396436cea64405ccd5e2e4a723b36 100644 --- a/backend/apps/webui/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, ConfigDict -from typing import List, Union, Optional +from typing import Union, Optional from sqlalchemy import Column, String, BigInteger, Text @@ -80,25 +80,25 @@ class MemoriesTable: ) db.commit() return self.get_memory_by_id(id) - except: + except Exception: return None - def get_memories(self) -> List[MemoryModel]: + def get_memories(self) -> list[MemoryModel]: with get_db() as db: try: memories = db.query(Memory).all() return [MemoryModel.model_validate(memory) for memory in memories] - except: + except Exception: return None - def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]: + def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]: with get_db() as db: try: memories = db.query(Memory).filter_by(user_id=user_id).all() return [MemoryModel.model_validate(memory) for memory in memories] - except: + except Exception: return None def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: @@ -107,7 +107,7 @@ class MemoriesTable: try: memory = db.get(Memory, id) return MemoryModel.model_validate(memory) - except: + except Exception: return None def delete_memory_by_id(self, id: str) -> bool: @@ -119,7 +119,7 @@ class MemoriesTable: return True - except: + except Exception: return False def delete_memories_by_user_id(self, user_id: str) -> bool: @@ -130,7 +130,7 @@ class MemoriesTable: db.commit() return True - except: + except Exception: return False def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: @@ -141,7 +141,7 @@ class MemoriesTable: db.commit() return True - except: + except Exception: return False diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 8277d1d0bace4a4aacb30c7c689668ab22035764..616beb2a9bbfde18bb019581dedcbfa86d929de1 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -137,7 +137,7 @@ class ModelsTable: print(e) return None - def get_all_models(self) -> List[ModelModel]: + def get_all_models(self) -> list[ModelModel]: with get_db() as db: return [ModelModel.model_validate(model) for model in db.query(Model).all()] @@ -146,7 +146,7 @@ class ModelsTable: with get_db() as db: model = db.get(Model, id) return ModelModel.model_validate(model) - except: + except Exception: return None def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: @@ -175,7 +175,7 @@ class ModelsTable: db.commit() return True - except: + except Exception: return False diff --git a/backend/apps/webui/models/prompts.py b/backend/apps/webui/models/prompts.py index b8467b63164f537a6e669c5c078a95eb2d162339..942f64a43567bee4003e8c928065b91529824412 100644 --- a/backend/apps/webui/models/prompts.py +++ b/backend/apps/webui/models/prompts.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, ConfigDict -from typing import List, Optional +from typing import Optional import time from sqlalchemy import String, Column, BigInteger, Text @@ -79,10 +79,10 @@ class PromptsTable: prompt = db.query(Prompt).filter_by(command=command).first() return PromptModel.model_validate(prompt) - except: + except Exception: return None - def get_prompts(self) -> List[PromptModel]: + def get_prompts(self) -> list[PromptModel]: with get_db() as db: return [ @@ -101,7 +101,7 @@ class PromptsTable: prompt.timestamp = int(time.time()) db.commit() return PromptModel.model_validate(prompt) - except: + except Exception: return None def delete_prompt_by_command(self, command: str) -> bool: @@ -112,7 +112,7 @@ class PromptsTable: db.commit() return True - except: + except Exception: return False diff --git a/backend/apps/webui/models/tags.py b/backend/apps/webui/models/tags.py index 7285b6fe245fc39099040d0a4f568017eee09797..7ce06cb60b257b18560231aff3eca11814e39425 100644 --- a/backend/apps/webui/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, ConfigDict -from typing import List, Optional +from typing import Optional import json import uuid @@ -69,11 +69,11 @@ class ChatIdTagForm(BaseModel): class TagChatIdsResponse(BaseModel): - chat_ids: List[str] + chat_ids: list[str] class ChatTagsResponse(BaseModel): - tags: List[str] + tags: list[str] class TagTable: @@ -109,7 +109,7 @@ class TagTable: self, user_id: str, form_data: ChatIdTagForm ) -> Optional[ChatIdTagModel]: tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id) - if tag == None: + if tag is None: tag = self.insert_new_tag(form_data.tag_name, user_id) id = str(uuid.uuid4()) @@ -132,10 +132,10 @@ class TagTable: return ChatIdTagModel.model_validate(result) else: return None - except: + except Exception: return None - def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: + def get_tags_by_user_id(self, user_id: str) -> list[TagModel]: with get_db() as db: tag_names = [ chat_id_tag.tag_name @@ -159,7 +159,7 @@ class TagTable: def get_tags_by_chat_id_and_user_id( self, chat_id: str, user_id: str - ) -> List[TagModel]: + ) -> list[TagModel]: with get_db() as db: tag_names = [ @@ -184,7 +184,7 @@ class TagTable: def get_chat_ids_by_tag_name_and_user_id( self, tag_name: str, user_id: str - ) -> List[ChatIdTagModel]: + ) -> list[ChatIdTagModel]: with get_db() as db: return [ diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index 685ce6fcfbd3b721482a4bf62888402fc73a9137..c8c56fb9740098998b4177a5e3e241903ff7daa0 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, ConfigDict -from typing import List, Optional +from typing import Optional import time import logging from sqlalchemy import String, Column, BigInteger, Text @@ -45,7 +45,7 @@ class ToolModel(BaseModel): user_id: str name: str content: str - specs: List[dict] + specs: list[dict] meta: ToolMeta updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -81,7 +81,7 @@ class ToolValves(BaseModel): class ToolsTable: def insert_new_tool( - self, user_id: str, form_data: ToolForm, specs: List[dict] + self, user_id: str, form_data: ToolForm, specs: list[dict] ) -> Optional[ToolModel]: with get_db() as db: @@ -115,10 +115,10 @@ class ToolsTable: tool = db.get(Tool, id) return ToolModel.model_validate(tool) - except: + except Exception: return None - def get_tools(self) -> List[ToolModel]: + def get_tools(self) -> list[ToolModel]: with get_db() as db: return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] @@ -141,7 +141,7 @@ class ToolsTable: ) db.commit() return self.get_tool_by_id(id) - except: + except Exception: return None def get_user_valves_by_id_and_user_id( @@ -196,7 +196,7 @@ class ToolsTable: tool = db.query(Tool).get(id) db.refresh(tool) return ToolModel.model_validate(tool) - except: + except Exception: return None def delete_tool_by_id(self, id: str) -> bool: @@ -206,7 +206,7 @@ class ToolsTable: db.commit() return True - except: + except Exception: return False diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 2f30cda0230292c86ab1a6fb89010ae9d419b610..36dfa4f85573d8cab45ba2cd1ef37402cac92eab 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, ConfigDict, parse_obj_as -from typing import List, Union, Optional +from typing import Union, Optional import time from sqlalchemy import String, Column, BigInteger, Text @@ -125,7 +125,7 @@ class UsersTable: user = db.query(User).filter_by(api_key=api_key).first() return UserModel.model_validate(user) - except: + except Exception: return None def get_user_by_email(self, email: str) -> Optional[UserModel]: @@ -134,7 +134,7 @@ class UsersTable: user = db.query(User).filter_by(email=email).first() return UserModel.model_validate(user) - except: + except Exception: return None def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: @@ -143,10 +143,10 @@ class UsersTable: user = db.query(User).filter_by(oauth_sub=sub).first() return UserModel.model_validate(user) - except: + except Exception: return None - def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: + def get_users(self, skip: int = 0, limit: int = 50) -> list[UserModel]: with get_db() as db: users = ( db.query(User) @@ -164,7 +164,7 @@ class UsersTable: with get_db() as db: user = db.query(User).order_by(User.created_at).first() return UserModel.model_validate(user) - except: + except Exception: return None def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: @@ -174,7 +174,7 @@ class UsersTable: db.commit() user = db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) - except: + except Exception: return None def update_user_profile_image_url_by_id( @@ -189,7 +189,7 @@ class UsersTable: user = db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) - except: + except Exception: return None def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: @@ -203,7 +203,7 @@ class UsersTable: user = db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) - except: + except Exception: return None def update_user_oauth_sub_by_id( @@ -216,7 +216,7 @@ class UsersTable: user = db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) - except: + except Exception: return None def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: @@ -245,7 +245,7 @@ class UsersTable: return True else: return False - except: + except Exception: return False def update_user_api_key_by_id(self, id: str, api_key: str) -> str: @@ -254,7 +254,7 @@ class UsersTable: result = db.query(User).filter_by(id=id).update({"api_key": api_key}) db.commit() return True if result == 1 else False - except: + except Exception: return False def get_user_api_key_by_id(self, id: str) -> Optional[str]: diff --git a/backend/apps/webui/routers/chats.py b/backend/apps/webui/routers/chats.py index 6e89722d354af49d629ddc03662c820076bf7d20..6621e73372bf2e5623366dae4d4a4a028553c670 100644 --- a/backend/apps/webui/routers/chats.py +++ b/backend/apps/webui/routers/chats.py @@ -1,6 +1,6 @@ from fastapi import Depends, Request, HTTPException, status from datetime import datetime, timedelta -from typing import List, Union, Optional +from typing import Union, Optional from utils.utils import get_verified_user, get_admin_user from fastapi import APIRouter from pydantic import BaseModel @@ -40,8 +40,8 @@ router = APIRouter() ############################ -@router.get("/", response_model=List[ChatTitleIdResponse]) -@router.get("/list", response_model=List[ChatTitleIdResponse]) +@router.get("/", response_model=list[ChatTitleIdResponse]) +@router.get("/list", response_model=list[ChatTitleIdResponse]) async def get_session_user_chat_list( user=Depends(get_verified_user), page: Optional[int] = None ): @@ -80,7 +80,7 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user ############################ -@router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse]) +@router.get("/list/user/{user_id}", response_model=list[ChatTitleIdResponse]) async def get_user_chat_list_by_user_id( user_id: str, user=Depends(get_admin_user), @@ -119,7 +119,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): ############################ -@router.get("/all", response_model=List[ChatResponse]) +@router.get("/all", response_model=list[ChatResponse]) async def get_user_chats(user=Depends(get_verified_user)): return [ ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) @@ -132,7 +132,7 @@ async def get_user_chats(user=Depends(get_verified_user)): ############################ -@router.get("/all/archived", response_model=List[ChatResponse]) +@router.get("/all/archived", response_model=list[ChatResponse]) async def get_user_archived_chats(user=Depends(get_verified_user)): return [ ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) @@ -145,7 +145,7 @@ async def get_user_archived_chats(user=Depends(get_verified_user)): ############################ -@router.get("/all/db", response_model=List[ChatResponse]) +@router.get("/all/db", response_model=list[ChatResponse]) async def get_all_user_chats_in_db(user=Depends(get_admin_user)): if not ENABLE_ADMIN_EXPORT: raise HTTPException( @@ -163,7 +163,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)): ############################ -@router.get("/archived", response_model=List[ChatTitleIdResponse]) +@router.get("/archived", response_model=list[ChatTitleIdResponse]) async def get_archived_session_user_chat_list( user=Depends(get_verified_user), skip: int = 0, limit: int = 50 ): @@ -216,7 +216,7 @@ class TagNameForm(BaseModel): limit: Optional[int] = 50 -@router.post("/tags", response_model=List[ChatTitleIdResponse]) +@router.post("/tags", response_model=list[ChatTitleIdResponse]) async def get_user_chat_list_by_tag_name( form_data: TagNameForm, user=Depends(get_verified_user) ): @@ -241,7 +241,7 @@ async def get_user_chat_list_by_tag_name( ############################ -@router.get("/tags/all", response_model=List[TagModel]) +@router.get("/tags/all", response_model=list[TagModel]) async def get_all_tags(user=Depends(get_verified_user)): try: tags = Tags.get_tags_by_user_id(user.id) @@ -417,7 +417,7 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)): ############################ -@router.get("/{id}/tags", response_model=List[TagModel]) +@router.get("/{id}/tags", response_model=list[TagModel]) async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) diff --git a/backend/apps/webui/routers/configs.py b/backend/apps/webui/routers/configs.py index 39e435013541d40c1f058be8b783595d08515bf4..68c6873742e5e7c9b1ab53aff542fba4184cb5fe 100644 --- a/backend/apps/webui/routers/configs.py +++ b/backend/apps/webui/routers/configs.py @@ -1,7 +1,7 @@ from fastapi import Response, Request from fastapi import Depends, FastAPI, HTTPException, status from datetime import datetime, timedelta -from typing import List, Union +from typing import Union from fastapi import APIRouter from pydantic import BaseModel @@ -29,12 +29,12 @@ class SetDefaultModelsForm(BaseModel): class PromptSuggestion(BaseModel): - title: List[str] + title: list[str] content: str class SetDefaultSuggestionsForm(BaseModel): - suggestions: List[PromptSuggestion] + suggestions: list[PromptSuggestion] ############################ @@ -50,7 +50,7 @@ async def set_global_default_models( return request.app.state.config.DEFAULT_MODELS -@router.post("/default/suggestions", response_model=List[PromptSuggestion]) +@router.post("/default/suggestions", response_model=list[PromptSuggestion]) async def set_global_default_suggestions( request: Request, form_data: SetDefaultSuggestionsForm, @@ -67,10 +67,10 @@ async def set_global_default_suggestions( class SetBannersForm(BaseModel): - banners: List[BannerModel] + banners: list[BannerModel] -@router.post("/banners", response_model=List[BannerModel]) +@router.post("/banners", response_model=list[BannerModel]) async def set_banners( request: Request, form_data: SetBannersForm, @@ -81,7 +81,7 @@ async def set_banners( return request.app.state.config.BANNERS -@router.get("/banners", response_model=List[BannerModel]) +@router.get("/banners", response_model=list[BannerModel]) async def get_banners( request: Request, user=Depends(get_verified_user), diff --git a/backend/apps/webui/routers/documents.py b/backend/apps/webui/routers/documents.py index 2299b2fee3d5fd8e9a65c15ab564e05f2c256039..3bb2aa15b520c00565e37536c4f7ce47e4c7fdb8 100644 --- a/backend/apps/webui/routers/documents.py +++ b/backend/apps/webui/routers/documents.py @@ -1,6 +1,6 @@ from fastapi import Depends, FastAPI, HTTPException, status from datetime import datetime, timedelta -from typing import List, Union, Optional +from typing import Union, Optional from fastapi import APIRouter from pydantic import BaseModel @@ -24,7 +24,7 @@ router = APIRouter() ############################ -@router.get("/", response_model=List[DocumentResponse]) +@router.get("/", response_model=list[DocumentResponse]) async def get_documents(user=Depends(get_verified_user)): docs = [ DocumentResponse( @@ -46,7 +46,7 @@ async def get_documents(user=Depends(get_verified_user)): @router.post("/create", response_model=Optional[DocumentResponse]) async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)): doc = Documents.get_doc_by_name(form_data.name) - if doc == None: + if doc is None: doc = Documents.insert_new_doc(user.id, form_data) if doc: @@ -102,7 +102,7 @@ class TagItem(BaseModel): class TagDocumentForm(BaseModel): name: str - tags: List[dict] + tags: list[dict] @router.post("/doc/tags", response_model=Optional[DocumentResponse]) diff --git a/backend/apps/webui/routers/files.py b/backend/apps/webui/routers/files.py index 99fb923a12ccbb034c80979dc28f8fceec323a51..ba571fc71385a37c5993a5d78950548b357d6fde 100644 --- a/backend/apps/webui/routers/files.py +++ b/backend/apps/webui/routers/files.py @@ -11,7 +11,7 @@ from fastapi import ( from datetime import datetime, timedelta -from typing import List, Union, Optional +from typing import Union, Optional from pathlib import Path from fastapi import APIRouter @@ -104,7 +104,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): ############################ -@router.get("/", response_model=List[FileModel]) +@router.get("/", response_model=list[FileModel]) async def list_files(user=Depends(get_verified_user)): files = Files.get_files() return files diff --git a/backend/apps/webui/routers/functions.py b/backend/apps/webui/routers/functions.py index eb5216b202b0858b87cfcda16fc1ed9e4c3cd96c..f40d28264544fa06f9f3f2ec36dc0194abaae647 100644 --- a/backend/apps/webui/routers/functions.py +++ b/backend/apps/webui/routers/functions.py @@ -1,6 +1,6 @@ from fastapi import Depends, FastAPI, HTTPException, status, Request from datetime import datetime, timedelta -from typing import List, Union, Optional +from typing import Union, Optional from fastapi import APIRouter from pydantic import BaseModel @@ -30,7 +30,7 @@ router = APIRouter() ############################ -@router.get("/", response_model=List[FunctionResponse]) +@router.get("/", response_model=list[FunctionResponse]) async def get_functions(user=Depends(get_verified_user)): return Functions.get_functions() @@ -40,7 +40,7 @@ async def get_functions(user=Depends(get_verified_user)): ############################ -@router.get("/export", response_model=List[FunctionModel]) +@router.get("/export", response_model=list[FunctionModel]) async def get_functions(user=Depends(get_admin_user)): return Functions.get_functions() @@ -63,7 +63,7 @@ async def create_new_function( form_data.id = form_data.id.lower() function = Functions.get_function_by_id(form_data.id) - if function == None: + if function is None: function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py") try: with open(function_path, "w") as function_file: @@ -235,7 +235,7 @@ async def delete_function_by_id( function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py") try: os.remove(function_path) - except: + except Exception: pass return result diff --git a/backend/apps/webui/routers/memories.py b/backend/apps/webui/routers/memories.py index 2c473ebe8f6389f3512841f0020abab5d4dc9187..a7b5474f0adb8e825bf5614633064e4b317735e7 100644 --- a/backend/apps/webui/routers/memories.py +++ b/backend/apps/webui/routers/memories.py @@ -1,7 +1,7 @@ from fastapi import Response, Request from fastapi import Depends, FastAPI, HTTPException, status from datetime import datetime, timedelta -from typing import List, Union, Optional +from typing import Union, Optional from fastapi import APIRouter from pydantic import BaseModel @@ -30,7 +30,7 @@ async def get_embeddings(request: Request): ############################ -@router.get("/", response_model=List[MemoryModel]) +@router.get("/", response_model=list[MemoryModel]) async def get_memories(user=Depends(get_verified_user)): return Memories.get_memories_by_user_id(user.id) diff --git a/backend/apps/webui/routers/models.py b/backend/apps/webui/routers/models.py index eeae9e1c41a7608b5fb256bb46d9e5de25af1738..8faeed7a64779a70ed40471799367088df2b651c 100644 --- a/backend/apps/webui/routers/models.py +++ b/backend/apps/webui/routers/models.py @@ -1,6 +1,6 @@ from fastapi import Depends, FastAPI, HTTPException, status, Request from datetime import datetime, timedelta -from typing import List, Union, Optional +from typing import Union, Optional from fastapi import APIRouter from pydantic import BaseModel @@ -18,7 +18,7 @@ router = APIRouter() ########################### -@router.get("/", response_model=List[ModelResponse]) +@router.get("/", response_model=list[ModelResponse]) async def get_models(user=Depends(get_verified_user)): return Models.get_all_models() diff --git a/backend/apps/webui/routers/prompts.py b/backend/apps/webui/routers/prompts.py index c674590e95998979f8f7749901fa9a699b9cb8f6..39d79362af53542bdec9205c1bbcb8ae7713e13d 100644 --- a/backend/apps/webui/routers/prompts.py +++ b/backend/apps/webui/routers/prompts.py @@ -1,6 +1,6 @@ from fastapi import Depends, FastAPI, HTTPException, status from datetime import datetime, timedelta -from typing import List, Union, Optional +from typing import Union, Optional from fastapi import APIRouter from pydantic import BaseModel @@ -18,7 +18,7 @@ router = APIRouter() ############################ -@router.get("/", response_model=List[PromptModel]) +@router.get("/", response_model=list[PromptModel]) async def get_prompts(user=Depends(get_verified_user)): return Prompts.get_prompts() @@ -31,7 +31,7 @@ async def get_prompts(user=Depends(get_verified_user)): @router.post("/create", response_model=Optional[PromptModel]) async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)): prompt = Prompts.get_prompt_by_command(form_data.command) - if prompt == None: + if prompt is None: prompt = Prompts.insert_new_prompt(user.id, form_data) if prompt: diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py index 7e60fe4d1ec52d1eaee9c9d62e0417a677ea913f..d6da7ae92289ac4ed18d113af57fee5116dcf320 100644 --- a/backend/apps/webui/routers/tools.py +++ b/backend/apps/webui/routers/tools.py @@ -1,5 +1,5 @@ from fastapi import Depends, HTTPException, status, Request -from typing import List, Optional +from typing import Optional from fastapi import APIRouter @@ -27,7 +27,7 @@ router = APIRouter() ############################ -@router.get("/", response_model=List[ToolResponse]) +@router.get("/", response_model=list[ToolResponse]) async def get_toolkits(user=Depends(get_verified_user)): toolkits = [toolkit for toolkit in Tools.get_tools()] return toolkits @@ -38,7 +38,7 @@ async def get_toolkits(user=Depends(get_verified_user)): ############################ -@router.get("/export", response_model=List[ToolModel]) +@router.get("/export", response_model=list[ToolModel]) async def get_toolkits(user=Depends(get_admin_user)): toolkits = [toolkit for toolkit in Tools.get_tools()] return toolkits diff --git a/backend/apps/webui/routers/users.py b/backend/apps/webui/routers/users.py index 9627f0b06779bd2486ee41cb846583ec613ae3c1..543757275a9b46626564f320980473cce2eeeadc 100644 --- a/backend/apps/webui/routers/users.py +++ b/backend/apps/webui/routers/users.py @@ -1,7 +1,7 @@ from fastapi import Response, Request from fastapi import Depends, FastAPI, HTTPException, status from datetime import datetime, timedelta -from typing import List, Union, Optional +from typing import Union, Optional from fastapi import APIRouter from pydantic import BaseModel @@ -39,7 +39,7 @@ router = APIRouter() ############################ -@router.get("/", response_model=List[UserModel]) +@router.get("/", response_model=list[UserModel]) async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)): return Users.get_users(skip, limit) diff --git a/backend/apps/webui/routers/utils.py b/backend/apps/webui/routers/utils.py index 4ffe748b0bebc1a02c2a77f3c379abb347a9a612..7a3c3393248b363929d92d55631013c2b8efe4e8 100644 --- a/backend/apps/webui/routers/utils.py +++ b/backend/apps/webui/routers/utils.py @@ -17,7 +17,7 @@ from utils.misc import calculate_sha256, get_gravatar_url from config import OLLAMA_BASE_URLS, DATA_DIR, UPLOAD_DIR, ENABLE_ADMIN_EXPORT from constants import ERROR_MESSAGES -from typing import List + router = APIRouter() @@ -57,7 +57,7 @@ async def get_html_from_markdown( class ChatForm(BaseModel): title: str - messages: List[dict] + messages: list[dict] @router.post("/pdf") diff --git a/backend/apps/webui/utils.py b/backend/apps/webui/utils.py index 96d2b29ebfa67a5d2d27feb43b8434041cbc04e2..bf5ebedeb7bf3fc87838f56a8faf77c63cac759a 100644 --- a/backend/apps/webui/utils.py +++ b/backend/apps/webui/utils.py @@ -1,6 +1,8 @@ from importlib import util import os import re +import sys +import subprocess from config import TOOLS_DIR, FUNCTIONS_DIR @@ -52,6 +54,7 @@ def load_toolkit_module_by_id(toolkit_id): frontmatter = extract_frontmatter(toolkit_path) try: + install_frontmatter_requirements(frontmatter.get("requirements", "")) spec.loader.exec_module(module) print(f"Loaded module: {module.__name__}") if hasattr(module, "Tools"): @@ -73,6 +76,7 @@ def load_function_module_by_id(function_id): frontmatter = extract_frontmatter(function_path) try: + install_frontmatter_requirements(frontmatter.get("requirements", "")) spec.loader.exec_module(module) print(f"Loaded module: {module.__name__}") if hasattr(module, "Pipe"): @@ -88,3 +92,13 @@ def load_function_module_by_id(function_id): # Move the file to the error folder os.rename(function_path, f"{function_path}.error") raise e + + +def install_frontmatter_requirements(requirements): + if requirements: + req_list = [req.strip() for req in requirements.split(",")] + for req in req_list: + print(f"Installing requirement: {req}") + subprocess.check_call([sys.executable, "-m", "pip", "install", req]) + else: + print("No requirements found in frontmatter.") diff --git a/backend/config.py b/backend/config.py index 5a095666ed4e82c86cc176183aa98141c22cbbd3..07ee06a58c2895857f4232cc2a7623eec84ac69f 100644 --- a/backend/config.py +++ b/backend/config.py @@ -87,14 +87,14 @@ class EndpointFilter(logging.Filter): logging.getLogger("uvicorn.access").addFilter(EndpointFilter()) -WEBUI_NAME = os.environ.get("WEBUI_NAME", "Transform.AI") -# if WEBUI_NAME != "Open WebUI": -# WEBUI_NAME += " (Open WebUI)" +WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") +if WEBUI_NAME != "Open WebUI": + WEBUI_NAME += " (Open WebUI)" WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000") -# WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" -WEBUI_FAVICON_URL = "/static/favicon.png" +WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" + #################################### # ENV (dev,test,prod) @@ -104,7 +104,7 @@ ENV = os.environ.get("ENV", "dev") try: PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text()) -except: +except Exception: try: PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")} except importlib.metadata.PackageNotFoundError: @@ -137,7 +137,7 @@ try: with open(str(changelog_path.absolute()), "r", encoding="utf8") as file: changelog_content = file.read() -except: +except Exception: changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode() @@ -202,12 +202,12 @@ if RESET_CONFIG_ON_START: os.remove(f"{DATA_DIR}/config.json") with open(f"{DATA_DIR}/config.json", "w") as f: f.write("{}") - except: + except Exception: pass try: CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text()) -except: +except Exception: CONFIG_DATA = {} @@ -433,6 +433,12 @@ OAUTH_PICTURE_CLAIM = PersistentConfig( os.environ.get("OAUTH_PICTURE_CLAIM", "picture"), ) +OAUTH_EMAIL_CLAIM = PersistentConfig( + "OAUTH_EMAIL_CLAIM", + "oauth.oidc.email_claim", + os.environ.get("OAUTH_EMAIL_CLAIM", "email"), +) + def load_oauth_providers(): OAUTH_PROVIDERS.clear() @@ -514,7 +520,6 @@ if CUSTOM_NAME: data = r.json() if r.ok: if "logo" in data: - WEBUI_FAVICON_URL = url = ( f"https://api.openwebui.com{data['logo']}" if data["logo"][0] == "/" @@ -642,7 +647,7 @@ if AIOHTTP_CLIENT_TIMEOUT == "": else: try: AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT) - except: + except Exception: AIOHTTP_CLIENT_TIMEOUT = 300 @@ -722,7 +727,7 @@ try: OPENAI_API_KEY = OPENAI_API_KEYS.value[ OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1") ] -except: +except Exception: pass OPENAI_API_BASE_URL = "https://api.openai.com/v1" @@ -1038,7 +1043,7 @@ RAG_EMBEDDING_MODEL = PersistentConfig( "rag.embedding_model", os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"), ) -log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}"), +log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}") RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" @@ -1060,7 +1065,7 @@ RAG_RERANKING_MODEL = PersistentConfig( os.environ.get("RAG_RERANKING_MODEL", ""), ) if RAG_RERANKING_MODEL.value != "": - log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}"), + log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}") RAG_RERANKING_MODEL_AUTO_UPDATE = ( os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true" diff --git a/backend/data/litellm/config.yaml b/backend/data/litellm/config.yaml index 7f792ff6baf0a834456c26c0344582275732683f..7d9d2b72304cd82b7621891affb2d79a0303cd72 100644 --- a/backend/data/litellm/config.yaml +++ b/backend/data/litellm/config.yaml @@ -1,6 +1,4 @@ general_settings: {} -litellm_settings: - success_callback: ["langfuse"] - failure_callback: ["langfuse"] +litellm_settings: {} model_list: [] router_settings: {} diff --git a/backend/main.py b/backend/main.py index 51b0933d69919bafd55dbc970c948161ea1584e9..d8ce5f5d7879db9331c6a36f74b3dd9476fe0767 100644 --- a/backend/main.py +++ b/backend/main.py @@ -51,7 +51,7 @@ from apps.webui.internal.db import Session from pydantic import BaseModel -from typing import List, Optional +from typing import Optional from apps.webui.models.auths import Auths from apps.webui.models.models import Models @@ -1883,7 +1883,7 @@ async def get_pipeline_valves( res = r.json() if "detail" in res: detail = res["detail"] - except: + except Exception: pass raise HTTPException( @@ -2027,7 +2027,7 @@ async def get_model_filter_config(user=Depends(get_admin_user)): class ModelFilterConfigForm(BaseModel): enabled: bool - models: List[str] + models: list[str] @app.post("/api/config/model/filter") @@ -2158,7 +2158,8 @@ async def oauth_callback(provider: str, request: Request, response: Response): log.warning(f"OAuth callback failed, sub is missing: {user_data}") raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) provider_sub = f"{provider}@{sub}" - email = user_data.get("email", "").lower() + email_claim = webui_app.state.config.OAUTH_EMAIL_CLAIM + email = user_data.get(email_claim, "").lower() # We currently mandate that email addresses are provided if not email: log.warning(f"OAuth callback failed, email is missing: {user_data}") @@ -2263,7 +2264,7 @@ async def get_manifest_json(): "display": "standalone", "background_color": "#343541", "orientation": "portrait-primary", - "icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "1000x1000"}], + "icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}], } diff --git a/backend/requirements.txt b/backend/requirements.txt index e8466a649a12ccd1bcc5d872ce05503f402010a2..6ef299b5fab415204ec52015aed39a8104ef6f57 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -11,7 +11,7 @@ python-jose==3.3.0 passlib[bcrypt]==1.7.4 requests==2.32.3 -aiohttp==3.9.5 +aiohttp==3.10.2 sqlalchemy==2.0.31 alembic==1.13.2 @@ -34,12 +34,12 @@ anthropic google-generativeai==0.7.2 tiktoken -langchain==0.2.11 +langchain==0.2.12 langchain-community==0.2.10 langchain-chroma==0.1.2 fake-useragent==1.5.1 -chromadb==0.5.4 +chromadb==0.5.5 sentence-transformers==3.0.1 pypdf==4.3.1 docx2txt==0.8 @@ -62,11 +62,11 @@ rank-bm25==0.2.2 faster-whisper==1.0.2 -PyJWT[crypto]==2.8.0 +PyJWT[crypto]==2.9.0 authlib==1.3.1 black==24.8.0 -langfuse==2.39.2 +langfuse==2.43.3 youtube-transcript-api==0.6.2 pytube==15.0.0 @@ -76,5 +76,5 @@ duckduckgo-search~=6.2.1 ## Tests docker~=7.1.0 -pytest~=8.2.2 +pytest~=8.3.2 pytest-docker~=3.1.1 diff --git a/backend/start.sh b/backend/start.sh index 16a004e45c266413787def1f731ea36da996247f..0a5c48e8c424de1a14e1dd3124aa08cf38179e7c 100755 --- a/backend/start.sh +++ b/backend/start.sh @@ -30,7 +30,6 @@ if [[ "${USE_CUDA_DOCKER,,}" == "true" ]]; then export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib" fi - # Check if SPACE_ID is set, if so, configure for space if [ -n "$SPACE_ID" ]; then echo "Configuring for HuggingFace Space deployment" diff --git a/backend/static/favicon.png b/backend/static/favicon.png index d16cbe245179bdb150a67b86f5a43b99ebf6cc30..2b2074780847581edf9cf2ed0d2e9ebd8ff08c56 100644 Binary files a/backend/static/favicon.png and b/backend/static/favicon.png differ diff --git a/backend/static/logo.png b/backend/static/logo.png index d16cbe245179bdb150a67b86f5a43b99ebf6cc30..519af1db620dbf4de3694660dae7abd7392f0b3c 100644 Binary files a/backend/static/logo.png and b/backend/static/logo.png differ diff --git a/backend/utils/logo.png b/backend/utils/logo.png index d16cbe245179bdb150a67b86f5a43b99ebf6cc30..519af1db620dbf4de3694660dae7abd7392f0b3c 100644 Binary files a/backend/utils/logo.png and b/backend/utils/logo.png differ diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 25dd4dd5b66c5de02f800dca7d12cc82e242c4b1..2eed58f41eba362cc736bad30bf7750624313c6b 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -2,14 +2,14 @@ from pathlib import Path import hashlib import re from datetime import timedelta -from typing import Optional, List, Tuple +from typing import Optional, Callable import uuid import time from utils.task import prompt_template -def get_last_user_message_item(messages: List[dict]) -> Optional[dict]: +def get_last_user_message_item(messages: list[dict]) -> Optional[dict]: for message in reversed(messages): if message["role"] == "user": return message @@ -26,7 +26,7 @@ def get_content_from_message(message: dict) -> Optional[str]: return None -def get_last_user_message(messages: List[dict]) -> Optional[str]: +def get_last_user_message(messages: list[dict]) -> Optional[str]: message = get_last_user_message_item(messages) if message is None: return None @@ -34,31 +34,31 @@ def get_last_user_message(messages: List[dict]) -> Optional[str]: return get_content_from_message(message) -def get_last_assistant_message(messages: List[dict]) -> Optional[str]: +def get_last_assistant_message(messages: list[dict]) -> Optional[str]: for message in reversed(messages): if message["role"] == "assistant": return get_content_from_message(message) return None -def get_system_message(messages: List[dict]) -> Optional[dict]: +def get_system_message(messages: list[dict]) -> Optional[dict]: for message in messages: if message["role"] == "system": return message return None -def remove_system_message(messages: List[dict]) -> List[dict]: +def remove_system_message(messages: list[dict]) -> list[dict]: return [message for message in messages if message["role"] != "system"] -def pop_system_message(messages: List[dict]) -> Tuple[Optional[dict], List[dict]]: +def pop_system_message(messages: list[dict]) -> tuple[Optional[dict], list[dict]]: return get_system_message(messages), remove_system_message(messages) def prepend_to_first_user_message_content( - content: str, messages: List[dict] -) -> List[dict]: + content: str, messages: list[dict] +) -> list[dict]: for message in messages: if message["role"] == "user": if isinstance(message["content"], list): @@ -71,7 +71,7 @@ def prepend_to_first_user_message_content( return messages -def add_or_update_system_message(content: str, messages: List[dict]): +def add_or_update_system_message(content: str, messages: list[dict]): """ Adds a new system message at the beginning of the messages list or updates the existing system message at the beginning. @@ -135,10 +135,21 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di # inplace function: form_data is modified -def apply_model_params_to_body(params: dict, form_data: dict) -> dict: +def apply_model_params_to_body( + params: dict, form_data: dict, mappings: dict[str, Callable] +) -> dict: if not params: return form_data + for key, cast_func in mappings.items(): + if (value := params.get(key)) is not None: + form_data[key] = cast_func(value) + + return form_data + + +# inplace function: form_data is modified +def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: mappings = { "temperature": float, "top_p": int, @@ -147,10 +158,40 @@ def apply_model_params_to_body(params: dict, form_data: dict) -> dict: "seed": lambda x: x, "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], } + return apply_model_params_to_body(params, form_data, mappings) + + +def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: + opts = [ + "temperature", + "top_p", + "seed", + "mirostat", + "mirostat_eta", + "mirostat_tau", + "num_ctx", + "num_batch", + "num_keep", + "repeat_last_n", + "tfs_z", + "top_k", + "min_p", + "use_mmap", + "use_mlock", + "num_thread", + "num_gpu", + ] + mappings = {i: lambda x: x for i in opts} + form_data = apply_model_params_to_body(params, form_data, mappings) + + name_differences = { + "max_tokens": "num_predict", + "frequency_penalty": "repeat_penalty", + } - for key, cast_func in mappings.items(): - if (value := params.get(key)) is not None: - form_data[key] = cast_func(value) + for key, value in name_differences.items(): + if (param := params.get(key, None)) is not None: + form_data[value] = param return form_data diff --git a/backend/utils/tools.py b/backend/utils/tools.py index 3e5d82fd6d15255e1f89faab2ec91b037eb88116..eac36b5d90bdfc85f4a492c96b864e7d6df2993e 100644 --- a/backend/utils/tools.py +++ b/backend/utils/tools.py @@ -1,5 +1,5 @@ import inspect -from typing import get_type_hints, List, Dict, Any +from typing import get_type_hints def doc_to_dict(docstring): @@ -16,7 +16,7 @@ def doc_to_dict(docstring): return ret_dict -def get_tools_specs(tools) -> List[dict]: +def get_tools_specs(tools) -> list[dict]: function_list = [ {"name": func, "function": getattr(tools, func)} for func in dir(tools) diff --git a/cypress/e2e/chat.cy.ts b/cypress/e2e/chat.cy.ts index ddb33d6c06b54bcbe562241c73569c710b900948..20be9755a4fe97b660c904d7bd9d6bfd3bcebbe4 100644 --- a/cypress/e2e/chat.cy.ts +++ b/cypress/e2e/chat.cy.ts @@ -38,9 +38,10 @@ describe('Settings', () => { // User's message should be visible cy.get('.chat-user').should('exist'); // Wait for the response - cy.get('.chat-assistant', { timeout: 120_000 }) // .chat-assistant is created after the first token is received - .find('div[aria-label="Generation Info"]', { timeout: 120_000 }) // Generation Info is created after the stop token is received - .should('exist'); + // .chat-assistant is created after the first token is received + cy.get('.chat-assistant', { timeout: 10_000 }).should('exist'); + // Generation Info is created after the stop token is received + cy.get('div[aria-label="Generation Info"]', { timeout: 120_000 }).should('exist'); }); it('user can share chat', () => { @@ -57,21 +58,24 @@ describe('Settings', () => { // User's message should be visible cy.get('.chat-user').should('exist'); // Wait for the response - cy.get('.chat-assistant', { timeout: 120_000 }) // .chat-assistant is created after the first token is received - .find('div[aria-label="Generation Info"]', { timeout: 120_000 }) // Generation Info is created after the stop token is received - .should('exist'); + // .chat-assistant is created after the first token is received + cy.get('.chat-assistant', { timeout: 10_000 }).should('exist'); + // Generation Info is created after the stop token is received + cy.get('div[aria-label="Generation Info"]', { timeout: 120_000 }).should('exist'); // spy on requests const spy = cy.spy(); - cy.intercept('GET', '/api/v1/chats/*', spy); + cy.intercept('POST', '/api/v1/chats/**/share', spy); // Open context menu cy.get('#chat-context-menu-button').click(); // Click share button cy.get('#chat-share-button').click(); // Check if the share dialog is visible cy.get('#copy-and-share-chat-button').should('exist'); - cy.wrap({}, { timeout: 5000 }).should(() => { - // Check if the request was made twice (once for to replace chat object and once more due to change event) - expect(spy).to.be.callCount(2); + // Click the copy button + cy.get('#copy-and-share-chat-button').click(); + cy.wrap({}, { timeout: 5_000 }).should(() => { + // Check if the share request was made + expect(spy).to.be.callCount(1); }); }); @@ -89,9 +93,10 @@ describe('Settings', () => { // User's message should be visible cy.get('.chat-user').should('exist'); // Wait for the response - cy.get('.chat-assistant', { timeout: 120_000 }) // .chat-assistant is created after the first token is received - .find('div[aria-label="Generation Info"]', { timeout: 120_000 }) // Generation Info is created after the stop token is received - .should('exist'); + // .chat-assistant is created after the first token is received + cy.get('.chat-assistant', { timeout: 10_000 }).should('exist'); + // Generation Info is created after the stop token is received + cy.get('div[aria-label="Generation Info"]', { timeout: 120_000 }).should('exist'); // Click on the generate image button cy.get('[aria-label="Generate Image"]').click(); // Wait for image to be visible diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index 325964b1a94a94b3a296a4da8e84525bc24d3fb8..ec8a79bbcea4ed9b62faa312e1ca6041b4426f3c 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -22,7 +22,6 @@ Noticed something off? Have an idea? Check our [Issues tab](https://github.com/o > [!IMPORTANT] > > - **Template Compliance:** Please be aware that failure to follow the provided issue template, or not providing the requested information at all, will likely result in your issue being closed without further consideration. This approach is critical for maintaining the manageability and integrity of issue tracking. -> > - **Detail is Key:** To ensure your issue is understood and can be effectively addressed, it's imperative to include comprehensive details. Descriptions should be clear, including steps to reproduce, expected outcomes, and actual results. Lack of sufficient detail may hinder our ability to resolve your issue. ### 🧭 Scope of Support diff --git a/package-lock.json b/package-lock.json index b7e714ed78c391043a5c290a24a5b0434fe54117..aa813a4dbbbfc02c8b5a950b8714dea63331f821 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.3.12", + "version": "0.3.13", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.3.12", + "version": "0.3.13", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-python": "^6.1.6", @@ -18,6 +18,7 @@ "codemirror": "^6.0.1", "crc-32": "^1.2.2", "dayjs": "^1.11.10", + "dompurify": "^3.1.6", "eventsource-parser": "^1.1.2", "file-saver": "^2.0.5", "fuse.js": "^7.0.0", @@ -29,6 +30,7 @@ "js-sha256": "^0.10.1", "katex": "^0.16.9", "marked": "^9.1.0", + "marked-katex-extension": "^5.1.1", "mermaid": "^10.9.1", "pyodide": "^0.26.1", "socket.io-client": "^4.2.0", @@ -1544,6 +1546,11 @@ "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==", "dev": true }, + "node_modules/@types/katex": { + "version": "0.16.7", + "resolved": "https://registry.npmjs.org/@types/katex/-/katex-0.16.7.tgz", + "integrity": "sha512-HMwFiRujE5PjrgwHQ25+bsLJgowjGjm5Z8FVSf0N6PwgJrwxH0QxzHYDcKsTfV3wva0vzrpqMTJS2jXPr5BMEQ==" + }, "node_modules/@types/mdast": { "version": "3.0.15", "resolved": "https://registry.npmjs.org/@types/mdast/-/mdast-3.0.15.tgz", @@ -3912,9 +3919,9 @@ } }, "node_modules/dompurify": { - "version": "3.1.5", - "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.1.5.tgz", - "integrity": "sha512-lwG+n5h8QNpxtyrJW/gJWckL+1/DQiYMX8f7t8Z2AZTPw1esVrqjI63i7Zc2Gz0aKzLVMYC1V1PL/ky+aY/NgA==" + "version": "3.1.6", + "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.1.6.tgz", + "integrity": "sha512-cTOAhc36AalkjtBpfG6O8JimdTMWNXjiePT2xQH/ppBGi/4uIpmj8eKyIkMJErXWARyINV/sB38yf8JCLF5pbQ==" }, "node_modules/domutils": { "version": "3.1.0", @@ -6036,6 +6043,18 @@ "node": ">= 16" } }, + "node_modules/marked-katex-extension": { + "version": "5.1.1", + "resolved": "https://registry.npmjs.org/marked-katex-extension/-/marked-katex-extension-5.1.1.tgz", + "integrity": "sha512-piquiCyZpZ1aiocoJlJkRXr+hkk5UI4xw9GhRZiIAAgvX5rhzUDSJ0seup1JcsgueC8MLNDuqe5cRcAzkFE42Q==", + "dependencies": { + "@types/katex": "^0.16.7" + }, + "peerDependencies": { + "katex": ">=0.16 <0.17", + "marked": ">=4 <15" + } + }, "node_modules/matcher-collection": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/matcher-collection/-/matcher-collection-2.0.1.tgz", diff --git a/package.json b/package.json index 0c7a8518ac618d1ba572452a0c8ca4c79fbb2d8f..fef2cbaef675bc662d6f2cdc0804dc1914209bfe 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.3.12", + "version": "0.3.13", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", @@ -59,6 +59,7 @@ "codemirror": "^6.0.1", "crc-32": "^1.2.2", "dayjs": "^1.11.10", + "dompurify": "^3.1.6", "eventsource-parser": "^1.1.2", "file-saver": "^2.0.5", "fuse.js": "^7.0.0", @@ -70,6 +71,7 @@ "js-sha256": "^0.10.1", "katex": "^0.16.9", "marked": "^9.1.0", + "marked-katex-extension": "^5.1.1", "mermaid": "^10.9.1", "pyodide": "^0.26.1", "socket.io-client": "^4.2.0", diff --git a/pyproject.toml b/pyproject.toml index 0b7af7f18576360a43f811f57cedbb29befd43e2..159bce0727fff0052b84639fb75d6f636bc1dd0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "open-webui" -description = "Open WebUI (Formerly Ollama WebUI)" +description = "Open WebUI" authors = [ { name = "Timothy Jaeryang Baek", email = "tim@openwebui.com" } ] @@ -19,7 +19,7 @@ dependencies = [ "passlib[bcrypt]==1.7.4", "requests==2.32.3", - "aiohttp==3.9.5", + "aiohttp==3.10.2", "sqlalchemy==2.0.31", "alembic==1.13.2", @@ -41,12 +41,12 @@ dependencies = [ "google-generativeai==0.7.2", "tiktoken", - "langchain==0.2.11", + "langchain==0.2.12", "langchain-community==0.2.10", "langchain-chroma==0.1.2", "fake-useragent==1.5.1", - "chromadb==0.5.4", + "chromadb==0.5.5", "sentence-transformers==3.0.1", "pypdf==4.3.1", "docx2txt==0.8", @@ -69,11 +69,11 @@ dependencies = [ "faster-whisper==1.0.2", - "PyJWT[crypto]==2.8.0", + "PyJWT[crypto]==2.9.0", "authlib==1.3.1", "black==24.8.0", - "langfuse==2.39.2", + "langfuse==2.43.3", "youtube-transcript-api==0.6.2", "pytube==15.0.0", diff --git a/requirements-dev.lock b/requirements-dev.lock index da1f66fccea4ba3e51950f0ce4630b349bde6f56..6b3f5185127ebf08015779e19e086016b8496f19 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -10,7 +10,9 @@ # universal: false -e file:. -aiohttp==3.9.5 +aiohappyeyeballs==2.3.5 + # via aiohttp +aiohttp==3.10.2 # via langchain # via langchain-community # via open-webui @@ -84,9 +86,9 @@ chardet==5.2.0 charset-normalizer==3.3.2 # via requests # via unstructured-client -chroma-hnswlib==0.7.5 +chroma-hnswlib==0.7.6 # via chromadb -chromadb==0.5.4 +chromadb==0.5.5 # via langchain-chroma # via open-webui click==8.1.7 @@ -269,7 +271,7 @@ jsonpointer==2.4 # via jsonpatch kubernetes==29.0.0 # via chromadb -langchain==0.2.11 +langchain==0.2.12 # via langchain-community # via open-webui langchain-chroma==0.1.2 @@ -285,7 +287,7 @@ langchain-text-splitters==0.2.0 # via langchain langdetect==1.0.9 # via unstructured -langfuse==2.39.2 +langfuse==2.43.3 # via open-webui langsmith==0.1.96 # via langchain @@ -491,7 +493,7 @@ pydub==0.25.1 # via open-webui pygments==2.18.0 # via rich -pyjwt==2.8.0 +pyjwt==2.9.0 # via open-webui pymongo==4.8.0 # via open-webui diff --git a/requirements.lock b/requirements.lock index da1f66fccea4ba3e51950f0ce4630b349bde6f56..6b3f5185127ebf08015779e19e086016b8496f19 100644 --- a/requirements.lock +++ b/requirements.lock @@ -10,7 +10,9 @@ # universal: false -e file:. -aiohttp==3.9.5 +aiohappyeyeballs==2.3.5 + # via aiohttp +aiohttp==3.10.2 # via langchain # via langchain-community # via open-webui @@ -84,9 +86,9 @@ chardet==5.2.0 charset-normalizer==3.3.2 # via requests # via unstructured-client -chroma-hnswlib==0.7.5 +chroma-hnswlib==0.7.6 # via chromadb -chromadb==0.5.4 +chromadb==0.5.5 # via langchain-chroma # via open-webui click==8.1.7 @@ -269,7 +271,7 @@ jsonpointer==2.4 # via jsonpatch kubernetes==29.0.0 # via chromadb -langchain==0.2.11 +langchain==0.2.12 # via langchain-community # via open-webui langchain-chroma==0.1.2 @@ -285,7 +287,7 @@ langchain-text-splitters==0.2.0 # via langchain langdetect==1.0.9 # via unstructured -langfuse==2.39.2 +langfuse==2.43.3 # via open-webui langsmith==0.1.96 # via langchain @@ -491,7 +493,7 @@ pydub==0.25.1 # via open-webui pygments==2.18.0 # via rich -pyjwt==2.8.0 +pyjwt==2.9.0 # via open-webui pymongo==4.8.0 # via open-webui diff --git a/run-ollama-docker.sh b/run-ollama-docker.sh index 3164df63de98dbb74ba7e0b1ab451e8217c83a80..c2a025bea3fa88beab7c0e6640de682dc8169c02 100644 --- a/run-ollama-docker.sh +++ b/run-ollama-docker.sh @@ -8,15 +8,11 @@ read -r -p "Do you want ollama in Docker with GPU support? (y/n): " use_gpu docker rm -f ollama || true docker pull ollama/ollama:latest -# docker_args="-d -v ollama:/root/.ollama -p $host_port:$container_port --name ollama ollama/ollama" -docker_args="-d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main" -# docker_args="-d -p 3000:8080 -v ollama:/root/.ollama -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:ollama" +docker_args="-d -v ollama:/root/.ollama -p $host_port:$container_port --name ollama ollama/ollama" -# if [ "$use_gpu" = "y" ]; then -# docker_args="--gpus=all $docker_args" -# fi - -docker_args="$docker_args" +if [ "$use_gpu" = "y" ]; then + docker_args="--gpus=all $docker_args" +fi docker run $docker_args diff --git a/src/app.html b/src/app.html index 5d48e1d7e8532479cb3cd18f38d7871318b60571..718f7e194c733fafa9f87389213221700af64408 100644 --- a/src/app.html +++ b/src/app.html @@ -1,4 +1,4 @@ - +
diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index 084d2d5f18a54dc11d73da5beba5c2dd93f700c5..c4c449156c1a6c1a9058329c7488e5967bdb3e7f 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -1,5 +1,4 @@ import { OLLAMA_API_BASE_URL } from '$lib/constants'; -import { titleGenerationTemplate } from '$lib/utils'; export const getOllamaConfig = async (token: string = '') => { let error = null; @@ -203,55 +202,6 @@ export const getOllamaModels = async (token: string = '') => { }); }; -// TODO: migrate to backend -export const generateTitle = async ( - token: string = '', - template: string, - model: string, - prompt: string -) => { - let error = null; - - template = titleGenerationTemplate(template, prompt); - - console.log(template); - - const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - Authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - model: model, - prompt: template, - stream: false, - options: { - // Restrict the number of tokens generated to 50 - num_predict: 50 - } - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - if ('detail' in err) { - error = err.detail; - } - return null; - }); - - if (error) { - throw error; - } - - return res?.response.replace(/["']/g, '') ?? 'New Chat'; -}; - export const generatePrompt = async (token: string = '', model: string, conversation: string) => { let error = null; diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index 2a52ebb3209ef80577732718646c4fe17759ef72..2bb11d12a74e48933536509d9fb16800d2efc53b 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -1,6 +1,4 @@ import { OPENAI_API_BASE_URL } from '$lib/constants'; -import { titleGenerationTemplate } from '$lib/utils'; -import { type Model, models, settings } from '$lib/stores'; export const getOpenAIConfig = async (token: string = '') => { let error = null; @@ -260,7 +258,7 @@ export const getOpenAIModelsDirect = async ( throw error; } - const models = Array.isArray(res) ? res : res?.data ?? null; + const models = Array.isArray(res) ? res : (res?.data ?? null); return models .map((model) => ({ id: model.id, name: model.name ?? model.id, external: true })) @@ -330,126 +328,3 @@ export const synthesizeOpenAISpeech = async ( return res; }; - -export const generateTitle = async ( - token: string = '', - template: string, - model: string, - prompt: string, - chat_id?: string, - url: string = OPENAI_API_BASE_URL -) => { - let error = null; - - template = titleGenerationTemplate(template, prompt); - - console.log(template); - - const res = await fetch(`${url}/chat/completions`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - Authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - model: model, - messages: [ - { - role: 'user', - content: template - } - ], - stream: false, - // Restricting the max tokens to 50 to avoid long titles - max_tokens: 50, - ...(chat_id && { chat_id: chat_id }), - title: true - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - if ('detail' in err) { - error = err.detail; - } - return null; - }); - - if (error) { - throw error; - } - - return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat'; -}; - -export const generateSearchQuery = async ( - token: string = '', - model: string, - previousMessages: string[], - prompt: string, - url: string = OPENAI_API_BASE_URL -): Promise