Vishaltiwari2019 commited on
Commit
9bf33d9
·
verified ·
1 Parent(s): b2ac503

Upload 5 files

Browse files
Files changed (5) hide show
  1. .env +2 -0
  2. langgraph_chain.py +101 -0
  3. main.py +19 -0
  4. requirements.txt +12 -0
  5. tools.py +41 -0
.env ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ GROQ_API_KEY=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
2
+ SERPAPI_API_KEY=e92e6c6b0f63d2352fedc24c5f5db7cc2977e075ac048e3ab916449d3b536200
langgraph_chain.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph.graph import StateGraph, END
2
+ from langchain.chains import RetrievalQA
3
+ from typing import TypedDict, Optional
4
+ from tools import llm, load_vectorstore, search_tool
5
+
6
+ # Load your vectorstore
7
+ vectorstore = load_vectorstore()
8
+
9
+ # --- TypedDict to define graph state schema ---
10
+ class GraphState(TypedDict):
11
+ question: str
12
+ pdf_answer: Optional[str]
13
+ llm_answer: Optional[str]
14
+ web_answer: Optional[str]
15
+
16
+ # --- LangGraph Node Functions ---
17
+
18
+ def pdf_qa_node(state: GraphState) -> GraphState:
19
+ query = state["question"]
20
+ qa = RetrievalQA.from_chain_type(llm=llm, retriever=vectorstore.as_retriever())
21
+ result = qa.run(query)
22
+ return {**state, "pdf_answer": result}
23
+
24
+ def check_pdf_relevance(state: GraphState) -> str:
25
+ ans = state.get("pdf_answer", "").lower()
26
+ if (
27
+ "i don't know" in ans
28
+ or "i don't have information" in ans
29
+ or "no relevant" in ans
30
+ or "not available" in ans
31
+ or len(ans.strip()) < 20
32
+ ):
33
+ return "llm_fallback"
34
+ return "respond_pdf"
35
+
36
+ def llm_fallback_node(state: GraphState) -> GraphState:
37
+ query = state["question"]
38
+ prompt = f"""You are a helpful AI assistant. The user asked a question, and no relevant documents were found.
39
+
40
+ Try your best to answer this:
41
+
42
+ Question: {query}
43
+ Answer:"""
44
+ res = llm.invoke(prompt)
45
+ return {**state, "llm_answer": res.content}
46
+
47
+ def check_llm_confidence(state: GraphState) -> str:
48
+ ans = state.get("llm_answer", "").lower()
49
+ if "i don't know" in ans or "not sure" in ans or "no idea" in ans:
50
+ return "web_search"
51
+ return "respond_llm"
52
+
53
+ def web_search_node(state: GraphState) -> GraphState:
54
+ query = state["question"]
55
+ result = search_tool(query)
56
+ return {**state, "web_answer": result}
57
+
58
+ def respond_pdf(state: GraphState) -> dict:
59
+ print("📄 Responding from PDF")
60
+ return {"answer": state["pdf_answer"]}
61
+
62
+ def respond_llm(state: GraphState) -> dict:
63
+ print("🤖 Responding from LLM")
64
+ return {"answer": state["llm_answer"]}
65
+
66
+ def respond_web(state: GraphState) -> dict:
67
+ print("🌐 Responding from Web Search")
68
+ return {"answer": state["web_answer"]}
69
+
70
+ # --- Graph Creation Function ---
71
+ def create_graph():
72
+ builder = StateGraph(GraphState) # Pass schema
73
+
74
+ builder.add_node("pdf_qa", pdf_qa_node)
75
+ builder.add_node("llm_fallback", llm_fallback_node)
76
+ builder.add_node("web_search", web_search_node)
77
+
78
+ builder.add_node("respond_pdf", respond_pdf)
79
+ builder.add_node("respond_llm", respond_llm)
80
+ builder.add_node("respond_web", respond_web)
81
+
82
+ builder.set_entry_point("pdf_qa")
83
+
84
+ builder.add_conditional_edges("pdf_qa", check_pdf_relevance, {
85
+ "respond_pdf": "respond_pdf",
86
+ "llm_fallback": "llm_fallback"
87
+ })
88
+
89
+ builder.add_conditional_edges("llm_fallback", check_llm_confidence, {
90
+ "respond_llm": "respond_llm",
91
+ "web_search": "web_search"
92
+ })
93
+
94
+ builder.add_edge("web_search", "respond_web")
95
+
96
+ # Set all end nodes
97
+ builder.add_edge("respond_pdf", END)
98
+ builder.add_edge("respond_llm", END)
99
+ builder.add_edge("respond_web", END)
100
+
101
+ return builder.compile()
main.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from pydantic import BaseModel
3
+ from langgraph_chain import create_graph
4
+ from fastapi import Form
5
+
6
+ app = FastAPI()
7
+ graph = create_graph()
8
+
9
+ class Query(BaseModel):
10
+ question: str
11
+
12
+ @app.get("/")
13
+ def read_root():
14
+ return {"Hello": "World"}
15
+
16
+ @app.post("/ask")
17
+ def ask_q(question: str = Form(...)):
18
+ result = graph.invoke({"question": question})
19
+ return {"response": result}
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain
2
+ langgraph
3
+ openai
4
+ groq
5
+ faiss-cpu
6
+ sentence-transformers
7
+ pypdf
8
+ python-dotenv
9
+ fastapi
10
+ uvicorn
11
+ serpapi
12
+ langchain_community
tools.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langchain.chat_models import ChatOpenAI
4
+ from langchain.document_loaders import PyPDFLoader
5
+ from langchain.text_splitter import CharacterTextSplitter
6
+ from langchain.vectorstores import FAISS
7
+ from langchain.embeddings import HuggingFaceEmbeddings
8
+ import serpapi
9
+
10
+ load_dotenv()
11
+
12
+ # LLM (Groq + LLaMA3)
13
+ llm = ChatOpenAI(
14
+ model="llama3-8b-8192",
15
+ openai_api_base="https://api.groq.com/openai/v1",
16
+ openai_api_key=os.environ["GROQ_API_KEY"]
17
+ )
18
+
19
+ # Embeddings (HuggingFace)
20
+ embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
21
+
22
+ # Load PDFs and create FAISS vectorstore
23
+ def load_vectorstore(pdf_dir="pdfs/"):
24
+ docs = []
25
+ for file in os.listdir(pdf_dir):
26
+ if file.endswith(".pdf"):
27
+ loader = PyPDFLoader(os.path.join(pdf_dir, file))
28
+ docs.extend(loader.load())
29
+ splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
30
+ chunks = splitter.split_documents(docs)
31
+ return FAISS.from_documents(chunks, embedding=embeddings)
32
+
33
+ # Custom Web Search tool using SerpAPI
34
+ def search_tool(query: str):
35
+ client = serpapi.Client(api_key=os.getenv("SERPAPI_API_KEY"))
36
+ search = client.search({
37
+ "engine": "google",
38
+ "q": query,
39
+ })
40
+ results = dict(search)
41
+ return results["organic_results"][0]["snippet"] # Return the snippet or any part of the result