chatbot / main.py
Hammad712's picture
Create main.py
1702b26 verified
raw
history blame
3.66 kB
import os
import zipfile
import tempfile
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
app = FastAPI()
# === Globals ===
llm = None
embeddings = None
vectorstore = None
retriever = None
chain = None
class QueryRequest(BaseModel):
question: str
def _unpack_faiss(src: str, dest_dir: str) -> str:
"""
If src is a .zip, unzip it into dest_dir and return
the path to the extracted FAISS folder. Otherwise
assume src is already a folder and return it.
"""
if zipfile.is_zipfile(src):
with zipfile.ZipFile(src, "r") as zf:
zf.extractall(dest_dir)
# if there’s exactly one subfolder, use it
items = os.listdir(dest_dir)
if len(items) == 1 and os.path.isdir(os.path.join(dest_dir, items[0])):
return os.path.join(dest_dir, items[0])
return dest_dir
else:
# src is already a directory
return src
@app.on_event("startup")
def load_components():
global llm, embeddings, vectorstore, retriever, chain
# --- 1) Initialize LLM & Embeddings ---
api_key = os.getenv("api_key")
llm = ChatGroq(
model="meta-llama/llama-4-scout-17b-16e-instruct",
temperature=0,
max_tokens=1024,
api_key=api_key,
)
embeddings = HuggingFaceEmbeddings(
model_name="intfloat/multilingual-e5-large",
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True},
)
# --- 2) Load & merge two FAISS indexes ---
# Paths to your two vectorstores (could be .zip or folders)
src1 = "faiss_index.zip"
src2 = "faiss_index_extra.zip"
# Temporary dirs for extraction
tmp1 = tempfile.mkdtemp()
tmp2 = tempfile.mkdtemp()
# Unpack and load each
path1 = _unpack_faiss(src1, tmp1)
vs1 = FAISS.load_local(path1, embeddings, allow_dangerous_deserialization=True)
path2 = _unpack_faiss(src2, tmp2)
vs2 = FAISS.load_local(path2, embeddings, allow_dangerous_deserialization=True)
# Merge vs2 into vs1
vs1.merge_from(vs2)
# Assign the merged store to our global
vectorstore = vs1
# --- 3) Build retriever & QA chain ---
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
prompt = PromptTemplate(
template="""
You are an expert assistant on Islamic knowledge.
Use **only** the information in the “Retrieved context” to answer the user’s question.
Do **not** add any outside information, personal opinions, or conjecture—if the answer is not contained in the context, reply with “لا أعلم”.
Be concise, accurate, and directly address the user’s question.
Retrieved context:
{context}
User’s question:
{question}
Your response:
""",
input_variables=["context", "question"],
)
chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=False,
chain_type_kwargs={"prompt": prompt},
)
print("✅ Loaded and merged both FAISS indexes, QA chain is ready.")
@app.get("/")
def root():
return {"message": "Arabic Hadith Finder API is up..."}
@app.post("/query")
def query(request: QueryRequest):
try:
result = chain.invoke({"query": request.question})
return {"answer": result["result"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))