Spaces:
Sleeping
Sleeping
import os | |
import zipfile | |
import logging | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from langchain_community.vectorstores import FAISS | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_groq import ChatGroq | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
# === Globals === | |
llm = None | |
embeddings = None | |
vectorstore = None | |
retriever = None | |
quiz_chain = None | |
grade_chain = None | |
class QuizRequest(BaseModel): | |
question: str | |
class GradeRequest(BaseModel): | |
question: str # string of Q/A pairs | |
def load_components(): | |
global llm, embeddings, vectorstore, retriever, quiz_chain, grade_chain | |
try: | |
api_key = os.getenv("api_key") | |
if not api_key: | |
logger.error("API_KEY environment variable is not set or empty.") | |
raise RuntimeError("API_KEY environment variable is not set or empty.") | |
logger.info("API_KEY is set.") | |
# 1) Init LLM & Embeddings | |
llm = ChatGroq( | |
model="meta-llama/llama-4-scout-17b-16e-instruct", | |
temperature=0, | |
max_tokens=1024, | |
api_key=api_key, | |
) | |
embeddings = HuggingFaceEmbeddings( | |
model_name="intfloat/multilingual-e5-large", | |
model_kwargs={"device": "cpu"}, | |
encode_kwargs={"normalize_embeddings": True}, | |
) | |
# 2) Load FAISS indexes | |
for zip_name, dir_name in [("faiss_index.zip", "faiss_index"), ("faiss_index(1).zip", "faiss_index_extra")]: | |
if not os.path.exists(dir_name): | |
with zipfile.ZipFile(zip_name, 'r') as z: | |
z.extractall(dir_name) | |
logger.info(f"Unzipped {zip_name} to {dir_name}.") | |
else: | |
logger.info(f"Directory {dir_name} already exists.") | |
vs1 = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True) | |
logger.info("FAISS index 1 loaded.") | |
vs2 = FAISS.load_local("faiss_index_extra", embeddings, allow_dangerous_deserialization=True) | |
logger.info("FAISS index 2 loaded.") | |
vs1.merge_from(vs2) | |
vectorstore = vs1 | |
logger.info("Merged FAISS indexes into a single vectorstore.") | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 10}) | |
# Quiz generation chain | |
quiz_prompt = PromptTemplate( | |
template=""" | |
Generate a quiz on the topic "{question}" using **only** the information in the "Retrieved context". | |
Include clear questions and multiple-choice options (A, B, C, D). Also provide the answers of the questions with them. | |
If context is insufficient, reply with "I don't know". | |
Retrieved context: | |
{context} | |
Quiz topic: | |
{question} | |
Quiz: | |
""", | |
input_variables=["context", "question"], | |
) | |
quiz_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=False, | |
chain_type_kwargs={"prompt": quiz_prompt}, | |
) | |
logger.info("Quiz chain ready.") | |
except Exception as e: | |
logger.error("Error loading components", exc_info=True) | |
raise | |
def root(): | |
return {"message": "API is up and running!"} | |
def create_quiz(request: QuizRequest): | |
try: | |
logger.info("Generating quiz for topic: %s", request.question) | |
result = quiz_chain.invoke({"query": request.question}) | |
logger.info("Quiz generated successfully.") | |
return {"quiz": result.get("result")} | |
except Exception as e: | |
logger.error("Error generating quiz", exc_info=True) | |
raise HTTPException(status_code=500, detail=str(e)) | |