grading / main.py
Hammad712's picture
Create main.py
7f269b9 verified
raw
history blame
2.33 kB
import os
import json
from typing import List
from fastapi import FastAPI, UploadFile, File, HTTPException
from pydantic import BaseModel
from langchain_groq import ChatGroq
from langchain.document_loaders import PyPDFLoader
# Load API key securely from environment variable
API_KEY = os.getenv("GROQ_API_KEY")
if not API_KEY:
raise ValueError("GROQ_API_KEY environment variable not set.")
app = FastAPI(title="PDF Question Extractor", version="1.0")
# Pydantic model for response
class ExtractionResult(BaseModel):
answers: List[str]
# Initialize LLM
def get_llm():
return ChatGroq(
model="llama-3.3-70b-versatile",
temperature=0,
max_tokens=1024,
api_key=API_KEY
)
llm = get_llm()
@app.post("/extract-answers/")
async def extract_answers(file: UploadFile = File(...)):
try:
# Save the uploaded file temporarily
file_path = f"./temp_{file.filename}"
with open(file_path, "wb") as buffer:
buffer.write(file.file.read())
# Load and extract text from PDF
loader = PyPDFLoader(file_path)
pages = loader.load_and_split()
all_page_content = "\n".join(page.page_content for page in pages)
# JSON schema definition
schema_dict = ExtractionResult.model_json_schema()
schema = json.dumps(schema_dict, indent=2)
# System message
system_message = (
"You are a document analysis tool that extracts the options and correct answers from the provided document content. "
"The output must be a JSON object that strictly follows the schema: " + schema
)
# User message
user_message = (
"Please extract the correct answers and options (A, B, C, D, E) from the following document content:\n\n"
+ all_page_content
)
# Construct final prompt
prompt = system_message + "\n\n" + user_message
# Get LLM response
response = llm.invoke(prompt, response_format={"type": "json_object"})
# Parse and validate response
result = ExtractionResult.model_validate_json(response.content)
# Cleanup
os.remove(file_path)
return result.model_dump()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))