namberino commited on
Commit
5f9ac06
·
1 Parent(s): 0ba52f3

Initial commit

Browse files
Files changed (5) hide show
  1. app.py +75 -0
  2. enhanced_rag_mcq.py +933 -0
  3. fastapi_app.py +135 -0
  4. requirements.txt +0 -0
  5. tmp/README.md +1 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ from fastapi.testclient import TestClient
4
+ from fastapi_app import app as fastapi_app # import your renamed FastAPI module
5
+ import io
6
+
7
+ # In-process client for the FastAPI app
8
+ client = TestClient(fastapi_app)
9
+
10
+ def call_generate(file_obj, topics, n_questions, difficulty, question_type):
11
+ if file_obj is None:
12
+ return {"error": "No file uploaded."}
13
+
14
+ # Read the uploaded file bytes and create multipart payload
15
+ file_bytes = file_obj.read()
16
+ files = {"file": ("uploaded_file", file_bytes, "application/octet-stream")}
17
+
18
+ data = {
19
+ "topics": topics if topics is not None else "",
20
+ "n_questions": str(n_questions),
21
+ "difficulty": difficulty if difficulty is not None else "",
22
+ "question_type": question_type if question_type is not None else ""
23
+ }
24
+
25
+ try:
26
+ resp = client.post("/generate/", files=files, data=data, timeout=120) # increase timeout if needed
27
+ except Exception as e:
28
+ return {"error": f"Request failed: {e}"}
29
+
30
+ if resp.status_code != 200:
31
+ # return helpful debug info
32
+ return {
33
+ "status_code": resp.status_code,
34
+ "text": resp.text,
35
+ "json": None
36
+ }
37
+
38
+ # print(resp.text)
39
+
40
+ # Parse JSON response
41
+ try:
42
+ out = resp.json()
43
+ except Exception:
44
+ # maybe the endpoint returns text: return it directly
45
+ return {"text": resp.text}
46
+
47
+ # pretty-format the JSON for display
48
+ return out
49
+
50
+ # Gradio UI
51
+ with gr.Blocks(title="RAG MCQ generator") as gradio_app:
52
+ gr.Markdown("## Upload a file and generate MCQs")
53
+
54
+ with gr.Row():
55
+ file_input = gr.File(label="Upload file (PDF, docx, etc)")
56
+ topics = gr.Textbox(label="Topics (comma separated)", placeholder="e.g. calculus, derivatives")
57
+ with gr.Row():
58
+ n_questions = gr.Slider(minimum=1, maximum=50, step=1, value=5, label="Number of questions")
59
+ difficulty = gr.Dropdown(choices=["easy", "medium", "hard"], value="medium", label="Difficulty")
60
+ question_type = gr.Dropdown(choices=["mcq", "short", "long"], value="mcq", label="Question type")
61
+
62
+ generate_btn = gr.Button("Generate")
63
+ output = gr.JSON(label="Response")
64
+
65
+ generate_btn.click(
66
+ fn=call_generate,
67
+ inputs=[file_input, topics, n_questions, difficulty, question_type],
68
+ outputs=[output],
69
+ )
70
+
71
+ app = gradio_app
72
+
73
+ if __name__ == "__main__":
74
+ # demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
75
+ gradio_app.launch()
enhanced_rag_mcq.py ADDED
@@ -0,0 +1,933 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced RAG System for Multiple Choice Question Generation
3
+ Author: MathematicGuy
4
+ Date: July 2025
5
+
6
+ This module implements an advanced RAG system specifically designed for generating
7
+ high-quality Multiple Choice Questions from educational documents.
8
+ """
9
+
10
+ import os
11
+ import json
12
+ import time
13
+ import torch
14
+ import re
15
+ from typing import List, Dict, Any, Optional, Tuple
16
+ from dataclasses import dataclass, asdict
17
+ from pathlib import Path
18
+ from enum import Enum
19
+ import numpy as np
20
+
21
+ # LangChain imports
22
+ from langchain_community.document_loaders import PyPDFLoader
23
+ from langchain_huggingface.embeddings import HuggingFaceEmbeddings
24
+ from langchain_experimental.text_splitter import SemanticChunker
25
+ from langchain_huggingface.llms import HuggingFacePipeline
26
+ from langchain_core.output_parsers import JsonOutputParser
27
+ from langchain_core.runnables import RunnableParallel, RunnablePassthrough
28
+ from langchain_core.prompts import PromptTemplate
29
+ from langchain_community.vectorstores import FAISS
30
+
31
+ from langchain_core.documents import Document
32
+
33
+ # Transformers imports
34
+ from transformers import (
35
+ AutoModelForCausalLM,
36
+ AutoTokenizer,
37
+ )
38
+ from transformers.pipelines import pipeline
39
+ from transformers.utils.quantization_config import BitsAndBytesConfig
40
+
41
+
42
+ class QuestionType(Enum):
43
+ """Enumeration of different question types"""
44
+ DEFINITION = "definition"
45
+ COMPARISON = "comparison"
46
+ APPLICATION = "application"
47
+ ANALYSIS = "analysis"
48
+ EVALUATION = "evaluation"
49
+
50
+
51
+ class DifficultyLevel(Enum):
52
+ """Enumeration of difficulty levels"""
53
+ EASY = "easy"
54
+ MEDIUM = "medium"
55
+ HARD = "hard"
56
+ EXPERT = "expert"
57
+
58
+
59
+ @dataclass
60
+ class MCQOption:
61
+ """Data class for MCQ options"""
62
+ label: str
63
+ text: str
64
+ is_correct: bool
65
+
66
+
67
+ @dataclass
68
+ class MCQQuestion:
69
+ """Data class for Multiple Choice Question"""
70
+ question: str
71
+ context: str
72
+ options: List[MCQOption]
73
+ explanation: str
74
+ difficulty: str
75
+ topic: str
76
+ question_type: str
77
+ source: str
78
+ confidence_score: float = 0.0
79
+
80
+ def to_dict(self) -> Dict[str, Any]:
81
+ """Convert to dictionary format"""
82
+ return {
83
+ "question": self.question,
84
+ "context": self.context,
85
+ "options": {opt.label: opt.text for opt in self.options},
86
+ "correct_answer": next(opt.label for opt in self.options if opt.is_correct),
87
+ "explanation": self.explanation,
88
+ "difficulty": self.difficulty,
89
+ "topic": self.topic,
90
+ "question_type": self.question_type,
91
+ "source": self.source,
92
+ "confidence_score": self.confidence_score
93
+ }
94
+
95
+
96
+ class PromptTemplateManager:
97
+ """Manages different prompt templates for various question types"""
98
+
99
+ def __init__(self):
100
+ self.base_template = self._create_base_template()
101
+ self.templates = self._initialize_templates()
102
+
103
+ def _create_base_template(self) -> str:
104
+ """Create the base template used by all question types"""
105
+ return """
106
+ Hãy tạo 1 câu hỏi trắc nghiệm dựa trên nội dung sau đây.
107
+
108
+ Nội dung: {context}
109
+ Chủ đề: {topic}
110
+ Mức độ: {difficulty}
111
+ Loại câu hỏi: {question_type}
112
+
113
+ QUAN TRỌNG: Chỉ trả về JSON hợp lệ, không có text bổ sung, luôn trả lời bằng tiếng việt:
114
+
115
+ {{
116
+ "question": "Câu hỏi rõ ràng về {topic}",
117
+ "options": {{
118
+ "A": "Đáp án A",
119
+ "B": "Đáp án B",
120
+ "C": "Đáp án C",
121
+ "D": "Đáp án D"
122
+ }},
123
+ "correct_answer": "A",
124
+ "explanation": "Giải thích tại sao đáp án A đúng",
125
+ "topic": "{topic}",
126
+ "difficulty": "{difficulty}",
127
+ "question_type": "{question_type}"
128
+ }}
129
+
130
+ Trả lời:
131
+ """
132
+
133
+ def _create_question_specific_instruction(self, question_type: str) -> str:
134
+ """Create specific instructions for each question type"""
135
+ instructions = {
136
+ "definition": "Tạo câu hỏi định nghĩa thuật ngữ. Tập trung vào định nghĩa chính xác.",
137
+ "application": "Tạo câu hỏi ứng dụng thực tế. Bao gồm tình huống cụ thể.",
138
+ "analysis": "Tạo câu hỏi phân tích code/sơ đồ. Kiểm tra tư duy phản biện.",
139
+ "comparison": "Tạo câu hỏi so sánh khái niệm. Tập trung vào điểm khác biệt.",
140
+ "evaluation": "Tạo câu hỏi đánh giá phương pháp. Yêu cầu quyết định dựa trên tiêu chí."
141
+ }
142
+ return instructions.get(question_type, "Tạo câu hỏi chất lượng cao.")
143
+
144
+ def _initialize_templates(self) -> Dict[str, str]:
145
+ """Initialize all templates using the base template"""
146
+ templates = {"base": self.base_template}
147
+
148
+ # Create shorter, more direct templates for each question type
149
+ instructions = {
150
+ "definition": "Tạo câu hỏi định nghĩa thuật ngữ.",
151
+ "application": "Tạo câu hỏi ứng dụng thực tế.",
152
+ "analysis": "Tạo câu hỏi phân tích code/sơ đồ.",
153
+ "comparison": "Tạo câu hỏi so sánh khái niệm.",
154
+ "evaluation": "Tạo câu hỏi đánh giá phương pháp."
155
+ }
156
+
157
+ # Add specific templates for each question type
158
+ for question_type in QuestionType:
159
+ instruction = instructions.get(question_type.value, "Tạo câu hỏi chất lượng cao.")
160
+ templates[question_type.value] = f"{instruction}\n\n{self.base_template}"
161
+
162
+ return templates
163
+
164
+ def get_template(self, question_type: QuestionType = QuestionType.DEFINITION) -> str:
165
+ """Get prompt template for specific question type"""
166
+ return self.templates.get(question_type.value, self.templates["base"])
167
+
168
+ def update_base_template(self, new_base_template: str):
169
+ """Update the base template and regenerate all templates"""
170
+ self.base_template = new_base_template
171
+ self.templates = self._initialize_templates()
172
+ print("✅ Base template updated and all templates regenerated")
173
+
174
+ def get_template_info(self) -> Dict[str, int]:
175
+ """Get information about all templates (for debugging)"""
176
+ return {
177
+ template_type: len(template)
178
+ for template_type, template in self.templates.items()
179
+ }
180
+
181
+ class QualityValidator:
182
+ """Validates the quality of generated MCQ questions"""
183
+
184
+ def __init__(self):
185
+ self.min_question_length = 10
186
+ self.max_question_length = 200
187
+ self.min_explanation_length = 20
188
+
189
+ def validate_mcq(self, mcq: MCQQuestion) -> Tuple[bool, List[str]]:
190
+ """Validate MCQ and return validation result with issues"""
191
+ issues = []
192
+
193
+ # Check question length
194
+ if len(mcq.question) < self.min_question_length:
195
+ issues.append("Question too short")
196
+ elif len(mcq.question) > self.max_question_length:
197
+ issues.append("Question too long")
198
+
199
+ # Check options count
200
+ if len(mcq.options) != 4:
201
+ issues.append("Must have exactly 4 options")
202
+
203
+ # Check for single correct answer
204
+ correct_count = sum(1 for opt in mcq.options if opt.is_correct)
205
+ if correct_count != 1:
206
+ issues.append("Must have exactly one correct answer")
207
+
208
+ # Check explanation
209
+ if len(mcq.explanation) < self.min_explanation_length:
210
+ issues.append("Explanation too short")
211
+
212
+ # Check for distinct options
213
+ option_texts = [opt.text for opt in mcq.options]
214
+ if len(set(option_texts)) != len(option_texts):
215
+ issues.append("Options must be distinct")
216
+
217
+ return len(issues) == 0, issues
218
+
219
+ def calculate_quality_score(self, mcq: MCQQuestion) -> float:
220
+ """Calculate quality score from 0 to 100"""
221
+ is_valid, issues = self.validate_mcq(mcq)
222
+ print("MCQ Output Quality Score:", issues)
223
+
224
+ if not is_valid:
225
+ return 0.0
226
+
227
+ # Start with base score
228
+ score = 70.0
229
+
230
+ # Bonus for good explanation length
231
+ if len(mcq.explanation) > 50:
232
+ score += 10
233
+
234
+ # Bonus for appropriate question length
235
+ if 20 <= len(mcq.question) <= 100:
236
+ score += 10
237
+
238
+ # Bonus for diverse option lengths (indicates good distractors)
239
+ option_lengths = [len(opt.text) for opt in mcq.options]
240
+ if max(option_lengths) - min(option_lengths) < 50: # Similar lengths
241
+ score += 10
242
+
243
+ return min(score, 100.0)
244
+
245
+
246
+ class DifficultyAnalyzer:
247
+ """Analyzes and adjusts question difficulty"""
248
+
249
+ def __init__(self):
250
+ self.difficulty_keywords = {
251
+ DifficultyLevel.EASY: ["là gì", "định nghĩa", "ví dụ", "đơn giản"],
252
+ DifficultyLevel.MEDIUM: ["so sánh", "khác biệt", "ứng dụng", "khi nào"],
253
+ DifficultyLevel.HARD: ["phân tích", "đánh giá", "tối ưu", "thiết kế"],
254
+ DifficultyLevel.EXPERT: ["tổng hợp", "sáng tạo", "nghiên cứu", "phát triển"]
255
+ }
256
+
257
+ def assess_difficulty(self, question: str, context: str) -> DifficultyLevel:
258
+ """Assess question difficulty based on content analysis"""
259
+ question_lower = question.lower()
260
+
261
+ # Count difficulty indicators
262
+ difficulty_scores = {}
263
+ for level, keywords in self.difficulty_keywords.items():
264
+ score = sum(1 for keyword in keywords if keyword in question_lower)
265
+ difficulty_scores[level] = score
266
+
267
+ # Return highest scoring difficulty
268
+ if not any(difficulty_scores.values()):
269
+ return DifficultyLevel.MEDIUM
270
+
271
+ return max(difficulty_scores.keys(), key=lambda k: difficulty_scores[k])
272
+
273
+
274
+ class ContextAwareRetriever:
275
+ """Enhanced retriever with context awareness and diversity"""
276
+
277
+ def __init__(self, vector_db: FAISS, diversity_threshold: float = 0.7):
278
+ self.vector_db = vector_db
279
+ self.diversity_threshold = diversity_threshold
280
+
281
+ def retrieve_diverse_contexts(self, query: str, k: int = 5) -> List[Document]:
282
+ """Retrieve documents with semantic diversity"""
283
+ # Get more candidates than needed
284
+ candidates = self.vector_db.similarity_search(query, k=k*2)
285
+
286
+ if not candidates:
287
+ return []
288
+
289
+ # Select diverse documents
290
+ selected = [candidates[0]] # Always include the most relevant
291
+
292
+ for candidate in candidates[1:]:
293
+ if len(selected) >= k:
294
+ break
295
+
296
+ # Check diversity with already selected documents
297
+ is_diverse = True
298
+ for selected_doc in selected:
299
+ similarity = self._calculate_similarity(candidate.page_content,
300
+ selected_doc.page_content)
301
+ if similarity > self.diversity_threshold:
302
+ is_diverse = False
303
+ break
304
+
305
+ if is_diverse:
306
+ selected.append(candidate)
307
+
308
+ return selected[:k]
309
+
310
+ def _calculate_similarity(self, text1: str, text2: str) -> float:
311
+ """Calculate text similarity (simplified implementation)"""
312
+ words1 = set(text1.lower().split())
313
+ words2 = set(text2.lower().split())
314
+
315
+ if not words1 or not words2:
316
+ return 0.0
317
+
318
+ intersection = words1.intersection(words2)
319
+ union = words1.union(words2)
320
+
321
+ return len(intersection) / len(union)
322
+
323
+
324
+ class EnhancedRAGMCQGenerator:
325
+ """Enhanced RAG system for MCQ generation"""
326
+
327
+ def __init__(self, config: Optional[Dict[str, Any]] = None):
328
+ self.config = config or self._get_default_config()
329
+ self.embeddings = None
330
+ self.llm = None
331
+ self.vector_db = None
332
+ self.retriever = None
333
+ self.prompt_manager = PromptTemplateManager()
334
+ self.validator = QualityValidator()
335
+ self.difficulty_analyzer = DifficultyAnalyzer()
336
+
337
+ def _get_default_config(self) -> Dict[str, Any]:
338
+ """Get default configuration"""
339
+ return {
340
+ "embedding_model": "bkai-foundation-models/vietnamese-bi-encoder",
341
+ "llm_model": "Qwen/Qwen2.5-3B-Instruct", # 7B, 1.5B
342
+ "chunk_size": 500,
343
+ "chunk_overlap": 50,
344
+ "retrieval_k": 3,
345
+ "generation_temperature": 0.7,
346
+ "max_tokens": 512,
347
+ "diversity_threshold": 0.7,
348
+ "max_context_length": 600, # Maximum context characters
349
+ "max_input_tokens": 1600 # Maximum total input tokens
350
+ }
351
+
352
+ def _truncate_context(self, context: str, max_length: Optional[int] = None) -> str:
353
+ """Intelligently truncate context to fit within token limits"""
354
+ actual_max_length = max_length if max_length is not None else self.config["max_context_length"]
355
+
356
+ if len(context) <= actual_max_length:
357
+ return context
358
+
359
+ # Try to truncate at sentence boundary
360
+ sentences = context.split('. ')
361
+ truncated = ""
362
+
363
+ for sentence in sentences:
364
+ if len(truncated + sentence + '. ') <= actual_max_length:
365
+ truncated += sentence + '. '
366
+ else:
367
+ break
368
+
369
+ # If no complete sentences fit, truncate at word boundary
370
+ if not truncated:
371
+ words = context.split()
372
+ truncated = ""
373
+ for word in words:
374
+ if len(truncated + word + ' ') <= actual_max_length:
375
+ truncated += word + ' '
376
+ else:
377
+ break
378
+
379
+ return truncated.strip()
380
+
381
+ def _estimate_token_count(self, text: str) -> int:
382
+ """Estimate token count for Vietnamese text (approximation)"""
383
+ # Vietnamese typically has ~0.75 tokens per character
384
+ return int(len(text) * 0.75)
385
+
386
+ #? Parse Json String
387
+ def _extract_json_from_response(self, response: str) -> dict:
388
+ """Extract JSON from LLM response with multiple fallback strategies"""
389
+
390
+ # Strategy 1: Clean response of prompt repetition
391
+ clean_response = response
392
+ if "Tạo câu hỏi" in response:
393
+ # Find where the actual response starts (after the prompt)
394
+ response_parts = response.split("JSON:")
395
+ if len(response_parts) > 1:
396
+ clean_response = response_parts[-1].strip()
397
+ else:
398
+ # Try splitting on common phrases
399
+ for split_phrase in ["QUAN TRỌNG:", "Trả về JSON:", "{"]:
400
+ if split_phrase in response:
401
+ clean_response = response.split(split_phrase)[-1].strip()
402
+ if split_phrase == "{":
403
+ clean_response = "{" + clean_response
404
+ break
405
+
406
+ # Strategy 2: Find JSON boundaries
407
+ json_start = clean_response.find("{")
408
+ json_end = clean_response.rfind("}") + 1
409
+
410
+ if json_start != -1 and json_end > json_start:
411
+ json_text = clean_response[json_start:json_end]
412
+ try:
413
+ return json.loads(json_text)
414
+ except json.JSONDecodeError:
415
+ pass
416
+
417
+ # Strategy 3: Use regex to find JSON-like structures
418
+ json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
419
+ json_matches = re.findall(json_pattern, clean_response, re.DOTALL)
420
+
421
+ for json_match in reversed(json_matches): # Try from last to first
422
+ try:
423
+ return json.loads(json_match)
424
+ except json.JSONDecodeError:
425
+ continue
426
+
427
+ # Strategy 4: Try to fix common JSON issues
428
+ for json_match in reversed(json_matches):
429
+ try:
430
+ # Fix common issues like trailing commas
431
+ fixed_json = re.sub(r',(\s*[}\]])', r'\1', json_match)
432
+ return json.loads(fixed_json)
433
+ except json.JSONDecodeError:
434
+ continue
435
+
436
+ raise ValueError("No valid JSON found in response")
437
+
438
+ def check_prompt_length(self, prompt: str) -> Tuple[bool, int]:
439
+ """Check if prompt length is within safe limits"""
440
+ estimated_tokens = self._estimate_token_count(prompt)
441
+ max_safe_tokens = self.config["max_input_tokens"]
442
+
443
+ is_safe = estimated_tokens <= max_safe_tokens
444
+ return is_safe, estimated_tokens
445
+
446
+ def initialize_system(self):
447
+ """Initialize all system components"""
448
+ print("🔧 Initializing Enhanced RAG MCQ Generator...")
449
+
450
+ # Load embeddings
451
+ self.embeddings = HuggingFaceEmbeddings(
452
+ model_name=self.config["embedding_model"]
453
+ )
454
+ print("✅ Embeddings loaded")
455
+
456
+ # Load LLM
457
+ self.llm = self._load_llm()
458
+ print("✅ LLM loaded")
459
+
460
+ def _load_llm(self) -> HuggingFacePipeline:
461
+ """Load and configure the LLM"""
462
+ token_path = Path("./tokens/hugging_face_token.txt")
463
+ if token_path.exists():
464
+ with token_path.open("r") as f:
465
+ hf_token = f.read().strip()
466
+ else:
467
+ hf_token = None
468
+
469
+ bnb_config = BitsAndBytesConfig(
470
+ load_in_4bit=True,
471
+ bnb_4bit_use_double_quant=True,
472
+ bnb_4bit_compute_dtype=torch.bfloat16,
473
+ bnb_4bit_quant_type="nf4"
474
+ )
475
+
476
+ model = AutoModelForCausalLM.from_pretrained(
477
+ self.config["llm_model"],
478
+ quantization_config=bnb_config,
479
+ low_cpu_mem_usage=True,
480
+ device_map="cuda", # Use CUDA if available
481
+ token=hf_token
482
+ )
483
+
484
+ tokenizer = AutoTokenizer.from_pretrained(self.config["llm_model"])
485
+ # tokenizer.pad_token = tokenizer.eos_token
486
+
487
+ model_pipeline = pipeline(
488
+ "text-generation",
489
+ model=model,
490
+ tokenizer=tokenizer,
491
+ max_new_tokens=self.config["max_tokens"],
492
+ temperature=self.config["generation_temperature"],
493
+ pad_token_id=tokenizer.eos_token_id,
494
+ device_map="auto"
495
+ )
496
+
497
+ return HuggingFacePipeline(pipeline=model_pipeline)
498
+
499
+ def load_documents(self, folder_path: str) -> Tuple[List[Document], List[str]]:
500
+ """Load and process documents"""
501
+ folder = Path(folder_path)
502
+ if not folder.exists():
503
+ raise FileNotFoundError(f"Folder not found: {folder}")
504
+
505
+ pdf_files = list(folder.glob("*.pdf"))
506
+ if not pdf_files:
507
+ raise ValueError(f"No PDF files found in: {folder}")
508
+
509
+ all_docs, filenames = [], []
510
+ for pdf_file in pdf_files:
511
+ try:
512
+ loader = PyPDFLoader(str(pdf_file))
513
+ docs = loader.load()
514
+ all_docs.extend(docs)
515
+ filenames.append(pdf_file.name)
516
+ print(f"✅ Loaded {pdf_file.name} ({len(docs)} pages)")
517
+ except Exception as e:
518
+ print(f"❌ Failed loading {pdf_file.name}: {e}")
519
+
520
+ return all_docs, filenames
521
+
522
+ def build_vector_database(self, docs: List[Document]) -> int:
523
+ """Build vector database with semantic chunking"""
524
+ if not self.embeddings:
525
+ raise RuntimeError("Embeddings not initialized. Call initialize_system() first.")
526
+
527
+ chunker = SemanticChunker(
528
+ embeddings=self.embeddings,
529
+ buffer_size=1,
530
+ breakpoint_threshold_type="percentile",
531
+ breakpoint_threshold_amount=95,
532
+ min_chunk_size=self.config["chunk_size"],
533
+ add_start_index=True
534
+ )
535
+
536
+ chunks = chunker.split_documents(docs)
537
+ self.vector_db = FAISS.from_documents(chunks, embedding=self.embeddings)
538
+
539
+ # Initialize enhanced retriever
540
+ self.retriever = ContextAwareRetriever(
541
+ self.vector_db,
542
+ self.config["diversity_threshold"]
543
+ )
544
+
545
+ print(f"✅ Created vector database with {len(chunks)} chunks")
546
+ return len(chunks)
547
+
548
+ def generate_mcq(self,
549
+ topic: str,
550
+ difficulty: DifficultyLevel = DifficultyLevel.MEDIUM,
551
+ question_type: QuestionType = QuestionType.DEFINITION,
552
+ context_query: Optional[str] = None) -> MCQQuestion:
553
+ """Generate a single MCQ with proper length management"""
554
+
555
+ if not self.retriever:
556
+ raise RuntimeError("System not initialized. Call initialize_system() first.")
557
+
558
+ if not self.llm:
559
+ raise RuntimeError("LLM not initialized. Call initialize_system() first.")
560
+
561
+ # Use topic as query if no specific context query provided
562
+ query = context_query or topic
563
+
564
+ # Retrieve relevant contexts (reduced number)
565
+ contexts = self.retriever.retrieve_diverse_contexts(
566
+ query, k=self.config["retrieval_k"]
567
+ )
568
+
569
+ if not contexts:
570
+ raise ValueError(f"No relevant context found for topic: {topic}")
571
+
572
+ # Format and truncate contexts
573
+ context_text = "\n\n".join(doc.page_content for doc in contexts)
574
+ context_text = self._truncate_context(context_text)
575
+
576
+ # Get appropriate prompt template
577
+ template_text = self.prompt_manager.get_template(question_type)
578
+ template_infor_debug = self.prompt_manager.get_template_info()
579
+ prompt_template = PromptTemplate.from_template(template_text)
580
+ print(f"Template Structure Info: \n {template_infor_debug}")
581
+
582
+ # Generate question with length checking
583
+ prompt_input = {
584
+ "context": context_text,
585
+ "topic": topic,
586
+ "difficulty": difficulty.value,
587
+ "question_type": question_type.value
588
+ }
589
+
590
+ formatted_prompt = prompt_template.format(**prompt_input)
591
+
592
+ # Check prompt length and truncate if necessary
593
+ is_safe, token_count = self.check_prompt_length(formatted_prompt)
594
+
595
+ if not is_safe:
596
+ print(f"⚠️ Prompt too long ({token_count} tokens), truncating context...")
597
+ # Further reduce context
598
+ reduced_context = self._truncate_context(context_text, 400)
599
+ prompt_input["context"] = reduced_context
600
+ formatted_prompt = prompt_template.format(**prompt_input)
601
+ is_safe, token_count = self.check_prompt_length(formatted_prompt)
602
+
603
+ if not is_safe:
604
+ print(f"⚠️ Still too long ({token_count} tokens), using minimal context...")
605
+ # Use only first paragraph
606
+ minimal_context = context_text.split('\n')[0][:300]
607
+ prompt_input["context"] = minimal_context
608
+ formatted_prompt = prompt_template.format(**prompt_input)
609
+
610
+ print(f"📏 Prompt length: {self._estimate_token_count(formatted_prompt)} tokens")
611
+
612
+ try:
613
+ response = self.llm.invoke(formatted_prompt)
614
+ print(f"✅ Generated response length: {len(response)} characters")
615
+ except Exception as e:
616
+ if "length" in str(e).lower():
617
+ print(f"❌ Length error persists: {e}")
618
+ # Emergency fallback - use very short context
619
+ emergency_context = topic + ": " + context_text[:200]
620
+ prompt_input["context"] = emergency_context
621
+ formatted_prompt = prompt_template.format(**prompt_input)
622
+ print(f"🚨 Emergency context length: {self._estimate_token_count(formatted_prompt)} tokens")
623
+ response = self.llm.invoke(formatted_prompt)
624
+ else:
625
+ raise e
626
+
627
+ # Parse JSON response
628
+ try:
629
+ print(f"🔍 Parsing response (first 300 chars): {response}...")
630
+ response_data = self._extract_json_from_response(response)
631
+ print(f"✅ Successfully parsed JSON response")
632
+
633
+ # Create MCQ object
634
+ options = []
635
+ for label, text in response_data["options"].items():
636
+ is_correct = label == response_data["correct_answer"]
637
+ options.append(MCQOption(label, text, is_correct))
638
+
639
+ mcq = MCQQuestion(
640
+ question=response_data["question"],
641
+ context=prompt_input["context"], # Use the truncated context
642
+ options=options,
643
+ explanation=response_data.get("explanation", ""),
644
+ difficulty=difficulty.value,
645
+ topic=topic,
646
+ question_type=question_type.value,
647
+ source=f"{contexts[0].metadata.get('source', 'Unknown')}"
648
+ )
649
+
650
+ # Calculate quality score
651
+ mcq.confidence_score = self.validator.calculate_quality_score(mcq)
652
+
653
+ return mcq
654
+
655
+ except (json.JSONDecodeError, KeyError) as e:
656
+ print(f"❌ Response parsing error: {e}")
657
+ print(f"Raw response: {response[:500]}...")
658
+ raise ValueError(f"Failed to parse LLM response: {e}")
659
+
660
+ def _batch_invoke(self, prompts: List[str]) -> List[str]:
661
+ if not prompts:
662
+ return []
663
+
664
+ # Try to use transformers pipeline (batch mode)
665
+ pl = getattr(self.llm, "pipeline", None)
666
+ if pl is not None:
667
+ try:
668
+ # Call the pipeline with a list. Transformers will return a list of generation outputs.
669
+ raw_outputs = pl(prompts)
670
+
671
+ responses = []
672
+ for out in raw_outputs:
673
+ # The pipeline may return either a dict (single result) or a list of dicts (if return_full_text or num_return_sequences was set)
674
+ if isinstance(out, list) and out:
675
+ text = out[0].get("generated_text", "")
676
+ elif isinstance(out, dict):
677
+ text = out.get("generated_text", "")
678
+ else:
679
+ # fallback: coerce to string
680
+ text = str(out)
681
+ responses.append(text)
682
+
683
+ if len(responses) == len(prompts):
684
+ return responses
685
+ else:
686
+ print("⚠️ Batch pipeline returned unexpected shape — falling back")
687
+ except Exception as e:
688
+ # Batch mode failed. Fall back to sequential invocations.
689
+ print(f"⚠️ Batch invoke failed: {e}. Falling back to sequential.")
690
+
691
+ # Sequential invocation to preserve behavior
692
+ results = []
693
+ for p in prompts:
694
+ results.append(self.llm.invoke(p))
695
+ return results
696
+
697
+ def generate_batch(self,
698
+ topics: List[str],
699
+ question_per_topic: int = 5,
700
+ difficulties: Optional[List[DifficultyLevel]] = None,
701
+ question_types: Optional[List[QuestionType]] = None) -> List[MCQQuestion]:
702
+ """Generate batch of MCQs"""
703
+
704
+ if difficulties is None:
705
+ difficulties = [DifficultyLevel.EASY, DifficultyLevel.MEDIUM, DifficultyLevel.HARD]
706
+
707
+ if question_types is None:
708
+ question_types = [QuestionType.DEFINITION, QuestionType.APPLICATION]
709
+
710
+ total_questions = len(topics) * question_per_topic
711
+ prompt_metadatas = [] # stores tuples (topic, difficulty, question_type)
712
+ formatted_prompts = []
713
+
714
+ print(f"🎯 Generating {total_questions} MCQs...")
715
+
716
+ for i, topic in enumerate(topics):
717
+ print(f"📝 Processing topic {i+1}/{len(topics)}: {topic}")
718
+
719
+ for j in range(question_per_topic):
720
+ difficulty = difficulties[j % len(difficulties)]
721
+ question_type = question_types[j % len(question_types)]
722
+
723
+ query = topic
724
+
725
+ # retrieve context once per prompt
726
+ contexts = self.retriever.retrieve_diverse_contexts(query, k=self.config.get("k", 5)) if hasattr(self, "retriever") else []
727
+ context_text = "\n\n".join([d.page_content for d in contexts]) if contexts else topic
728
+
729
+ # Build prompt with PromptTemplateManager
730
+ prompt_template = self.template_manager.get_template(question_type)
731
+ prompt_input = {
732
+ "context": context_text,
733
+ "topic": topic,
734
+ "difficulty": difficulty.value if hasattr(difficulty, 'value') else str(difficulty),
735
+ "question_type": question_type.value if hasattr(question_type, 'value') else str(question_type)
736
+ }
737
+ formatted = prompt_template.format(**prompt_input)
738
+
739
+ # Length check and fallback
740
+ is_safe, token_count = self.check_prompt_length(formatted)
741
+ if not is_safe:
742
+ truncated = formatted[: self.config.get("max_prompt_chars", 2000)]
743
+ formatted = truncated
744
+
745
+ prompt_metadatas.append((topic, difficulty, question_type))
746
+ formatted_prompts.append(formatted)
747
+
748
+ total = len(formatted_prompts)
749
+ if total == 0:
750
+ return []
751
+
752
+ print(f"📦 Sending {total} prompts to the LLM in batch mode (if supported)")
753
+ start_t = time.time()
754
+ raw_responses = self._batch_invoke(formatted_prompts)
755
+ elapsed = time.time() - start_t
756
+ print(f"⏱ LLM batch time: {elapsed:.2f}s for {total} prompts")
757
+
758
+ # Parse raw responses back into MCQQuestion objects
759
+ mcqs = []
760
+ for meta, response in zip(prompt_metadatas, raw_responses):
761
+ topic, difficulty, question_type = meta
762
+ try:
763
+ response_data = self._extract_json_from_response(response)
764
+
765
+ # Reconstruct MCQQuestion
766
+ options = []
767
+ for label, text in response_data["options"].items():
768
+ is_correct = label == response_data["correct_answer"]
769
+ options.append(self.MCQOption(label=label, text=text, is_correct=is_correct))
770
+
771
+ mcq = self.MCQQuestion(
772
+ question_id=response_data.get("id", None),
773
+ topic=topic,
774
+ question_text=response_data["question"],
775
+ options=options,
776
+ explanation=response_data.get("explanation", ""),
777
+ difficulty=(difficulty if hasattr(difficulty, 'name') else difficulty),
778
+ question_type=(question_type if hasattr(question_type, 'name') else question_type),
779
+ confidence_score=response_data.get("confidence_score", 0.0)
780
+ )
781
+
782
+ if hasattr(self, 'validator'):
783
+ mcq = self.validator.calculate_quality_score(mcq)
784
+
785
+ mcqs.append(mcq)
786
+ except Exception as e:
787
+ print(f"❌ Failed parsing response for topic={topic}: {e}")
788
+
789
+ print(f"🎉 Generated {len(mcqs)}/{total} MCQs successfully (batched)")
790
+ return mcqs
791
+
792
+ def export_mcqs(self, mcqs: List[MCQQuestion], output_path: str):
793
+ """Export MCQs to JSON file"""
794
+ output_data = {
795
+ "metadata": {
796
+ "total_questions": len(mcqs),
797
+ "generation_timestamp": time.time(),
798
+ "average_quality": np.mean([mcq.confidence_score for mcq in mcqs]) if mcqs else 0
799
+ },
800
+ "questions": [mcq.to_dict() for mcq in mcqs]
801
+ }
802
+
803
+ with open(output_path, 'w', encoding='utf-8') as f:
804
+ json.dump(output_data, f, ensure_ascii=False, indent=2)
805
+
806
+ print(f"📁 Exported {len(mcqs)} MCQs to {output_path}")
807
+
808
+ def debug_system_state(self):
809
+ """Debug function to check system initialization state"""
810
+ print("🔍 System Debug Information:")
811
+ print(f" Embeddings initialized: {'✅' if self.embeddings else '❌'}")
812
+ print(f" LLM initialized: {'✅' if self.llm else '❌'}")
813
+ print(f" Vector database created: {'✅' if self.vector_db else '❌'}")
814
+ print(f" Retriever initialized: {'✅' if self.retriever else '❌'}")
815
+ print(f" Config loaded: {'✅' if self.config else '❌'}")
816
+
817
+ if self.config:
818
+ print(f" Embedding model: {self.config.get('embedding_model', 'Not set')}")
819
+ print(f" LLM model: {self.config.get('llm_model', 'Not set')}")
820
+ print(f" Max context length: {self.config.get('max_context_length', 'Not set')}")
821
+ print(f" Max input tokens: {self.config.get('max_input_tokens', 'Not set')}")
822
+
823
+ # Show template information
824
+ template_info = self.prompt_manager.get_template_info()
825
+ print(f" Template sizes:")
826
+ for template_type, size in template_info.items():
827
+ print(f" {template_type}: {size} characters")
828
+
829
+
830
+ def debug_prompt_templates():
831
+ """Debug function to test prompt template generation"""
832
+ print("🔍 Testing Prompt Templates:")
833
+ prompt_manager = PromptTemplateManager()
834
+
835
+ for question_type in QuestionType:
836
+ try:
837
+ template = prompt_manager.get_template(question_type)
838
+ print(f" {question_type.value}: ✅ Template loaded ({len(template)} chars)")
839
+ except Exception as e:
840
+ print(f" {question_type.value}: ❌ Error - {e}")
841
+
842
+
843
+ def main():
844
+ """Main function demonstrating the enhanced RAG MCQ system"""
845
+ print("🚀 Starting Enhanced RAG MCQ Generation System")
846
+
847
+
848
+ # Test prompt templates first
849
+ print("\n🧪 Testing prompt templates...")
850
+ debug_prompt_templates()
851
+
852
+ # Initialize system
853
+ generator = EnhancedRAGMCQGenerator()
854
+
855
+ # Check initial state
856
+ print("\n🔍 Initial system state:")
857
+ generator.debug_system_state()
858
+
859
+ try:
860
+ generator.initialize_system()
861
+
862
+ # Check state after initialization
863
+ print("\n🔍 Post-initialization state:")
864
+ generator.debug_system_state()
865
+
866
+ # Load documents
867
+ folder_path = "pdfs" # Updated path to your PDF folder
868
+ try:
869
+ docs, filenames = generator.load_documents(folder_path)
870
+ num_chunks = generator.build_vector_database(docs)
871
+ start = time.time() #? Calc generation time
872
+
873
+ print(f"⏱️ Loading Time: {time.time() - start:.2f}s") #? Loading document time
874
+ print(f"📚 System ready with {len(filenames)} files and {num_chunks} chunks")
875
+
876
+ # Generate sample MCQs
877
+ topics = ["Object Oriented Programming", "Malware Reverse Engineering"]
878
+
879
+ # Single question generation
880
+ # print("\n🎯 Generating single MCQ...")
881
+ # mcq = generator.generate_mcq(
882
+ # topic=topics[0],
883
+ # difficulty=DifficultyLevel.MEDIUM,
884
+ # question_type=QuestionType.DEFINITION
885
+ # )
886
+
887
+ # print(f"Question: {mcq.question}")
888
+ # print(f"Quality Score: {mcq.confidence_score:.1f}")
889
+
890
+ # Batch generation
891
+ n_question = 2
892
+ print("\n🎯 Generating batch MCQs...")
893
+
894
+ #? MAIN OUTPUT: Multiple Choice Question
895
+ mcqs = generator.generate_batch(
896
+ topics=topics,
897
+ question_per_topic=n_question
898
+ )
899
+
900
+ # Export results
901
+ output_path = "generated_mcqs.json"
902
+ generator.export_mcqs(mcqs, output_path)
903
+
904
+ # Quality summary
905
+ print(f"Average mcq generation time taken: {((time.time() - start)/n_question)/60:.2f} min")
906
+ quality_scores = [mcq.confidence_score for mcq in mcqs]
907
+ print(f"\n📊 Quality Summary:")
908
+ print(f"Average Quality: {np.mean(quality_scores):.1f}")
909
+ print(f"Min Quality: {np.min(quality_scores):.1f}")
910
+ print(f"Max Quality: {np.max(quality_scores):.1f}")
911
+
912
+ except FileNotFoundError as e:
913
+ print(f"❌ Document folder error: {e}")
914
+ print("💡 Please ensure your PDF files are in the 'pdfs' folder")
915
+ except Exception as e:
916
+ print(f"❌ Document processing error: {e}")
917
+ print("💡 Check your PDF files and folder structure")
918
+
919
+ except Exception as e:
920
+ print(f"❌ System initialization error: {e}")
921
+ print("💡 Check your dependencies and API keys")
922
+ generator.debug_system_state()
923
+
924
+
925
+ if __name__ == "__main__":
926
+ # Check system components
927
+ # generator = EnhancedRAGMCQGenerator()
928
+ # generator.debug_system_state()
929
+
930
+ # # Test templates separately
931
+ # debug_prompt_templates()
932
+
933
+ main()
fastapi_app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from enhanced_rag_mcq import EnhancedRAGMCQGenerator, debug_prompt_templates, DifficultyLevel, QuestionType
3
+ import numpy as np
4
+ import os
5
+ import shutil
6
+ from pydantic import BaseModel, Field
7
+ from typing import List, Optional, Dict
8
+ from fastapi import FastAPI, Form, HTTPException, UploadFile, File
9
+ from contextlib import asynccontextmanager
10
+
11
+
12
+ generator: Optional[EnhancedRAGMCQGenerator] = None
13
+ tmp_folder = "./tmp" #? make sure folder upload here
14
+ if not os.path.exists(tmp_folder):
15
+ os.makedirs(tmp_folder)
16
+
17
+ class GenerateRequest(BaseModel):
18
+ topics: List[str] = Field(..., description="List of topics for MCQ generation")
19
+ question_per_topic: int = Field(1, ge=1, le=10, description="Number of questions per topic")
20
+ difficulty: Optional[DifficultyLevel] = Field(
21
+ DifficultyLevel.MEDIUM,
22
+ description="Difficulty level for generated questions"
23
+ )
24
+ qtype: Optional[QuestionType] = Field(
25
+ QuestionType.DEFINITION,
26
+ description="Type of question to generate"
27
+ )
28
+
29
+ class MCQResponse(BaseModel):
30
+ question: str
31
+ options: Dict
32
+ correct_answer: str
33
+ confidence_score: float
34
+
35
+ class GenerateResponse(BaseModel):
36
+ topics: List[str]
37
+ generated: List[MCQResponse]
38
+ avg_confidence: float
39
+ generation_time: float
40
+
41
+ @asynccontextmanager
42
+ async def lifespan(app: FastAPI):
43
+ global generator
44
+ generator = EnhancedRAGMCQGenerator()
45
+ try:
46
+ generator.initialize_system()
47
+ print("RAG system initialized.")
48
+ except Exception as e:
49
+ print(f"❌ Failed to initialize RAG system: {e}")
50
+ yield
51
+ # Optional: Cleanup code after shutdown
52
+
53
+ app = FastAPI(
54
+ title="Enhanced RAG MCQ Generation API",
55
+ description="An API wrapping the RAG-based MCQ generator using FastAPI",
56
+ version="1.0.0",
57
+ lifespan=lifespan
58
+ )
59
+
60
+ #? cmd: fastapi run app.py
61
+ @app.post("/generate/")
62
+ async def mcq_gen(
63
+ file: UploadFile = File(...),
64
+ topics: str = Form(...),
65
+ n_questions: str = Form(...),
66
+ difficulty: str = Form(...),
67
+ qtype: str = Form(...)
68
+ ):
69
+ if not generator:
70
+ raise HTTPException(status_code=500, detail="Generator not initialized")
71
+
72
+ topic_list = [t.strip() for t in topics.split(',') if t.strip()]
73
+ if not topic_list:
74
+ raise HTTPException(status_code=400, detail="At least one topic must be provided")
75
+
76
+ # Validate and convert enum values
77
+ try:
78
+ difficulty_enum = DifficultyLevel(difficulty.lower())
79
+ except ValueError:
80
+ valid_difficulties = [d.value for d in DifficultyLevel]
81
+ raise HTTPException(status_code=400, detail=f"Invalid difficulty. Must be one of: {valid_difficulties}")
82
+
83
+ try:
84
+ qtype_enum = QuestionType(qtype.lower())
85
+ except ValueError:
86
+ valid_types = [q.value for q in QuestionType]
87
+ raise HTTPException(status_code=400, detail=f"Invalid question type. Must be one of: {valid_types}")
88
+
89
+ # Save uploaded PDF to temporary folder
90
+ filename = file.filename if file.filename else "uploaded_file"
91
+ file_path = os.path.join(tmp_folder, filename)
92
+ with open(file_path, "wb") as buffer:
93
+ shutil.copyfileobj(file.file, buffer)
94
+ file.file.close()
95
+
96
+ try:
97
+ # Load and index the uploaded document
98
+ docs, _ = generator.load_documents(tmp_folder)
99
+ generator.build_vector_database(docs)
100
+ except Exception as e:
101
+ raise HTTPException(status_code=500, detail=f"Document processing error: {e}")
102
+
103
+ start_time = time.time()
104
+ try:
105
+ mcqs = generator.generate_batch(
106
+ topics=topic_list,
107
+ question_per_topic=int(n_questions),
108
+ difficulties=[difficulty_enum],
109
+ question_types=[qtype_enum]
110
+ )
111
+ except Exception as e:
112
+ raise HTTPException(status_code=500, detail=str(e))
113
+ end_time = time.time()
114
+
115
+ res_dict = [m.to_dict() for m in mcqs]
116
+
117
+ responses = [
118
+ MCQResponse(
119
+ question=m["question"],
120
+ options=m["options"],
121
+ correct_answer=m["correct_answer"],
122
+ confidence_score=m["confidence_score"]
123
+ ) for m in res_dict
124
+ ]
125
+ avg_conf = sum(m["confidence_score"] for m in res_dict) / len(mcqs)
126
+ # Clean up temporary files
127
+ shutil.rmtree(tmp_folder)
128
+ os.makedirs(tmp_folder)
129
+
130
+ return GenerateResponse(
131
+ topics=topic_list,
132
+ generated=responses,
133
+ avg_confidence=avg_conf,
134
+ generation_time=end_time - start_time
135
+ )
requirements.txt ADDED
Binary file (3.18 kB). View file
 
tmp/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ This is the temporary directory for storing PDF files