|
import time |
|
import logging |
|
import sys |
|
import os |
|
import base64 |
|
|
|
import asyncio |
|
from aiocache import cached |
|
from typing import Any, Optional |
|
import random |
|
import json |
|
import html |
|
import inspect |
|
import re |
|
import ast |
|
|
|
from uuid import uuid4 |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
|
|
from fastapi import Request |
|
from fastapi import BackgroundTasks |
|
|
|
from starlette.responses import Response, StreamingResponse |
|
|
|
|
|
from open_webui.models.chats import Chats |
|
from open_webui.models.users import Users |
|
from open_webui.socket.main import ( |
|
get_event_call, |
|
get_event_emitter, |
|
get_active_status_by_user_id, |
|
) |
|
from open_webui.routers.tasks import ( |
|
generate_queries, |
|
generate_title, |
|
generate_image_prompt, |
|
generate_chat_tags, |
|
) |
|
from open_webui.routers.retrieval import process_web_search, SearchForm |
|
from open_webui.routers.images import image_generations, GenerateImageForm |
|
from open_webui.routers.pipelines import ( |
|
process_pipeline_inlet_filter, |
|
process_pipeline_outlet_filter, |
|
) |
|
|
|
from open_webui.utils.webhook import post_webhook |
|
|
|
|
|
from open_webui.models.users import UserModel |
|
from open_webui.models.functions import Functions |
|
from open_webui.models.models import Models |
|
|
|
from open_webui.retrieval.utils import get_sources_from_files |
|
|
|
|
|
from open_webui.utils.chat import generate_chat_completion |
|
from open_webui.utils.task import ( |
|
get_task_model_id, |
|
rag_template, |
|
tools_function_calling_generation_template, |
|
) |
|
from open_webui.utils.misc import ( |
|
deep_update, |
|
get_message_list, |
|
add_or_update_system_message, |
|
add_or_update_user_message, |
|
get_last_user_message, |
|
get_last_assistant_message, |
|
prepend_to_first_user_message_content, |
|
convert_logit_bias_input_to_json, |
|
) |
|
from open_webui.utils.tools import get_tools |
|
from open_webui.utils.plugin import load_function_module_by_id |
|
from open_webui.utils.filter import ( |
|
get_sorted_filter_ids, |
|
process_filter_functions, |
|
) |
|
from open_webui.utils.code_interpreter import execute_code_jupyter |
|
|
|
from open_webui.tasks import create_task |
|
|
|
from open_webui.config import ( |
|
CACHE_DIR, |
|
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, |
|
DEFAULT_CODE_INTERPRETER_PROMPT, |
|
) |
|
from open_webui.env import ( |
|
SRC_LOG_LEVELS, |
|
GLOBAL_LOG_LEVEL, |
|
BYPASS_MODEL_ACCESS_CONTROL, |
|
ENABLE_REALTIME_CHAT_SAVE, |
|
) |
|
from open_webui.constants import TASKS |
|
|
|
|
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) |
|
log = logging.getLogger(__name__) |
|
log.setLevel(SRC_LOG_LEVELS["MAIN"]) |
|
|
|
|
|
async def chat_completion_tools_handler( |
|
request: Request, body: dict, user: UserModel, models, tools |
|
) -> tuple[dict, dict]: |
|
async def get_content_from_response(response) -> Optional[str]: |
|
content = None |
|
if hasattr(response, "body_iterator"): |
|
async for chunk in response.body_iterator: |
|
data = json.loads(chunk.decode("utf-8")) |
|
content = data["choices"][0]["message"]["content"] |
|
|
|
|
|
if response.background is not None: |
|
await response.background() |
|
else: |
|
content = response["choices"][0]["message"]["content"] |
|
return content |
|
|
|
def get_tools_function_calling_payload(messages, task_model_id, content): |
|
user_message = get_last_user_message(messages) |
|
history = "\n".join( |
|
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" |
|
for message in messages[::-1][:4] |
|
) |
|
|
|
prompt = f"History:\n{history}\nQuery: {user_message}" |
|
|
|
return { |
|
"model": task_model_id, |
|
"messages": [ |
|
{"role": "system", "content": content}, |
|
{"role": "user", "content": f"Query: {prompt}"}, |
|
], |
|
"stream": False, |
|
"metadata": {"task": str(TASKS.FUNCTION_CALLING)}, |
|
} |
|
|
|
task_model_id = get_task_model_id( |
|
body["model"], |
|
request.app.state.config.TASK_MODEL, |
|
request.app.state.config.TASK_MODEL_EXTERNAL, |
|
models, |
|
) |
|
|
|
skip_files = False |
|
sources = [] |
|
|
|
specs = [tool["spec"] for tool in tools.values()] |
|
tools_specs = json.dumps(specs) |
|
|
|
if request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "": |
|
template = request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE |
|
else: |
|
template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE |
|
|
|
tools_function_calling_prompt = tools_function_calling_generation_template( |
|
template, tools_specs |
|
) |
|
log.info(f"{tools_function_calling_prompt=}") |
|
payload = get_tools_function_calling_payload( |
|
body["messages"], task_model_id, tools_function_calling_prompt |
|
) |
|
|
|
try: |
|
response = await generate_chat_completion(request, form_data=payload, user=user) |
|
log.debug(f"{response=}") |
|
content = await get_content_from_response(response) |
|
log.debug(f"{content=}") |
|
|
|
if not content: |
|
return body, {} |
|
|
|
try: |
|
content = content[content.find("{") : content.rfind("}") + 1] |
|
if not content: |
|
raise Exception("No JSON object found in the response") |
|
|
|
result = json.loads(content) |
|
|
|
async def tool_call_handler(tool_call): |
|
nonlocal skip_files |
|
|
|
log.debug(f"{tool_call=}") |
|
|
|
tool_function_name = tool_call.get("name", None) |
|
if tool_function_name not in tools: |
|
return body, {} |
|
|
|
tool_function_params = tool_call.get("parameters", {}) |
|
|
|
try: |
|
required_params = ( |
|
tools[tool_function_name] |
|
.get("spec", {}) |
|
.get("parameters", {}) |
|
.get("required", []) |
|
) |
|
tool_function = tools[tool_function_name]["callable"] |
|
tool_function_params = { |
|
k: v |
|
for k, v in tool_function_params.items() |
|
if k in required_params |
|
} |
|
tool_output = await tool_function(**tool_function_params) |
|
|
|
except Exception as e: |
|
tool_output = str(e) |
|
|
|
if isinstance(tool_output, str): |
|
if tools[tool_function_name]["citation"]: |
|
sources.append( |
|
{ |
|
"source": { |
|
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" |
|
}, |
|
"document": [tool_output], |
|
"metadata": [ |
|
{ |
|
"source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" |
|
} |
|
], |
|
} |
|
) |
|
else: |
|
sources.append( |
|
{ |
|
"source": {}, |
|
"document": [tool_output], |
|
"metadata": [ |
|
{ |
|
"source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" |
|
} |
|
], |
|
} |
|
) |
|
|
|
if tools[tool_function_name]["file_handler"]: |
|
skip_files = True |
|
|
|
|
|
if result.get("tool_calls"): |
|
for tool_call in result.get("tool_calls"): |
|
await tool_call_handler(tool_call) |
|
else: |
|
await tool_call_handler(result) |
|
|
|
except Exception as e: |
|
log.exception(f"Error: {e}") |
|
content = None |
|
except Exception as e: |
|
log.exception(f"Error: {e}") |
|
content = None |
|
|
|
log.debug(f"tool_contexts: {sources}") |
|
|
|
if skip_files and "files" in body.get("metadata", {}): |
|
del body["metadata"]["files"] |
|
|
|
return body, {"sources": sources} |
|
|
|
|
|
async def chat_web_search_handler( |
|
request: Request, form_data: dict, extra_params: dict, user |
|
): |
|
event_emitter = extra_params["__event_emitter__"] |
|
await event_emitter( |
|
{ |
|
"type": "status", |
|
"data": { |
|
"action": "web_search", |
|
"description": "Generating search query", |
|
"done": False, |
|
}, |
|
} |
|
) |
|
|
|
messages = form_data["messages"] |
|
user_message = get_last_user_message(messages) |
|
|
|
queries = [] |
|
try: |
|
res = await generate_queries( |
|
request, |
|
{ |
|
"model": form_data["model"], |
|
"messages": messages, |
|
"prompt": user_message, |
|
"type": "web_search", |
|
}, |
|
user, |
|
) |
|
|
|
response = res["choices"][0]["message"]["content"] |
|
|
|
try: |
|
bracket_start = response.find("{") |
|
bracket_end = response.rfind("}") + 1 |
|
|
|
if bracket_start == -1 or bracket_end == -1: |
|
raise Exception("No JSON object found in the response") |
|
|
|
response = response[bracket_start:bracket_end] |
|
queries = json.loads(response) |
|
queries = queries.get("queries", []) |
|
except Exception as e: |
|
queries = [response] |
|
|
|
except Exception as e: |
|
log.exception(e) |
|
queries = [user_message] |
|
|
|
if len(queries) == 0: |
|
await event_emitter( |
|
{ |
|
"type": "status", |
|
"data": { |
|
"action": "web_search", |
|
"description": "No search query generated", |
|
"done": True, |
|
}, |
|
} |
|
) |
|
return form_data |
|
|
|
all_results = [] |
|
|
|
for searchQuery in queries: |
|
await event_emitter( |
|
{ |
|
"type": "status", |
|
"data": { |
|
"action": "web_search", |
|
"description": 'Searching "{{searchQuery}}"', |
|
"query": searchQuery, |
|
"done": False, |
|
}, |
|
} |
|
) |
|
|
|
try: |
|
results = await process_web_search( |
|
request, |
|
SearchForm( |
|
**{ |
|
"query": searchQuery, |
|
} |
|
), |
|
user=user, |
|
) |
|
|
|
if results: |
|
all_results.append(results) |
|
files = form_data.get("files", []) |
|
|
|
if results.get("collection_name"): |
|
files.append( |
|
{ |
|
"collection_name": results["collection_name"], |
|
"name": searchQuery, |
|
"type": "web_search", |
|
"urls": results["filenames"], |
|
} |
|
) |
|
elif results.get("docs"): |
|
files.append( |
|
{ |
|
"docs": results.get("docs", []), |
|
"name": searchQuery, |
|
"type": "web_search", |
|
"urls": results["filenames"], |
|
} |
|
) |
|
|
|
form_data["files"] = files |
|
except Exception as e: |
|
log.exception(e) |
|
await event_emitter( |
|
{ |
|
"type": "status", |
|
"data": { |
|
"action": "web_search", |
|
"description": 'Error searching "{{searchQuery}}"', |
|
"query": searchQuery, |
|
"done": True, |
|
"error": True, |
|
}, |
|
} |
|
) |
|
|
|
if all_results: |
|
urls = [] |
|
for results in all_results: |
|
if "filenames" in results: |
|
urls.extend(results["filenames"]) |
|
|
|
await event_emitter( |
|
{ |
|
"type": "status", |
|
"data": { |
|
"action": "web_search", |
|
"description": "Searched {{count}} sites", |
|
"urls": urls, |
|
"done": True, |
|
}, |
|
} |
|
) |
|
else: |
|
await event_emitter( |
|
{ |
|
"type": "status", |
|
"data": { |
|
"action": "web_search", |
|
"description": "No search results found", |
|
"done": True, |
|
"error": True, |
|
}, |
|
} |
|
) |
|
|
|
return form_data |
|
|
|
|
|
async def chat_image_generation_handler( |
|
request: Request, form_data: dict, extra_params: dict, user |
|
): |
|
__event_emitter__ = extra_params["__event_emitter__"] |
|
await __event_emitter__( |
|
{ |
|
"type": "status", |
|
"data": {"description": "Generating an image", "done": False}, |
|
} |
|
) |
|
|
|
messages = form_data["messages"] |
|
user_message = get_last_user_message(messages) |
|
|
|
prompt = user_message |
|
negative_prompt = "" |
|
|
|
if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION: |
|
try: |
|
res = await generate_image_prompt( |
|
request, |
|
{ |
|
"model": form_data["model"], |
|
"messages": messages, |
|
}, |
|
user, |
|
) |
|
|
|
response = res["choices"][0]["message"]["content"] |
|
|
|
try: |
|
bracket_start = response.find("{") |
|
bracket_end = response.rfind("}") + 1 |
|
|
|
if bracket_start == -1 or bracket_end == -1: |
|
raise Exception("No JSON object found in the response") |
|
|
|
response = response[bracket_start:bracket_end] |
|
response = json.loads(response) |
|
prompt = response.get("prompt", []) |
|
except Exception as e: |
|
prompt = user_message |
|
|
|
except Exception as e: |
|
log.exception(e) |
|
prompt = user_message |
|
|
|
system_message_content = "" |
|
|
|
try: |
|
images = await image_generations( |
|
request=request, |
|
form_data=GenerateImageForm(**{"prompt": prompt}), |
|
user=user, |
|
) |
|
|
|
await __event_emitter__( |
|
{ |
|
"type": "status", |
|
"data": {"description": "Generated an image", "done": True}, |
|
} |
|
) |
|
|
|
for image in images: |
|
await __event_emitter__( |
|
{ |
|
"type": "message", |
|
"data": {"content": f"\n"}, |
|
} |
|
) |
|
|
|
system_message_content = "<context>User is shown the generated image, tell the user that the image has been generated</context>" |
|
except Exception as e: |
|
log.exception(e) |
|
await __event_emitter__( |
|
{ |
|
"type": "status", |
|
"data": { |
|
"description": f"An error occurred while generating an image", |
|
"done": True, |
|
}, |
|
} |
|
) |
|
|
|
system_message_content = "<context>Unable to generate an image, tell the user that an error occurred</context>" |
|
|
|
if system_message_content: |
|
form_data["messages"] = add_or_update_system_message( |
|
system_message_content, form_data["messages"] |
|
) |
|
|
|
return form_data |
|
|
|
|
|
async def chat_completion_files_handler( |
|
request: Request, body: dict, user: UserModel |
|
) -> tuple[dict, dict[str, list]]: |
|
sources = [] |
|
|
|
if files := body.get("metadata", {}).get("files", None): |
|
queries = [] |
|
try: |
|
queries_response = await generate_queries( |
|
request, |
|
{ |
|
"model": body["model"], |
|
"messages": body["messages"], |
|
"type": "retrieval", |
|
}, |
|
user, |
|
) |
|
queries_response = queries_response["choices"][0]["message"]["content"] |
|
|
|
try: |
|
bracket_start = queries_response.find("{") |
|
bracket_end = queries_response.rfind("}") + 1 |
|
|
|
if bracket_start == -1 or bracket_end == -1: |
|
raise Exception("No JSON object found in the response") |
|
|
|
queries_response = queries_response[bracket_start:bracket_end] |
|
queries_response = json.loads(queries_response) |
|
except Exception as e: |
|
queries_response = {"queries": [queries_response]} |
|
|
|
queries = queries_response.get("queries", []) |
|
except: |
|
pass |
|
|
|
if len(queries) == 0: |
|
queries = [get_last_user_message(body["messages"])] |
|
|
|
try: |
|
|
|
loop = asyncio.get_running_loop() |
|
with ThreadPoolExecutor() as executor: |
|
sources = await loop.run_in_executor( |
|
executor, |
|
lambda: get_sources_from_files( |
|
request=request, |
|
files=files, |
|
queries=queries, |
|
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION( |
|
query, user=user |
|
), |
|
k=request.app.state.config.TOP_K, |
|
reranking_function=request.app.state.rf, |
|
r=request.app.state.config.RELEVANCE_THRESHOLD, |
|
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, |
|
full_context=request.app.state.config.RAG_FULL_CONTEXT, |
|
), |
|
) |
|
except Exception as e: |
|
log.exception(e) |
|
|
|
log.debug(f"rag_contexts:sources: {sources}") |
|
|
|
return body, {"sources": sources} |
|
|
|
|
|
def apply_params_to_form_data(form_data, model): |
|
params = form_data.pop("params", {}) |
|
if model.get("ollama"): |
|
form_data["options"] = params |
|
|
|
if "format" in params: |
|
form_data["format"] = params["format"] |
|
|
|
if "keep_alive" in params: |
|
form_data["keep_alive"] = params["keep_alive"] |
|
else: |
|
if "seed" in params: |
|
form_data["seed"] = params["seed"] |
|
|
|
if "stop" in params: |
|
form_data["stop"] = params["stop"] |
|
|
|
if "temperature" in params: |
|
form_data["temperature"] = params["temperature"] |
|
|
|
if "max_tokens" in params: |
|
form_data["max_tokens"] = params["max_tokens"] |
|
|
|
if "top_p" in params: |
|
form_data["top_p"] = params["top_p"] |
|
|
|
if "frequency_penalty" in params: |
|
form_data["frequency_penalty"] = params["frequency_penalty"] |
|
|
|
if "reasoning_effort" in params: |
|
form_data["reasoning_effort"] = params["reasoning_effort"] |
|
if "logit_bias" in params: |
|
try: |
|
form_data["logit_bias"] = json.loads( |
|
convert_logit_bias_input_to_json(params["logit_bias"]) |
|
) |
|
except Exception as e: |
|
print(f"Error parsing logit_bias: {e}") |
|
|
|
return form_data |
|
|
|
|
|
async def process_chat_payload(request, form_data, user, metadata, model): |
|
|
|
form_data = apply_params_to_form_data(form_data, model) |
|
log.debug(f"form_data: {form_data}") |
|
|
|
event_emitter = get_event_emitter(metadata) |
|
event_call = get_event_call(metadata) |
|
|
|
extra_params = { |
|
"__event_emitter__": event_emitter, |
|
"__event_call__": event_call, |
|
"__user__": { |
|
"id": user.id, |
|
"email": user.email, |
|
"name": user.name, |
|
"role": user.role, |
|
}, |
|
"__metadata__": metadata, |
|
"__request__": request, |
|
"__model__": model, |
|
} |
|
|
|
|
|
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"): |
|
models = { |
|
request.state.model["id"]: request.state.model, |
|
} |
|
else: |
|
models = request.app.state.MODELS |
|
|
|
task_model_id = get_task_model_id( |
|
form_data["model"], |
|
request.app.state.config.TASK_MODEL, |
|
request.app.state.config.TASK_MODEL_EXTERNAL, |
|
models, |
|
) |
|
|
|
events = [] |
|
sources = [] |
|
|
|
user_message = get_last_user_message(form_data["messages"]) |
|
model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False) |
|
|
|
if model_knowledge: |
|
await event_emitter( |
|
{ |
|
"type": "status", |
|
"data": { |
|
"action": "knowledge_search", |
|
"query": user_message, |
|
"done": False, |
|
}, |
|
} |
|
) |
|
|
|
knowledge_files = [] |
|
for item in model_knowledge: |
|
if item.get("collection_name"): |
|
knowledge_files.append( |
|
{ |
|
"id": item.get("collection_name"), |
|
"name": item.get("name"), |
|
"legacy": True, |
|
} |
|
) |
|
elif item.get("collection_names"): |
|
knowledge_files.append( |
|
{ |
|
"name": item.get("name"), |
|
"type": "collection", |
|
"collection_names": item.get("collection_names"), |
|
"legacy": True, |
|
} |
|
) |
|
else: |
|
knowledge_files.append(item) |
|
|
|
files = form_data.get("files", []) |
|
files.extend(knowledge_files) |
|
form_data["files"] = files |
|
|
|
variables = form_data.pop("variables", None) |
|
|
|
|
|
try: |
|
form_data = await process_pipeline_inlet_filter( |
|
request, form_data, user, models |
|
) |
|
except Exception as e: |
|
raise e |
|
|
|
try: |
|
filter_functions = [ |
|
Functions.get_function_by_id(filter_id) |
|
for filter_id in get_sorted_filter_ids(model) |
|
] |
|
|
|
form_data, flags = await process_filter_functions( |
|
request=request, |
|
filter_functions=filter_functions, |
|
filter_type="inlet", |
|
form_data=form_data, |
|
extra_params=extra_params, |
|
) |
|
except Exception as e: |
|
raise Exception(f"Error: {e}") |
|
|
|
features = form_data.pop("features", None) |
|
if features: |
|
if "web_search" in features and features["web_search"]: |
|
form_data = await chat_web_search_handler( |
|
request, form_data, extra_params, user |
|
) |
|
|
|
if "image_generation" in features and features["image_generation"]: |
|
form_data = await chat_image_generation_handler( |
|
request, form_data, extra_params, user |
|
) |
|
|
|
if "code_interpreter" in features and features["code_interpreter"]: |
|
form_data["messages"] = add_or_update_user_message( |
|
( |
|
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE |
|
if request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE != "" |
|
else DEFAULT_CODE_INTERPRETER_PROMPT |
|
), |
|
form_data["messages"], |
|
) |
|
|
|
tool_ids = form_data.pop("tool_ids", None) |
|
files = form_data.pop("files", None) |
|
|
|
|
|
if files: |
|
files = list({json.dumps(f, sort_keys=True): f for f in files}.values()) |
|
|
|
metadata = { |
|
**metadata, |
|
"tool_ids": tool_ids, |
|
"files": files, |
|
} |
|
form_data["metadata"] = metadata |
|
|
|
tool_ids = metadata.get("tool_ids", None) |
|
log.debug(f"{tool_ids=}") |
|
|
|
if tool_ids: |
|
|
|
tools = get_tools( |
|
request, |
|
tool_ids, |
|
user, |
|
{ |
|
**extra_params, |
|
"__model__": models[task_model_id], |
|
"__messages__": form_data["messages"], |
|
"__files__": metadata.get("files", []), |
|
}, |
|
) |
|
log.info(f"{tools=}") |
|
|
|
if metadata.get("function_calling") == "native": |
|
|
|
metadata["tools"] = tools |
|
form_data["tools"] = [ |
|
{"type": "function", "function": tool.get("spec", {})} |
|
for tool in tools.values() |
|
] |
|
else: |
|
|
|
try: |
|
form_data, flags = await chat_completion_tools_handler( |
|
request, form_data, user, models, tools |
|
) |
|
sources.extend(flags.get("sources", [])) |
|
|
|
except Exception as e: |
|
log.exception(e) |
|
|
|
try: |
|
form_data, flags = await chat_completion_files_handler(request, form_data, user) |
|
sources.extend(flags.get("sources", [])) |
|
except Exception as e: |
|
log.exception(e) |
|
|
|
|
|
if len(sources) > 0: |
|
context_string = "" |
|
for source_idx, source in enumerate(sources): |
|
if "document" in source: |
|
for doc_idx, doc_context in enumerate(source["document"]): |
|
context_string += f"<source><source_id>{source_idx}</source_id><source_context>{doc_context}</source_context></source>\n" |
|
|
|
context_string = context_string.strip() |
|
prompt = get_last_user_message(form_data["messages"]) |
|
|
|
if prompt is None: |
|
raise Exception("No user message found") |
|
if ( |
|
request.app.state.config.RELEVANCE_THRESHOLD == 0 |
|
and context_string.strip() == "" |
|
): |
|
log.debug( |
|
f"With a 0 relevancy threshold for RAG, the context cannot be empty" |
|
) |
|
|
|
|
|
|
|
if model.get("owned_by") == "ollama": |
|
form_data["messages"] = prepend_to_first_user_message_content( |
|
rag_template( |
|
request.app.state.config.RAG_TEMPLATE, context_string, prompt |
|
), |
|
form_data["messages"], |
|
) |
|
else: |
|
form_data["messages"] = add_or_update_system_message( |
|
rag_template( |
|
request.app.state.config.RAG_TEMPLATE, context_string, prompt |
|
), |
|
form_data["messages"], |
|
) |
|
|
|
|
|
sources = [source for source in sources if source.get("source", {}).get("name", "")] |
|
|
|
if len(sources) > 0: |
|
events.append({"sources": sources}) |
|
|
|
if model_knowledge: |
|
await event_emitter( |
|
{ |
|
"type": "status", |
|
"data": { |
|
"action": "knowledge_search", |
|
"query": user_message, |
|
"done": True, |
|
"hidden": True, |
|
}, |
|
} |
|
) |
|
|
|
return form_data, metadata, events |
|
|
|
|
|
async def process_chat_response( |
|
request, response, form_data, user, metadata, model, events, tasks |
|
): |
|
async def background_tasks_handler(): |
|
message_map = Chats.get_messages_by_chat_id(metadata["chat_id"]) |
|
message = message_map.get(metadata["message_id"]) if message_map else None |
|
|
|
if message: |
|
messages = get_message_list(message_map, message.get("id")) |
|
|
|
if tasks and messages: |
|
if TASKS.TITLE_GENERATION in tasks: |
|
if tasks[TASKS.TITLE_GENERATION]: |
|
res = await generate_title( |
|
request, |
|
{ |
|
"model": message["model"], |
|
"messages": messages, |
|
"chat_id": metadata["chat_id"], |
|
}, |
|
user, |
|
) |
|
|
|
if res and isinstance(res, dict): |
|
if len(res.get("choices", [])) == 1: |
|
title_string = ( |
|
res.get("choices", [])[0] |
|
.get("message", {}) |
|
.get("content", message.get("content", "New Chat")) |
|
) |
|
else: |
|
title_string = "" |
|
|
|
title_string = title_string[ |
|
title_string.find("{") : title_string.rfind("}") + 1 |
|
] |
|
|
|
try: |
|
title = json.loads(title_string).get( |
|
"title", "New Chat" |
|
) |
|
except Exception as e: |
|
title = "" |
|
|
|
if not title: |
|
title = messages[0].get("content", "New Chat") |
|
|
|
Chats.update_chat_title_by_id(metadata["chat_id"], title) |
|
|
|
await event_emitter( |
|
{ |
|
"type": "chat:title", |
|
"data": title, |
|
} |
|
) |
|
elif len(messages) == 2: |
|
title = messages[0].get("content", "New Chat") |
|
|
|
Chats.update_chat_title_by_id(metadata["chat_id"], title) |
|
|
|
await event_emitter( |
|
{ |
|
"type": "chat:title", |
|
"data": message.get("content", "New Chat"), |
|
} |
|
) |
|
|
|
if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]: |
|
res = await generate_chat_tags( |
|
request, |
|
{ |
|
"model": message["model"], |
|
"messages": messages, |
|
"chat_id": metadata["chat_id"], |
|
}, |
|
user, |
|
) |
|
|
|
if res and isinstance(res, dict): |
|
if len(res.get("choices", [])) == 1: |
|
tags_string = ( |
|
res.get("choices", [])[0] |
|
.get("message", {}) |
|
.get("content", "") |
|
) |
|
else: |
|
tags_string = "" |
|
|
|
tags_string = tags_string[ |
|
tags_string.find("{") : tags_string.rfind("}") + 1 |
|
] |
|
|
|
try: |
|
tags = json.loads(tags_string).get("tags", []) |
|
Chats.update_chat_tags_by_id( |
|
metadata["chat_id"], tags, user |
|
) |
|
|
|
await event_emitter( |
|
{ |
|
"type": "chat:tags", |
|
"data": tags, |
|
} |
|
) |
|
except Exception as e: |
|
pass |
|
|
|
event_emitter = None |
|
event_caller = None |
|
if ( |
|
"session_id" in metadata |
|
and metadata["session_id"] |
|
and "chat_id" in metadata |
|
and metadata["chat_id"] |
|
and "message_id" in metadata |
|
and metadata["message_id"] |
|
): |
|
event_emitter = get_event_emitter(metadata) |
|
event_caller = get_event_call(metadata) |
|
|
|
|
|
if not isinstance(response, StreamingResponse): |
|
if event_emitter: |
|
if "selected_model_id" in response: |
|
Chats.upsert_message_to_chat_by_id_and_message_id( |
|
metadata["chat_id"], |
|
metadata["message_id"], |
|
{ |
|
"selectedModelId": response["selected_model_id"], |
|
}, |
|
) |
|
|
|
if response.get("choices", [])[0].get("message", {}).get("content"): |
|
content = response["choices"][0]["message"]["content"] |
|
|
|
if content: |
|
|
|
await event_emitter( |
|
{ |
|
"type": "chat:completion", |
|
"data": response, |
|
} |
|
) |
|
|
|
title = Chats.get_chat_title_by_id(metadata["chat_id"]) |
|
|
|
await event_emitter( |
|
{ |
|
"type": "chat:completion", |
|
"data": { |
|
"done": True, |
|
"content": content, |
|
"title": title, |
|
}, |
|
} |
|
) |
|
|
|
|
|
Chats.upsert_message_to_chat_by_id_and_message_id( |
|
metadata["chat_id"], |
|
metadata["message_id"], |
|
{ |
|
"content": content, |
|
}, |
|
) |
|
|
|
|
|
if get_active_status_by_user_id(user.id) is None: |
|
webhook_url = Users.get_user_webhook_url_by_id(user.id) |
|
if webhook_url: |
|
post_webhook( |
|
request.app.state.WEBUI_NAME, |
|
webhook_url, |
|
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", |
|
{ |
|
"action": "chat", |
|
"message": content, |
|
"title": title, |
|
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}", |
|
}, |
|
) |
|
|
|
await background_tasks_handler() |
|
|
|
return response |
|
else: |
|
return response |
|
|
|
|
|
if not any( |
|
content_type in response.headers["Content-Type"] |
|
for content_type in ["text/event-stream", "application/x-ndjson"] |
|
): |
|
return response |
|
|
|
extra_params = { |
|
"__event_emitter__": event_emitter, |
|
"__event_call__": event_caller, |
|
"__user__": { |
|
"id": user.id, |
|
"email": user.email, |
|
"name": user.name, |
|
"role": user.role, |
|
}, |
|
"__metadata__": metadata, |
|
"__request__": request, |
|
"__model__": model, |
|
} |
|
filter_functions = [ |
|
Functions.get_function_by_id(filter_id) |
|
for filter_id in get_sorted_filter_ids(model) |
|
] |
|
|
|
print(f"{filter_functions=}") |
|
|
|
|
|
if event_emitter and event_caller: |
|
task_id = str(uuid4()) |
|
model_id = form_data.get("model", "") |
|
|
|
Chats.upsert_message_to_chat_by_id_and_message_id( |
|
metadata["chat_id"], |
|
metadata["message_id"], |
|
{ |
|
"model": model_id, |
|
}, |
|
) |
|
|
|
def split_content_and_whitespace(content): |
|
content_stripped = content.rstrip() |
|
original_whitespace = ( |
|
content[len(content_stripped) :] |
|
if len(content) > len(content_stripped) |
|
else "" |
|
) |
|
return content_stripped, original_whitespace |
|
|
|
def is_opening_code_block(content): |
|
backtick_segments = content.split("```") |
|
|
|
return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0 |
|
|
|
|
|
async def post_response_handler(response, events): |
|
def serialize_content_blocks(content_blocks, raw=False): |
|
content = "" |
|
|
|
for block in content_blocks: |
|
if block["type"] == "text": |
|
content = f"{content}{block['content'].strip()}\n" |
|
elif block["type"] == "tool_calls": |
|
attributes = block.get("attributes", {}) |
|
|
|
block_content = block.get("content", []) |
|
results = block.get("results", []) |
|
|
|
if results: |
|
|
|
result_display_content = "" |
|
|
|
for result in results: |
|
tool_call_id = result.get("tool_call_id", "") |
|
tool_name = "" |
|
|
|
for tool_call in block_content: |
|
if tool_call.get("id", "") == tool_call_id: |
|
tool_name = tool_call.get("function", {}).get( |
|
"name", "" |
|
) |
|
break |
|
|
|
result_display_content = f"{result_display_content}\n> {tool_name}: {result.get('content', '')}" |
|
|
|
if not raw: |
|
content = f'{content}\n<details type="tool_calls" done="true" content="{html.escape(json.dumps(block_content))}" results="{html.escape(json.dumps(results))}">\n<summary>Tool Executed</summary>\n{result_display_content}\n</details>\n' |
|
else: |
|
tool_calls_display_content = "" |
|
|
|
for tool_call in block_content: |
|
tool_calls_display_content = f"{tool_calls_display_content}\n> Executing {tool_call.get('function', {}).get('name', '')}" |
|
|
|
if not raw: |
|
content = f'{content}\n<details type="tool_calls" done="false" content="{html.escape(json.dumps(block_content))}">\n<summary>Tool Executing...</summary>\n{tool_calls_display_content}\n</details>\n' |
|
|
|
elif block["type"] == "reasoning": |
|
reasoning_display_content = "\n".join( |
|
(f"> {line}" if not line.startswith(">") else line) |
|
for line in block["content"].splitlines() |
|
) |
|
|
|
reasoning_duration = block.get("duration", None) |
|
|
|
if reasoning_duration is not None: |
|
if raw: |
|
content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n' |
|
else: |
|
content = f'{content}\n<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n' |
|
else: |
|
if raw: |
|
content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n' |
|
else: |
|
content = f'{content}\n<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n' |
|
|
|
elif block["type"] == "code_interpreter": |
|
attributes = block.get("attributes", {}) |
|
output = block.get("output", None) |
|
lang = attributes.get("lang", "") |
|
|
|
content_stripped, original_whitespace = ( |
|
split_content_and_whitespace(content) |
|
) |
|
if is_opening_code_block(content_stripped): |
|
|
|
content = ( |
|
content_stripped.rstrip("`").rstrip() |
|
+ original_whitespace |
|
) |
|
else: |
|
|
|
content = content_stripped + original_whitespace |
|
|
|
if output: |
|
output = html.escape(json.dumps(output)) |
|
|
|
if raw: |
|
content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n```output\n{output}\n```\n' |
|
else: |
|
content = f'{content}\n<details type="code_interpreter" done="true" output="{output}">\n<summary>Analyzed</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n' |
|
else: |
|
if raw: |
|
content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n' |
|
else: |
|
content = f'{content}\n<details type="code_interpreter" done="false">\n<summary>Analyzing...</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n' |
|
|
|
else: |
|
block_content = str(block["content"]).strip() |
|
content = f"{content}{block['type']}: {block_content}\n" |
|
|
|
return content.strip() |
|
|
|
def convert_content_blocks_to_messages(content_blocks): |
|
messages = [] |
|
|
|
temp_blocks = [] |
|
for idx, block in enumerate(content_blocks): |
|
if block["type"] == "tool_calls": |
|
messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": serialize_content_blocks(temp_blocks), |
|
"tool_calls": block.get("content"), |
|
} |
|
) |
|
|
|
results = block.get("results", []) |
|
|
|
for result in results: |
|
messages.append( |
|
{ |
|
"role": "tool", |
|
"tool_call_id": result["tool_call_id"], |
|
"content": result["content"], |
|
} |
|
) |
|
temp_blocks = [] |
|
else: |
|
temp_blocks.append(block) |
|
|
|
if temp_blocks: |
|
content = serialize_content_blocks(temp_blocks) |
|
if content: |
|
messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": content, |
|
} |
|
) |
|
|
|
return messages |
|
|
|
def tag_content_handler(content_type, tags, content, content_blocks): |
|
end_flag = False |
|
|
|
def extract_attributes(tag_content): |
|
"""Extract attributes from a tag if they exist.""" |
|
attributes = {} |
|
if not tag_content: |
|
return attributes |
|
|
|
matches = re.findall(r'(\w+)\s*=\s*"([^"]+)"', tag_content) |
|
for key, value in matches: |
|
attributes[key] = value |
|
return attributes |
|
|
|
if content_blocks[-1]["type"] == "text": |
|
for start_tag, end_tag in tags: |
|
|
|
start_tag_pattern = rf"<{re.escape(start_tag)}(\s.*?)?>" |
|
match = re.search(start_tag_pattern, content) |
|
if match: |
|
attr_content = ( |
|
match.group(1) if match.group(1) else "" |
|
) |
|
attributes = extract_attributes( |
|
attr_content |
|
) |
|
|
|
|
|
before_tag = content[ |
|
: match.start() |
|
] |
|
after_tag = content[ |
|
match.end() : |
|
] |
|
|
|
|
|
content_blocks[-1]["content"] = content_blocks[-1][ |
|
"content" |
|
].replace(match.group(0) + after_tag, "") |
|
|
|
if before_tag: |
|
content_blocks[-1]["content"] = before_tag |
|
|
|
if not content_blocks[-1]["content"]: |
|
content_blocks.pop() |
|
|
|
|
|
content_blocks.append( |
|
{ |
|
"type": content_type, |
|
"start_tag": start_tag, |
|
"end_tag": end_tag, |
|
"attributes": attributes, |
|
"content": "", |
|
"started_at": time.time(), |
|
} |
|
) |
|
|
|
if after_tag: |
|
content_blocks[-1]["content"] = after_tag |
|
|
|
break |
|
elif content_blocks[-1]["type"] == content_type: |
|
start_tag = content_blocks[-1]["start_tag"] |
|
end_tag = content_blocks[-1]["end_tag"] |
|
|
|
end_tag_pattern = rf"<{re.escape(end_tag)}>" |
|
|
|
|
|
if re.search(end_tag_pattern, content): |
|
end_flag = True |
|
|
|
block_content = content_blocks[-1]["content"] |
|
|
|
start_tag_pattern = rf"<{re.escape(start_tag)}(.*?)>" |
|
block_content = re.sub( |
|
start_tag_pattern, "", block_content |
|
).strip() |
|
|
|
end_tag_regex = re.compile(end_tag_pattern, re.DOTALL) |
|
split_content = end_tag_regex.split(block_content, maxsplit=1) |
|
|
|
|
|
block_content = ( |
|
split_content[0].strip() if split_content else "" |
|
) |
|
|
|
|
|
leftover_content = ( |
|
split_content[1].strip() if len(split_content) > 1 else "" |
|
) |
|
|
|
if block_content: |
|
content_blocks[-1]["content"] = block_content |
|
content_blocks[-1]["ended_at"] = time.time() |
|
content_blocks[-1]["duration"] = int( |
|
content_blocks[-1]["ended_at"] |
|
- content_blocks[-1]["started_at"] |
|
) |
|
|
|
|
|
if content_type != "code_interpreter": |
|
if leftover_content: |
|
|
|
content_blocks.append( |
|
{ |
|
"type": "text", |
|
"content": leftover_content, |
|
} |
|
) |
|
else: |
|
content_blocks.append( |
|
{ |
|
"type": "text", |
|
"content": "", |
|
} |
|
) |
|
|
|
else: |
|
|
|
content_blocks.pop() |
|
|
|
if leftover_content: |
|
content_blocks.append( |
|
{ |
|
"type": "text", |
|
"content": leftover_content, |
|
} |
|
) |
|
else: |
|
content_blocks.append( |
|
{ |
|
"type": "text", |
|
"content": "", |
|
} |
|
) |
|
|
|
|
|
content = re.sub( |
|
rf"<{re.escape(start_tag)}(.*?)>(.|\n)*?<{re.escape(end_tag)}>", |
|
"", |
|
content, |
|
flags=re.DOTALL, |
|
) |
|
|
|
return content, content_blocks, end_flag |
|
|
|
message = Chats.get_message_by_id_and_message_id( |
|
metadata["chat_id"], metadata["message_id"] |
|
) |
|
|
|
tool_calls = [] |
|
|
|
last_assistant_message = None |
|
try: |
|
if form_data["messages"][-1]["role"] == "assistant": |
|
last_assistant_message = get_last_assistant_message( |
|
form_data["messages"] |
|
) |
|
except Exception as e: |
|
pass |
|
|
|
content = ( |
|
message.get("content", "") |
|
if message |
|
else last_assistant_message if last_assistant_message else "" |
|
) |
|
|
|
content_blocks = [ |
|
{ |
|
"type": "text", |
|
"content": content, |
|
} |
|
] |
|
|
|
|
|
DETECT_REASONING = True |
|
DETECT_SOLUTION = True |
|
DETECT_CODE_INTERPRETER = metadata.get("features", {}).get( |
|
"code_interpreter", False |
|
) |
|
|
|
reasoning_tags = [ |
|
("think", "/think"), |
|
("thinking", "/thinking"), |
|
("reason", "/reason"), |
|
("reasoning", "/reasoning"), |
|
("thought", "/thought"), |
|
("Thought", "/Thought"), |
|
("|begin_of_thought|", "|end_of_thought|"), |
|
] |
|
|
|
code_interpreter_tags = [("code_interpreter", "/code_interpreter")] |
|
|
|
solution_tags = [("|begin_of_solution|", "|end_of_solution|")] |
|
|
|
try: |
|
for event in events: |
|
await event_emitter( |
|
{ |
|
"type": "chat:completion", |
|
"data": event, |
|
} |
|
) |
|
|
|
|
|
Chats.upsert_message_to_chat_by_id_and_message_id( |
|
metadata["chat_id"], |
|
metadata["message_id"], |
|
{ |
|
**event, |
|
}, |
|
) |
|
|
|
async def stream_body_handler(response): |
|
nonlocal content |
|
nonlocal content_blocks |
|
|
|
response_tool_calls = [] |
|
|
|
async for line in response.body_iterator: |
|
line = line.decode("utf-8") if isinstance(line, bytes) else line |
|
data = line |
|
|
|
|
|
if not data.strip(): |
|
continue |
|
|
|
|
|
if not data.startswith("data:"): |
|
continue |
|
|
|
|
|
data = data[len("data:") :].strip() |
|
|
|
try: |
|
data = json.loads(data) |
|
|
|
data, _ = await process_filter_functions( |
|
request=request, |
|
filter_functions=filter_functions, |
|
filter_type="stream", |
|
form_data=data, |
|
extra_params=extra_params, |
|
) |
|
|
|
if data: |
|
if "selected_model_id" in data: |
|
model_id = data["selected_model_id"] |
|
Chats.upsert_message_to_chat_by_id_and_message_id( |
|
metadata["chat_id"], |
|
metadata["message_id"], |
|
{ |
|
"selectedModelId": model_id, |
|
}, |
|
) |
|
else: |
|
choices = data.get("choices", []) |
|
if not choices: |
|
usage = data.get("usage", {}) |
|
if usage: |
|
await event_emitter( |
|
{ |
|
"type": "chat:completion", |
|
"data": { |
|
"usage": usage, |
|
}, |
|
} |
|
) |
|
continue |
|
|
|
delta = choices[0].get("delta", {}) |
|
delta_tool_calls = delta.get("tool_calls", None) |
|
|
|
if delta_tool_calls: |
|
for delta_tool_call in delta_tool_calls: |
|
tool_call_index = delta_tool_call.get( |
|
"index" |
|
) |
|
|
|
if tool_call_index is not None: |
|
if ( |
|
len(response_tool_calls) |
|
<= tool_call_index |
|
): |
|
response_tool_calls.append( |
|
delta_tool_call |
|
) |
|
else: |
|
delta_name = delta_tool_call.get( |
|
"function", {} |
|
).get("name") |
|
delta_arguments = ( |
|
delta_tool_call.get( |
|
"function", {} |
|
).get("arguments") |
|
) |
|
|
|
if delta_name: |
|
response_tool_calls[ |
|
tool_call_index |
|
]["function"][ |
|
"name" |
|
] += delta_name |
|
|
|
if delta_arguments: |
|
response_tool_calls[ |
|
tool_call_index |
|
]["function"][ |
|
"arguments" |
|
] += delta_arguments |
|
|
|
value = delta.get("content") |
|
|
|
reasoning_content = delta.get("reasoning_content") |
|
if reasoning_content: |
|
if ( |
|
not content_blocks |
|
or content_blocks[-1]["type"] != "reasoning" |
|
): |
|
reasoning_block = { |
|
"type": "reasoning", |
|
"start_tag": "think", |
|
"end_tag": "/think", |
|
"attributes": { |
|
"type": "reasoning_content" |
|
}, |
|
"content": "", |
|
"started_at": time.time(), |
|
} |
|
content_blocks.append(reasoning_block) |
|
else: |
|
reasoning_block = content_blocks[-1] |
|
|
|
reasoning_block["content"] += reasoning_content |
|
|
|
data = { |
|
"content": serialize_content_blocks( |
|
content_blocks |
|
) |
|
} |
|
|
|
if value: |
|
if ( |
|
content_blocks |
|
and content_blocks[-1]["type"] |
|
== "reasoning" |
|
and content_blocks[-1] |
|
.get("attributes", {}) |
|
.get("type") |
|
== "reasoning_content" |
|
): |
|
reasoning_block = content_blocks[-1] |
|
reasoning_block["ended_at"] = time.time() |
|
reasoning_block["duration"] = int( |
|
reasoning_block["ended_at"] |
|
- reasoning_block["started_at"] |
|
) |
|
|
|
content_blocks.append( |
|
{ |
|
"type": "text", |
|
"content": "", |
|
} |
|
) |
|
|
|
content = f"{content}{value}" |
|
if not content_blocks: |
|
content_blocks.append( |
|
{ |
|
"type": "text", |
|
"content": "", |
|
} |
|
) |
|
|
|
content_blocks[-1]["content"] = ( |
|
content_blocks[-1]["content"] + value |
|
) |
|
|
|
if DETECT_REASONING: |
|
content, content_blocks, _ = ( |
|
tag_content_handler( |
|
"reasoning", |
|
reasoning_tags, |
|
content, |
|
content_blocks, |
|
) |
|
) |
|
|
|
if DETECT_CODE_INTERPRETER: |
|
content, content_blocks, end = ( |
|
tag_content_handler( |
|
"code_interpreter", |
|
code_interpreter_tags, |
|
content, |
|
content_blocks, |
|
) |
|
) |
|
|
|
if end: |
|
break |
|
|
|
if DETECT_SOLUTION: |
|
content, content_blocks, _ = ( |
|
tag_content_handler( |
|
"solution", |
|
solution_tags, |
|
content, |
|
content_blocks, |
|
) |
|
) |
|
|
|
if ENABLE_REALTIME_CHAT_SAVE: |
|
|
|
Chats.upsert_message_to_chat_by_id_and_message_id( |
|
metadata["chat_id"], |
|
metadata["message_id"], |
|
{ |
|
"content": serialize_content_blocks( |
|
content_blocks |
|
), |
|
}, |
|
) |
|
else: |
|
data = { |
|
"content": serialize_content_blocks( |
|
content_blocks |
|
), |
|
} |
|
|
|
await event_emitter( |
|
{ |
|
"type": "chat:completion", |
|
"data": data, |
|
} |
|
) |
|
except Exception as e: |
|
done = "data: [DONE]" in line |
|
if done: |
|
pass |
|
else: |
|
log.debug("Error: ", e) |
|
continue |
|
|
|
if content_blocks: |
|
|
|
if content_blocks[-1]["type"] == "text": |
|
content_blocks[-1]["content"] = content_blocks[-1][ |
|
"content" |
|
].strip() |
|
|
|
if not content_blocks[-1]["content"]: |
|
content_blocks.pop() |
|
|
|
if not content_blocks: |
|
content_blocks.append( |
|
{ |
|
"type": "text", |
|
"content": "", |
|
} |
|
) |
|
|
|
if response_tool_calls: |
|
tool_calls.append(response_tool_calls) |
|
|
|
if response.background: |
|
await response.background() |
|
|
|
await stream_body_handler(response) |
|
|
|
MAX_TOOL_CALL_RETRIES = 5 |
|
tool_call_retries = 0 |
|
|
|
while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES: |
|
tool_call_retries += 1 |
|
|
|
response_tool_calls = tool_calls.pop(0) |
|
|
|
content_blocks.append( |
|
{ |
|
"type": "tool_calls", |
|
"content": response_tool_calls, |
|
} |
|
) |
|
|
|
await event_emitter( |
|
{ |
|
"type": "chat:completion", |
|
"data": { |
|
"content": serialize_content_blocks(content_blocks), |
|
}, |
|
} |
|
) |
|
|
|
tools = metadata.get("tools", {}) |
|
|
|
results = [] |
|
for tool_call in response_tool_calls: |
|
tool_call_id = tool_call.get("id", "") |
|
tool_name = tool_call.get("function", {}).get("name", "") |
|
|
|
tool_function_params = {} |
|
try: |
|
|
|
tool_function_params = ast.literal_eval( |
|
tool_call.get("function", {}).get("arguments", "{}") |
|
) |
|
except Exception as e: |
|
log.debug(e) |
|
|
|
tool_result = None |
|
|
|
if tool_name in tools: |
|
tool = tools[tool_name] |
|
spec = tool.get("spec", {}) |
|
|
|
try: |
|
required_params = spec.get("parameters", {}).get( |
|
"required", [] |
|
) |
|
tool_function = tool["callable"] |
|
tool_function_params = { |
|
k: v |
|
for k, v in tool_function_params.items() |
|
if k in required_params |
|
} |
|
tool_result = await tool_function( |
|
**tool_function_params |
|
) |
|
except Exception as e: |
|
tool_result = str(e) |
|
|
|
results.append( |
|
{ |
|
"tool_call_id": tool_call_id, |
|
"content": tool_result, |
|
} |
|
) |
|
|
|
content_blocks[-1]["results"] = results |
|
|
|
content_blocks.append( |
|
{ |
|
"type": "text", |
|
"content": "", |
|
} |
|
) |
|
|
|
await event_emitter( |
|
{ |
|
"type": "chat:completion", |
|
"data": { |
|
"content": serialize_content_blocks(content_blocks), |
|
}, |
|
} |
|
) |
|
|
|
try: |
|
res = await generate_chat_completion( |
|
request, |
|
{ |
|
"model": model_id, |
|
"stream": True, |
|
"tools": form_data["tools"], |
|
"messages": [ |
|
*form_data["messages"], |
|
*convert_content_blocks_to_messages(content_blocks), |
|
], |
|
}, |
|
user, |
|
) |
|
|
|
if isinstance(res, StreamingResponse): |
|
await stream_body_handler(res) |
|
else: |
|
break |
|
except Exception as e: |
|
log.debug(e) |
|
break |
|
|
|
if DETECT_CODE_INTERPRETER: |
|
MAX_RETRIES = 5 |
|
retries = 0 |
|
|
|
while ( |
|
content_blocks[-1]["type"] == "code_interpreter" |
|
and retries < MAX_RETRIES |
|
): |
|
await event_emitter( |
|
{ |
|
"type": "chat:completion", |
|
"data": { |
|
"content": serialize_content_blocks(content_blocks), |
|
}, |
|
} |
|
) |
|
|
|
retries += 1 |
|
log.debug(f"Attempt count: {retries}") |
|
|
|
output = "" |
|
try: |
|
if content_blocks[-1]["attributes"].get("type") == "code": |
|
code = content_blocks[-1]["content"] |
|
|
|
if ( |
|
request.app.state.config.CODE_INTERPRETER_ENGINE |
|
== "pyodide" |
|
): |
|
output = await event_caller( |
|
{ |
|
"type": "execute:python", |
|
"data": { |
|
"id": str(uuid4()), |
|
"code": code, |
|
"session_id": metadata.get( |
|
"session_id", None |
|
), |
|
}, |
|
} |
|
) |
|
elif ( |
|
request.app.state.config.CODE_INTERPRETER_ENGINE |
|
== "jupyter" |
|
): |
|
output = await execute_code_jupyter( |
|
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL, |
|
code, |
|
( |
|
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN |
|
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH |
|
== "token" |
|
else None |
|
), |
|
( |
|
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD |
|
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH |
|
== "password" |
|
else None |
|
), |
|
request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT, |
|
) |
|
else: |
|
output = { |
|
"stdout": "Code interpreter engine not configured." |
|
} |
|
|
|
log.debug(f"Code interpreter output: {output}") |
|
|
|
if isinstance(output, dict): |
|
stdout = output.get("stdout", "") |
|
|
|
if isinstance(stdout, str): |
|
stdoutLines = stdout.split("\n") |
|
for idx, line in enumerate(stdoutLines): |
|
if "data:image/png;base64" in line: |
|
id = str(uuid4()) |
|
|
|
|
|
os.makedirs( |
|
os.path.join(CACHE_DIR, "images"), |
|
exist_ok=True, |
|
) |
|
|
|
image_path = os.path.join( |
|
CACHE_DIR, |
|
f"images/{id}.png", |
|
) |
|
|
|
with open(image_path, "wb") as f: |
|
f.write( |
|
base64.b64decode( |
|
line.split(",")[1] |
|
) |
|
) |
|
|
|
stdoutLines[idx] = ( |
|
f"" |
|
) |
|
|
|
output["stdout"] = "\n".join(stdoutLines) |
|
|
|
result = output.get("result", "") |
|
|
|
if isinstance(result, str): |
|
resultLines = result.split("\n") |
|
for idx, line in enumerate(resultLines): |
|
if "data:image/png;base64" in line: |
|
id = str(uuid4()) |
|
|
|
|
|
os.makedirs( |
|
os.path.join(CACHE_DIR, "images"), |
|
exist_ok=True, |
|
) |
|
|
|
image_path = os.path.join( |
|
CACHE_DIR, |
|
f"images/{id}.png", |
|
) |
|
|
|
with open(image_path, "wb") as f: |
|
f.write( |
|
base64.b64decode( |
|
line.split(",")[1] |
|
) |
|
) |
|
|
|
resultLines[idx] = ( |
|
f"" |
|
) |
|
|
|
output["result"] = "\n".join(resultLines) |
|
except Exception as e: |
|
output = str(e) |
|
|
|
content_blocks[-1]["output"] = output |
|
|
|
content_blocks.append( |
|
{ |
|
"type": "text", |
|
"content": "", |
|
} |
|
) |
|
|
|
await event_emitter( |
|
{ |
|
"type": "chat:completion", |
|
"data": { |
|
"content": serialize_content_blocks(content_blocks), |
|
}, |
|
} |
|
) |
|
|
|
log.info(f"content_blocks={content_blocks}") |
|
log.info( |
|
f"serialize_content_blocks={serialize_content_blocks(content_blocks)}" |
|
) |
|
|
|
try: |
|
res = await generate_chat_completion( |
|
request, |
|
{ |
|
"model": model_id, |
|
"stream": True, |
|
"messages": [ |
|
*form_data["messages"], |
|
{ |
|
"role": "assistant", |
|
"content": serialize_content_blocks( |
|
content_blocks, raw=True |
|
), |
|
}, |
|
], |
|
}, |
|
user, |
|
) |
|
|
|
if isinstance(res, StreamingResponse): |
|
await stream_body_handler(res) |
|
else: |
|
break |
|
except Exception as e: |
|
log.debug(e) |
|
break |
|
|
|
title = Chats.get_chat_title_by_id(metadata["chat_id"]) |
|
data = { |
|
"done": True, |
|
"content": serialize_content_blocks(content_blocks), |
|
"title": title, |
|
} |
|
|
|
if not ENABLE_REALTIME_CHAT_SAVE: |
|
|
|
Chats.upsert_message_to_chat_by_id_and_message_id( |
|
metadata["chat_id"], |
|
metadata["message_id"], |
|
{ |
|
"content": serialize_content_blocks(content_blocks), |
|
}, |
|
) |
|
|
|
|
|
if get_active_status_by_user_id(user.id) is None: |
|
webhook_url = Users.get_user_webhook_url_by_id(user.id) |
|
if webhook_url: |
|
post_webhook( |
|
request.app.state.WEBUI_NAME, |
|
webhook_url, |
|
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", |
|
{ |
|
"action": "chat", |
|
"message": content, |
|
"title": title, |
|
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}", |
|
}, |
|
) |
|
|
|
await event_emitter( |
|
{ |
|
"type": "chat:completion", |
|
"data": data, |
|
} |
|
) |
|
|
|
await background_tasks_handler() |
|
except asyncio.CancelledError: |
|
log.warning("Task was cancelled!") |
|
await event_emitter({"type": "task-cancelled"}) |
|
|
|
if not ENABLE_REALTIME_CHAT_SAVE: |
|
|
|
Chats.upsert_message_to_chat_by_id_and_message_id( |
|
metadata["chat_id"], |
|
metadata["message_id"], |
|
{ |
|
"content": serialize_content_blocks(content_blocks), |
|
}, |
|
) |
|
|
|
if response.background is not None: |
|
await response.background() |
|
|
|
|
|
task_id, _ = create_task(post_response_handler(response, events)) |
|
return {"status": True, "task_id": task_id} |
|
|
|
else: |
|
|
|
async def stream_wrapper(original_generator, events): |
|
def wrap_item(item): |
|
return f"data: {item}\n\n" |
|
|
|
for event in events: |
|
event, _ = await process_filter_functions( |
|
request=request, |
|
filter_functions=filter_functions, |
|
filter_type="stream", |
|
form_data=event, |
|
extra_params=extra_params, |
|
) |
|
|
|
if event: |
|
yield wrap_item(json.dumps(event)) |
|
|
|
async for data in original_generator: |
|
data, _ = await process_filter_functions( |
|
request=request, |
|
filter_functions=filter_functions, |
|
filter_type="stream", |
|
form_data=data, |
|
extra_params=extra_params, |
|
) |
|
|
|
if data: |
|
yield data |
|
|
|
return StreamingResponse( |
|
stream_wrapper(response.body_iterator, events), |
|
headers=dict(response.headers), |
|
background=response.background, |
|
) |
|
|