|
from fastapi import FastAPI, HTTPException, Depends, File, UploadFile, Form, Response, BackgroundTasks |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
|
from fastapi.responses import StreamingResponse |
|
from pydantic import BaseModel, Field, EmailStr |
|
from typing import List, Optional, Dict, Any, Union |
|
import uuid |
|
import os |
|
import io |
|
from urllib.parse import quote_plus |
|
|
|
import shutil |
|
from datetime import datetime, timedelta |
|
from dotenv import load_dotenv |
|
import hashlib |
|
import jwt |
|
from passlib.context import CryptContext |
|
from pymongo import MongoClient |
|
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_core.documents import Document |
|
from langchain_groq import ChatGroq |
|
from google import genai |
|
from google.genai import types |
|
|
|
|
|
MONGO_PASSWORD = quote_plus(os.getenv("MONGO_PASSWORD")) |
|
MONGO_DATABASE_NAME = os.getenv("DATABASE_NAME") |
|
MONGO_COLLECTION_NAME = os.getenv("COLLECTION_NAME") |
|
connection_string_template = os.getenv("CONNECTION_STRING") |
|
MONGO_CLUSTER_URL = connection_string_template.replace("${PASSWORD}", MONGO_PASSWORD) |
|
CHAT_COLLECTION = MONGO_COLLECTION_NAME or "chat_history" |
|
USER_COLLECTION = "users" |
|
VIDEO_COLLECTION = "videos" |
|
|
|
|
|
SECRET_KEY = os.getenv("SECRET_KEY") |
|
ALGORITHM = "HS256" |
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30 |
|
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
|
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
|
|
|
|
|
app = FastAPI(title="RAG System API", description="An API for question answering based on video content with user authentication") |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
class TranscriptionRequest(BaseModel): |
|
youtube_url: str |
|
|
|
class QueryRequest(BaseModel): |
|
query: str |
|
session_id: str |
|
|
|
class QueryResponse(BaseModel): |
|
answer: str |
|
session_id: str |
|
source_documents: Optional[List[str]] = None |
|
|
|
class User(BaseModel): |
|
username: str |
|
email: EmailStr |
|
full_name: Optional[str] = None |
|
|
|
class UserInDB(User): |
|
hashed_password: str |
|
|
|
class UserCreate(User): |
|
password: str |
|
|
|
class Token(BaseModel): |
|
access_token: str |
|
token_type: str |
|
|
|
class TokenData(BaseModel): |
|
username: Optional[str] = None |
|
|
|
class VideoData(BaseModel): |
|
video_id: str |
|
user_id: str |
|
title: str |
|
source_type: str |
|
source_url: Optional[str] = None |
|
created_at: datetime = Field(default_factory=datetime.utcnow) |
|
transcription: str |
|
size: Optional[int] = None |
|
|
|
|
|
class MongoDB: |
|
def __init__(self): |
|
self.client = MongoClient(MONGO_CLUSTER_URL) |
|
self.db = self.client[MONGO_DATABASE_NAME] |
|
self.users = self.db[USER_COLLECTION] |
|
self.videos = self.db[VIDEO_COLLECTION] |
|
|
|
|
|
self.users.create_index("username", unique=True) |
|
self.users.create_index("email", unique=True) |
|
self.videos.create_index("video_id", unique=True) |
|
self.videos.create_index("user_id") |
|
|
|
def close(self): |
|
self.client.close() |
|
|
|
|
|
class ChatManagement: |
|
def __init__(self, cluster_url, database_name, collection_name): |
|
self.connection_string = cluster_url |
|
self.database_name = database_name |
|
self.collection_name = collection_name |
|
self.chat_sessions = {} |
|
|
|
def create_new_chat(self): |
|
|
|
chat_id = str(uuid.uuid4()) |
|
|
|
chat_message_history = MongoDBChatMessageHistory( |
|
session_id=chat_id, |
|
connection_string=self.connection_string, |
|
database_name=self.database_name, |
|
collection_name=self.collection_name |
|
) |
|
|
|
self.chat_sessions[chat_id] = chat_message_history |
|
return chat_id |
|
|
|
def get_chat_history(self, chat_id): |
|
|
|
if chat_id in self.chat_sessions: |
|
return self.chat_sessions[chat_id] |
|
|
|
chat_message_history = MongoDBChatMessageHistory( |
|
session_id=chat_id, |
|
connection_string=self.connection_string, |
|
database_name=self.database_name, |
|
collection_name=self.collection_name |
|
) |
|
if chat_message_history.messages: |
|
self.chat_sessions[chat_id] = chat_message_history |
|
return chat_message_history |
|
return None |
|
|
|
def initialize_chat_history(self, chat_id): |
|
|
|
if chat_id in self.chat_sessions: |
|
return self.chat_sessions[chat_id] |
|
|
|
chat_message_history = MongoDBChatMessageHistory( |
|
session_id=chat_id, |
|
connection_string=self.connection_string, |
|
database_name=self.database_name, |
|
collection_name=self.collection_name |
|
) |
|
|
|
self.chat_sessions[chat_id] = chat_message_history |
|
return chat_message_history |
|
|
|
|
|
mongodb = MongoDB() |
|
chat_manager = ChatManagement(MONGO_CLUSTER_URL, MONGO_DATABASE_NAME, CHAT_COLLECTION) |
|
sessions = {} |
|
|
|
|
|
VIDEOS_DIR = "temp_videos" |
|
os.makedirs(VIDEOS_DIR, exist_ok=True) |
|
|
|
|
|
def verify_password(plain_password, hashed_password): |
|
return pwd_context.verify(plain_password, hashed_password) |
|
|
|
def get_password_hash(password): |
|
return pwd_context.hash(password) |
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): |
|
to_encode = data.copy() |
|
if expires_delta: |
|
expire = datetime.utcnow() + expires_delta |
|
else: |
|
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) |
|
to_encode.update({"exp": expire}) |
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) |
|
return encoded_jwt |
|
|
|
def get_user(username: str): |
|
user_data = mongodb.users.find_one({"username": username}) |
|
if user_data: |
|
return UserInDB(**user_data) |
|
return None |
|
|
|
def authenticate_user(username: str, password: str): |
|
user = get_user(username) |
|
if not user: |
|
return False |
|
if not verify_password(password, user.hashed_password): |
|
return False |
|
return user |
|
|
|
async def get_current_user(token: str = Depends(oauth2_scheme)): |
|
credentials_exception = HTTPException( |
|
status_code=401, |
|
detail="Could not validate credentials", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
try: |
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
|
username: str = payload.get("sub") |
|
if username is None: |
|
raise credentials_exception |
|
token_data = TokenData(username=username) |
|
except jwt.PyJWTError: |
|
raise credentials_exception |
|
user = get_user(username=token_data.username) |
|
if user is None: |
|
raise credentials_exception |
|
return user |
|
|
|
|
|
def init_google_client(): |
|
api_key = os.getenv("GOOGLE_API_KEY", "") |
|
if not api_key: |
|
raise ValueError("GOOGLE_API_KEY environment variable not set") |
|
return genai.Client(api_key=api_key) |
|
|
|
|
|
def get_llm(): |
|
""" |
|
Returns the language model instance (LLM) using ChatGroq API. |
|
The LLM used is Llama 3.3 with a versatile 70 billion parameters model. |
|
""" |
|
api_key = os.getenv("CHATGROQ_API_KEY", "") |
|
if not api_key: |
|
raise ValueError("CHATGROQ_API_KEY environment variable not set") |
|
|
|
llm = ChatGroq( |
|
model="llama-3.3-70b-versatile", |
|
temperature=0, |
|
max_tokens=1024, |
|
api_key=api_key |
|
) |
|
return llm |
|
|
|
|
|
def get_embeddings(): |
|
model_name = "BAAI/bge-small-en" |
|
model_kwargs = {"device": "cpu"} |
|
encode_kwargs = {"normalize_embeddings": True} |
|
embeddings = HuggingFaceEmbeddings( |
|
model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs |
|
) |
|
return embeddings |
|
|
|
|
|
quiz_solving_prompt = ''' |
|
You are an assistant specialized in solving quizzes. Your goal is to provide accurate, concise, and contextually relevant answers. |
|
Use the following retrieved context to answer the user's question. |
|
If the context lacks sufficient information, respond with "I don't know." Do not make up answers or provide unverified information. |
|
|
|
Guidelines: |
|
1. Extract key information from the context to form a coherent response. |
|
2. Maintain a clear and professional tone. |
|
3. If the question requires clarification, specify it politely. |
|
|
|
Retrieved context: |
|
{context} |
|
|
|
User's question: |
|
{question} |
|
|
|
Your response: |
|
''' |
|
|
|
|
|
user_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", quiz_solving_prompt), |
|
("human", "{question}"), |
|
] |
|
) |
|
|
|
|
|
def create_chain(retriever): |
|
llm = get_llm() |
|
chain = ConversationalRetrievalChain.from_llm( |
|
llm=llm, |
|
retriever=retriever, |
|
return_source_documents=True, |
|
chain_type='stuff', |
|
combine_docs_chain_kwargs={"prompt": user_prompt}, |
|
verbose=False, |
|
) |
|
return chain |
|
|
|
|
|
def process_transcription(transcription, user_id, title, source_type, source_url=None, file_size=None): |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=20) |
|
all_splits = text_splitter.split_text(transcription) |
|
|
|
|
|
embeddings = get_embeddings() |
|
vectorstore = FAISS.from_texts(all_splits, embeddings) |
|
retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) |
|
|
|
|
|
session_id = str(uuid.uuid4()) |
|
|
|
|
|
video_data = { |
|
"video_id": session_id, |
|
"user_id": user_id, |
|
"title": title, |
|
"source_type": source_type, |
|
"source_url": source_url, |
|
"created_at": datetime.utcnow(), |
|
"transcription": transcription, |
|
"size": file_size |
|
} |
|
|
|
mongodb.videos.insert_one(video_data) |
|
|
|
|
|
sessions[session_id] = { |
|
"retriever": retriever, |
|
"chat_history": chat_manager.initialize_chat_history(session_id) |
|
} |
|
|
|
return session_id |
|
|
|
|
|
def save_video_file(video_id, file_path, contents): |
|
os.makedirs(os.path.dirname(file_path), exist_ok=True) |
|
with open(file_path, "wb") as f: |
|
f.write(contents) |
|
|
|
|
|
@app.post("/register", response_model=User) |
|
async def register_user(user: UserCreate): |
|
|
|
if mongodb.users.find_one({"username": user.username}): |
|
raise HTTPException(status_code=400, detail="Username already registered") |
|
|
|
|
|
if mongodb.users.find_one({"email": user.email}): |
|
raise HTTPException(status_code=400, detail="Email already registered") |
|
|
|
|
|
hashed_password = get_password_hash(user.password) |
|
user_dict = user.dict() |
|
del user_dict["password"] |
|
user_dict["hashed_password"] = hashed_password |
|
|
|
|
|
mongodb.users.insert_one(user_dict) |
|
|
|
return User(**user_dict) |
|
|
|
@app.post("/token", response_model=Token) |
|
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()): |
|
user = authenticate_user(form_data.username, form_data.password) |
|
if not user: |
|
raise HTTPException( |
|
status_code=401, |
|
detail="Incorrect username or password", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) |
|
access_token = create_access_token( |
|
data={"sub": user.username}, expires_delta=access_token_expires |
|
) |
|
return {"access_token": access_token, "token_type": "bearer"} |
|
|
|
|
|
@app.post("/transcribe", response_model=Dict[str, str]) |
|
async def transcribe_video( |
|
request: TranscriptionRequest, |
|
current_user: User = Depends(get_current_user) |
|
): |
|
""" |
|
Transcribe a YouTube video and prepare the RAG system |
|
""" |
|
try: |
|
|
|
client = init_google_client() |
|
|
|
|
|
response = client.models.generate_content( |
|
model='models/gemini-2.0-flash', |
|
contents=types.Content( |
|
parts=[ |
|
types.Part(text='Transcribe the Video. Write all the things described in the video'), |
|
types.Part( |
|
file_data=types.FileData(file_uri=request.youtube_url) |
|
) |
|
] |
|
) |
|
) |
|
|
|
|
|
transcription = response.candidates[0].content.parts[0].text |
|
|
|
|
|
video_title = f"YouTube Video - {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}" |
|
session_id = process_transcription( |
|
transcription, |
|
current_user.username, |
|
video_title, |
|
"youtube", |
|
request.youtube_url |
|
) |
|
|
|
return {"session_id": session_id, "message": "YouTube video transcribed and RAG system prepared"} |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error transcribing video: {str(e)}") |
|
|
|
@app.post("/upload", response_model=Dict[str, str]) |
|
async def upload_video( |
|
background_tasks: BackgroundTasks, |
|
title: str = Form(...), |
|
file: UploadFile = File(...), |
|
prompt: str = Form("Transcribe the Video. Write all the things described in the video"), |
|
current_user: User = Depends(get_current_user) |
|
): |
|
""" |
|
Upload a video file (max 20MB), transcribe it and prepare the RAG system |
|
""" |
|
try: |
|
|
|
contents = await file.read() |
|
file_size = len(contents) |
|
if file_size > 20 * 1024 * 1024: |
|
raise HTTPException(status_code=400, detail="File size exceeds 20MB limit") |
|
|
|
|
|
if not file.content_type.startswith('video/'): |
|
raise HTTPException(status_code=400, detail="File must be a video") |
|
|
|
|
|
client = init_google_client() |
|
|
|
|
|
response = client.models.generate_content( |
|
model='models/gemini-2.0-flash', |
|
contents=types.Content( |
|
parts=[ |
|
types.Part(text=prompt), |
|
types.Part( |
|
inline_data=types.Blob(data=contents, mime_type=file.content_type) |
|
) |
|
] |
|
) |
|
) |
|
|
|
|
|
transcription = response.candidates[0].content.parts[0].text |
|
|
|
|
|
session_id = process_transcription( |
|
transcription, |
|
current_user.username, |
|
title, |
|
"upload", |
|
None, |
|
file_size |
|
) |
|
|
|
|
|
file_extension = os.path.splitext(file.filename)[1] |
|
file_path = os.path.join(VIDEOS_DIR, f"{session_id}{file_extension}") |
|
background_tasks.add_task(save_video_file, session_id, file_path, contents) |
|
|
|
return {"session_id": session_id, "message": "Uploaded video transcribed and RAG system prepared"} |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error processing uploaded video: {str(e)}") |
|
finally: |
|
|
|
await file.seek(0) |
|
|
|
@app.get("/download/{video_id}") |
|
async def download_video( |
|
video_id: str, |
|
current_user: User = Depends(get_current_user) |
|
): |
|
""" |
|
Download a previously uploaded video |
|
""" |
|
|
|
video_data = mongodb.videos.find_one({"video_id": video_id}) |
|
|
|
if not video_data: |
|
raise HTTPException(status_code=404, detail="Video not found") |
|
|
|
|
|
if video_data["user_id"] != current_user.username: |
|
raise HTTPException(status_code=403, detail="Not authorized to access this video") |
|
|
|
|
|
if video_data["source_type"] == "youtube": |
|
return {"message": "This is a YouTube video. Please use the original URL to access the video.", "url": video_data["source_url"]} |
|
|
|
|
|
|
|
video_files = [f for f in os.listdir(VIDEOS_DIR) if f.startswith(video_id)] |
|
|
|
if not video_files: |
|
raise HTTPException(status_code=404, detail="Video file not found") |
|
|
|
file_path = os.path.join(VIDEOS_DIR, video_files[0]) |
|
|
|
|
|
file_extension = os.path.splitext(video_files[0])[1] |
|
mime_type = f"video/{file_extension[1:]}" if file_extension else "video/mp4" |
|
|
|
|
|
def iterfile(): |
|
with open(file_path, "rb") as f: |
|
while chunk := f.read(8192): |
|
yield chunk |
|
|
|
return StreamingResponse( |
|
iterfile(), |
|
media_type=mime_type, |
|
headers={"Content-Disposition": f"attachment; filename={video_data['title']}{file_extension}"} |
|
) |
|
|
|
@app.post("/query", response_model=QueryResponse) |
|
async def query_system( |
|
request: QueryRequest, |
|
current_user: User = Depends(get_current_user) |
|
): |
|
""" |
|
Query the RAG system with a question |
|
""" |
|
try: |
|
session_id = request.session_id |
|
|
|
|
|
if not session_id or session_id not in sessions: |
|
raise HTTPException(status_code=404, detail="Session not found. Please transcribe a video first.") |
|
|
|
|
|
video_data = mongodb.videos.find_one({"video_id": session_id}) |
|
if not video_data or video_data["user_id"] != current_user.username: |
|
raise HTTPException(status_code=403, detail="Not authorized to access this session") |
|
|
|
|
|
session = sessions[session_id] |
|
retriever = session["retriever"] |
|
|
|
|
|
chat_history = chat_manager.initialize_chat_history(session_id) |
|
|
|
|
|
chain = create_chain(retriever) |
|
|
|
|
|
messages = chat_history.messages |
|
|
|
|
|
langchain_chat_history = [] |
|
|
|
|
|
if messages: |
|
|
|
|
|
i = 0 |
|
while i < len(messages) - 1: |
|
user_message = messages[i].content |
|
ai_message = messages[i+1].content |
|
langchain_chat_history.append((user_message, ai_message)) |
|
i += 2 |
|
|
|
|
|
print(f"Chat history length: {len(langchain_chat_history)}") |
|
print(f"Query: {request.query}") |
|
|
|
try: |
|
|
|
result = chain.invoke({ |
|
"question": request.query, |
|
"chat_history": langchain_chat_history |
|
}) |
|
|
|
|
|
answer = result.get("answer", "I couldn't find an answer to your question.") |
|
|
|
|
|
chat_history.add_user_message(request.query) |
|
chat_history.add_ai_message(answer) |
|
|
|
|
|
source_docs = [] |
|
if "source_documents" in result and result["source_documents"]: |
|
for doc in result["source_documents"]: |
|
try: |
|
|
|
if hasattr(doc, 'page_content'): |
|
|
|
content = doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content |
|
source_docs.append(content) |
|
elif isinstance(doc, dict) and 'page_content' in doc: |
|
|
|
content = doc['page_content'][:100] + "..." if len(doc['page_content']) > 100 else doc['page_content'] |
|
source_docs.append(content) |
|
elif isinstance(doc, str): |
|
|
|
content = doc[:100] + "..." if len(doc) > 100 else doc |
|
source_docs.append(content) |
|
except Exception as doc_error: |
|
print(f"Error processing source document: {str(doc_error)}") |
|
|
|
return { |
|
"answer": answer, |
|
"session_id": session_id, |
|
"source_documents": source_docs |
|
} |
|
|
|
except Exception as chain_error: |
|
print(f"Chain invocation error: {str(chain_error)}") |
|
|
|
fallback_answer = "I apologize, but I encountered an error while processing your question. Please try rephrasing your query or asking about a different topic." |
|
|
|
|
|
chat_history.add_user_message(request.query) |
|
chat_history.add_ai_message(fallback_answer) |
|
|
|
return { |
|
"answer": fallback_answer, |
|
"session_id": session_id, |
|
"source_documents": [] |
|
} |
|
|
|
except Exception as e: |
|
print(f"Query system error: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
raise HTTPException(status_code=500, detail=f"Error querying system: {str(e)}") |
|
|
|
|
|
@app.get("/sessions", response_model=List[Dict[str, Any]]) |
|
async def get_user_sessions(current_user: User = Depends(get_current_user)): |
|
""" |
|
Get all video sessions for the current user |
|
""" |
|
user_videos = list(mongodb.videos.find({"user_id": current_user.username})) |
|
|
|
|
|
sessions_list = [] |
|
for video in user_videos: |
|
sessions_list.append({ |
|
"session_id": video["video_id"], |
|
"title": video["title"], |
|
"source_type": video["source_type"], |
|
"created_at": video["created_at"], |
|
"transcription_preview": video["transcription"][:200] + "..." if len(video["transcription"]) > 200 else video["transcription"] |
|
}) |
|
|
|
return sessions_list |
|
|
|
@app.get("/sessions/{session_id}", response_model=Dict[str, Any]) |
|
async def get_session_info( |
|
session_id: str, |
|
current_user: User = Depends(get_current_user) |
|
): |
|
""" |
|
Get information about a specific session |
|
""" |
|
|
|
video_data = mongodb.videos.find_one({"video_id": session_id}) |
|
|
|
if not video_data: |
|
raise HTTPException(status_code=404, detail="Session not found") |
|
|
|
|
|
if video_data["user_id"] != current_user.username: |
|
raise HTTPException(status_code=403, detail="Not authorized to access this session") |
|
|
|
|
|
chat_history_obj = chat_manager.get_chat_history(session_id) |
|
chat_messages = [] |
|
|
|
if chat_history_obj: |
|
messages = chat_history_obj.messages |
|
for i in range(0, len(messages), 2): |
|
if i+1 < len(messages): |
|
chat_messages.append({ |
|
"question": messages[i].content, |
|
"answer": messages[i+1].content |
|
}) |
|
|
|
return { |
|
"session_id": session_id, |
|
"title": video_data["title"], |
|
"source_type": video_data["source_type"], |
|
"source_url": video_data.get("source_url"), |
|
"created_at": video_data["created_at"], |
|
"transcription_preview": video_data["transcription"][:200] + "..." if len(video_data["transcription"]) > 200 else video_data["transcription"], |
|
"full_transcription": video_data["transcription"], |
|
"chat_history": chat_messages |
|
} |
|
|
|
@app.delete("/sessions/{session_id}") |
|
async def delete_session( |
|
session_id: str, |
|
current_user: User = Depends(get_current_user) |
|
): |
|
""" |
|
Delete a session |
|
""" |
|
|
|
video_data = mongodb.videos.find_one({"video_id": session_id}) |
|
|
|
if not video_data: |
|
raise HTTPException(status_code=404, detail="Session not found") |
|
|
|
|
|
if video_data["user_id"] != current_user.username: |
|
raise HTTPException(status_code=403, detail="Not authorized to access this session") |
|
|
|
|
|
mongodb.videos.delete_one({"video_id": session_id}) |
|
|
|
|
|
chat_history = chat_manager.get_chat_history(session_id) |
|
if chat_history: |
|
|
|
mongodb.db[CHAT_COLLECTION].delete_many({"session_id": session_id}) |
|
|
|
|
|
if session_id in sessions: |
|
del sessions[session_id] |
|
|
|
|
|
video_files = [f for f in os.listdir(VIDEOS_DIR) if f.startswith(session_id)] |
|
for file in video_files: |
|
try: |
|
os.remove(os.path.join(VIDEOS_DIR, file)) |
|
except: |
|
pass |
|
|
|
return {"message": f"Session {session_id} deleted successfully"} |
|
|
|
@app.get("/") |
|
async def root(): |
|
""" |
|
API root endpoint |
|
""" |
|
return { |
|
"message": "Video Transcription and QA API", |
|
"endpoints": { |
|
"/register": "Register a new user", |
|
"/token": "Login and get access token", |
|
"/transcribe": "Transcribe YouTube videos", |
|
"/upload": "Upload and transcribe video files (max 20MB)", |
|
"/download/{video_id}": "Download an uploaded video", |
|
"/query": "Query the RAG system", |
|
"/sessions": "List all user sessions", |
|
"/sessions/{session_id}": "Get session information", |
|
} |
|
} |
|
|
|
@app.on_event("shutdown") |
|
def shutdown_event(): |
|
mongodb.close() |
|
|
|
shutil.rmtree(VIDEOS_DIR, ignore_errors=True) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |