dsmultimedika's picture
Improve the code bot development
d57efd6
raw
history blame
4.81 kB
from typing import Annotated
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from service.dto import UserPromptRequest, BotResponse
from core.chat.chatstore import ChatStore
from db.database import get_db
from db.models import Metadata, Session_Publisher
from db.models import Session as SessionModel
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy import select
from api.function import (
generate_streaming_completion,
generate_completion_non_streaming,
)
from datetime import datetime
from api.router.user import user_dependency
from sse_starlette.sse import EventSourceResponse
from utils.utils import generate_uuid
from langfuse.llama_index import LlamaIndexCallbackHandler
router = APIRouter(tags=["Bot_One"])
db_dependency = Annotated[Session, Depends(get_db)]
def get_chat_store():
return ChatStore()
@router.post("/bot_one/{metadata_id}")
async def create_bot_one(user: user_dependency, db: db_dependency, metadata_id: int):
if user is None:
return JSONResponse(status_code=401, content="Authentication Failed")
# Generate a new session ID (UUID)
try:
session_id = generate_uuid()
# Create the new session
new_session = Session_Publisher(
id=session_id,
user_id=user.get("id"),
metadata_id=metadata_id,
)
db.add(new_session)
db.commit() # Commit the new session to the database
return {
"statur": "session id created successfully",
"session_id": session_id,
}
except Exception as e:
return JSONResponse(
status_code=500, content=f"An unexpected in retrieving session id {str(e)}"
)
@router.post("/bot_one/{metadata_id}/{session_id}")
async def generator_bot_one(
user: user_dependency,
db: db_dependency,
metadata_id: int,
session_id: str,
user_prompt_request: UserPromptRequest,
):
if user is None:
return JSONResponse(status_code=401, content="Authentication Failed")
langfuse_callback_handler = LlamaIndexCallbackHandler()
langfuse_callback_handler.set_trace_params(
user_id=user.get("username"), session_id=session_id
)
# Query to retrieve the titles
try:
query = (
select(Metadata.title)
.join(Session_Publisher, Metadata.id == metadata_id)
.where(
Session_Publisher.user_id == user.get("id"),
Session_Publisher.id == session_id,
)
)
result = db.execute(query)
titles = result.scalars().all()
print(titles)
except SQLAlchemyError as e:
return JSONResponse(status_code=500, content=f"Database error: {str(e)}")
except Exception as e:
return JSONResponse(
status_code=500, content=f"An unexpected error occurred: {str(e)}"
)
if user_prompt_request.streaming:
return EventSourceResponse(
generate_streaming_completion(
user_prompt_request.prompt,
session_id,
)
)
else:
response, metadata, scores = generate_completion_non_streaming(
session_id, user_prompt_request.prompt, titles, type_bot="specific"
)
existing_session = (
db.query(Session_Publisher).filter(Session_Publisher.id == session_id).first()
)
existing_session.updated_at = datetime.now()
db.commit()
return BotResponse(
content=response,
metadata=metadata,
scores=scores,
)
@router.get("/bot_one{metadata_id}")
async def get_all_session_bot_one(
user: user_dependency, db: db_dependency, metadata_id: int
):
if user is None:
return JSONResponse(status_code=401, content="Authentication Failed")
try:
# Query the session IDs based on the user ID
query = select(Session_Publisher.id, Session_Publisher.updated_at).where(
Session_Publisher.user_id == user.get("id"),
Session_Publisher.metadata_id == metadata_id,
)
result = db.execute(query)
sessions = result.all()
session_data = [{"id": session.id, "updated_at": str(session.updated_at)} for session in sessions]
# Convert list of tuples to a simple list
session_sorted_data = sorted(session_data, key=lambda x: datetime.fromisoformat(x['updated_at']), reverse=True)
return session_sorted_data
except Exception as e:
# Log the error and return JSONResponse for FastAPI
print(f"An error occurred while fetching session IDs: {e}")
return JSONResponse(status_code=400, content="Error retrieving session IDs")