Hammad712 commited on
Commit
1702b26
·
verified ·
1 Parent(s): e8a3f1b

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +131 -0
main.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+ import tempfile
4
+
5
+ from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel
7
+
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain_huggingface import HuggingFaceEmbeddings
10
+ from langchain_groq import ChatGroq
11
+ from langchain.chains import RetrievalQA
12
+ from langchain.prompts import PromptTemplate
13
+
14
+ app = FastAPI()
15
+
16
+ # === Globals ===
17
+ llm = None
18
+ embeddings = None
19
+ vectorstore = None
20
+ retriever = None
21
+ chain = None
22
+
23
+
24
+ class QueryRequest(BaseModel):
25
+ question: str
26
+
27
+
28
+ def _unpack_faiss(src: str, dest_dir: str) -> str:
29
+ """
30
+ If src is a .zip, unzip it into dest_dir and return
31
+ the path to the extracted FAISS folder. Otherwise
32
+ assume src is already a folder and return it.
33
+ """
34
+ if zipfile.is_zipfile(src):
35
+ with zipfile.ZipFile(src, "r") as zf:
36
+ zf.extractall(dest_dir)
37
+ # if there’s exactly one subfolder, use it
38
+ items = os.listdir(dest_dir)
39
+ if len(items) == 1 and os.path.isdir(os.path.join(dest_dir, items[0])):
40
+ return os.path.join(dest_dir, items[0])
41
+ return dest_dir
42
+ else:
43
+ # src is already a directory
44
+ return src
45
+
46
+
47
+ @app.on_event("startup")
48
+ def load_components():
49
+ global llm, embeddings, vectorstore, retriever, chain
50
+
51
+ # --- 1) Initialize LLM & Embeddings ---
52
+ api_key = os.getenv("api_key")
53
+ llm = ChatGroq(
54
+ model="meta-llama/llama-4-scout-17b-16e-instruct",
55
+ temperature=0,
56
+ max_tokens=1024,
57
+ api_key=api_key,
58
+ )
59
+
60
+ embeddings = HuggingFaceEmbeddings(
61
+ model_name="intfloat/multilingual-e5-large",
62
+ model_kwargs={"device": "cpu"},
63
+ encode_kwargs={"normalize_embeddings": True},
64
+ )
65
+
66
+ # --- 2) Load & merge two FAISS indexes ---
67
+ # Paths to your two vectorstores (could be .zip or folders)
68
+ src1 = "faiss_index.zip"
69
+ src2 = "faiss_index_extra.zip"
70
+
71
+ # Temporary dirs for extraction
72
+ tmp1 = tempfile.mkdtemp()
73
+ tmp2 = tempfile.mkdtemp()
74
+
75
+ # Unpack and load each
76
+ path1 = _unpack_faiss(src1, tmp1)
77
+ vs1 = FAISS.load_local(path1, embeddings, allow_dangerous_deserialization=True)
78
+
79
+ path2 = _unpack_faiss(src2, tmp2)
80
+ vs2 = FAISS.load_local(path2, embeddings, allow_dangerous_deserialization=True)
81
+
82
+ # Merge vs2 into vs1
83
+ vs1.merge_from(vs2)
84
+
85
+ # Assign the merged store to our global
86
+ vectorstore = vs1
87
+
88
+ # --- 3) Build retriever & QA chain ---
89
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
90
+
91
+ prompt = PromptTemplate(
92
+ template="""
93
+ You are an expert assistant on Islamic knowledge.
94
+ Use **only** the information in the “Retrieved context” to answer the user’s question.
95
+ Do **not** add any outside information, personal opinions, or conjecture—if the answer is not contained in the context, reply with “لا أعلم”.
96
+ Be concise, accurate, and directly address the user’s question.
97
+
98
+ Retrieved context:
99
+ {context}
100
+
101
+ User’s question:
102
+ {question}
103
+
104
+ Your response:
105
+ """,
106
+ input_variables=["context", "question"],
107
+ )
108
+
109
+ chain = RetrievalQA.from_chain_type(
110
+ llm=llm,
111
+ chain_type="stuff",
112
+ retriever=retriever,
113
+ return_source_documents=False,
114
+ chain_type_kwargs={"prompt": prompt},
115
+ )
116
+
117
+ print("✅ Loaded and merged both FAISS indexes, QA chain is ready.")
118
+
119
+
120
+ @app.get("/")
121
+ def root():
122
+ return {"message": "Arabic Hadith Finder API is up..."}
123
+
124
+
125
+ @app.post("/query")
126
+ def query(request: QueryRequest):
127
+ try:
128
+ result = chain.invoke({"query": request.question})
129
+ return {"answer": result["result"]}
130
+ except Exception as e:
131
+ raise HTTPException(status_code=500, detail=str(e))