Hammad712 commited on
Commit
c158c4a
·
verified ·
1 Parent(s): 524faa3

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +310 -0
main.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Body, Query, File, UploadFile, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from typing import List, Optional, Dict, Any, Union
5
+ import uuid
6
+ import os
7
+ from dotenv import load_dotenv
8
+
9
+ # Load environment variables
10
+ load_dotenv()
11
+
12
+ # Import necessary libraries
13
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
14
+ from langchain.vectorstores import FAISS
15
+ from langchain.chains import ConversationalRetrievalChain
16
+ from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
17
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
18
+ from langchain_core.documents import Document
19
+ from langchain_groq import ChatGroq
20
+ from google import genai
21
+ from google.genai import types
22
+
23
+ # Initialize FastAPI app
24
+ app = FastAPI(title="RAG System API", description="An API for question answering based on YouTube video content or uploaded video files")
25
+
26
+ # Configure CORS
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"],
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ # Define models
36
+ class TranscriptionRequest(BaseModel):
37
+ youtube_url: str
38
+
39
+ class QueryRequest(BaseModel):
40
+ query: str
41
+ session_id: Optional[str] = None
42
+
43
+ class QueryResponse(BaseModel):
44
+ answer: str
45
+ session_id: str
46
+ source_documents: Optional[List[str]] = None
47
+
48
+ # Global variables
49
+ sessions = {}
50
+
51
+ # Initialize Google API client
52
+ def init_google_client():
53
+ api_key = os.getenv("GOOGLE_API_KEY", "")
54
+ if not api_key:
55
+ raise ValueError("GOOGLE_API_KEY environment variable not set")
56
+ return genai.Client(api_key=api_key)
57
+
58
+ # Get LLM
59
+ def get_llm():
60
+ """
61
+ Returns the language model instance (LLM) using ChatGroq API.
62
+ The LLM used is Llama 3.1 with a versatile 70 billion parameters model.
63
+ """
64
+ api_key = os.getenv("GROQ_API_KEY", "")
65
+ if not api_key:
66
+ raise ValueError("GROQ_API_KEY environment variable not set")
67
+
68
+ llm = ChatGroq(
69
+ model="llama-3.3-70b-versatile",
70
+ temperature=0,
71
+ max_tokens=1024,
72
+ api_key=api_key
73
+ )
74
+ return llm
75
+
76
+ # Get embeddings
77
+ def get_embeddings():
78
+ model_name = "BAAI/bge-small-en"
79
+ model_kwargs = {"device": "cpu"}
80
+ encode_kwargs = {"normalize_embeddings": True}
81
+ embeddings = HuggingFaceBgeEmbeddings(
82
+ model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
83
+ )
84
+ return embeddings
85
+
86
+ # Create prompt template
87
+ quiz_solving_prompt = '''
88
+ You are an assistant specialized in solving quizzes. Your goal is to provide accurate, concise, and contextually relevant answers.
89
+ Use the following retrieved context to answer the user's question.
90
+ If the context lacks sufficient information, respond with "I don't know." Do not make up answers or provide unverified information.
91
+
92
+ Guidelines:
93
+ 1. Extract key information from the context to form a coherent response.
94
+ 2. Maintain a clear and professional tone.
95
+ 3. If the question requires clarification, specify it politely.
96
+
97
+ Retrieved context:
98
+ {context}
99
+
100
+ User's question:
101
+ {question}
102
+
103
+ Your response:
104
+ '''
105
+
106
+ # Create a prompt template to pass the context and user input to the chain
107
+ user_prompt = ChatPromptTemplate.from_messages(
108
+ [
109
+ ("system", quiz_solving_prompt),
110
+ ("human", "{question}"),
111
+ ]
112
+ )
113
+
114
+ # Create a chain
115
+ def create_chain(retriever):
116
+ llm = get_llm()
117
+ chain = ConversationalRetrievalChain.from_llm(
118
+ llm=llm,
119
+ retriever=retriever,
120
+ return_source_documents=True,
121
+ chain_type='stuff',
122
+ combine_docs_chain_kwargs={"prompt": user_prompt},
123
+ verbose=False,
124
+ )
125
+ return chain
126
+
127
+ # Process transcription and prepare RAG system
128
+ def process_transcription(transcription):
129
+ # Process the transcription
130
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=20)
131
+ all_splits = text_splitter.split_text(transcription)
132
+
133
+ # Create vector store
134
+ embeddings = get_embeddings()
135
+ vectorstore = FAISS.from_texts(all_splits, embeddings)
136
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
137
+
138
+ # Create a session ID
139
+ session_id = str(uuid.uuid4())
140
+
141
+ # Store session data
142
+ sessions[session_id] = {
143
+ "retriever": retriever,
144
+ "chat_history": [],
145
+ "transcription": transcription
146
+ }
147
+
148
+ return session_id
149
+
150
+ @app.post("/transcribe", response_model=Dict[str, str])
151
+ async def transcribe_video(request: TranscriptionRequest):
152
+ """
153
+ Transcribe a YouTube video and prepare the RAG system
154
+ """
155
+ try:
156
+ # Initialize Google API client
157
+ client = init_google_client()
158
+
159
+ # Transcribe the video
160
+ response = client.models.generate_content(
161
+ model='models/gemini-2.0-flash',
162
+ contents=types.Content(
163
+ parts=[
164
+ types.Part(text='Transcribe the Video. Write all the things described in the video'),
165
+ types.Part(
166
+ file_data=types.FileData(file_uri=request.youtube_url)
167
+ )
168
+ ]
169
+ )
170
+ )
171
+
172
+ # Get transcription text
173
+ transcription = response.candidates[0].content.parts[0].text
174
+
175
+ # Process transcription and get session ID
176
+ session_id = process_transcription(transcription)
177
+
178
+ return {"session_id": session_id, "message": "YouTube video transcribed and RAG system prepared"}
179
+
180
+ except Exception as e:
181
+ raise HTTPException(status_code=500, detail=f"Error transcribing video: {str(e)}")
182
+
183
+ @app.post("/upload", response_model=Dict[str, str])
184
+ async def upload_video(file: UploadFile = File(...), prompt: str = Form("Transcribe the Video. Write all the things described in the video")):
185
+ """
186
+ Upload a video file (max 20MB), transcribe it and prepare the RAG system
187
+ """
188
+ try:
189
+ # Check file size (20MB limit)
190
+ contents = await file.read()
191
+ if len(contents) > 20 * 1024 * 1024: # 20MB in bytes
192
+ raise HTTPException(status_code=400, detail="File size exceeds 20MB limit")
193
+
194
+ # Check file type
195
+ if not file.content_type.startswith('video/'):
196
+ raise HTTPException(status_code=400, detail="File must be a video")
197
+
198
+ # Initialize Google API client
199
+ client = init_google_client()
200
+
201
+ # Transcribe the video
202
+ response = client.models.generate_content(
203
+ model='models/gemini-2.0-flash',
204
+ contents=types.Content(
205
+ parts=[
206
+ types.Part(text=prompt),
207
+ types.Part(
208
+ inline_data=types.Blob(data=contents, mime_type=file.content_type)
209
+ )
210
+ ]
211
+ )
212
+ )
213
+
214
+ # Get transcription text
215
+ transcription = response.candidates[0].content.parts[0].text
216
+
217
+ # Process transcription and get session ID
218
+ session_id = process_transcription(transcription)
219
+
220
+ return {"session_id": session_id, "message": "Uploaded video transcribed and RAG system prepared"}
221
+
222
+ except Exception as e:
223
+ raise HTTPException(status_code=500, detail=f"Error processing uploaded video: {str(e)}")
224
+ finally:
225
+ # Reset file pointer
226
+ await file.seek(0)
227
+
228
+ @app.post("/query", response_model=QueryResponse)
229
+ async def query_system(request: QueryRequest):
230
+ """
231
+ Query the RAG system with a question
232
+ """
233
+ try:
234
+ session_id = request.session_id
235
+
236
+ # Create a new session if none provided
237
+ if not session_id or session_id not in sessions:
238
+ raise HTTPException(status_code=404, detail="Session not found. Please transcribe a video first.")
239
+
240
+ # Get session data
241
+ session = sessions[session_id]
242
+ retriever = session["retriever"]
243
+ chat_history = session["chat_history"]
244
+
245
+ # Create chain
246
+ chain = create_chain(retriever)
247
+
248
+ # Query the chain
249
+ result = chain({"question": request.query, "chat_history": chat_history})
250
+
251
+ # Update chat history
252
+ chat_history.append((request.query, result["answer"]))
253
+
254
+ # Prepare source documents
255
+ source_docs = [doc.page_content[:100] + "..." for doc in result.get("source_documents", [])]
256
+
257
+ return {
258
+ "answer": result["answer"],
259
+ "session_id": session_id,
260
+ "source_documents": source_docs
261
+ }
262
+
263
+ except Exception as e:
264
+ raise HTTPException(status_code=500, detail=f"Error querying system: {str(e)}")
265
+
266
+ @app.get("/sessions/{session_id}", response_model=Dict[str, Any])
267
+ async def get_session_info(session_id: str):
268
+ """
269
+ Get information about a specific session
270
+ """
271
+ if session_id not in sessions:
272
+ raise HTTPException(status_code=404, detail="Session not found")
273
+
274
+ session = sessions[session_id]
275
+
276
+ return {
277
+ "session_id": session_id,
278
+ "chat_history_length": len(session["chat_history"]),
279
+ "transcription_preview": session["transcription"][:200] + "..."
280
+ }
281
+
282
+ @app.delete("/sessions/{session_id}")
283
+ async def delete_session(session_id: str):
284
+ """
285
+ Delete a session
286
+ """
287
+ if session_id not in sessions:
288
+ raise HTTPException(status_code=404, detail="Session not found")
289
+
290
+ del sessions[session_id]
291
+ return {"message": f"Session {session_id} deleted successfully"}
292
+
293
+ @app.get("/")
294
+ async def root():
295
+ """
296
+ API root endpoint
297
+ """
298
+ return {
299
+ "message": "Video Transcription and QA API",
300
+ "endpoints": {
301
+ "/transcribe": "Transcribe YouTube videos",
302
+ "/upload": "Upload and transcribe video files (max 20MB)",
303
+ "/query": "Query the RAG system",
304
+ "/sessions/{session_id}": "Get session information",
305
+ }
306
+ }
307
+
308
+ if __name__ == "__main__":
309
+ import uvicorn
310
+ uvicorn.run(app, host="0.0.0.0", port=8000)