Hammad712 commited on
Commit
84f2275
·
verified ·
1 Parent(s): e32fb25

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +290 -38
main.py CHANGED
@@ -1,56 +1,308 @@
1
- import os
2
- import shutil
3
- from fastapi import FastAPI, UploadFile, File, HTTPException
4
- from fastapi.responses import JSONResponse
5
- from models import EvaluationRequest
6
- from pdf_processor import extract_answers_from_pdf, evaluate_student
 
 
 
7
 
8
- app = FastAPI()
 
 
9
 
10
- # Directory to temporarily store uploaded files.
11
- UPLOAD_DIR = "uploads"
12
- if not os.path.exists(UPLOAD_DIR):
13
- os.makedirs(UPLOAD_DIR)
14
 
15
- @app.get("/")
16
- async def root():
 
 
 
 
17
  """
18
- Root endpoint that provides a welcome message.
 
19
  """
20
- return {
21
- "message": "Welcome to the PDF Processing API. Use '/extract/' to extract answers from a PDF or '/evaluate/' to calculate marks using pre-extracted answers."
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- @app.post("/extract/")
25
- async def extract_pdf(file: UploadFile = File(...)):
26
  """
27
- Endpoint to extract answers from a PDF file.
 
28
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  try:
30
- file_location = os.path.join(UPLOAD_DIR, file.filename)
31
- with open(file_location, "wb") as f:
32
- shutil.copyfileobj(file.file, f)
33
-
34
- result = extract_answers_from_pdf(file_location)
35
- return JSONResponse(content=result.model_dump())
36
- except Exception as e:
37
- raise HTTPException(status_code=500, detail=str(e))
38
- finally:
39
- if os.path.exists(file_location):
40
- os.remove(file_location)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- @app.post("/evaluate/")
43
- async def evaluate(evaluation_request: EvaluationRequest):
44
  """
45
- Endpoint to evaluate student answers.
46
- Expects a JSON payload with the pre-extracted answer key and student answers.
47
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  try:
49
- evaluation = evaluate_student(evaluation_request.answer_key, evaluation_request.student)
50
- return JSONResponse(content=evaluation.model_dump())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  except Exception as e:
52
  raise HTTPException(status_code=500, detail=str(e))
53
 
 
54
  if __name__ == "__main__":
55
- import uvicorn
56
  uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Query
2
+ from fastapi.responses import JSONResponse, StreamingResponse
3
+ import uvicorn
4
+ import io
5
+ import json
6
+ import numpy as np
7
+ import cv2
8
+ from PIL import Image
9
+ from pdf2image import convert_from_bytes
10
 
11
+ # Import the Google GenAI client libraries.
12
+ from google import genai
13
+ from google.genai import types
14
 
15
+ # Initialize the GenAI client with your API key.
16
+ client = genai.Client(api_key="AIzaSyDDDHg9GWl6-9aq9Wo43GHfk2wcakhgwBQ")
 
 
17
 
18
+ app = FastAPI(title="Student Result Card API")
19
+
20
+ # -----------------------------
21
+ # Preprocessing Methods
22
+ # -----------------------------
23
+ def preprocess_candidate_info(image_cv):
24
  """
25
+ Preprocess the image to extract the candidate information region.
26
+ Region is defined by a mask covering the top-left portion.
27
  """
28
+ height, width = image_cv.shape[:2]
29
+ mask = np.zeros((height, width), dtype="uint8")
30
+ margin_top = int(height * 0.10)
31
+ margin_bottom = int(height * 0.25)
32
+ cv2.rectangle(mask, (0, margin_top), (width, height - margin_bottom), 255, -1)
33
+ masked = cv2.bitwise_and(image_cv, image_cv, mask=mask)
34
+ coords = cv2.findNonZero(mask)
35
+ x, y, w, h = cv2.boundingRect(coords)
36
+ cropped = masked[y:y+h, x:x+w]
37
+ return Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
38
+
39
+ def preprocess_mcq(image_cv):
40
+ """
41
+ Preprocess the image to extract the MCQ answers region (questions 1 to 10).
42
+ Region is defined by a mask on the left side of the page.
43
+ """
44
+ height, width = image_cv.shape[:2]
45
+ mask = np.zeros((height, width), dtype="uint8")
46
+ margin_top = int(height * 0.27)
47
+ margin_bottom = int(height * 0.23)
48
+ right_boundary = int(width * 0.35)
49
+ cv2.rectangle(mask, (0, margin_top), (right_boundary, height - margin_bottom), 255, -1)
50
+ masked = cv2.bitwise_and(image_cv, image_cv, mask=mask)
51
+ coords = cv2.findNonZero(mask)
52
+ x, y, w, h = cv2.boundingRect(coords)
53
+ cropped = masked[y:y+h, x:x+w]
54
+ return Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
55
 
56
+ def preprocess_free_response(image_cv):
 
57
  """
58
+ Preprocess the image to extract the free-response answers region (questions 11 to 15).
59
+ Region is defined by a mask on the middle-right part of the page.
60
  """
61
+ height, width = image_cv.shape[:2]
62
+ mask = np.zeros((height, width), dtype="uint8")
63
+ margin_top = int(height * 0.27)
64
+ margin_bottom = int(height * 0.38)
65
+ left_boundary = int(width * 0.35)
66
+ right_boundary = int(width * 0.68)
67
+ cv2.rectangle(mask, (left_boundary, margin_top), (right_boundary, height - margin_bottom), 255, -1)
68
+ masked = cv2.bitwise_and(image_cv, image_cv, mask=mask)
69
+ coords = cv2.findNonZero(mask)
70
+ x, y, w, h = cv2.boundingRect(coords)
71
+ cropped = masked[y:y+h, x:x+w]
72
+ return Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
73
+
74
+ def preprocess_full_answers(image_cv):
75
+ """
76
+ For extracting the correct answer key, we assume the entire page contains the answers.
77
+ """
78
+ return Image.fromarray(cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB))
79
+
80
+ # -----------------------------
81
+ # Extraction Methods using Gemini
82
+ # -----------------------------
83
+ def extract_json_from_output(output_str):
84
+ """
85
+ Extracts a JSON object from a string containing extra text.
86
+ """
87
+ start = output_str.find('{')
88
+ end = output_str.rfind('}')
89
+ if start == -1 or end == -1:
90
+ return None
91
+ json_str = output_str[start:end+1]
92
  try:
93
+ return json.loads(json_str)
94
+ except json.JSONDecodeError:
95
+ return None
96
+
97
+ def get_student_info(image_input):
98
+ """
99
+ Extracts candidate information from an image.
100
+ """
101
+ output_format = """
102
+ Answer in the following JSON format. Do not write anything else:
103
+ {
104
+ "Candidate Info": {
105
+ "Name": "<name>",
106
+ "Number": "<number>",
107
+ "Country": "<country>",
108
+ "Level": "<level>"
109
+ }
110
+ }
111
+ """
112
+ prompt = f"""
113
+ You are an assistant that extracts candidate information from an image.
114
+ The image contains details including name, candidate number, country, and level.
115
+ Extract the information accurately and provide the result in JSON using the format below:
116
+ {output_format}
117
+ """
118
+ response = client.models.generate_content(model="gemini-2.0-flash", contents=[prompt, image_input])
119
+ return extract_json_from_output(response.text)
120
 
121
+ def get_mcq_answers(image_input):
 
122
  """
123
+ Extracts multiple-choice answers (questions 1 to 10) from an image.
 
124
  """
125
+ output_format = """
126
+ Answer in the following JSON format do not write anything else:
127
+ {
128
+ "Answers": {
129
+ "1": "<option>",
130
+ "2": "<option>",
131
+ "3": "<option>",
132
+ "4": "<option>",
133
+ "5": "<option>",
134
+ "6": "<option>",
135
+ "7": "<option>",
136
+ "8": "<option>",
137
+ "9": "<option>",
138
+ "10": "<option>"
139
+ }
140
+ }
141
+ """
142
+ prompt = f"""
143
+ You are an assistant that extracts MCQ answers from an image.
144
+ The image is a screenshot of a 10-question multiple-choice answer sheet.
145
+ Extract which option is marked for each question (1 to 10) and provide the answers in JSON using the format below:
146
+ {output_format}
147
+ """
148
+ response = client.models.generate_content(model="gemini-2.0-flash", contents=[prompt, image_input])
149
+ return extract_json_from_output(response.text)
150
+
151
+ def get_free_response_answers(image_input):
152
+ """
153
+ Extracts free-text answers (questions 11 to 15) from an image.
154
+ """
155
+ output_format = """
156
+ Answer in the following JSON format. Do not write anything else:
157
+ {
158
+ "Free Answers": {
159
+ "11": "<answer for question 11>",
160
+ "12": "<answer for question 12>",
161
+ "13": "<answer for question 13>",
162
+ "14": "<answer for question 14>",
163
+ "15": "<answer for question 15>"
164
+ }
165
+ }
166
+ """
167
+ prompt = f"""
168
+ You are an assistant that extracts free-text answers from an image.
169
+ The image contains responses for questions 11 to 15.
170
+ Extract the answers accurately and provide the result in JSON using the format below:
171
+ {output_format}
172
+ """
173
+ response = client.models.generate_content(model="gemini-2.0-flash", contents=[prompt, image_input])
174
+ return extract_json_from_output(response.text)
175
+
176
+ def get_all_answers(image_input):
177
+ """
178
+ Extracts all answers (questions 1 to 15) from an image of the correct answer key.
179
+ """
180
+ output_format = """
181
+ Answer in the following JSON format. Do not write anything else:
182
+ {
183
+ "Answers": {
184
+ "1": "<option or text>",
185
+ "2": "<option or text>",
186
+ "3": "<option or text>",
187
+ "4": "<option or text>",
188
+ "5": "<option or text>",
189
+ "6": "<option or text>",
190
+ "7": "<option or text>",
191
+ "8": "<option or text>",
192
+ "9": "<option or text>",
193
+ "10": "<option or text>",
194
+ "11": "<free-text answer>",
195
+ "12": "<free-text answer>",
196
+ "13": "<free-text answer>",
197
+ "14": "<free-text answer>",
198
+ "15": "<free-text answer>"
199
+ }
200
+ }
201
+ """
202
+ prompt = f"""
203
+ You are an assistant that extracts answers from an image.
204
+ The image is a screenshot of an answer sheet containing 15 questions.
205
+ For questions 1 to 10, the answers are multiple-choice selections.
206
+ For questions 11 to 15, the answers are free-text responses.
207
+ Extract the answer for each question and provide the result in JSON using the format below:
208
+ {output_format}
209
+ """
210
+ response = client.models.generate_content(model="gemini-2.0-flash", contents=[prompt, image_input])
211
+ return extract_json_from_output(response.text)
212
+
213
+ # -----------------------------
214
+ # Method to calculate result card
215
+ # -----------------------------
216
+ def calculate_result(student_info, student_mcq, student_free, correct_answers):
217
+ """
218
+ Compares student's answers with the correct answers, calculates marks and percentage,
219
+ and returns a result card in JSON.
220
+ """
221
+ student_all = {}
222
+ if student_mcq and "Answers" in student_mcq:
223
+ student_all.update(student_mcq["Answers"])
224
+ if student_free and "Free Answers" in student_free:
225
+ student_all.update(student_free["Free Answers"])
226
+
227
+ correct_all = correct_answers.get("Answers", {})
228
+ total_questions = 15
229
+ marks = 0
230
+ detailed = {}
231
+
232
+ for q in map(str, range(1, total_questions + 1)):
233
+ student_ans = student_all.get(q, "").strip()
234
+ correct_ans = correct_all.get(q, "").strip()
235
+ if student_ans == correct_ans:
236
+ marks += 1
237
+ detailed[q] = {"Student": student_ans, "Correct": correct_ans, "Result": "Correct"}
238
+ else:
239
+ detailed[q] = {"Student": student_ans, "Correct": correct_ans, "Result": "Incorrect"}
240
+
241
+ percentage = (marks / total_questions) * 100
242
+ result_card = {
243
+ "Candidate Info": student_info.get("Candidate Info", {}),
244
+ "Total Marks": marks,
245
+ "Total Questions": total_questions,
246
+ "Percentage": percentage,
247
+ "Detailed Results": detailed
248
+ }
249
+ return result_card
250
+
251
+ # -----------------------------
252
+ # API Endpoint to process PDFs and return student result cards
253
+ # -----------------------------
254
+ @app.post("/process")
255
+ async def process_pdfs(
256
+ student_pdf: UploadFile = File(...),
257
+ answer_key_pdf: UploadFile = File(...),
258
+ download: bool = Query(False, description="Set to true to download result card list as a JSON file")
259
+ ):
260
  try:
261
+ # Read student PDF bytes and convert to images
262
+ student_bytes = await student_pdf.read()
263
+ student_images = convert_from_bytes(student_bytes)
264
+
265
+ # Read answer key PDF bytes and convert to images; assume correct key is in the last page.
266
+ answer_key_bytes = await answer_key_pdf.read()
267
+ answer_key_images = convert_from_bytes(answer_key_bytes)
268
+ last_page = answer_key_images[-1]
269
+ last_page_cv = np.array(last_page)
270
+ last_page_cv = cv2.cvtColor(last_page_cv, cv2.COLOR_RGB2BGR)
271
+ correct_image = preprocess_full_answers(last_page_cv)
272
+ correct_answers = get_all_answers(correct_image)
273
+
274
+ student_result_cards = []
275
+
276
+ # Process each student page.
277
+ for idx, page in enumerate(student_images):
278
+ page_cv = np.array(page)
279
+ page_cv = cv2.cvtColor(page_cv, cv2.COLOR_RGB2BGR)
280
+ student_info_image = preprocess_candidate_info(page_cv)
281
+ mcq_image = preprocess_mcq(page_cv)
282
+ free_image = preprocess_free_response(page_cv)
283
+
284
+ student_info = get_student_info(student_info_image)
285
+ student_mcq = get_mcq_answers(mcq_image)
286
+ student_free = get_free_response_answers(free_image)
287
+
288
+ result_card = calculate_result(student_info, student_mcq, student_free, correct_answers)
289
+ result_card["Student Index"] = idx + 1
290
+ student_result_cards.append(result_card)
291
+
292
+ if download:
293
+ # Create downloadable JSON file
294
+ json_bytes = json.dumps({"result_cards": student_result_cards}, indent=2).encode("utf-8")
295
+ return StreamingResponse(
296
+ io.BytesIO(json_bytes),
297
+ media_type="application/json",
298
+ headers={"Content-Disposition": "attachment; filename=result_cards.json"}
299
+ )
300
+ else:
301
+ return JSONResponse(content={"result_cards": student_result_cards})
302
+
303
  except Exception as e:
304
  raise HTTPException(status_code=500, detail=str(e))
305
 
306
+
307
  if __name__ == "__main__":
 
308
  uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)