Islamic-Quiz / main.py
Hammad712's picture
Update main.py
87668f1 verified
import os
import zipfile
import logging
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
app = FastAPI()
# === Globals ===
llm = None
embeddings = None
vectorstore = None
retriever = None
quiz_chain = None
grade_chain = None
class QuizRequest(BaseModel):
question: str
class GradeRequest(BaseModel):
question: str # string of Q/A pairs
@app.on_event("startup")
def load_components():
global llm, embeddings, vectorstore, retriever, quiz_chain, grade_chain
try:
api_key = os.getenv("api_key")
if not api_key:
logger.error("API_KEY environment variable is not set or empty.")
raise RuntimeError("API_KEY environment variable is not set or empty.")
logger.info("API_KEY is set.")
# 1) Init LLM & Embeddings
llm = ChatGroq(
model="meta-llama/llama-4-scout-17b-16e-instruct",
temperature=0,
max_tokens=1024,
api_key=api_key,
)
embeddings = HuggingFaceEmbeddings(
model_name="intfloat/multilingual-e5-large",
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True},
)
# 2) Load FAISS indexes
for zip_name, dir_name in [("faiss_index.zip", "faiss_index"), ("faiss_index(1).zip", "faiss_index_extra")]:
if not os.path.exists(dir_name):
with zipfile.ZipFile(zip_name, 'r') as z:
z.extractall(dir_name)
logger.info(f"Unzipped {zip_name} to {dir_name}.")
else:
logger.info(f"Directory {dir_name} already exists.")
vs1 = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
logger.info("FAISS index 1 loaded.")
vs2 = FAISS.load_local("faiss_index_extra", embeddings, allow_dangerous_deserialization=True)
logger.info("FAISS index 2 loaded.")
vs1.merge_from(vs2)
vectorstore = vs1
logger.info("Merged FAISS indexes into a single vectorstore.")
retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
# Quiz generation chain
quiz_prompt = PromptTemplate(
template="""
Generate a quiz on the topic "{question}" using **only** the information in the "Retrieved context".
Include clear questions and multiple-choice options (A, B, C, D). Also provide the answers of the questions with them.
If context is insufficient, reply with "I don't know".
Retrieved context:
{context}
Quiz topic:
{question}
Quiz:
""",
input_variables=["context", "question"],
)
quiz_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=False,
chain_type_kwargs={"prompt": quiz_prompt},
)
logger.info("Quiz chain ready.")
except Exception as e:
logger.error("Error loading components", exc_info=True)
raise
@app.get("/")
def root():
return {"message": "API is up and running!"}
@app.post("/quiz")
def create_quiz(request: QuizRequest):
try:
logger.info("Generating quiz for topic: %s", request.question)
result = quiz_chain.invoke({"query": request.question})
logger.info("Quiz generated successfully.")
return {"quiz": result.get("result")}
except Exception as e:
logger.error("Error generating quiz", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))