Spaces:
Sleeping
Sleeping
from fastapi import Depends, Request, HTTPException, status | |
from datetime import datetime, timedelta | |
from typing import Union, Optional | |
from utils.utils import get_verified_user, get_admin_user | |
from fastapi import APIRouter | |
from pydantic import BaseModel | |
import json | |
import logging | |
from apps.webui.models.users import Users | |
from apps.webui.models.chats import ( | |
ChatModel, | |
ChatResponse, | |
ChatTitleForm, | |
ChatForm, | |
ChatTitleIdResponse, | |
Chats, | |
) | |
from apps.webui.models.tags import ( | |
TagModel, | |
ChatIdTagModel, | |
ChatIdTagForm, | |
ChatTagsResponse, | |
Tags, | |
) | |
from constants import ERROR_MESSAGES | |
from config import SRC_LOG_LEVELS, ENABLE_ADMIN_EXPORT, ENABLE_ADMIN_CHAT_ACCESS | |
log = logging.getLogger(__name__) | |
log.setLevel(SRC_LOG_LEVELS["MODELS"]) | |
router = APIRouter() | |
############################ | |
# GetChatList | |
############################ | |
async def get_session_user_chat_list( | |
user=Depends(get_verified_user), page: Optional[int] = None | |
): | |
if page is not None: | |
limit = 60 | |
skip = (page - 1) * limit | |
return Chats.get_chat_title_id_list_by_user_id(user.id, skip=skip, limit=limit) | |
else: | |
return Chats.get_chat_title_id_list_by_user_id(user.id) | |
############################ | |
# DeleteAllChats | |
############################ | |
async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)): | |
if ( | |
user.role == "user" | |
and not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"] | |
): | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | |
) | |
result = Chats.delete_chats_by_user_id(user.id) | |
return result | |
############################ | |
# GetUserChatList | |
############################ | |
async def get_user_chat_list_by_user_id( | |
user_id: str, | |
user=Depends(get_admin_user), | |
skip: int = 0, | |
limit: int = 50, | |
): | |
if not ENABLE_ADMIN_CHAT_ACCESS: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | |
) | |
return Chats.get_chat_list_by_user_id( | |
user_id, include_archived=True, skip=skip, limit=limit | |
) | |
############################ | |
# CreateNewChat | |
############################ | |
async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): | |
try: | |
chat = Chats.insert_new_chat(user.id, form_data) | |
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | |
except Exception as e: | |
log.exception(e) | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() | |
) | |
############################ | |
# GetChats | |
############################ | |
async def get_user_chats(user=Depends(get_verified_user)): | |
return [ | |
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | |
for chat in Chats.get_chats_by_user_id(user.id) | |
] | |
############################ | |
# GetArchivedChats | |
############################ | |
async def get_user_archived_chats(user=Depends(get_verified_user)): | |
return [ | |
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | |
for chat in Chats.get_archived_chats_by_user_id(user.id) | |
] | |
############################ | |
# GetAllChatsInDB | |
############################ | |
async def get_all_user_chats_in_db(user=Depends(get_admin_user)): | |
if not ENABLE_ADMIN_EXPORT: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | |
) | |
return [ | |
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | |
for chat in Chats.get_chats() | |
] | |
############################ | |
# GetArchivedChats | |
############################ | |
async def get_archived_session_user_chat_list( | |
user=Depends(get_verified_user), skip: int = 0, limit: int = 50 | |
): | |
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) | |
############################ | |
# ArchiveAllChats | |
############################ | |
async def archive_all_chats(user=Depends(get_verified_user)): | |
return Chats.archive_all_chats_by_user_id(user.id) | |
############################ | |
# GetSharedChatById | |
############################ | |
async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)): | |
if user.role == "pending": | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND | |
) | |
if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS): | |
chat = Chats.get_chat_by_share_id(share_id) | |
elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS: | |
chat = Chats.get_chat_by_id(share_id) | |
if chat: | |
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND | |
) | |
############################ | |
# GetChatsByTags | |
############################ | |
class TagNameForm(BaseModel): | |
name: str | |
skip: Optional[int] = 0 | |
limit: Optional[int] = 50 | |
async def get_user_chat_list_by_tag_name( | |
form_data: TagNameForm, user=Depends(get_verified_user) | |
): | |
chat_ids = [ | |
chat_id_tag.chat_id | |
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id( | |
form_data.name, user.id | |
) | |
] | |
chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit) | |
if len(chats) == 0: | |
Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id) | |
return chats | |
############################ | |
# GetAllTags | |
############################ | |
async def get_all_tags(user=Depends(get_verified_user)): | |
try: | |
tags = Tags.get_tags_by_user_id(user.id) | |
return tags | |
except Exception as e: | |
log.exception(e) | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() | |
) | |
############################ | |
# GetChatById | |
############################ | |
async def get_chat_by_id(id: str, user=Depends(get_verified_user)): | |
chat = Chats.get_chat_by_id_and_user_id(id, user.id) | |
if chat: | |
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND | |
) | |
############################ | |
# UpdateChatById | |
############################ | |
async def update_chat_by_id( | |
id: str, form_data: ChatForm, user=Depends(get_verified_user) | |
): | |
chat = Chats.get_chat_by_id_and_user_id(id, user.id) | |
if chat: | |
updated_chat = {**json.loads(chat.chat), **form_data.chat} | |
chat = Chats.update_chat_by_id(id, updated_chat) | |
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | |
) | |
############################ | |
# DeleteChatById | |
############################ | |
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): | |
if user.role == "admin": | |
result = Chats.delete_chat_by_id(id) | |
return result | |
else: | |
if not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | |
) | |
result = Chats.delete_chat_by_id_and_user_id(id, user.id) | |
return result | |
############################ | |
# CloneChat | |
############################ | |
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)): | |
chat = Chats.get_chat_by_id_and_user_id(id, user.id) | |
if chat: | |
chat_body = json.loads(chat.chat) | |
updated_chat = { | |
**chat_body, | |
"originalChatId": chat.id, | |
"branchPointMessageId": chat_body["history"]["currentId"], | |
"title": f"Clone of {chat.title}", | |
} | |
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat})) | |
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() | |
) | |
############################ | |
# ArchiveChat | |
############################ | |
async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): | |
chat = Chats.get_chat_by_id_and_user_id(id, user.id) | |
if chat: | |
chat = Chats.toggle_chat_archive_by_id(id) | |
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() | |
) | |
############################ | |
# ShareChatById | |
############################ | |
async def share_chat_by_id(id: str, user=Depends(get_verified_user)): | |
chat = Chats.get_chat_by_id_and_user_id(id, user.id) | |
if chat: | |
if chat.share_id: | |
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id) | |
return ChatResponse( | |
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)} | |
) | |
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id) | |
if not shared_chat: | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=ERROR_MESSAGES.DEFAULT(), | |
) | |
return ChatResponse( | |
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)} | |
) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | |
) | |
############################ | |
# DeletedSharedChatById | |
############################ | |
async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)): | |
chat = Chats.get_chat_by_id_and_user_id(id, user.id) | |
if chat: | |
if not chat.share_id: | |
return False | |
result = Chats.delete_shared_chat_by_chat_id(id) | |
update_result = Chats.update_chat_share_id_by_id(id, None) | |
return result and update_result != None | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | |
) | |
############################ | |
# GetChatTagsById | |
############################ | |
async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): | |
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) | |
if tags != None: | |
return tags | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND | |
) | |
############################ | |
# AddChatTagById | |
############################ | |
async def add_chat_tag_by_id( | |
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user) | |
): | |
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) | |
if form_data.tag_name not in tags: | |
tag = Tags.add_tag_to_chat(user.id, form_data) | |
if tag: | |
return tag | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() | |
) | |
############################ | |
# DeleteChatTagById | |
############################ | |
async def delete_chat_tag_by_id( | |
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user) | |
): | |
result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id( | |
form_data.tag_name, id, user.id | |
) | |
if result: | |
return result | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND | |
) | |
############################ | |
# DeleteAllChatTagsById | |
############################ | |
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)): | |
result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id) | |
if result: | |
return result | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND | |
) | |