Spaces:
Running
Running
File size: 2,919 Bytes
330ae1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import random
import logging
import sys
from fastapi import Request
from open_webui.models.users import UserModel
from open_webui.models.models import Models
from open_webui.utils.models import check_model_access
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
from open_webui.routers.openai import embeddings as openai_embeddings
from open_webui.routers.ollama import (
embeddings as ollama_embeddings,
GenerateEmbeddingsForm,
)
from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
from open_webui.utils.response import convert_embedding_response_ollama_to_openai
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def generate_embeddings(
request: Request,
form_data: dict,
user: UserModel,
bypass_filter: bool = False,
):
"""
Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama).
Args:
request (Request): The FastAPI request context.
form_data (dict): The input data sent to the endpoint.
user (UserModel): The authenticated user.
bypass_filter (bool): If True, disables access filtering (default False).
Returns:
dict: The embeddings response, following OpenAI API compatibility.
"""
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True
# Attach extra metadata from request.state if present
if hasattr(request.state, "metadata"):
if "metadata" not in form_data:
form_data["metadata"] = request.state.metadata
else:
form_data["metadata"] = {
**form_data["metadata"],
**request.state.metadata,
}
# If "direct" flag present, use only that model
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data.get("model")
if model_id not in models:
raise Exception("Model not found")
model = models[model_id]
# Access filtering
if not getattr(request.state, "direct", False):
if not bypass_filter and user.role == "user":
check_model_access(user, model)
# Ollama backend
if model.get("owned_by") == "ollama":
ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
response = await ollama_embeddings(
request=request,
form_data=GenerateEmbeddingsForm(**ollama_payload),
user=user,
)
return convert_embedding_response_ollama_to_openai(response)
# Default: OpenAI or compatible backend
return await openai_embeddings(
request=request,
form_data=form_data,
user=user,
)
|