chatbot / main.py
Hammad712's picture
Update main.py
a347f56 verified
raw
history blame
3.85 kB
import os
import zipfile
import tempfile
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
app = FastAPI()
# === Globals ===
llm = None
embeddings = None
vectorstore = None
retriever = None
chain = None
class QueryRequest(BaseModel):
question: str
def _unpack_faiss(src_path: str, extract_to: str) -> str:
"""
If src_path is a .zip, unzip to extract_to and return the directory
containing the .faiss file. If it's already a folder, just return it.
"""
# 1) ZIP case
if src_path.lower().endswith(".zip"):
if not os.path.isfile(src_path):
raise FileNotFoundError(f"Could not find zip file: {src_path}")
with zipfile.ZipFile(src_path, "r") as zf:
zf.extractall(extract_to)
# walk until we find any .faiss file
for root, _, files in os.walk(extract_to):
if any(fn.endswith(".faiss") for fn in files):
return root
raise RuntimeError(f"No .faiss index found inside {src_path}")
# 2) directory case
if os.path.isdir(src_path):
return src_path
raise RuntimeError(f"Path is neither a .zip nor a directory: {src_path}")
@app.on_event("startup")
def load_components():
global llm, embeddings, vectorstore, retriever, chain
# --- 1) Init LLM & Embeddings ---
llm = ChatGroq(
model="meta-llama/llama-4-scout-17b-16e-instruct",
temperature=0,
max_tokens=1024,
api_key=os.getenv("api_key"),
)
embeddings = HuggingFaceEmbeddings(
model_name="intfloat/multilingual-e5-large",
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True},
)
# --- 2) Load & merge two FAISS indexes ---
src1 = "faiss_index.zip"
src2 = "faiss_index_extra.zip"
# Use TemporaryDirectory objects so they stick around until program exit
tmp1 = tempfile.TemporaryDirectory()
tmp2 = tempfile.TemporaryDirectory()
# Unpack & locate
dir1 = _unpack_faiss(src1, tmp1.name)
dir2 = _unpack_faiss(src2, tmp2.name)
# Load them
vs1 = FAISS.load_local(dir1, embeddings, allow_dangerous_deserialization=True)
vs2 = FAISS.load_local(dir2, embeddings, allow_dangerous_deserialization=True)
# Merge vs2 into vs1
vs1.merge_from(vs2)
vectorstore = vs1
# --- 3) Build retriever & QA chain ---
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
prompt = PromptTemplate(
template="""
You are an expert assistant on Islamic knowledge.
Use **only** the information in the “Retrieved context” to answer the user’s question.
Do **not** add any outside information, personal opinions, or conjecture—if the answer is not contained in the context, reply with “لا أعلم”.
Be concise, accurate, and directly address the user’s question.
Retrieved context:
{context}
User’s question:
{question}
Your response:
""",
input_variables=["context", "question"],
)
chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=False,
chain_type_kwargs={"prompt": prompt},
)
print("✅ Loaded & merged both FAISS indexes, QA chain ready.")
@app.get("/")
def root():
return {"message": "Arabic Hadith Finder API is up..."}
@app.post("/query")
def query(request: QueryRequest):
try:
result = chain.invoke({"query": request.question})
return {"answer": result["result"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))