Spaces:
Sleeping
Sleeping
from typing import Annotated | |
from fastapi import APIRouter, Depends | |
from fastapi.responses import JSONResponse | |
from service.dto import UserPromptRequest, BotResponse | |
from core.chat.chatstore import ChatStore | |
from db.database import get_db | |
from db.models import Metadata, Session_Publisher | |
from db.models import Session as SessionModel | |
from sqlalchemy.orm import Session | |
from sqlalchemy.exc import SQLAlchemyError | |
from sqlalchemy import select | |
from api.function import ( | |
generate_streaming_completion, | |
generate_completion_non_streaming, | |
) | |
from datetime import datetime | |
from api.router.user import user_dependency | |
from sse_starlette.sse import EventSourceResponse | |
from utils.utils import generate_uuid | |
from langfuse.llama_index import LlamaIndexCallbackHandler | |
router = APIRouter(tags=["Bot_One"]) | |
db_dependency = Annotated[Session, Depends(get_db)] | |
def get_chat_store(): | |
return ChatStore() | |
async def create_bot_one(user: user_dependency, db: db_dependency, metadata_id: int): | |
if user is None: | |
return JSONResponse(status_code=401, content="Authentication Failed") | |
# Generate a new session ID (UUID) | |
try: | |
session_id = generate_uuid() | |
# Create the new session | |
new_session = Session_Publisher( | |
id=session_id, | |
user_id=user.get("id"), | |
metadata_id=metadata_id, | |
) | |
db.add(new_session) | |
db.commit() # Commit the new session to the database | |
return { | |
"statur": "session id created successfully", | |
"session_id": session_id, | |
} | |
except Exception as e: | |
return JSONResponse( | |
status_code=500, content=f"An unexpected in retrieving session id {str(e)}" | |
) | |
async def generator_bot_one( | |
user: user_dependency, | |
db: db_dependency, | |
metadata_id: int, | |
session_id: str, | |
user_prompt_request: UserPromptRequest, | |
): | |
if user is None: | |
return JSONResponse(status_code=401, content="Authentication Failed") | |
langfuse_callback_handler = LlamaIndexCallbackHandler() | |
langfuse_callback_handler.set_trace_params( | |
user_id=user.get("username"), session_id=session_id | |
) | |
# Query to retrieve the titles | |
try: | |
query = ( | |
select(Metadata.title) | |
.join(Session_Publisher, Metadata.id == metadata_id) | |
.where( | |
Session_Publisher.user_id == user.get("id"), | |
Session_Publisher.id == session_id, | |
) | |
) | |
result = db.execute(query) | |
titles = result.scalars().all() | |
print(titles) | |
except SQLAlchemyError as e: | |
return JSONResponse(status_code=500, content=f"Database error: {str(e)}") | |
except Exception as e: | |
return JSONResponse( | |
status_code=500, content=f"An unexpected error occurred: {str(e)}" | |
) | |
if user_prompt_request.streaming: | |
return EventSourceResponse( | |
generate_streaming_completion( | |
user_prompt_request.prompt, | |
session_id, | |
) | |
) | |
else: | |
response, metadata, scores = generate_completion_non_streaming( | |
session_id, user_prompt_request.prompt, titles, type_bot="specific" | |
) | |
existing_session = ( | |
db.query(Session_Publisher).filter(Session_Publisher.id == session_id).first() | |
) | |
existing_session.updated_at = datetime.now() | |
db.commit() | |
return BotResponse( | |
content=response, | |
metadata=metadata, | |
scores=scores, | |
) | |
async def get_all_session_bot_one( | |
user: user_dependency, db: db_dependency, metadata_id: int | |
): | |
if user is None: | |
return JSONResponse(status_code=401, content="Authentication Failed") | |
try: | |
# Query the session IDs based on the user ID | |
query = select(Session_Publisher.id, Session_Publisher.updated_at).where( | |
Session_Publisher.user_id == user.get("id"), | |
Session_Publisher.metadata_id == metadata_id, | |
) | |
result = db.execute(query) | |
sessions = result.all() | |
session_data = [{"id": session.id, "updated_at": str(session.updated_at)} for session in sessions] | |
# Convert list of tuples to a simple list | |
session_sorted_data = sorted(session_data, key=lambda x: datetime.fromisoformat(x['updated_at']), reverse=True) | |
return session_sorted_data | |
except Exception as e: | |
# Log the error and return JSONResponse for FastAPI | |
print(f"An error occurred while fetching session IDs: {e}") | |
return JSONResponse(status_code=400, content="Error retrieving session IDs") | |