|
from fastapi import APIRouter, Depends, HTTPException, Response, status, Request |
|
from fastapi.responses import JSONResponse, RedirectResponse |
|
|
|
from pydantic import BaseModel |
|
from typing import Optional |
|
import logging |
|
import re |
|
|
|
from open_webui.utils.chat import generate_chat_completion |
|
from open_webui.utils.task import ( |
|
title_generation_template, |
|
query_generation_template, |
|
image_prompt_generation_template, |
|
autocomplete_generation_template, |
|
tags_generation_template, |
|
emoji_generation_template, |
|
moa_response_generation_template, |
|
) |
|
from open_webui.utils.auth import get_admin_user, get_verified_user |
|
from open_webui.constants import TASKS |
|
|
|
from open_webui.routers.pipelines import process_pipeline_inlet_filter |
|
from open_webui.utils.filter import ( |
|
get_sorted_filter_ids, |
|
process_filter_functions, |
|
) |
|
from open_webui.utils.task import get_task_model_id |
|
|
|
from open_webui.config import ( |
|
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE, |
|
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE, |
|
DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, |
|
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, |
|
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, |
|
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE, |
|
DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE, |
|
) |
|
from open_webui.env import SRC_LOG_LEVELS |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
log.setLevel(SRC_LOG_LEVELS["MODELS"]) |
|
|
|
router = APIRouter() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/config") |
|
async def get_task_config(request: Request, user=Depends(get_verified_user)): |
|
return { |
|
"TASK_MODEL": request.app.state.config.TASK_MODEL, |
|
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, |
|
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, |
|
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, |
|
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, |
|
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, |
|
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, |
|
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, |
|
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION, |
|
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, |
|
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, |
|
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, |
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, |
|
} |
|
|
|
|
|
class TaskConfigForm(BaseModel): |
|
TASK_MODEL: Optional[str] |
|
TASK_MODEL_EXTERNAL: Optional[str] |
|
ENABLE_TITLE_GENERATION: bool |
|
TITLE_GENERATION_PROMPT_TEMPLATE: str |
|
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str |
|
ENABLE_AUTOCOMPLETE_GENERATION: bool |
|
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int |
|
TAGS_GENERATION_PROMPT_TEMPLATE: str |
|
ENABLE_TAGS_GENERATION: bool |
|
ENABLE_SEARCH_QUERY_GENERATION: bool |
|
ENABLE_RETRIEVAL_QUERY_GENERATION: bool |
|
QUERY_GENERATION_PROMPT_TEMPLATE: str |
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str |
|
|
|
|
|
@router.post("/config/update") |
|
async def update_task_config( |
|
request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user) |
|
): |
|
request.app.state.config.TASK_MODEL = form_data.TASK_MODEL |
|
request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL |
|
request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION |
|
request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( |
|
form_data.TITLE_GENERATION_PROMPT_TEMPLATE |
|
) |
|
|
|
request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = ( |
|
form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE |
|
) |
|
|
|
request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ( |
|
form_data.ENABLE_AUTOCOMPLETE_GENERATION |
|
) |
|
request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( |
|
form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH |
|
) |
|
|
|
request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = ( |
|
form_data.TAGS_GENERATION_PROMPT_TEMPLATE |
|
) |
|
request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION |
|
request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ( |
|
form_data.ENABLE_SEARCH_QUERY_GENERATION |
|
) |
|
request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ( |
|
form_data.ENABLE_RETRIEVAL_QUERY_GENERATION |
|
) |
|
|
|
request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = ( |
|
form_data.QUERY_GENERATION_PROMPT_TEMPLATE |
|
) |
|
request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( |
|
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE |
|
) |
|
|
|
return { |
|
"TASK_MODEL": request.app.state.config.TASK_MODEL, |
|
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, |
|
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION, |
|
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, |
|
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, |
|
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, |
|
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, |
|
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, |
|
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, |
|
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, |
|
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, |
|
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, |
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, |
|
} |
|
|
|
|
|
@router.post("/title/completions") |
|
async def generate_title( |
|
request: Request, form_data: dict, user=Depends(get_verified_user) |
|
): |
|
|
|
if not request.app.state.config.ENABLE_TITLE_GENERATION: |
|
return JSONResponse( |
|
status_code=status.HTTP_200_OK, |
|
content={"detail": "Title generation is disabled"}, |
|
) |
|
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"): |
|
models = { |
|
request.state.model["id"]: request.state.model, |
|
} |
|
else: |
|
models = request.app.state.MODELS |
|
|
|
model_id = form_data["model"] |
|
if model_id not in models: |
|
raise HTTPException( |
|
status_code=status.HTTP_404_NOT_FOUND, |
|
detail="Model not found", |
|
) |
|
|
|
|
|
|
|
task_model_id = get_task_model_id( |
|
model_id, |
|
request.app.state.config.TASK_MODEL, |
|
request.app.state.config.TASK_MODEL_EXTERNAL, |
|
models, |
|
) |
|
|
|
log.debug( |
|
f"generating chat title using model {task_model_id} for user {user.email} " |
|
) |
|
|
|
if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": |
|
template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE |
|
else: |
|
template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE |
|
|
|
messages = form_data["messages"] |
|
|
|
|
|
for message in messages: |
|
message["content"] = re.sub( |
|
r"<details\s+type=\"reasoning\"[^>]*>.*?<\/details>", |
|
"", |
|
message["content"], |
|
flags=re.S, |
|
).strip() |
|
|
|
content = title_generation_template( |
|
template, |
|
messages, |
|
{ |
|
"name": user.name, |
|
"location": user.info.get("location") if user.info else None, |
|
}, |
|
) |
|
|
|
payload = { |
|
"model": task_model_id, |
|
"messages": [{"role": "user", "content": content}], |
|
"stream": False, |
|
**( |
|
{"max_tokens": 1000} |
|
if models[task_model_id].get("owned_by") == "ollama" |
|
else { |
|
"max_completion_tokens": 1000, |
|
} |
|
), |
|
"metadata": { |
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}), |
|
"task": str(TASKS.TITLE_GENERATION), |
|
"task_body": form_data, |
|
"chat_id": form_data.get("chat_id", None), |
|
}, |
|
} |
|
|
|
|
|
try: |
|
payload = await process_pipeline_inlet_filter(request, payload, user, models) |
|
except Exception as e: |
|
raise e |
|
|
|
try: |
|
return await generate_chat_completion(request, form_data=payload, user=user) |
|
except Exception as e: |
|
log.error("Exception occurred", exc_info=True) |
|
return JSONResponse( |
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
content={"detail": "An internal error has occurred."}, |
|
) |
|
|
|
|
|
@router.post("/tags/completions") |
|
async def generate_chat_tags( |
|
request: Request, form_data: dict, user=Depends(get_verified_user) |
|
): |
|
|
|
if not request.app.state.config.ENABLE_TAGS_GENERATION: |
|
return JSONResponse( |
|
status_code=status.HTTP_200_OK, |
|
content={"detail": "Tags generation is disabled"}, |
|
) |
|
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"): |
|
models = { |
|
request.state.model["id"]: request.state.model, |
|
} |
|
else: |
|
models = request.app.state.MODELS |
|
|
|
model_id = form_data["model"] |
|
if model_id not in models: |
|
raise HTTPException( |
|
status_code=status.HTTP_404_NOT_FOUND, |
|
detail="Model not found", |
|
) |
|
|
|
|
|
|
|
task_model_id = get_task_model_id( |
|
model_id, |
|
request.app.state.config.TASK_MODEL, |
|
request.app.state.config.TASK_MODEL_EXTERNAL, |
|
models, |
|
) |
|
|
|
log.debug( |
|
f"generating chat tags using model {task_model_id} for user {user.email} " |
|
) |
|
|
|
if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": |
|
template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE |
|
else: |
|
template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE |
|
|
|
content = tags_generation_template( |
|
template, form_data["messages"], {"name": user.name} |
|
) |
|
|
|
payload = { |
|
"model": task_model_id, |
|
"messages": [{"role": "user", "content": content}], |
|
"stream": False, |
|
"metadata": { |
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}), |
|
"task": str(TASKS.TAGS_GENERATION), |
|
"task_body": form_data, |
|
"chat_id": form_data.get("chat_id", None), |
|
}, |
|
} |
|
|
|
|
|
try: |
|
payload = await process_pipeline_inlet_filter(request, payload, user, models) |
|
except Exception as e: |
|
raise e |
|
|
|
try: |
|
return await generate_chat_completion(request, form_data=payload, user=user) |
|
except Exception as e: |
|
log.error(f"Error generating chat completion: {e}") |
|
return JSONResponse( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
content={"detail": "An internal error has occurred."}, |
|
) |
|
|
|
|
|
@router.post("/image_prompt/completions") |
|
async def generate_image_prompt( |
|
request: Request, form_data: dict, user=Depends(get_verified_user) |
|
): |
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"): |
|
models = { |
|
request.state.model["id"]: request.state.model, |
|
} |
|
else: |
|
models = request.app.state.MODELS |
|
|
|
model_id = form_data["model"] |
|
if model_id not in models: |
|
raise HTTPException( |
|
status_code=status.HTTP_404_NOT_FOUND, |
|
detail="Model not found", |
|
) |
|
|
|
|
|
|
|
task_model_id = get_task_model_id( |
|
model_id, |
|
request.app.state.config.TASK_MODEL, |
|
request.app.state.config.TASK_MODEL_EXTERNAL, |
|
models, |
|
) |
|
|
|
log.debug( |
|
f"generating image prompt using model {task_model_id} for user {user.email} " |
|
) |
|
|
|
if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "": |
|
template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE |
|
else: |
|
template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE |
|
|
|
content = image_prompt_generation_template( |
|
template, |
|
form_data["messages"], |
|
user={ |
|
"name": user.name, |
|
}, |
|
) |
|
|
|
payload = { |
|
"model": task_model_id, |
|
"messages": [{"role": "user", "content": content}], |
|
"stream": False, |
|
"metadata": { |
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}), |
|
"task": str(TASKS.IMAGE_PROMPT_GENERATION), |
|
"task_body": form_data, |
|
"chat_id": form_data.get("chat_id", None), |
|
}, |
|
} |
|
|
|
|
|
try: |
|
payload = await process_pipeline_inlet_filter(request, payload, user, models) |
|
except Exception as e: |
|
raise e |
|
|
|
try: |
|
return await generate_chat_completion(request, form_data=payload, user=user) |
|
except Exception as e: |
|
log.error("Exception occurred", exc_info=True) |
|
return JSONResponse( |
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
content={"detail": "An internal error has occurred."}, |
|
) |
|
|
|
|
|
@router.post("/queries/completions") |
|
async def generate_queries( |
|
request: Request, form_data: dict, user=Depends(get_verified_user) |
|
): |
|
|
|
type = form_data.get("type") |
|
if type == "web_search": |
|
if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION: |
|
raise HTTPException( |
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
detail=f"Search query generation is disabled", |
|
) |
|
elif type == "retrieval": |
|
if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION: |
|
raise HTTPException( |
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
detail=f"Query generation is disabled", |
|
) |
|
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"): |
|
models = { |
|
request.state.model["id"]: request.state.model, |
|
} |
|
else: |
|
models = request.app.state.MODELS |
|
|
|
model_id = form_data["model"] |
|
if model_id not in models: |
|
raise HTTPException( |
|
status_code=status.HTTP_404_NOT_FOUND, |
|
detail="Model not found", |
|
) |
|
|
|
|
|
|
|
task_model_id = get_task_model_id( |
|
model_id, |
|
request.app.state.config.TASK_MODEL, |
|
request.app.state.config.TASK_MODEL_EXTERNAL, |
|
models, |
|
) |
|
|
|
log.debug( |
|
f"generating {type} queries using model {task_model_id} for user {user.email}" |
|
) |
|
|
|
if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "": |
|
template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE |
|
else: |
|
template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE |
|
|
|
content = query_generation_template( |
|
template, form_data["messages"], {"name": user.name} |
|
) |
|
|
|
payload = { |
|
"model": task_model_id, |
|
"messages": [{"role": "user", "content": content}], |
|
"stream": False, |
|
"metadata": { |
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}), |
|
"task": str(TASKS.QUERY_GENERATION), |
|
"task_body": form_data, |
|
"chat_id": form_data.get("chat_id", None), |
|
}, |
|
} |
|
|
|
|
|
try: |
|
payload = await process_pipeline_inlet_filter(request, payload, user, models) |
|
except Exception as e: |
|
raise e |
|
|
|
try: |
|
return await generate_chat_completion(request, form_data=payload, user=user) |
|
except Exception as e: |
|
return JSONResponse( |
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
content={"detail": str(e)}, |
|
) |
|
|
|
|
|
@router.post("/auto/completions") |
|
async def generate_autocompletion( |
|
request: Request, form_data: dict, user=Depends(get_verified_user) |
|
): |
|
if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION: |
|
raise HTTPException( |
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
detail=f"Autocompletion generation is disabled", |
|
) |
|
|
|
type = form_data.get("type") |
|
prompt = form_data.get("prompt") |
|
messages = form_data.get("messages") |
|
|
|
if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0: |
|
if ( |
|
len(prompt) |
|
> request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH |
|
): |
|
raise HTTPException( |
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", |
|
) |
|
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"): |
|
models = { |
|
request.state.model["id"]: request.state.model, |
|
} |
|
else: |
|
models = request.app.state.MODELS |
|
|
|
model_id = form_data["model"] |
|
if model_id not in models: |
|
raise HTTPException( |
|
status_code=status.HTTP_404_NOT_FOUND, |
|
detail="Model not found", |
|
) |
|
|
|
|
|
|
|
task_model_id = get_task_model_id( |
|
model_id, |
|
request.app.state.config.TASK_MODEL, |
|
request.app.state.config.TASK_MODEL_EXTERNAL, |
|
models, |
|
) |
|
|
|
log.debug( |
|
f"generating autocompletion using model {task_model_id} for user {user.email}" |
|
) |
|
|
|
if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "": |
|
template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE |
|
else: |
|
template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE |
|
|
|
content = autocomplete_generation_template( |
|
template, prompt, messages, type, {"name": user.name} |
|
) |
|
|
|
payload = { |
|
"model": task_model_id, |
|
"messages": [{"role": "user", "content": content}], |
|
"stream": False, |
|
"metadata": { |
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}), |
|
"task": str(TASKS.AUTOCOMPLETE_GENERATION), |
|
"task_body": form_data, |
|
"chat_id": form_data.get("chat_id", None), |
|
}, |
|
} |
|
|
|
|
|
try: |
|
payload = await process_pipeline_inlet_filter(request, payload, user, models) |
|
except Exception as e: |
|
raise e |
|
|
|
try: |
|
return await generate_chat_completion(request, form_data=payload, user=user) |
|
except Exception as e: |
|
log.error(f"Error generating chat completion: {e}") |
|
return JSONResponse( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
content={"detail": "An internal error has occurred."}, |
|
) |
|
|
|
|
|
@router.post("/emoji/completions") |
|
async def generate_emoji( |
|
request: Request, form_data: dict, user=Depends(get_verified_user) |
|
): |
|
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"): |
|
models = { |
|
request.state.model["id"]: request.state.model, |
|
} |
|
else: |
|
models = request.app.state.MODELS |
|
|
|
model_id = form_data["model"] |
|
if model_id not in models: |
|
raise HTTPException( |
|
status_code=status.HTTP_404_NOT_FOUND, |
|
detail="Model not found", |
|
) |
|
|
|
|
|
|
|
task_model_id = get_task_model_id( |
|
model_id, |
|
request.app.state.config.TASK_MODEL, |
|
request.app.state.config.TASK_MODEL_EXTERNAL, |
|
models, |
|
) |
|
|
|
log.debug(f"generating emoji using model {task_model_id} for user {user.email} ") |
|
|
|
template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE |
|
|
|
content = emoji_generation_template( |
|
template, |
|
form_data["prompt"], |
|
{ |
|
"name": user.name, |
|
"location": user.info.get("location") if user.info else None, |
|
}, |
|
) |
|
|
|
payload = { |
|
"model": task_model_id, |
|
"messages": [{"role": "user", "content": content}], |
|
"stream": False, |
|
**( |
|
{"max_tokens": 4} |
|
if models[task_model_id].get("owned_by") == "ollama" |
|
else { |
|
"max_completion_tokens": 4, |
|
} |
|
), |
|
"chat_id": form_data.get("chat_id", None), |
|
"metadata": { |
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}), |
|
"task": str(TASKS.EMOJI_GENERATION), |
|
"task_body": form_data, |
|
}, |
|
} |
|
|
|
|
|
try: |
|
payload = await process_pipeline_inlet_filter(request, payload, user, models) |
|
except Exception as e: |
|
raise e |
|
|
|
try: |
|
return await generate_chat_completion(request, form_data=payload, user=user) |
|
except Exception as e: |
|
return JSONResponse( |
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
content={"detail": str(e)}, |
|
) |
|
|
|
|
|
@router.post("/moa/completions") |
|
async def generate_moa_response( |
|
request: Request, form_data: dict, user=Depends(get_verified_user) |
|
): |
|
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"): |
|
models = { |
|
request.state.model["id"]: request.state.model, |
|
} |
|
else: |
|
models = request.app.state.MODELS |
|
|
|
model_id = form_data["model"] |
|
|
|
if model_id not in models: |
|
raise HTTPException( |
|
status_code=status.HTTP_404_NOT_FOUND, |
|
detail="Model not found", |
|
) |
|
|
|
template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE |
|
|
|
content = moa_response_generation_template( |
|
template, |
|
form_data["prompt"], |
|
form_data["responses"], |
|
) |
|
|
|
payload = { |
|
"model": model_id, |
|
"messages": [{"role": "user", "content": content}], |
|
"stream": form_data.get("stream", False), |
|
"metadata": { |
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}), |
|
"chat_id": form_data.get("chat_id", None), |
|
"task": str(TASKS.MOA_RESPONSE_GENERATION), |
|
"task_body": form_data, |
|
}, |
|
} |
|
|
|
|
|
try: |
|
payload = await process_pipeline_inlet_filter(request, payload, user, models) |
|
except Exception as e: |
|
raise e |
|
|
|
try: |
|
return await generate_chat_completion(request, form_data=payload, user=user) |
|
except Exception as e: |
|
return JSONResponse( |
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
content={"detail": str(e)}, |
|
) |
|
|