grading / main.py
Hammad712's picture
Update main.py
16aaeed verified
raw
history blame
2.76 kB
import os
import json
from typing import List
from fastapi import FastAPI, UploadFile, File, HTTPException
from pydantic import BaseModel
from langchain_groq import ChatGroq
from langchain.document_loaders import PyPDFLoader
# Securely load your Groq API key from environment variables
API_KEY = os.getenv("GROQ_API_KEY")
if not API_KEY:
raise ValueError("GROQ_API_KEY environment variable not set.")
app = FastAPI(title="PDF Question Extractor", version="1.0")
# Define the expected JSON response schema
class ExtractionResult(BaseModel):
answers: List[str]
# Initialize the language model (LLM)
def get_llm():
return ChatGroq(
model="llama-3.3-70b-versatile",
temperature=0,
max_tokens=1024,
api_key=API_KEY
)
llm = get_llm()
# Root endpoint: Provides a welcome message and instructions
@app.get("/")
async def root():
return {
"message": "Welcome to the PDF Question Extractor API.",
"usage": "POST your PDF to /extract-answers/ to extract answers."
}
# PDF extraction endpoint: Processes a PDF file upload
@app.post("/extract-answers/")
async def extract_answers(file: UploadFile = File(...)):
try:
# Save the uploaded file temporarily
file_path = f"./temp_{file.filename}"
with open(file_path, "wb") as buffer:
buffer.write(file.file.read())
# Load and split the PDF into pages
loader = PyPDFLoader(file_path)
pages = loader.load_and_split()
all_page_content = "\n".join(page.page_content for page in pages)
# Generate the JSON schema from the Pydantic model
schema_dict = ExtractionResult.model_json_schema()
schema = json.dumps(schema_dict, indent=2)
# Build the prompt with system and user messages
system_message = (
"You are a document analysis tool that extracts the options and correct answers "
"from the provided document content. The output must be a JSON object that strictly follows the schema: "
+ schema
)
user_message = (
"Please extract the correct answers and options (A, B, C, D, E) from the following document content:\n\n"
+ all_page_content
)
prompt = system_message + "\n\n" + user_message
# Invoke the LLM and request a JSON response
response = llm.invoke(prompt, response_format={"type": "json_object"})
# Validate and parse the JSON response using Pydantic
result = ExtractionResult.model_validate_json(response.content)
# Cleanup the temporary file
os.remove(file_path)
return result.model_dump()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))