File size: 3,553 Bytes
b39c0ba
 
d57efd6
 
b39c0ba
 
 
 
 
 
 
d57efd6
 
 
 
b39c0ba
 
d57efd6
 
 
 
 
 
 
 
 
 
 
 
0767396
 
 
 
d57efd6
 
 
 
 
 
 
 
 
 
b39c0ba
 
 
 
 
 
d57efd6
b39c0ba
 
 
d57efd6
 
 
 
 
 
0767396
d57efd6
0767396
 
 
 
d57efd6
 
 
 
 
b39c0ba
d57efd6
 
b39c0ba
 
d57efd6
 
 
 
 
 
 
 
 
b39c0ba
0767396
b39c0ba
0767396
 
 
 
d57efd6
 
 
b39c0ba
 
 
 
 
 
d57efd6
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from typing import Annotated

from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from sse_starlette.sse import EventSourceResponse

from api.auth import check_user_authentication
from api.function import generate_streaming_completion
from api.router.user import user_dependency
from core.chat.bot_service import ChatCompletionService
from core.chat.chatstore import ChatStore
from db.database import get_db
from db.models import Session_Publisher
from langfuse.llama_index import LlamaIndexCallbackHandler
from service.dto import UserPromptRequest, BotResponse
from utils.utils import generate_uuid


router = APIRouter(tags=["Bot_General"])

db_dependency = Annotated[Session, Depends(get_db)]


def get_chat_store():
    return ChatStore()


@router.post("/bot_general/new")
async def create_session_general(user: user_dependency):
    auth_response = check_user_authentication(user)
    if auth_response:
        return auth_response
    session_id = generate_uuid()
    return {"session_id": session_id}


@router.get("/bot/{session_id}")
async def get_session_id(
    user: user_dependency,
    session_id: str,
    chat_store: ChatStore = Depends(get_chat_store),
):
    auth_response = check_user_authentication(user)
    if auth_response:
        return auth_response

    # Retrieve chat history from Redis
    chat_history = chat_store.get_messages(session_id)

    # If no chat history is found in Redis, fallback to the alternative store using mongoDB
    if chat_history is None or chat_history == []:
        chat_history = chat_store.get_all_messages_mongodb(session_id)

    return chat_history


@router.post("/bot/{session_id}")
async def bot_generator_general(
    user: user_dependency,session_id: str, user_prompt_request: UserPromptRequest
):
    auth_response = check_user_authentication(user)
    if auth_response:
        return auth_response
    
    langfuse_callback_handler = LlamaIndexCallbackHandler()
    langfuse_callback_handler.set_trace_params(user_id="guest", session_id=session_id)

    if user_prompt_request.streaming:
        return EventSourceResponse(
            generate_streaming_completion(user_prompt_request.prompt, session_id)
        )
    else:
        bot_service = ChatCompletionService(session_id, user_prompt_request.prompt)
        response, metadata, scores = bot_service.generate_completion()

        return BotResponse(
            content=response,
            metadata=metadata,
            scores=scores,
        )


@router.delete("/bot/{session_id}")
async def delete_bot(
    user: user_dependency,db: db_dependency, session_id: str, chat_store: ChatStore = Depends(get_chat_store)
):
    auth_response = check_user_authentication(user)
    if auth_response:
        return auth_response
    
    try:
        chat_store.delete_messages(session_id)
        # Delete session from database
        session = (
            db.query(Session_Publisher)
            .filter(Session_Publisher.id == session_id)
            .first()
        )

        if session:
            db.delete(session)
            db.commit()  # Commit the transaction
        else:
            return JSONResponse(status_code=404, content="Session not found")
        return {"info": f"Delete {session_id} successful"}
    except Exception as e:
        # Log the error and return JSONResponse for FastAPI
        print(f"An error occurred in update data.: {e}")
        return JSONResponse(status_code=400, content="the error when deleting message")