harshpatel080503 commited on
Commit
6ed6cba
·
verified ·
1 Parent(s): d72f98a

Create rag_chain.py

Browse files
Files changed (1) hide show
  1. rag_chain.py +76 -0
rag_chain.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag_chain.py
2
+
3
+ import os
4
+ from dotenv import load_dotenv
5
+ from youtube_transcript_api import YouTubeTranscriptApi
6
+
7
+ from langchain.embeddings import HuggingFaceEmbeddings
8
+ from langchain.vectorstores import FAISS
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain.chat_models import ChatOpenAI
11
+ from langchain.chains import RetrievalQA
12
+ from langchain.memory import ConversationBufferMemory
13
+ from langchain.prompts import PromptTemplate
14
+
15
+ load_dotenv()
16
+
17
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_ACCESS_TOKEN")
18
+ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
19
+
20
+ # Hugging Face Embeddings
21
+ os.environ['HF_HOME'] = 'E:/Generative AI/AI Models/Embedding Models'
22
+ embedding = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
23
+
24
+ # OpenRouter LLM (Meta LLaMA 3.3)
25
+ llm = ChatOpenAI(
26
+ openai_api_base="https://openrouter.ai/api/v1",
27
+ model="meta-llama/llama-3.3-70b-instruct:free",
28
+ )
29
+
30
+ # Custom prompt for RAG
31
+ qa_prompt = PromptTemplate(
32
+ template="""
33
+ You are a helpful assistant answering questions based on YouTube video content.
34
+
35
+ Context:
36
+ {context}
37
+
38
+ Question:
39
+ {question}
40
+
41
+ Answer:""",
42
+ input_variables=["context", "question"],
43
+ )
44
+
45
+ # Fetch transcript using YouTubeTranscriptApi
46
+ def fetch_transcript(video_id: str) -> str:
47
+ transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=["en", "hi"])
48
+ return " ".join([t["text"] for t in transcript])
49
+
50
+ # Build RAG chain from transcript
51
+ def build_chain(video_id: str) -> RetrievalQA:
52
+ text = fetch_transcript(video_id)
53
+
54
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
55
+ docs = splitter.create_documents([text])
56
+
57
+ vectorstore = FAISS.from_documents(docs, embedding)
58
+ retriever = vectorstore.as_retriever()
59
+
60
+ memory = ConversationBufferMemory(
61
+ memory_key="chat_history",
62
+ return_messages=True,
63
+ output_key="result"
64
+ )
65
+
66
+ qa_chain = RetrievalQA.from_chain_type(
67
+ llm=llm,
68
+ chain_type="stuff",
69
+ retriever=retriever,
70
+ memory=memory,
71
+ return_source_documents=True,
72
+ output_key="result",
73
+ chain_type_kwargs={"prompt": qa_prompt}
74
+ )
75
+
76
+ return qa_chain