chatbot / main.py
Hammad712's picture
Update main.py
6a8093c verified
import os
import zipfile
import logging
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
app = FastAPI()
# === Globals ===
llm = None
embeddings = None
vectorstore = None
retriever = None
chain = None
class QueryRequest(BaseModel):
question: str
@app.on_event("startup")
def load_components():
global llm, embeddings, vectorstore, retriever, chain
try:
# Ensure API key is provided
api_key = os.getenv("API_KEY")
if not api_key:
logger.error("API_KEY environment variable is not set or empty.")
raise RuntimeError("API_KEY environment variable is not set or empty.")
logger.info("API_KEY is set.")
# 1) Init LLM & Embeddings
llm = ChatGroq(
model="meta-llama/llama-4-scout-17b-16e-instruct",
temperature=0,
max_tokens=1024,
api_key=api_key,
)
embeddings = HuggingFaceEmbeddings(
model_name="intfloat/multilingual-e5-large",
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True},
)
# 2) Unzip & Load both FAISS vectorstores
for zip_name, dir_name in [("faiss_index.zip", "faiss_index"), ("faiss_index(1).zip", "faiss_index_extra")]:
if not os.path.exists(dir_name):
with zipfile.ZipFile(zip_name, 'r') as z:
z.extractall(dir_name)
logger.info(f"Unzipped {zip_name} to {dir_name}.")
else:
logger.info(f"Directory {dir_name} already exists.")
vs1 = FAISS.load_local(
"faiss_index",
embeddings,
allow_dangerous_deserialization=True
)
logger.info("FAISS index 1 loaded.")
vs2 = FAISS.load_local(
"faiss_index_extra",
embeddings,
allow_dangerous_deserialization=True
)
logger.info("FAISS index 2 loaded.")
# 3) Merge them
vs1.merge_from(vs2)
vectorstore = vs1
logger.info("Merged FAISS indexes into a single vectorstore.")
# 4) Create retriever & QA chain
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
prompt = PromptTemplate(
template="""
You are an expert assistant on Islamic knowledge.
Use **only** the information in the “Retrieved context” to answer general questions related to Islam.
Do **not** add any outside information, personal opinions, or conjecture—if the answer is not contained in the context, reply with "I don't know".
Be concise, accurate, and directly address the user’s question. Always write reference from where you answer.
Retrieved context:
{context}
User’s question:
{question}
Your response:
""",
input_variables=["context", "question"],
)
chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=False,
chain_type_kwargs={"prompt": prompt},
)
logger.info("QA chain ready.")
except Exception as e:
logger.error("Error loading components", exc_info=True)
raise
@app.get("/")
def root():
return {"message": "API is up and running!"}
@app.post("/query")
def query(request: QueryRequest):
try:
logger.info("Received query: %s", request.question)
result = chain.invoke({"query": request.question})
logger.info("Query processed successfully.")
return {"answer": result.get("result")}
except Exception as e:
logger.error("Error processing query", exc_info=True)
# Return detailed error for debugging
raise HTTPException(status_code=500, detail=str(e))