File size: 3,663 Bytes
1702b26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import zipfile
import tempfile

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate

app = FastAPI()

# === Globals ===
llm = None
embeddings = None
vectorstore = None
retriever = None
chain = None


class QueryRequest(BaseModel):
    question: str


def _unpack_faiss(src: str, dest_dir: str) -> str:
    """
    If src is a .zip, unzip it into dest_dir and return
    the path to the extracted FAISS folder. Otherwise
    assume src is already a folder and return it.
    """
    if zipfile.is_zipfile(src):
        with zipfile.ZipFile(src, "r") as zf:
            zf.extractall(dest_dir)
        # if there’s exactly one subfolder, use it
        items = os.listdir(dest_dir)
        if len(items) == 1 and os.path.isdir(os.path.join(dest_dir, items[0])):
            return os.path.join(dest_dir, items[0])
        return dest_dir
    else:
        # src is already a directory
        return src


@app.on_event("startup")
def load_components():
    global llm, embeddings, vectorstore, retriever, chain

    # --- 1) Initialize LLM & Embeddings ---
    api_key = os.getenv("api_key")
    llm = ChatGroq(
        model="meta-llama/llama-4-scout-17b-16e-instruct",
        temperature=0,
        max_tokens=1024,
        api_key=api_key,
    )

    embeddings = HuggingFaceEmbeddings(
        model_name="intfloat/multilingual-e5-large",
        model_kwargs={"device": "cpu"},
        encode_kwargs={"normalize_embeddings": True},
    )

    # --- 2) Load & merge two FAISS indexes ---
    # Paths to your two vectorstores (could be .zip or folders)
    src1 = "faiss_index.zip"
    src2 = "faiss_index_extra.zip"

    # Temporary dirs for extraction
    tmp1 = tempfile.mkdtemp()
    tmp2 = tempfile.mkdtemp()

    # Unpack and load each
    path1 = _unpack_faiss(src1, tmp1)
    vs1 = FAISS.load_local(path1, embeddings, allow_dangerous_deserialization=True)

    path2 = _unpack_faiss(src2, tmp2)
    vs2 = FAISS.load_local(path2, embeddings, allow_dangerous_deserialization=True)

    # Merge vs2 into vs1
    vs1.merge_from(vs2)

    # Assign the merged store to our global
    vectorstore = vs1

    # --- 3) Build retriever & QA chain ---
    retriever = vectorstore.as_retriever(search_kwargs={"k": 5})

    prompt = PromptTemplate(
        template="""
You are an expert assistant on Islamic knowledge.  
Use **only** the information in the “Retrieved context” to answer the user’s question.  
Do **not** add any outside information, personal opinions, or conjecture—if the answer is not contained in the context, reply with “لا أعلم”.  
Be concise, accurate, and directly address the user’s question.

Retrieved context:
{context}

User’s question:
{question}

Your response:
""",
        input_variables=["context", "question"],
    )

    chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=retriever,
        return_source_documents=False,
        chain_type_kwargs={"prompt": prompt},
    )

    print("✅ Loaded and merged both FAISS indexes, QA chain is ready.")


@app.get("/")
def root():
    return {"message": "Arabic Hadith Finder API is up..."}


@app.post("/query")
def query(request: QueryRequest):
    try:
        result = chain.invoke({"query": request.question})
        return {"answer": result["result"]}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))