diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 8d9dd7aebc1aaa95f0d85953d82cdefd9b2fa6af..41c19d32d94bc8532cf021ea9307325f308b504e 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -15,13 +15,6 @@ jobs: name: Run Cypress Integration Tests runs-on: ubuntu-latest steps: - - name: Maximize build space - uses: AdityaGarg8/remove-unwanted-software@v4.1 - with: - remove-android: 'true' - remove-haskell: 'true' - remove-codeql: 'true' - - name: Checkout Repository uses: actions/checkout@v4 diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index f4dfc8a732ac824800c855b2f46cedbb0fee47d3..06d0bcce8dc570e88c3c18fbf5539cb3db39a569 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -150,11 +150,10 @@ async def update_engine_url( else: url = form_data.AUTOMATIC1111_BASE_URL.strip("/") try: - r = requests.head(url) - r.raise_for_status() + r = requests.head(url) app.state.config.AUTOMATIC1111_BASE_URL = url except Exception as e: - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) + raise HTTPException(status_code=400, detail="Invalid URL provided.") if form_data.COMFYUI_BASE_URL == None: app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL @@ -163,10 +162,9 @@ async def update_engine_url( try: r = requests.head(url) - r.raise_for_status() app.state.config.COMFYUI_BASE_URL = url except Exception as e: - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) if form_data.AUTOMATIC1111_API_AUTH == None: app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH diff --git a/backend/apps/images/utils/comfyui.py b/backend/apps/images/utils/comfyui.py index 94875d9595c8db062319e385142984892e4187f1..f82076809e0b9a49365d88e0459f47b9ad1a02cb 100644 --- a/backend/apps/images/utils/comfyui.py +++ b/backend/apps/images/utils/comfyui.py @@ -1,5 +1,6 @@ import asyncio import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +import uuid import json import urllib.request import urllib.parse @@ -397,9 +398,7 @@ async def comfyui_generate_image( return None try: - images = await asyncio.to_thread( - get_images, ws, comfyui_prompt, client_id, base_url - ) + images = await asyncio.to_thread(get_images, ws, comfyui_prompt, client_id, base_url) except Exception as e: log.exception(f"Error while receiving images: {e}") images = None diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index f479ad35c0d039173b464342e0317709d8870056..442d99ff26f1015d1b469da8a6ce1d13d01b5b13 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -1,21 +1,27 @@ from fastapi import ( FastAPI, Request, + Response, HTTPException, Depends, + status, UploadFile, File, + BackgroundTasks, ) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse +from fastapi.concurrency import run_in_threadpool from pydantic import BaseModel, ConfigDict import os import re +import copy import random import requests import json +import uuid import aiohttp import asyncio import logging @@ -26,11 +32,16 @@ from typing import Optional, List, Union from starlette.background import BackgroundTask from apps.webui.models.models import Models +from apps.webui.models.users import Users from constants import ERROR_MESSAGES from utils.utils import ( + decode_token, + get_current_user, get_verified_user, get_admin_user, ) +from utils.task import prompt_template + from config import ( SRC_LOG_LEVELS, @@ -42,12 +53,7 @@ from config import ( UPLOAD_DIR, AppConfig, ) -from utils.misc import ( - calculate_sha256, - apply_model_params_to_body_ollama, - apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, -) +from utils.misc import calculate_sha256, add_or_update_system_message log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) @@ -177,7 +183,7 @@ async def post_streaming_url(url: str, payload: str, stream: bool = True): res = await r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except Exception: + except: error_detail = f"Ollama: {e}" raise HTTPException( @@ -232,7 +238,7 @@ async def get_all_models(): async def get_ollama_tags( url_idx: Optional[int] = None, user=Depends(get_verified_user) ): - if url_idx is None: + if url_idx == None: models = await get_all_models() if app.state.config.ENABLE_MODEL_FILTER: @@ -263,7 +269,7 @@ async def get_ollama_tags( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except Exception: + except: error_detail = f"Ollama: {e}" raise HTTPException( @@ -276,7 +282,8 @@ async def get_ollama_tags( @app.get("/api/version/{url_idx}") async def get_ollama_versions(url_idx: Optional[int] = None): if app.state.config.ENABLE_OLLAMA_API: - if url_idx is None: + if url_idx == None: + # returns lowest version tasks = [ fetch_url(f"{url}/api/version") @@ -316,7 +323,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except Exception: + except: error_detail = f"Ollama: {e}" raise HTTPException( @@ -339,6 +346,8 @@ async def pull_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + r = None + # Admin should be able to pull models from any source payload = {**form_data.model_dump(exclude_none=True), "insecure": True} @@ -358,7 +367,7 @@ async def push_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - if url_idx is None: + if url_idx == None: if form_data.name in app.state.MODELS: url_idx = app.state.MODELS[form_data.name]["urls"][0] else: @@ -408,7 +417,7 @@ async def copy_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - if url_idx is None: + if url_idx == None: if form_data.source in app.state.MODELS: url_idx = app.state.MODELS[form_data.source]["urls"][0] else: @@ -419,13 +428,13 @@ async def copy_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - r = requests.request( - method="POST", - url=f"{url}/api/copy", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) try: + r = requests.request( + method="POST", + url=f"{url}/api/copy", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) r.raise_for_status() log.debug(f"r.text: {r.text}") @@ -439,7 +448,7 @@ async def copy_model( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except Exception: + except: error_detail = f"Ollama: {e}" raise HTTPException( @@ -455,7 +464,7 @@ async def delete_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - if url_idx is None: + if url_idx == None: if form_data.name in app.state.MODELS: url_idx = app.state.MODELS[form_data.name]["urls"][0] else: @@ -467,12 +476,12 @@ async def delete_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - r = requests.request( - method="DELETE", - url=f"{url}/api/delete", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) try: + r = requests.request( + method="DELETE", + url=f"{url}/api/delete", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) r.raise_for_status() log.debug(f"r.text: {r.text}") @@ -486,7 +495,7 @@ async def delete_model( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except Exception: + except: error_detail = f"Ollama: {e}" raise HTTPException( @@ -507,12 +516,12 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - r = requests.request( - method="POST", - url=f"{url}/api/show", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) try: + r = requests.request( + method="POST", + url=f"{url}/api/show", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) r.raise_for_status() return r.json() @@ -524,7 +533,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except Exception: + except: error_detail = f"Ollama: {e}" raise HTTPException( @@ -547,7 +556,7 @@ async def generate_embeddings( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - if url_idx is None: + if url_idx == None: model = form_data.model if ":" not in model: @@ -564,12 +573,12 @@ async def generate_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - r = requests.request( - method="POST", - url=f"{url}/api/embeddings", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) try: + r = requests.request( + method="POST", + url=f"{url}/api/embeddings", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) r.raise_for_status() return r.json() @@ -581,7 +590,7 @@ async def generate_embeddings( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except Exception: + except: error_detail = f"Ollama: {e}" raise HTTPException( @@ -594,9 +603,10 @@ def generate_ollama_embeddings( form_data: GenerateEmbeddingsForm, url_idx: Optional[int] = None, ): + log.info(f"generate_ollama_embeddings {form_data}") - if url_idx is None: + if url_idx == None: model = form_data.model if ":" not in model: @@ -613,12 +623,12 @@ def generate_ollama_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - r = requests.request( - method="POST", - url=f"{url}/api/embeddings", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) try: + r = requests.request( + method="POST", + url=f"{url}/api/embeddings", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) r.raise_for_status() data = r.json() @@ -628,7 +638,7 @@ def generate_ollama_embeddings( if "embedding" in data: return data["embedding"] else: - raise Exception("Something went wrong :/") + raise "Something went wrong :/" except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" @@ -637,10 +647,10 @@ def generate_ollama_embeddings( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except Exception: + except: error_detail = f"Ollama: {e}" - raise Exception(error_detail) + raise error_detail class GenerateCompletionForm(BaseModel): @@ -664,7 +674,8 @@ async def generate_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - if url_idx is None: + + if url_idx == None: model = form_data.model if ":" not in model: @@ -702,18 +713,6 @@ class GenerateChatCompletionForm(BaseModel): keep_alive: Optional[Union[int, str]] = None -def get_ollama_url(url_idx: Optional[int], model: str): - if url_idx is None: - if model not in app.state.MODELS: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), - ) - url_idx = random.choice(app.state.MODELS[model]["urls"]) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - return url - - @app.post("/api/chat") @app.post("/api/chat/{url_idx}") async def generate_chat_completion( @@ -721,7 +720,12 @@ async def generate_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=") + + log.debug( + "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( + form_data.model_dump_json(exclude_none=True).encode() + ) + ) payload = { **form_data.model_dump(exclude_none=True, exclude=["metadata"]), @@ -736,21 +740,185 @@ async def generate_chat_completion( if model_info.base_model_id: payload["model"] = model_info.base_model_id - params = model_info.params.model_dump() + model_info.params = model_info.params.model_dump() - if params: + if model_info.params: if payload.get("options") is None: payload["options"] = {} - payload["options"] = apply_model_params_to_body_ollama( - params, payload["options"] + if ( + model_info.params.get("mirostat", None) + and payload["options"].get("mirostat") is None + ): + payload["options"]["mirostat"] = model_info.params.get("mirostat", None) + + if ( + model_info.params.get("mirostat_eta", None) + and payload["options"].get("mirostat_eta") is None + ): + payload["options"]["mirostat_eta"] = model_info.params.get( + "mirostat_eta", None + ) + + if ( + model_info.params.get("mirostat_tau", None) + and payload["options"].get("mirostat_tau") is None + ): + payload["options"]["mirostat_tau"] = model_info.params.get( + "mirostat_tau", None + ) + + if ( + model_info.params.get("num_ctx", None) + and payload["options"].get("num_ctx") is None + ): + payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None) + + if ( + model_info.params.get("num_batch", None) + and payload["options"].get("num_batch") is None + ): + payload["options"]["num_batch"] = model_info.params.get( + "num_batch", None + ) + + if ( + model_info.params.get("num_keep", None) + and payload["options"].get("num_keep") is None + ): + payload["options"]["num_keep"] = model_info.params.get("num_keep", None) + + if ( + model_info.params.get("repeat_last_n", None) + and payload["options"].get("repeat_last_n") is None + ): + payload["options"]["repeat_last_n"] = model_info.params.get( + "repeat_last_n", None + ) + + if ( + model_info.params.get("frequency_penalty", None) + and payload["options"].get("frequency_penalty") is None + ): + payload["options"]["repeat_penalty"] = model_info.params.get( + "frequency_penalty", None + ) + + if ( + model_info.params.get("temperature", None) is not None + and payload["options"].get("temperature") is None + ): + payload["options"]["temperature"] = model_info.params.get( + "temperature", None + ) + + if ( + model_info.params.get("seed", None) is not None + and payload["options"].get("seed") is None + ): + payload["options"]["seed"] = model_info.params.get("seed", None) + + if ( + model_info.params.get("stop", None) + and payload["options"].get("stop") is None + ): + payload["options"]["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) + + if ( + model_info.params.get("tfs_z", None) + and payload["options"].get("tfs_z") is None + ): + payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None) + + if ( + model_info.params.get("max_tokens", None) + and payload["options"].get("max_tokens") is None + ): + payload["options"]["num_predict"] = model_info.params.get( + "max_tokens", None + ) + + if ( + model_info.params.get("top_k", None) + and payload["options"].get("top_k") is None + ): + payload["options"]["top_k"] = model_info.params.get("top_k", None) + + if ( + model_info.params.get("top_p", None) + and payload["options"].get("top_p") is None + ): + payload["options"]["top_p"] = model_info.params.get("top_p", None) + + if ( + model_info.params.get("min_p", None) + and payload["options"].get("min_p") is None + ): + payload["options"]["min_p"] = model_info.params.get("min_p", None) + + if ( + model_info.params.get("use_mmap", None) + and payload["options"].get("use_mmap") is None + ): + payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None) + + if ( + model_info.params.get("use_mlock", None) + and payload["options"].get("use_mlock") is None + ): + payload["options"]["use_mlock"] = model_info.params.get( + "use_mlock", None + ) + + if ( + model_info.params.get("num_thread", None) + and payload["options"].get("num_thread") is None + ): + payload["options"]["num_thread"] = model_info.params.get( + "num_thread", None + ) + + system = model_info.params.get("system", None) + if system: + system = prompt_template( + system, + **( + { + "user_name": user.name, + "user_location": ( + user.info.get("location") if user.info else None + ), + } + if user + else {} + ), ) - payload = apply_model_system_prompt_to_body(params, payload, user) - if ":" not in payload["model"]: - payload["model"] = f"{payload['model']}:latest" + if payload.get("messages"): + payload["messages"] = add_or_update_system_message( + system, payload["messages"] + ) + + if url_idx == None: + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" + + if payload["model"] in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) - url = get_ollama_url(url_idx, payload["model"]) + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") log.debug(payload) @@ -784,28 +952,83 @@ async def generate_openai_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - completion_form = OpenAIChatCompletionForm(**form_data) - payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])} + form_data = OpenAIChatCompletionForm(**form_data) + payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])} + if "metadata" in payload: del payload["metadata"] - model_id = completion_form.model + model_id = form_data.model model_info = Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id - params = model_info.params.model_dump() + model_info.params = model_info.params.model_dump() - if params: - payload = apply_model_params_to_body_openai(params, payload) - payload = apply_model_system_prompt_to_body(params, payload, user) + if model_info.params: + payload["temperature"] = model_info.params.get("temperature", None) + payload["top_p"] = model_info.params.get("top_p", None) + payload["max_tokens"] = model_info.params.get("max_tokens", None) + payload["frequency_penalty"] = model_info.params.get( + "frequency_penalty", None + ) + payload["seed"] = model_info.params.get("seed", None) + payload["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) - if ":" not in payload["model"]: - payload["model"] = f"{payload['model']}:latest" + system = model_info.params.get("system", None) + + if system: + system = prompt_template( + system, + **( + { + "user_name": user.name, + "user_location": ( + user.info.get("location") if user.info else None + ), + } + if user + else {} + ), + ) + # Check if the payload already has a system message + # If not, add a system message to the payload + if payload.get("messages"): + for message in payload["messages"]: + if message.get("role") == "system": + message["content"] = system + message["content"] + break + else: + payload["messages"].insert( + 0, + { + "role": "system", + "content": system, + }, + ) - url = get_ollama_url(url_idx, payload["model"]) + if url_idx == None: + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" + + if payload["model"] in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") return await post_streaming_url( @@ -821,7 +1044,7 @@ async def get_openai_models( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - if url_idx is None: + if url_idx == None: models = await get_all_models() if app.state.config.ENABLE_MODEL_FILTER: @@ -876,7 +1099,7 @@ async def get_openai_models( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except Exception: + except: error_detail = f"Ollama: {e}" raise HTTPException( @@ -902,6 +1125,7 @@ def parse_huggingface_url(hf_url): path_components = parsed_url.path.split("/") # Extract the desired output + user_repo = "/".join(path_components[1:3]) model_file = path_components[-1] return model_file @@ -966,6 +1190,7 @@ async def download_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): + allowed_hosts = ["https://huggingface.co/", "https://github.com/"] if not any(form_data.url.startswith(host) for host in allowed_hosts): @@ -974,7 +1199,7 @@ async def download_model( detail="Invalid file_url. Only URLs from allowed hosts are permitted.", ) - if url_idx is None: + if url_idx == None: url_idx = 0 url = app.state.config.OLLAMA_BASE_URLS[url_idx] @@ -997,7 +1222,7 @@ def upload_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - if url_idx is None: + if url_idx == None: url_idx = 0 ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 50de53a53781e3495043a96f6f46be0916527dac..948ee7c46ab0c7228da30701b1580d9260cd9450 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -17,10 +17,7 @@ from utils.utils import ( get_verified_user, get_admin_user, ) -from utils.misc import ( - apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, -) +from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body from config import ( SRC_LOG_LEVELS, @@ -371,7 +368,7 @@ async def generate_chat_completion( payload["model"] = model_info.base_model_id params = model_info.params.model_dump() - payload = apply_model_params_to_body_openai(params, payload) + payload = apply_model_params_to_body(params, payload) payload = apply_model_system_prompt_to_body(params, payload, user) model = app.state.MODELS[payload.get("model")] diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index dddf3fbb2a127db8ef36e6060dfa3adf1fb0b4cf..3c387842d8c0867a18b197643885149e094bfb29 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -22,7 +22,7 @@ from apps.webui.utils import load_function_module_by_id from utils.misc import ( openai_chat_chunk_message_template, openai_chat_completion_message_template, - apply_model_params_to_body_openai, + apply_model_params_to_body, apply_model_system_prompt_to_body, ) @@ -291,7 +291,7 @@ async def generate_function_chat_completion(form_data, user): form_data["model"] = model_info.base_model_id params = model_info.params.model_dump() - form_data = apply_model_params_to_body_openai(params, form_data) + form_data = apply_model_params_to_body(params, form_data) form_data = apply_model_system_prompt_to_body(params, form_data, user) pipe_id = get_pipe_id(form_data) diff --git a/backend/requirements.txt b/backend/requirements.txt index 6ef299b5fab415204ec52015aed39a8104ef6f57..e8466a649a12ccd1bcc5d872ce05503f402010a2 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -11,7 +11,7 @@ python-jose==3.3.0 passlib[bcrypt]==1.7.4 requests==2.32.3 -aiohttp==3.10.2 +aiohttp==3.9.5 sqlalchemy==2.0.31 alembic==1.13.2 @@ -34,12 +34,12 @@ anthropic google-generativeai==0.7.2 tiktoken -langchain==0.2.12 +langchain==0.2.11 langchain-community==0.2.10 langchain-chroma==0.1.2 fake-useragent==1.5.1 -chromadb==0.5.5 +chromadb==0.5.4 sentence-transformers==3.0.1 pypdf==4.3.1 docx2txt==0.8 @@ -62,11 +62,11 @@ rank-bm25==0.2.2 faster-whisper==1.0.2 -PyJWT[crypto]==2.9.0 +PyJWT[crypto]==2.8.0 authlib==1.3.1 black==24.8.0 -langfuse==2.43.3 +langfuse==2.39.2 youtube-transcript-api==0.6.2 pytube==15.0.0 @@ -76,5 +76,5 @@ duckduckgo-search~=6.2.1 ## Tests docker~=7.1.0 -pytest~=8.3.2 +pytest~=8.2.2 pytest-docker~=3.1.1 diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 05830568b49ca03b9b4352cb6aecc196f491dfbb..25dd4dd5b66c5de02f800dca7d12cc82e242c4b1 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -2,7 +2,7 @@ from pathlib import Path import hashlib import re from datetime import timedelta -from typing import Optional, List, Tuple, Callable +from typing import Optional, List, Tuple import uuid import time @@ -135,21 +135,10 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di # inplace function: form_data is modified -def apply_model_params_to_body( - params: dict, form_data: dict, mappings: dict[str, Callable] -) -> dict: +def apply_model_params_to_body(params: dict, form_data: dict) -> dict: if not params: return form_data - for key, cast_func in mappings.items(): - if (value := params.get(key)) is not None: - form_data[key] = cast_func(value) - - return form_data - - -# inplace function: form_data is modified -def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: mappings = { "temperature": float, "top_p": int, @@ -158,40 +147,10 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: "seed": lambda x: x, "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], } - return apply_model_params_to_body(params, form_data, mappings) - - -def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: - opts = [ - "temperature", - "top_p", - "seed", - "mirostat", - "mirostat_eta", - "mirostat_tau", - "num_ctx", - "num_batch", - "num_keep", - "repeat_last_n", - "tfs_z", - "top_k", - "min_p", - "use_mmap", - "use_mlock", - "num_thread", - "num_gpu", - ] - mappings = {i: lambda x: x for i in opts} - form_data = apply_model_params_to_body(params, form_data, mappings) - - name_differences = { - "max_tokens": "num_predict", - "frequency_penalty": "repeat_penalty", - } - for key, value in name_differences.items(): - if (param := params.get(key, None)) is not None: - form_data[value] = param + for key, cast_func in mappings.items(): + if (value := params.get(key)) is not None: + form_data[key] = cast_func(value) return form_data diff --git a/cypress/e2e/chat.cy.ts b/cypress/e2e/chat.cy.ts index 20be9755a4fe97b660c904d7bd9d6bfd3bcebbe4..ddb33d6c06b54bcbe562241c73569c710b900948 100644 --- a/cypress/e2e/chat.cy.ts +++ b/cypress/e2e/chat.cy.ts @@ -38,10 +38,9 @@ describe('Settings', () => { // User's message should be visible cy.get('.chat-user').should('exist'); // Wait for the response - // .chat-assistant is created after the first token is received - cy.get('.chat-assistant', { timeout: 10_000 }).should('exist'); - // Generation Info is created after the stop token is received - cy.get('div[aria-label="Generation Info"]', { timeout: 120_000 }).should('exist'); + cy.get('.chat-assistant', { timeout: 120_000 }) // .chat-assistant is created after the first token is received + .find('div[aria-label="Generation Info"]', { timeout: 120_000 }) // Generation Info is created after the stop token is received + .should('exist'); }); it('user can share chat', () => { @@ -58,24 +57,21 @@ describe('Settings', () => { // User's message should be visible cy.get('.chat-user').should('exist'); // Wait for the response - // .chat-assistant is created after the first token is received - cy.get('.chat-assistant', { timeout: 10_000 }).should('exist'); - // Generation Info is created after the stop token is received - cy.get('div[aria-label="Generation Info"]', { timeout: 120_000 }).should('exist'); + cy.get('.chat-assistant', { timeout: 120_000 }) // .chat-assistant is created after the first token is received + .find('div[aria-label="Generation Info"]', { timeout: 120_000 }) // Generation Info is created after the stop token is received + .should('exist'); // spy on requests const spy = cy.spy(); - cy.intercept('POST', '/api/v1/chats/**/share', spy); + cy.intercept('GET', '/api/v1/chats/*', spy); // Open context menu cy.get('#chat-context-menu-button').click(); // Click share button cy.get('#chat-share-button').click(); // Check if the share dialog is visible cy.get('#copy-and-share-chat-button').should('exist'); - // Click the copy button - cy.get('#copy-and-share-chat-button').click(); - cy.wrap({}, { timeout: 5_000 }).should(() => { - // Check if the share request was made - expect(spy).to.be.callCount(1); + cy.wrap({}, { timeout: 5000 }).should(() => { + // Check if the request was made twice (once for to replace chat object and once more due to change event) + expect(spy).to.be.callCount(2); }); }); @@ -93,10 +89,9 @@ describe('Settings', () => { // User's message should be visible cy.get('.chat-user').should('exist'); // Wait for the response - // .chat-assistant is created after the first token is received - cy.get('.chat-assistant', { timeout: 10_000 }).should('exist'); - // Generation Info is created after the stop token is received - cy.get('div[aria-label="Generation Info"]', { timeout: 120_000 }).should('exist'); + cy.get('.chat-assistant', { timeout: 120_000 }) // .chat-assistant is created after the first token is received + .find('div[aria-label="Generation Info"]', { timeout: 120_000 }) // Generation Info is created after the stop token is received + .should('exist'); // Click on the generate image button cy.get('[aria-label="Generate Image"]').click(); // Wait for image to be visible diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index ec8a79bbcea4ed9b62faa312e1ca6041b4426f3c..325964b1a94a94b3a296a4da8e84525bc24d3fb8 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -22,6 +22,7 @@ Noticed something off? Have an idea? Check our [Issues tab](https://github.com/o > [!IMPORTANT] > > - **Template Compliance:** Please be aware that failure to follow the provided issue template, or not providing the requested information at all, will likely result in your issue being closed without further consideration. This approach is critical for maintaining the manageability and integrity of issue tracking. +> > - **Detail is Key:** To ensure your issue is understood and can be effectively addressed, it's imperative to include comprehensive details. Descriptions should be clear, including steps to reproduce, expected outcomes, and actual results. Lack of sufficient detail may hinder our ability to resolve your issue. ### 🧭 Scope of Support diff --git a/pyproject.toml b/pyproject.toml index 1784a9b4472e5cf5443392008ed83df77aee7182..0b7af7f18576360a43f811f57cedbb29befd43e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "passlib[bcrypt]==1.7.4", "requests==2.32.3", - "aiohttp==3.10.2", + "aiohttp==3.9.5", "sqlalchemy==2.0.31", "alembic==1.13.2", diff --git a/src/app.html b/src/app.html index 718f7e194c733fafa9f87389213221700af64408..5d48e1d7e8532479cb3cd18f38d7871318b60571 100644 --- a/src/app.html +++ b/src/app.html @@ -1,4 +1,4 @@ - +
diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index b075d634b85853cca36387f00e5bd34edb8c3ec0..2a52ebb3209ef80577732718646c4fe17759ef72 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -260,7 +260,7 @@ export const getOpenAIModelsDirect = async ( throw error; } - const models = Array.isArray(res) ? res : (res?.data ?? null); + const models = Array.isArray(res) ? res : res?.data ?? null; return models .map((model) => ({ id: model.id, name: model.name ?? model.id, external: true })) diff --git a/src/lib/components/ChangelogModal.svelte b/src/lib/components/ChangelogModal.svelte index 6a24ea5d9f9512bbd064f1b1b2afc1b6696bc375..48156f92473315ed60d1b814ef73dc21ee95491d 100644 --- a/src/lib/components/ChangelogModal.svelte +++ b/src/lib/components/ChangelogModal.svelte @@ -75,12 +75,12 @@ class="font-semibold uppercase text-xs {section === 'added' ? 'text-white bg-blue-600' : section === 'fixed' - ? 'text-white bg-green-600' - : section === 'changed' - ? 'text-white bg-yellow-600' - : section === 'removed' - ? 'text-white bg-red-600' - : ''} w-fit px-3 rounded-full my-2.5" + ? 'text-white bg-green-600' + : section === 'changed' + ? 'text-white bg-yellow-600' + : section === 'removed' + ? 'text-white bg-red-600' + : ''} w-fit px-3 rounded-full my-2.5" > {section} diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte index bac84902f9b6a6146f146eb7b049637ba6a75ccf..1b0b2c3fa8803c9e5e7f7416f514eec6b01bb6c0 100644 --- a/src/lib/components/admin/Settings/Documents.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -112,7 +112,7 @@ url: OpenAIUrl, batch_size: OpenAIBatchSize } - } + } : {}) }).catch(async (error) => { toast.error(error); diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index a290d5d3d398e0b7908d8bd0faaf3cad83f6d492..2c42c20463643e7d9959a541e2ebd9b5aeb9d40c 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -579,8 +579,8 @@ let selectedModelIds = modelId ? [modelId] : atSelectedModel !== undefined - ? [atSelectedModel.id] - : selectedModels; + ? [atSelectedModel.id] + : selectedModels; // Create response messages for each selected model const responseMessageIds = {}; @@ -739,11 +739,11 @@ ? await getAndUpdateUserLocation(localStorage.token) : undefined )}${ - (responseMessage?.userContext ?? null) + responseMessage?.userContext ?? null ? `\n\nUser Context:\n${responseMessage?.userContext ?? ''}` : '' }` - } + } : undefined, ...messages ] @@ -811,10 +811,10 @@ options: { ...(params ?? $settings.params ?? {}), stop: - (params?.stop ?? $settings?.params?.stop ?? undefined) + params?.stop ?? $settings?.params?.stop ?? undefined ? (params?.stop.split(',').map((token) => token.trim()) ?? $settings.params.stop).map( (str) => decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"')) - ) + ) : undefined, num_predict: params?.max_tokens ?? $settings?.params?.max_tokens ?? undefined, repeat_penalty: @@ -1056,10 +1056,10 @@ stream: true, model: model.id, stream_options: - (model.info?.meta?.capabilities?.usage ?? false) + model.info?.meta?.capabilities?.usage ?? false ? { include_usage: true - } + } : undefined, messages: [ params?.system || $settings.system || (responseMessage?.userContext ?? null) @@ -1072,11 +1072,11 @@ ? await getAndUpdateUserLocation(localStorage.token) : undefined )}${ - (responseMessage?.userContext ?? null) + responseMessage?.userContext ?? null ? `\n\nUser Context:\n${responseMessage?.userContext ?? ''}` : '' }` - } + } : undefined, ...messages ] @@ -1092,7 +1092,7 @@ text: arr.length - 1 !== idx ? message.content - : (message?.raContent ?? message.content) + : message?.raContent ?? message.content }, ...message.files .filter((file) => file.type === 'image') @@ -1103,20 +1103,20 @@ } })) ] - } + } : { content: arr.length - 1 !== idx ? message.content - : (message?.raContent ?? message.content) - }) + : message?.raContent ?? message.content + }) })), seed: params?.seed ?? $settings?.params?.seed ?? undefined, stop: - (params?.stop ?? $settings?.params?.stop ?? undefined) + params?.stop ?? $settings?.params?.stop ?? undefined ? (params?.stop.split(',').map((token) => token.trim()) ?? $settings.params.stop).map( (str) => decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"')) - ) + ) : undefined, temperature: params?.temperature ?? $settings?.params?.temperature ?? undefined, top_p: params?.top_p ?? $settings?.params?.top_p ?? undefined, diff --git a/src/lib/components/chat/Controls/Controls.svelte b/src/lib/components/chat/Controls/Controls.svelte index 31b58ab90e70352d1ed6222ffc1a85d2c4d3009b..69034a305a2e2b5db727e8d4b0ace765b09d4851 100644 --- a/src/lib/components/chat/Controls/Controls.svelte +++ b/src/lib/components/chat/Controls/Controls.svelte @@ -9,8 +9,6 @@ import FileItem from '$lib/components/common/FileItem.svelte'; import Collapsible from '$lib/components/common/Collapsible.svelte'; - import { user } from '$lib/stores'; - export let models = []; export let chatFiles = []; @@ -80,7 +78,7 @@