MaryamKarimi080 commited on
Commit
6ca7edb
·
verified ·
1 Parent(s): 18a6316

Update scripts/rag_chat.py

Browse files
Files changed (1) hide show
  1. scripts/rag_chat.py +42 -36
scripts/rag_chat.py CHANGED
@@ -1,36 +1,42 @@
1
- from langchain.chains import RetrievalQA
2
- from langchain_openai import ChatOpenAI
3
- from langchain_chroma import Chroma
4
- from langchain_openai import OpenAIEmbeddings
5
- from langchain.prompts import PromptTemplate
6
- from pathlib import Path
7
-
8
- BASE_DIR = Path(__file__).resolve().parent.parent
9
- DB_DIR = str(BASE_DIR / "db")
10
-
11
- def build_general_qa_chain(model_name=None):
12
- embedding = OpenAIEmbeddings(model="text-embedding-3-small")
13
- vectorstore = Chroma(persist_directory=DB_DIR, embedding_function=embedding)
14
-
15
- # Custom prompt with source attribution
16
- template = """Use the following context to answer the question.
17
- If the answer isn't found in the context, use your general knowledge but say so.
18
- Always cite your sources at the end with 'Source: <filename>' when using course materials.
19
-
20
- Context: {context}
21
- Question: {question}
22
- Helpful Answer:"""
23
-
24
- QA_PROMPT = PromptTemplate(
25
- template=template,
26
- input_variables=["context", "question"]
27
- )
28
-
29
- llm = ChatOpenAI(model_name=model_name or "gpt-4o-mini", temperature=0.0)
30
- qa_chain = RetrievalQA.from_chain_type(
31
- llm=llm,
32
- retriever=vectorstore.as_retriever(search_kwargs={"k": 4}),
33
- chain_type_kwargs={"prompt": QA_PROMPT},
34
- return_source_documents=True
35
- )
36
- return qa_chain
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from langchain.chains import RetrievalQA
4
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
5
+ from langchain_chroma import Chroma
6
+ from langchain.prompts import PromptTemplate
7
+
8
+ BASE_DIR = Path(__file__).resolve().parent.parent
9
+ DB_DIR = BASE_DIR / "db"
10
+
11
+ def build_general_qa_chain(model_name=None):
12
+ if not DB_DIR.exists():
13
+ print("📦 No DB found. Building vectorstore...")
14
+ from scripts import load_documents, chunk_and_embed, setup_vectorstore
15
+ load_documents.main()
16
+ chunk_and_embed.main()
17
+ setup_vectorstore.main()
18
+
19
+ embedding = OpenAIEmbeddings(model="text-embedding-3-small")
20
+ vectorstore = Chroma(persist_directory=str(DB_DIR), embedding_function=embedding)
21
+
22
+ template = """Use the following context to answer the question.
23
+ If the answer isn't found in the context, use your general knowledge but say so.
24
+ Always cite your sources at the end with 'Source: <filename>' when using course materials.
25
+
26
+ Context: {context}
27
+ Question: {question}
28
+ Helpful Answer:"""
29
+
30
+ QA_PROMPT = PromptTemplate(
31
+ template=template,
32
+ input_variables=["context", "question"]
33
+ )
34
+
35
+ llm = ChatOpenAI(model_name=model_name or "gpt-4o-mini", temperature=0.0)
36
+ qa_chain = RetrievalQA.from_chain_type(
37
+ llm=llm,
38
+ retriever=vectorstore.as_retriever(search_kwargs={"k": 4}),
39
+ chain_type_kwargs={"prompt": QA_PROMPT},
40
+ return_source_documents=True
41
+ )
42
+ return qa_chain