namberino
commited on
Commit
·
5f9ac06
1
Parent(s):
0ba52f3
Initial commit
Browse files- app.py +75 -0
- enhanced_rag_mcq.py +933 -0
- fastapi_app.py +135 -0
- requirements.txt +0 -0
- tmp/README.md +1 -0
app.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
from fastapi.testclient import TestClient
|
4 |
+
from fastapi_app import app as fastapi_app # import your renamed FastAPI module
|
5 |
+
import io
|
6 |
+
|
7 |
+
# In-process client for the FastAPI app
|
8 |
+
client = TestClient(fastapi_app)
|
9 |
+
|
10 |
+
def call_generate(file_obj, topics, n_questions, difficulty, question_type):
|
11 |
+
if file_obj is None:
|
12 |
+
return {"error": "No file uploaded."}
|
13 |
+
|
14 |
+
# Read the uploaded file bytes and create multipart payload
|
15 |
+
file_bytes = file_obj.read()
|
16 |
+
files = {"file": ("uploaded_file", file_bytes, "application/octet-stream")}
|
17 |
+
|
18 |
+
data = {
|
19 |
+
"topics": topics if topics is not None else "",
|
20 |
+
"n_questions": str(n_questions),
|
21 |
+
"difficulty": difficulty if difficulty is not None else "",
|
22 |
+
"question_type": question_type if question_type is not None else ""
|
23 |
+
}
|
24 |
+
|
25 |
+
try:
|
26 |
+
resp = client.post("/generate/", files=files, data=data, timeout=120) # increase timeout if needed
|
27 |
+
except Exception as e:
|
28 |
+
return {"error": f"Request failed: {e}"}
|
29 |
+
|
30 |
+
if resp.status_code != 200:
|
31 |
+
# return helpful debug info
|
32 |
+
return {
|
33 |
+
"status_code": resp.status_code,
|
34 |
+
"text": resp.text,
|
35 |
+
"json": None
|
36 |
+
}
|
37 |
+
|
38 |
+
# print(resp.text)
|
39 |
+
|
40 |
+
# Parse JSON response
|
41 |
+
try:
|
42 |
+
out = resp.json()
|
43 |
+
except Exception:
|
44 |
+
# maybe the endpoint returns text: return it directly
|
45 |
+
return {"text": resp.text}
|
46 |
+
|
47 |
+
# pretty-format the JSON for display
|
48 |
+
return out
|
49 |
+
|
50 |
+
# Gradio UI
|
51 |
+
with gr.Blocks(title="RAG MCQ generator") as gradio_app:
|
52 |
+
gr.Markdown("## Upload a file and generate MCQs")
|
53 |
+
|
54 |
+
with gr.Row():
|
55 |
+
file_input = gr.File(label="Upload file (PDF, docx, etc)")
|
56 |
+
topics = gr.Textbox(label="Topics (comma separated)", placeholder="e.g. calculus, derivatives")
|
57 |
+
with gr.Row():
|
58 |
+
n_questions = gr.Slider(minimum=1, maximum=50, step=1, value=5, label="Number of questions")
|
59 |
+
difficulty = gr.Dropdown(choices=["easy", "medium", "hard"], value="medium", label="Difficulty")
|
60 |
+
question_type = gr.Dropdown(choices=["mcq", "short", "long"], value="mcq", label="Question type")
|
61 |
+
|
62 |
+
generate_btn = gr.Button("Generate")
|
63 |
+
output = gr.JSON(label="Response")
|
64 |
+
|
65 |
+
generate_btn.click(
|
66 |
+
fn=call_generate,
|
67 |
+
inputs=[file_input, topics, n_questions, difficulty, question_type],
|
68 |
+
outputs=[output],
|
69 |
+
)
|
70 |
+
|
71 |
+
app = gradio_app
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
# demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
75 |
+
gradio_app.launch()
|
enhanced_rag_mcq.py
ADDED
@@ -0,0 +1,933 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Enhanced RAG System for Multiple Choice Question Generation
|
3 |
+
Author: MathematicGuy
|
4 |
+
Date: July 2025
|
5 |
+
|
6 |
+
This module implements an advanced RAG system specifically designed for generating
|
7 |
+
high-quality Multiple Choice Questions from educational documents.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import os
|
11 |
+
import json
|
12 |
+
import time
|
13 |
+
import torch
|
14 |
+
import re
|
15 |
+
from typing import List, Dict, Any, Optional, Tuple
|
16 |
+
from dataclasses import dataclass, asdict
|
17 |
+
from pathlib import Path
|
18 |
+
from enum import Enum
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
# LangChain imports
|
22 |
+
from langchain_community.document_loaders import PyPDFLoader
|
23 |
+
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
|
24 |
+
from langchain_experimental.text_splitter import SemanticChunker
|
25 |
+
from langchain_huggingface.llms import HuggingFacePipeline
|
26 |
+
from langchain_core.output_parsers import JsonOutputParser
|
27 |
+
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
28 |
+
from langchain_core.prompts import PromptTemplate
|
29 |
+
from langchain_community.vectorstores import FAISS
|
30 |
+
|
31 |
+
from langchain_core.documents import Document
|
32 |
+
|
33 |
+
# Transformers imports
|
34 |
+
from transformers import (
|
35 |
+
AutoModelForCausalLM,
|
36 |
+
AutoTokenizer,
|
37 |
+
)
|
38 |
+
from transformers.pipelines import pipeline
|
39 |
+
from transformers.utils.quantization_config import BitsAndBytesConfig
|
40 |
+
|
41 |
+
|
42 |
+
class QuestionType(Enum):
|
43 |
+
"""Enumeration of different question types"""
|
44 |
+
DEFINITION = "definition"
|
45 |
+
COMPARISON = "comparison"
|
46 |
+
APPLICATION = "application"
|
47 |
+
ANALYSIS = "analysis"
|
48 |
+
EVALUATION = "evaluation"
|
49 |
+
|
50 |
+
|
51 |
+
class DifficultyLevel(Enum):
|
52 |
+
"""Enumeration of difficulty levels"""
|
53 |
+
EASY = "easy"
|
54 |
+
MEDIUM = "medium"
|
55 |
+
HARD = "hard"
|
56 |
+
EXPERT = "expert"
|
57 |
+
|
58 |
+
|
59 |
+
@dataclass
|
60 |
+
class MCQOption:
|
61 |
+
"""Data class for MCQ options"""
|
62 |
+
label: str
|
63 |
+
text: str
|
64 |
+
is_correct: bool
|
65 |
+
|
66 |
+
|
67 |
+
@dataclass
|
68 |
+
class MCQQuestion:
|
69 |
+
"""Data class for Multiple Choice Question"""
|
70 |
+
question: str
|
71 |
+
context: str
|
72 |
+
options: List[MCQOption]
|
73 |
+
explanation: str
|
74 |
+
difficulty: str
|
75 |
+
topic: str
|
76 |
+
question_type: str
|
77 |
+
source: str
|
78 |
+
confidence_score: float = 0.0
|
79 |
+
|
80 |
+
def to_dict(self) -> Dict[str, Any]:
|
81 |
+
"""Convert to dictionary format"""
|
82 |
+
return {
|
83 |
+
"question": self.question,
|
84 |
+
"context": self.context,
|
85 |
+
"options": {opt.label: opt.text for opt in self.options},
|
86 |
+
"correct_answer": next(opt.label for opt in self.options if opt.is_correct),
|
87 |
+
"explanation": self.explanation,
|
88 |
+
"difficulty": self.difficulty,
|
89 |
+
"topic": self.topic,
|
90 |
+
"question_type": self.question_type,
|
91 |
+
"source": self.source,
|
92 |
+
"confidence_score": self.confidence_score
|
93 |
+
}
|
94 |
+
|
95 |
+
|
96 |
+
class PromptTemplateManager:
|
97 |
+
"""Manages different prompt templates for various question types"""
|
98 |
+
|
99 |
+
def __init__(self):
|
100 |
+
self.base_template = self._create_base_template()
|
101 |
+
self.templates = self._initialize_templates()
|
102 |
+
|
103 |
+
def _create_base_template(self) -> str:
|
104 |
+
"""Create the base template used by all question types"""
|
105 |
+
return """
|
106 |
+
Hãy tạo 1 câu hỏi trắc nghiệm dựa trên nội dung sau đây.
|
107 |
+
|
108 |
+
Nội dung: {context}
|
109 |
+
Chủ đề: {topic}
|
110 |
+
Mức độ: {difficulty}
|
111 |
+
Loại câu hỏi: {question_type}
|
112 |
+
|
113 |
+
QUAN TRỌNG: Chỉ trả về JSON hợp lệ, không có text bổ sung, luôn trả lời bằng tiếng việt:
|
114 |
+
|
115 |
+
{{
|
116 |
+
"question": "Câu hỏi rõ ràng về {topic}",
|
117 |
+
"options": {{
|
118 |
+
"A": "Đáp án A",
|
119 |
+
"B": "Đáp án B",
|
120 |
+
"C": "Đáp án C",
|
121 |
+
"D": "Đáp án D"
|
122 |
+
}},
|
123 |
+
"correct_answer": "A",
|
124 |
+
"explanation": "Giải thích tại sao đáp án A đúng",
|
125 |
+
"topic": "{topic}",
|
126 |
+
"difficulty": "{difficulty}",
|
127 |
+
"question_type": "{question_type}"
|
128 |
+
}}
|
129 |
+
|
130 |
+
Trả lời:
|
131 |
+
"""
|
132 |
+
|
133 |
+
def _create_question_specific_instruction(self, question_type: str) -> str:
|
134 |
+
"""Create specific instructions for each question type"""
|
135 |
+
instructions = {
|
136 |
+
"definition": "Tạo câu hỏi định nghĩa thuật ngữ. Tập trung vào định nghĩa chính xác.",
|
137 |
+
"application": "Tạo câu hỏi ứng dụng thực tế. Bao gồm tình huống cụ thể.",
|
138 |
+
"analysis": "Tạo câu hỏi phân tích code/sơ đồ. Kiểm tra tư duy phản biện.",
|
139 |
+
"comparison": "Tạo câu hỏi so sánh khái niệm. Tập trung vào điểm khác biệt.",
|
140 |
+
"evaluation": "Tạo câu hỏi đánh giá phương pháp. Yêu cầu quyết định dựa trên tiêu chí."
|
141 |
+
}
|
142 |
+
return instructions.get(question_type, "Tạo câu hỏi chất lượng cao.")
|
143 |
+
|
144 |
+
def _initialize_templates(self) -> Dict[str, str]:
|
145 |
+
"""Initialize all templates using the base template"""
|
146 |
+
templates = {"base": self.base_template}
|
147 |
+
|
148 |
+
# Create shorter, more direct templates for each question type
|
149 |
+
instructions = {
|
150 |
+
"definition": "Tạo câu hỏi định nghĩa thuật ngữ.",
|
151 |
+
"application": "Tạo câu hỏi ứng dụng thực tế.",
|
152 |
+
"analysis": "Tạo câu hỏi phân tích code/sơ đồ.",
|
153 |
+
"comparison": "Tạo câu hỏi so sánh khái niệm.",
|
154 |
+
"evaluation": "Tạo câu hỏi đánh giá phương pháp."
|
155 |
+
}
|
156 |
+
|
157 |
+
# Add specific templates for each question type
|
158 |
+
for question_type in QuestionType:
|
159 |
+
instruction = instructions.get(question_type.value, "Tạo câu hỏi chất lượng cao.")
|
160 |
+
templates[question_type.value] = f"{instruction}\n\n{self.base_template}"
|
161 |
+
|
162 |
+
return templates
|
163 |
+
|
164 |
+
def get_template(self, question_type: QuestionType = QuestionType.DEFINITION) -> str:
|
165 |
+
"""Get prompt template for specific question type"""
|
166 |
+
return self.templates.get(question_type.value, self.templates["base"])
|
167 |
+
|
168 |
+
def update_base_template(self, new_base_template: str):
|
169 |
+
"""Update the base template and regenerate all templates"""
|
170 |
+
self.base_template = new_base_template
|
171 |
+
self.templates = self._initialize_templates()
|
172 |
+
print("✅ Base template updated and all templates regenerated")
|
173 |
+
|
174 |
+
def get_template_info(self) -> Dict[str, int]:
|
175 |
+
"""Get information about all templates (for debugging)"""
|
176 |
+
return {
|
177 |
+
template_type: len(template)
|
178 |
+
for template_type, template in self.templates.items()
|
179 |
+
}
|
180 |
+
|
181 |
+
class QualityValidator:
|
182 |
+
"""Validates the quality of generated MCQ questions"""
|
183 |
+
|
184 |
+
def __init__(self):
|
185 |
+
self.min_question_length = 10
|
186 |
+
self.max_question_length = 200
|
187 |
+
self.min_explanation_length = 20
|
188 |
+
|
189 |
+
def validate_mcq(self, mcq: MCQQuestion) -> Tuple[bool, List[str]]:
|
190 |
+
"""Validate MCQ and return validation result with issues"""
|
191 |
+
issues = []
|
192 |
+
|
193 |
+
# Check question length
|
194 |
+
if len(mcq.question) < self.min_question_length:
|
195 |
+
issues.append("Question too short")
|
196 |
+
elif len(mcq.question) > self.max_question_length:
|
197 |
+
issues.append("Question too long")
|
198 |
+
|
199 |
+
# Check options count
|
200 |
+
if len(mcq.options) != 4:
|
201 |
+
issues.append("Must have exactly 4 options")
|
202 |
+
|
203 |
+
# Check for single correct answer
|
204 |
+
correct_count = sum(1 for opt in mcq.options if opt.is_correct)
|
205 |
+
if correct_count != 1:
|
206 |
+
issues.append("Must have exactly one correct answer")
|
207 |
+
|
208 |
+
# Check explanation
|
209 |
+
if len(mcq.explanation) < self.min_explanation_length:
|
210 |
+
issues.append("Explanation too short")
|
211 |
+
|
212 |
+
# Check for distinct options
|
213 |
+
option_texts = [opt.text for opt in mcq.options]
|
214 |
+
if len(set(option_texts)) != len(option_texts):
|
215 |
+
issues.append("Options must be distinct")
|
216 |
+
|
217 |
+
return len(issues) == 0, issues
|
218 |
+
|
219 |
+
def calculate_quality_score(self, mcq: MCQQuestion) -> float:
|
220 |
+
"""Calculate quality score from 0 to 100"""
|
221 |
+
is_valid, issues = self.validate_mcq(mcq)
|
222 |
+
print("MCQ Output Quality Score:", issues)
|
223 |
+
|
224 |
+
if not is_valid:
|
225 |
+
return 0.0
|
226 |
+
|
227 |
+
# Start with base score
|
228 |
+
score = 70.0
|
229 |
+
|
230 |
+
# Bonus for good explanation length
|
231 |
+
if len(mcq.explanation) > 50:
|
232 |
+
score += 10
|
233 |
+
|
234 |
+
# Bonus for appropriate question length
|
235 |
+
if 20 <= len(mcq.question) <= 100:
|
236 |
+
score += 10
|
237 |
+
|
238 |
+
# Bonus for diverse option lengths (indicates good distractors)
|
239 |
+
option_lengths = [len(opt.text) for opt in mcq.options]
|
240 |
+
if max(option_lengths) - min(option_lengths) < 50: # Similar lengths
|
241 |
+
score += 10
|
242 |
+
|
243 |
+
return min(score, 100.0)
|
244 |
+
|
245 |
+
|
246 |
+
class DifficultyAnalyzer:
|
247 |
+
"""Analyzes and adjusts question difficulty"""
|
248 |
+
|
249 |
+
def __init__(self):
|
250 |
+
self.difficulty_keywords = {
|
251 |
+
DifficultyLevel.EASY: ["là gì", "định nghĩa", "ví dụ", "đơn giản"],
|
252 |
+
DifficultyLevel.MEDIUM: ["so sánh", "khác biệt", "ứng dụng", "khi nào"],
|
253 |
+
DifficultyLevel.HARD: ["phân tích", "đánh giá", "tối ưu", "thiết kế"],
|
254 |
+
DifficultyLevel.EXPERT: ["tổng hợp", "sáng tạo", "nghiên cứu", "phát triển"]
|
255 |
+
}
|
256 |
+
|
257 |
+
def assess_difficulty(self, question: str, context: str) -> DifficultyLevel:
|
258 |
+
"""Assess question difficulty based on content analysis"""
|
259 |
+
question_lower = question.lower()
|
260 |
+
|
261 |
+
# Count difficulty indicators
|
262 |
+
difficulty_scores = {}
|
263 |
+
for level, keywords in self.difficulty_keywords.items():
|
264 |
+
score = sum(1 for keyword in keywords if keyword in question_lower)
|
265 |
+
difficulty_scores[level] = score
|
266 |
+
|
267 |
+
# Return highest scoring difficulty
|
268 |
+
if not any(difficulty_scores.values()):
|
269 |
+
return DifficultyLevel.MEDIUM
|
270 |
+
|
271 |
+
return max(difficulty_scores.keys(), key=lambda k: difficulty_scores[k])
|
272 |
+
|
273 |
+
|
274 |
+
class ContextAwareRetriever:
|
275 |
+
"""Enhanced retriever with context awareness and diversity"""
|
276 |
+
|
277 |
+
def __init__(self, vector_db: FAISS, diversity_threshold: float = 0.7):
|
278 |
+
self.vector_db = vector_db
|
279 |
+
self.diversity_threshold = diversity_threshold
|
280 |
+
|
281 |
+
def retrieve_diverse_contexts(self, query: str, k: int = 5) -> List[Document]:
|
282 |
+
"""Retrieve documents with semantic diversity"""
|
283 |
+
# Get more candidates than needed
|
284 |
+
candidates = self.vector_db.similarity_search(query, k=k*2)
|
285 |
+
|
286 |
+
if not candidates:
|
287 |
+
return []
|
288 |
+
|
289 |
+
# Select diverse documents
|
290 |
+
selected = [candidates[0]] # Always include the most relevant
|
291 |
+
|
292 |
+
for candidate in candidates[1:]:
|
293 |
+
if len(selected) >= k:
|
294 |
+
break
|
295 |
+
|
296 |
+
# Check diversity with already selected documents
|
297 |
+
is_diverse = True
|
298 |
+
for selected_doc in selected:
|
299 |
+
similarity = self._calculate_similarity(candidate.page_content,
|
300 |
+
selected_doc.page_content)
|
301 |
+
if similarity > self.diversity_threshold:
|
302 |
+
is_diverse = False
|
303 |
+
break
|
304 |
+
|
305 |
+
if is_diverse:
|
306 |
+
selected.append(candidate)
|
307 |
+
|
308 |
+
return selected[:k]
|
309 |
+
|
310 |
+
def _calculate_similarity(self, text1: str, text2: str) -> float:
|
311 |
+
"""Calculate text similarity (simplified implementation)"""
|
312 |
+
words1 = set(text1.lower().split())
|
313 |
+
words2 = set(text2.lower().split())
|
314 |
+
|
315 |
+
if not words1 or not words2:
|
316 |
+
return 0.0
|
317 |
+
|
318 |
+
intersection = words1.intersection(words2)
|
319 |
+
union = words1.union(words2)
|
320 |
+
|
321 |
+
return len(intersection) / len(union)
|
322 |
+
|
323 |
+
|
324 |
+
class EnhancedRAGMCQGenerator:
|
325 |
+
"""Enhanced RAG system for MCQ generation"""
|
326 |
+
|
327 |
+
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
328 |
+
self.config = config or self._get_default_config()
|
329 |
+
self.embeddings = None
|
330 |
+
self.llm = None
|
331 |
+
self.vector_db = None
|
332 |
+
self.retriever = None
|
333 |
+
self.prompt_manager = PromptTemplateManager()
|
334 |
+
self.validator = QualityValidator()
|
335 |
+
self.difficulty_analyzer = DifficultyAnalyzer()
|
336 |
+
|
337 |
+
def _get_default_config(self) -> Dict[str, Any]:
|
338 |
+
"""Get default configuration"""
|
339 |
+
return {
|
340 |
+
"embedding_model": "bkai-foundation-models/vietnamese-bi-encoder",
|
341 |
+
"llm_model": "Qwen/Qwen2.5-3B-Instruct", # 7B, 1.5B
|
342 |
+
"chunk_size": 500,
|
343 |
+
"chunk_overlap": 50,
|
344 |
+
"retrieval_k": 3,
|
345 |
+
"generation_temperature": 0.7,
|
346 |
+
"max_tokens": 512,
|
347 |
+
"diversity_threshold": 0.7,
|
348 |
+
"max_context_length": 600, # Maximum context characters
|
349 |
+
"max_input_tokens": 1600 # Maximum total input tokens
|
350 |
+
}
|
351 |
+
|
352 |
+
def _truncate_context(self, context: str, max_length: Optional[int] = None) -> str:
|
353 |
+
"""Intelligently truncate context to fit within token limits"""
|
354 |
+
actual_max_length = max_length if max_length is not None else self.config["max_context_length"]
|
355 |
+
|
356 |
+
if len(context) <= actual_max_length:
|
357 |
+
return context
|
358 |
+
|
359 |
+
# Try to truncate at sentence boundary
|
360 |
+
sentences = context.split('. ')
|
361 |
+
truncated = ""
|
362 |
+
|
363 |
+
for sentence in sentences:
|
364 |
+
if len(truncated + sentence + '. ') <= actual_max_length:
|
365 |
+
truncated += sentence + '. '
|
366 |
+
else:
|
367 |
+
break
|
368 |
+
|
369 |
+
# If no complete sentences fit, truncate at word boundary
|
370 |
+
if not truncated:
|
371 |
+
words = context.split()
|
372 |
+
truncated = ""
|
373 |
+
for word in words:
|
374 |
+
if len(truncated + word + ' ') <= actual_max_length:
|
375 |
+
truncated += word + ' '
|
376 |
+
else:
|
377 |
+
break
|
378 |
+
|
379 |
+
return truncated.strip()
|
380 |
+
|
381 |
+
def _estimate_token_count(self, text: str) -> int:
|
382 |
+
"""Estimate token count for Vietnamese text (approximation)"""
|
383 |
+
# Vietnamese typically has ~0.75 tokens per character
|
384 |
+
return int(len(text) * 0.75)
|
385 |
+
|
386 |
+
#? Parse Json String
|
387 |
+
def _extract_json_from_response(self, response: str) -> dict:
|
388 |
+
"""Extract JSON from LLM response with multiple fallback strategies"""
|
389 |
+
|
390 |
+
# Strategy 1: Clean response of prompt repetition
|
391 |
+
clean_response = response
|
392 |
+
if "Tạo câu hỏi" in response:
|
393 |
+
# Find where the actual response starts (after the prompt)
|
394 |
+
response_parts = response.split("JSON:")
|
395 |
+
if len(response_parts) > 1:
|
396 |
+
clean_response = response_parts[-1].strip()
|
397 |
+
else:
|
398 |
+
# Try splitting on common phrases
|
399 |
+
for split_phrase in ["QUAN TRỌNG:", "Trả về JSON:", "{"]:
|
400 |
+
if split_phrase in response:
|
401 |
+
clean_response = response.split(split_phrase)[-1].strip()
|
402 |
+
if split_phrase == "{":
|
403 |
+
clean_response = "{" + clean_response
|
404 |
+
break
|
405 |
+
|
406 |
+
# Strategy 2: Find JSON boundaries
|
407 |
+
json_start = clean_response.find("{")
|
408 |
+
json_end = clean_response.rfind("}") + 1
|
409 |
+
|
410 |
+
if json_start != -1 and json_end > json_start:
|
411 |
+
json_text = clean_response[json_start:json_end]
|
412 |
+
try:
|
413 |
+
return json.loads(json_text)
|
414 |
+
except json.JSONDecodeError:
|
415 |
+
pass
|
416 |
+
|
417 |
+
# Strategy 3: Use regex to find JSON-like structures
|
418 |
+
json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
|
419 |
+
json_matches = re.findall(json_pattern, clean_response, re.DOTALL)
|
420 |
+
|
421 |
+
for json_match in reversed(json_matches): # Try from last to first
|
422 |
+
try:
|
423 |
+
return json.loads(json_match)
|
424 |
+
except json.JSONDecodeError:
|
425 |
+
continue
|
426 |
+
|
427 |
+
# Strategy 4: Try to fix common JSON issues
|
428 |
+
for json_match in reversed(json_matches):
|
429 |
+
try:
|
430 |
+
# Fix common issues like trailing commas
|
431 |
+
fixed_json = re.sub(r',(\s*[}\]])', r'\1', json_match)
|
432 |
+
return json.loads(fixed_json)
|
433 |
+
except json.JSONDecodeError:
|
434 |
+
continue
|
435 |
+
|
436 |
+
raise ValueError("No valid JSON found in response")
|
437 |
+
|
438 |
+
def check_prompt_length(self, prompt: str) -> Tuple[bool, int]:
|
439 |
+
"""Check if prompt length is within safe limits"""
|
440 |
+
estimated_tokens = self._estimate_token_count(prompt)
|
441 |
+
max_safe_tokens = self.config["max_input_tokens"]
|
442 |
+
|
443 |
+
is_safe = estimated_tokens <= max_safe_tokens
|
444 |
+
return is_safe, estimated_tokens
|
445 |
+
|
446 |
+
def initialize_system(self):
|
447 |
+
"""Initialize all system components"""
|
448 |
+
print("🔧 Initializing Enhanced RAG MCQ Generator...")
|
449 |
+
|
450 |
+
# Load embeddings
|
451 |
+
self.embeddings = HuggingFaceEmbeddings(
|
452 |
+
model_name=self.config["embedding_model"]
|
453 |
+
)
|
454 |
+
print("✅ Embeddings loaded")
|
455 |
+
|
456 |
+
# Load LLM
|
457 |
+
self.llm = self._load_llm()
|
458 |
+
print("✅ LLM loaded")
|
459 |
+
|
460 |
+
def _load_llm(self) -> HuggingFacePipeline:
|
461 |
+
"""Load and configure the LLM"""
|
462 |
+
token_path = Path("./tokens/hugging_face_token.txt")
|
463 |
+
if token_path.exists():
|
464 |
+
with token_path.open("r") as f:
|
465 |
+
hf_token = f.read().strip()
|
466 |
+
else:
|
467 |
+
hf_token = None
|
468 |
+
|
469 |
+
bnb_config = BitsAndBytesConfig(
|
470 |
+
load_in_4bit=True,
|
471 |
+
bnb_4bit_use_double_quant=True,
|
472 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
473 |
+
bnb_4bit_quant_type="nf4"
|
474 |
+
)
|
475 |
+
|
476 |
+
model = AutoModelForCausalLM.from_pretrained(
|
477 |
+
self.config["llm_model"],
|
478 |
+
quantization_config=bnb_config,
|
479 |
+
low_cpu_mem_usage=True,
|
480 |
+
device_map="cuda", # Use CUDA if available
|
481 |
+
token=hf_token
|
482 |
+
)
|
483 |
+
|
484 |
+
tokenizer = AutoTokenizer.from_pretrained(self.config["llm_model"])
|
485 |
+
# tokenizer.pad_token = tokenizer.eos_token
|
486 |
+
|
487 |
+
model_pipeline = pipeline(
|
488 |
+
"text-generation",
|
489 |
+
model=model,
|
490 |
+
tokenizer=tokenizer,
|
491 |
+
max_new_tokens=self.config["max_tokens"],
|
492 |
+
temperature=self.config["generation_temperature"],
|
493 |
+
pad_token_id=tokenizer.eos_token_id,
|
494 |
+
device_map="auto"
|
495 |
+
)
|
496 |
+
|
497 |
+
return HuggingFacePipeline(pipeline=model_pipeline)
|
498 |
+
|
499 |
+
def load_documents(self, folder_path: str) -> Tuple[List[Document], List[str]]:
|
500 |
+
"""Load and process documents"""
|
501 |
+
folder = Path(folder_path)
|
502 |
+
if not folder.exists():
|
503 |
+
raise FileNotFoundError(f"Folder not found: {folder}")
|
504 |
+
|
505 |
+
pdf_files = list(folder.glob("*.pdf"))
|
506 |
+
if not pdf_files:
|
507 |
+
raise ValueError(f"No PDF files found in: {folder}")
|
508 |
+
|
509 |
+
all_docs, filenames = [], []
|
510 |
+
for pdf_file in pdf_files:
|
511 |
+
try:
|
512 |
+
loader = PyPDFLoader(str(pdf_file))
|
513 |
+
docs = loader.load()
|
514 |
+
all_docs.extend(docs)
|
515 |
+
filenames.append(pdf_file.name)
|
516 |
+
print(f"✅ Loaded {pdf_file.name} ({len(docs)} pages)")
|
517 |
+
except Exception as e:
|
518 |
+
print(f"❌ Failed loading {pdf_file.name}: {e}")
|
519 |
+
|
520 |
+
return all_docs, filenames
|
521 |
+
|
522 |
+
def build_vector_database(self, docs: List[Document]) -> int:
|
523 |
+
"""Build vector database with semantic chunking"""
|
524 |
+
if not self.embeddings:
|
525 |
+
raise RuntimeError("Embeddings not initialized. Call initialize_system() first.")
|
526 |
+
|
527 |
+
chunker = SemanticChunker(
|
528 |
+
embeddings=self.embeddings,
|
529 |
+
buffer_size=1,
|
530 |
+
breakpoint_threshold_type="percentile",
|
531 |
+
breakpoint_threshold_amount=95,
|
532 |
+
min_chunk_size=self.config["chunk_size"],
|
533 |
+
add_start_index=True
|
534 |
+
)
|
535 |
+
|
536 |
+
chunks = chunker.split_documents(docs)
|
537 |
+
self.vector_db = FAISS.from_documents(chunks, embedding=self.embeddings)
|
538 |
+
|
539 |
+
# Initialize enhanced retriever
|
540 |
+
self.retriever = ContextAwareRetriever(
|
541 |
+
self.vector_db,
|
542 |
+
self.config["diversity_threshold"]
|
543 |
+
)
|
544 |
+
|
545 |
+
print(f"✅ Created vector database with {len(chunks)} chunks")
|
546 |
+
return len(chunks)
|
547 |
+
|
548 |
+
def generate_mcq(self,
|
549 |
+
topic: str,
|
550 |
+
difficulty: DifficultyLevel = DifficultyLevel.MEDIUM,
|
551 |
+
question_type: QuestionType = QuestionType.DEFINITION,
|
552 |
+
context_query: Optional[str] = None) -> MCQQuestion:
|
553 |
+
"""Generate a single MCQ with proper length management"""
|
554 |
+
|
555 |
+
if not self.retriever:
|
556 |
+
raise RuntimeError("System not initialized. Call initialize_system() first.")
|
557 |
+
|
558 |
+
if not self.llm:
|
559 |
+
raise RuntimeError("LLM not initialized. Call initialize_system() first.")
|
560 |
+
|
561 |
+
# Use topic as query if no specific context query provided
|
562 |
+
query = context_query or topic
|
563 |
+
|
564 |
+
# Retrieve relevant contexts (reduced number)
|
565 |
+
contexts = self.retriever.retrieve_diverse_contexts(
|
566 |
+
query, k=self.config["retrieval_k"]
|
567 |
+
)
|
568 |
+
|
569 |
+
if not contexts:
|
570 |
+
raise ValueError(f"No relevant context found for topic: {topic}")
|
571 |
+
|
572 |
+
# Format and truncate contexts
|
573 |
+
context_text = "\n\n".join(doc.page_content for doc in contexts)
|
574 |
+
context_text = self._truncate_context(context_text)
|
575 |
+
|
576 |
+
# Get appropriate prompt template
|
577 |
+
template_text = self.prompt_manager.get_template(question_type)
|
578 |
+
template_infor_debug = self.prompt_manager.get_template_info()
|
579 |
+
prompt_template = PromptTemplate.from_template(template_text)
|
580 |
+
print(f"Template Structure Info: \n {template_infor_debug}")
|
581 |
+
|
582 |
+
# Generate question with length checking
|
583 |
+
prompt_input = {
|
584 |
+
"context": context_text,
|
585 |
+
"topic": topic,
|
586 |
+
"difficulty": difficulty.value,
|
587 |
+
"question_type": question_type.value
|
588 |
+
}
|
589 |
+
|
590 |
+
formatted_prompt = prompt_template.format(**prompt_input)
|
591 |
+
|
592 |
+
# Check prompt length and truncate if necessary
|
593 |
+
is_safe, token_count = self.check_prompt_length(formatted_prompt)
|
594 |
+
|
595 |
+
if not is_safe:
|
596 |
+
print(f"⚠️ Prompt too long ({token_count} tokens), truncating context...")
|
597 |
+
# Further reduce context
|
598 |
+
reduced_context = self._truncate_context(context_text, 400)
|
599 |
+
prompt_input["context"] = reduced_context
|
600 |
+
formatted_prompt = prompt_template.format(**prompt_input)
|
601 |
+
is_safe, token_count = self.check_prompt_length(formatted_prompt)
|
602 |
+
|
603 |
+
if not is_safe:
|
604 |
+
print(f"⚠️ Still too long ({token_count} tokens), using minimal context...")
|
605 |
+
# Use only first paragraph
|
606 |
+
minimal_context = context_text.split('\n')[0][:300]
|
607 |
+
prompt_input["context"] = minimal_context
|
608 |
+
formatted_prompt = prompt_template.format(**prompt_input)
|
609 |
+
|
610 |
+
print(f"📏 Prompt length: {self._estimate_token_count(formatted_prompt)} tokens")
|
611 |
+
|
612 |
+
try:
|
613 |
+
response = self.llm.invoke(formatted_prompt)
|
614 |
+
print(f"✅ Generated response length: {len(response)} characters")
|
615 |
+
except Exception as e:
|
616 |
+
if "length" in str(e).lower():
|
617 |
+
print(f"❌ Length error persists: {e}")
|
618 |
+
# Emergency fallback - use very short context
|
619 |
+
emergency_context = topic + ": " + context_text[:200]
|
620 |
+
prompt_input["context"] = emergency_context
|
621 |
+
formatted_prompt = prompt_template.format(**prompt_input)
|
622 |
+
print(f"🚨 Emergency context length: {self._estimate_token_count(formatted_prompt)} tokens")
|
623 |
+
response = self.llm.invoke(formatted_prompt)
|
624 |
+
else:
|
625 |
+
raise e
|
626 |
+
|
627 |
+
# Parse JSON response
|
628 |
+
try:
|
629 |
+
print(f"🔍 Parsing response (first 300 chars): {response}...")
|
630 |
+
response_data = self._extract_json_from_response(response)
|
631 |
+
print(f"✅ Successfully parsed JSON response")
|
632 |
+
|
633 |
+
# Create MCQ object
|
634 |
+
options = []
|
635 |
+
for label, text in response_data["options"].items():
|
636 |
+
is_correct = label == response_data["correct_answer"]
|
637 |
+
options.append(MCQOption(label, text, is_correct))
|
638 |
+
|
639 |
+
mcq = MCQQuestion(
|
640 |
+
question=response_data["question"],
|
641 |
+
context=prompt_input["context"], # Use the truncated context
|
642 |
+
options=options,
|
643 |
+
explanation=response_data.get("explanation", ""),
|
644 |
+
difficulty=difficulty.value,
|
645 |
+
topic=topic,
|
646 |
+
question_type=question_type.value,
|
647 |
+
source=f"{contexts[0].metadata.get('source', 'Unknown')}"
|
648 |
+
)
|
649 |
+
|
650 |
+
# Calculate quality score
|
651 |
+
mcq.confidence_score = self.validator.calculate_quality_score(mcq)
|
652 |
+
|
653 |
+
return mcq
|
654 |
+
|
655 |
+
except (json.JSONDecodeError, KeyError) as e:
|
656 |
+
print(f"❌ Response parsing error: {e}")
|
657 |
+
print(f"Raw response: {response[:500]}...")
|
658 |
+
raise ValueError(f"Failed to parse LLM response: {e}")
|
659 |
+
|
660 |
+
def _batch_invoke(self, prompts: List[str]) -> List[str]:
|
661 |
+
if not prompts:
|
662 |
+
return []
|
663 |
+
|
664 |
+
# Try to use transformers pipeline (batch mode)
|
665 |
+
pl = getattr(self.llm, "pipeline", None)
|
666 |
+
if pl is not None:
|
667 |
+
try:
|
668 |
+
# Call the pipeline with a list. Transformers will return a list of generation outputs.
|
669 |
+
raw_outputs = pl(prompts)
|
670 |
+
|
671 |
+
responses = []
|
672 |
+
for out in raw_outputs:
|
673 |
+
# The pipeline may return either a dict (single result) or a list of dicts (if return_full_text or num_return_sequences was set)
|
674 |
+
if isinstance(out, list) and out:
|
675 |
+
text = out[0].get("generated_text", "")
|
676 |
+
elif isinstance(out, dict):
|
677 |
+
text = out.get("generated_text", "")
|
678 |
+
else:
|
679 |
+
# fallback: coerce to string
|
680 |
+
text = str(out)
|
681 |
+
responses.append(text)
|
682 |
+
|
683 |
+
if len(responses) == len(prompts):
|
684 |
+
return responses
|
685 |
+
else:
|
686 |
+
print("⚠️ Batch pipeline returned unexpected shape — falling back")
|
687 |
+
except Exception as e:
|
688 |
+
# Batch mode failed. Fall back to sequential invocations.
|
689 |
+
print(f"⚠️ Batch invoke failed: {e}. Falling back to sequential.")
|
690 |
+
|
691 |
+
# Sequential invocation to preserve behavior
|
692 |
+
results = []
|
693 |
+
for p in prompts:
|
694 |
+
results.append(self.llm.invoke(p))
|
695 |
+
return results
|
696 |
+
|
697 |
+
def generate_batch(self,
|
698 |
+
topics: List[str],
|
699 |
+
question_per_topic: int = 5,
|
700 |
+
difficulties: Optional[List[DifficultyLevel]] = None,
|
701 |
+
question_types: Optional[List[QuestionType]] = None) -> List[MCQQuestion]:
|
702 |
+
"""Generate batch of MCQs"""
|
703 |
+
|
704 |
+
if difficulties is None:
|
705 |
+
difficulties = [DifficultyLevel.EASY, DifficultyLevel.MEDIUM, DifficultyLevel.HARD]
|
706 |
+
|
707 |
+
if question_types is None:
|
708 |
+
question_types = [QuestionType.DEFINITION, QuestionType.APPLICATION]
|
709 |
+
|
710 |
+
total_questions = len(topics) * question_per_topic
|
711 |
+
prompt_metadatas = [] # stores tuples (topic, difficulty, question_type)
|
712 |
+
formatted_prompts = []
|
713 |
+
|
714 |
+
print(f"🎯 Generating {total_questions} MCQs...")
|
715 |
+
|
716 |
+
for i, topic in enumerate(topics):
|
717 |
+
print(f"📝 Processing topic {i+1}/{len(topics)}: {topic}")
|
718 |
+
|
719 |
+
for j in range(question_per_topic):
|
720 |
+
difficulty = difficulties[j % len(difficulties)]
|
721 |
+
question_type = question_types[j % len(question_types)]
|
722 |
+
|
723 |
+
query = topic
|
724 |
+
|
725 |
+
# retrieve context once per prompt
|
726 |
+
contexts = self.retriever.retrieve_diverse_contexts(query, k=self.config.get("k", 5)) if hasattr(self, "retriever") else []
|
727 |
+
context_text = "\n\n".join([d.page_content for d in contexts]) if contexts else topic
|
728 |
+
|
729 |
+
# Build prompt with PromptTemplateManager
|
730 |
+
prompt_template = self.template_manager.get_template(question_type)
|
731 |
+
prompt_input = {
|
732 |
+
"context": context_text,
|
733 |
+
"topic": topic,
|
734 |
+
"difficulty": difficulty.value if hasattr(difficulty, 'value') else str(difficulty),
|
735 |
+
"question_type": question_type.value if hasattr(question_type, 'value') else str(question_type)
|
736 |
+
}
|
737 |
+
formatted = prompt_template.format(**prompt_input)
|
738 |
+
|
739 |
+
# Length check and fallback
|
740 |
+
is_safe, token_count = self.check_prompt_length(formatted)
|
741 |
+
if not is_safe:
|
742 |
+
truncated = formatted[: self.config.get("max_prompt_chars", 2000)]
|
743 |
+
formatted = truncated
|
744 |
+
|
745 |
+
prompt_metadatas.append((topic, difficulty, question_type))
|
746 |
+
formatted_prompts.append(formatted)
|
747 |
+
|
748 |
+
total = len(formatted_prompts)
|
749 |
+
if total == 0:
|
750 |
+
return []
|
751 |
+
|
752 |
+
print(f"📦 Sending {total} prompts to the LLM in batch mode (if supported)")
|
753 |
+
start_t = time.time()
|
754 |
+
raw_responses = self._batch_invoke(formatted_prompts)
|
755 |
+
elapsed = time.time() - start_t
|
756 |
+
print(f"⏱ LLM batch time: {elapsed:.2f}s for {total} prompts")
|
757 |
+
|
758 |
+
# Parse raw responses back into MCQQuestion objects
|
759 |
+
mcqs = []
|
760 |
+
for meta, response in zip(prompt_metadatas, raw_responses):
|
761 |
+
topic, difficulty, question_type = meta
|
762 |
+
try:
|
763 |
+
response_data = self._extract_json_from_response(response)
|
764 |
+
|
765 |
+
# Reconstruct MCQQuestion
|
766 |
+
options = []
|
767 |
+
for label, text in response_data["options"].items():
|
768 |
+
is_correct = label == response_data["correct_answer"]
|
769 |
+
options.append(self.MCQOption(label=label, text=text, is_correct=is_correct))
|
770 |
+
|
771 |
+
mcq = self.MCQQuestion(
|
772 |
+
question_id=response_data.get("id", None),
|
773 |
+
topic=topic,
|
774 |
+
question_text=response_data["question"],
|
775 |
+
options=options,
|
776 |
+
explanation=response_data.get("explanation", ""),
|
777 |
+
difficulty=(difficulty if hasattr(difficulty, 'name') else difficulty),
|
778 |
+
question_type=(question_type if hasattr(question_type, 'name') else question_type),
|
779 |
+
confidence_score=response_data.get("confidence_score", 0.0)
|
780 |
+
)
|
781 |
+
|
782 |
+
if hasattr(self, 'validator'):
|
783 |
+
mcq = self.validator.calculate_quality_score(mcq)
|
784 |
+
|
785 |
+
mcqs.append(mcq)
|
786 |
+
except Exception as e:
|
787 |
+
print(f"❌ Failed parsing response for topic={topic}: {e}")
|
788 |
+
|
789 |
+
print(f"🎉 Generated {len(mcqs)}/{total} MCQs successfully (batched)")
|
790 |
+
return mcqs
|
791 |
+
|
792 |
+
def export_mcqs(self, mcqs: List[MCQQuestion], output_path: str):
|
793 |
+
"""Export MCQs to JSON file"""
|
794 |
+
output_data = {
|
795 |
+
"metadata": {
|
796 |
+
"total_questions": len(mcqs),
|
797 |
+
"generation_timestamp": time.time(),
|
798 |
+
"average_quality": np.mean([mcq.confidence_score for mcq in mcqs]) if mcqs else 0
|
799 |
+
},
|
800 |
+
"questions": [mcq.to_dict() for mcq in mcqs]
|
801 |
+
}
|
802 |
+
|
803 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
804 |
+
json.dump(output_data, f, ensure_ascii=False, indent=2)
|
805 |
+
|
806 |
+
print(f"📁 Exported {len(mcqs)} MCQs to {output_path}")
|
807 |
+
|
808 |
+
def debug_system_state(self):
|
809 |
+
"""Debug function to check system initialization state"""
|
810 |
+
print("🔍 System Debug Information:")
|
811 |
+
print(f" Embeddings initialized: {'✅' if self.embeddings else '❌'}")
|
812 |
+
print(f" LLM initialized: {'✅' if self.llm else '❌'}")
|
813 |
+
print(f" Vector database created: {'✅' if self.vector_db else '❌'}")
|
814 |
+
print(f" Retriever initialized: {'✅' if self.retriever else '❌'}")
|
815 |
+
print(f" Config loaded: {'✅' if self.config else '❌'}")
|
816 |
+
|
817 |
+
if self.config:
|
818 |
+
print(f" Embedding model: {self.config.get('embedding_model', 'Not set')}")
|
819 |
+
print(f" LLM model: {self.config.get('llm_model', 'Not set')}")
|
820 |
+
print(f" Max context length: {self.config.get('max_context_length', 'Not set')}")
|
821 |
+
print(f" Max input tokens: {self.config.get('max_input_tokens', 'Not set')}")
|
822 |
+
|
823 |
+
# Show template information
|
824 |
+
template_info = self.prompt_manager.get_template_info()
|
825 |
+
print(f" Template sizes:")
|
826 |
+
for template_type, size in template_info.items():
|
827 |
+
print(f" {template_type}: {size} characters")
|
828 |
+
|
829 |
+
|
830 |
+
def debug_prompt_templates():
|
831 |
+
"""Debug function to test prompt template generation"""
|
832 |
+
print("🔍 Testing Prompt Templates:")
|
833 |
+
prompt_manager = PromptTemplateManager()
|
834 |
+
|
835 |
+
for question_type in QuestionType:
|
836 |
+
try:
|
837 |
+
template = prompt_manager.get_template(question_type)
|
838 |
+
print(f" {question_type.value}: ✅ Template loaded ({len(template)} chars)")
|
839 |
+
except Exception as e:
|
840 |
+
print(f" {question_type.value}: ❌ Error - {e}")
|
841 |
+
|
842 |
+
|
843 |
+
def main():
|
844 |
+
"""Main function demonstrating the enhanced RAG MCQ system"""
|
845 |
+
print("🚀 Starting Enhanced RAG MCQ Generation System")
|
846 |
+
|
847 |
+
|
848 |
+
# Test prompt templates first
|
849 |
+
print("\n🧪 Testing prompt templates...")
|
850 |
+
debug_prompt_templates()
|
851 |
+
|
852 |
+
# Initialize system
|
853 |
+
generator = EnhancedRAGMCQGenerator()
|
854 |
+
|
855 |
+
# Check initial state
|
856 |
+
print("\n🔍 Initial system state:")
|
857 |
+
generator.debug_system_state()
|
858 |
+
|
859 |
+
try:
|
860 |
+
generator.initialize_system()
|
861 |
+
|
862 |
+
# Check state after initialization
|
863 |
+
print("\n🔍 Post-initialization state:")
|
864 |
+
generator.debug_system_state()
|
865 |
+
|
866 |
+
# Load documents
|
867 |
+
folder_path = "pdfs" # Updated path to your PDF folder
|
868 |
+
try:
|
869 |
+
docs, filenames = generator.load_documents(folder_path)
|
870 |
+
num_chunks = generator.build_vector_database(docs)
|
871 |
+
start = time.time() #? Calc generation time
|
872 |
+
|
873 |
+
print(f"⏱️ Loading Time: {time.time() - start:.2f}s") #? Loading document time
|
874 |
+
print(f"📚 System ready with {len(filenames)} files and {num_chunks} chunks")
|
875 |
+
|
876 |
+
# Generate sample MCQs
|
877 |
+
topics = ["Object Oriented Programming", "Malware Reverse Engineering"]
|
878 |
+
|
879 |
+
# Single question generation
|
880 |
+
# print("\n🎯 Generating single MCQ...")
|
881 |
+
# mcq = generator.generate_mcq(
|
882 |
+
# topic=topics[0],
|
883 |
+
# difficulty=DifficultyLevel.MEDIUM,
|
884 |
+
# question_type=QuestionType.DEFINITION
|
885 |
+
# )
|
886 |
+
|
887 |
+
# print(f"Question: {mcq.question}")
|
888 |
+
# print(f"Quality Score: {mcq.confidence_score:.1f}")
|
889 |
+
|
890 |
+
# Batch generation
|
891 |
+
n_question = 2
|
892 |
+
print("\n🎯 Generating batch MCQs...")
|
893 |
+
|
894 |
+
#? MAIN OUTPUT: Multiple Choice Question
|
895 |
+
mcqs = generator.generate_batch(
|
896 |
+
topics=topics,
|
897 |
+
question_per_topic=n_question
|
898 |
+
)
|
899 |
+
|
900 |
+
# Export results
|
901 |
+
output_path = "generated_mcqs.json"
|
902 |
+
generator.export_mcqs(mcqs, output_path)
|
903 |
+
|
904 |
+
# Quality summary
|
905 |
+
print(f"Average mcq generation time taken: {((time.time() - start)/n_question)/60:.2f} min")
|
906 |
+
quality_scores = [mcq.confidence_score for mcq in mcqs]
|
907 |
+
print(f"\n📊 Quality Summary:")
|
908 |
+
print(f"Average Quality: {np.mean(quality_scores):.1f}")
|
909 |
+
print(f"Min Quality: {np.min(quality_scores):.1f}")
|
910 |
+
print(f"Max Quality: {np.max(quality_scores):.1f}")
|
911 |
+
|
912 |
+
except FileNotFoundError as e:
|
913 |
+
print(f"❌ Document folder error: {e}")
|
914 |
+
print("💡 Please ensure your PDF files are in the 'pdfs' folder")
|
915 |
+
except Exception as e:
|
916 |
+
print(f"❌ Document processing error: {e}")
|
917 |
+
print("💡 Check your PDF files and folder structure")
|
918 |
+
|
919 |
+
except Exception as e:
|
920 |
+
print(f"❌ System initialization error: {e}")
|
921 |
+
print("💡 Check your dependencies and API keys")
|
922 |
+
generator.debug_system_state()
|
923 |
+
|
924 |
+
|
925 |
+
if __name__ == "__main__":
|
926 |
+
# Check system components
|
927 |
+
# generator = EnhancedRAGMCQGenerator()
|
928 |
+
# generator.debug_system_state()
|
929 |
+
|
930 |
+
# # Test templates separately
|
931 |
+
# debug_prompt_templates()
|
932 |
+
|
933 |
+
main()
|
fastapi_app.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from enhanced_rag_mcq import EnhancedRAGMCQGenerator, debug_prompt_templates, DifficultyLevel, QuestionType
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
from pydantic import BaseModel, Field
|
7 |
+
from typing import List, Optional, Dict
|
8 |
+
from fastapi import FastAPI, Form, HTTPException, UploadFile, File
|
9 |
+
from contextlib import asynccontextmanager
|
10 |
+
|
11 |
+
|
12 |
+
generator: Optional[EnhancedRAGMCQGenerator] = None
|
13 |
+
tmp_folder = "./tmp" #? make sure folder upload here
|
14 |
+
if not os.path.exists(tmp_folder):
|
15 |
+
os.makedirs(tmp_folder)
|
16 |
+
|
17 |
+
class GenerateRequest(BaseModel):
|
18 |
+
topics: List[str] = Field(..., description="List of topics for MCQ generation")
|
19 |
+
question_per_topic: int = Field(1, ge=1, le=10, description="Number of questions per topic")
|
20 |
+
difficulty: Optional[DifficultyLevel] = Field(
|
21 |
+
DifficultyLevel.MEDIUM,
|
22 |
+
description="Difficulty level for generated questions"
|
23 |
+
)
|
24 |
+
qtype: Optional[QuestionType] = Field(
|
25 |
+
QuestionType.DEFINITION,
|
26 |
+
description="Type of question to generate"
|
27 |
+
)
|
28 |
+
|
29 |
+
class MCQResponse(BaseModel):
|
30 |
+
question: str
|
31 |
+
options: Dict
|
32 |
+
correct_answer: str
|
33 |
+
confidence_score: float
|
34 |
+
|
35 |
+
class GenerateResponse(BaseModel):
|
36 |
+
topics: List[str]
|
37 |
+
generated: List[MCQResponse]
|
38 |
+
avg_confidence: float
|
39 |
+
generation_time: float
|
40 |
+
|
41 |
+
@asynccontextmanager
|
42 |
+
async def lifespan(app: FastAPI):
|
43 |
+
global generator
|
44 |
+
generator = EnhancedRAGMCQGenerator()
|
45 |
+
try:
|
46 |
+
generator.initialize_system()
|
47 |
+
print("RAG system initialized.")
|
48 |
+
except Exception as e:
|
49 |
+
print(f"❌ Failed to initialize RAG system: {e}")
|
50 |
+
yield
|
51 |
+
# Optional: Cleanup code after shutdown
|
52 |
+
|
53 |
+
app = FastAPI(
|
54 |
+
title="Enhanced RAG MCQ Generation API",
|
55 |
+
description="An API wrapping the RAG-based MCQ generator using FastAPI",
|
56 |
+
version="1.0.0",
|
57 |
+
lifespan=lifespan
|
58 |
+
)
|
59 |
+
|
60 |
+
#? cmd: fastapi run app.py
|
61 |
+
@app.post("/generate/")
|
62 |
+
async def mcq_gen(
|
63 |
+
file: UploadFile = File(...),
|
64 |
+
topics: str = Form(...),
|
65 |
+
n_questions: str = Form(...),
|
66 |
+
difficulty: str = Form(...),
|
67 |
+
qtype: str = Form(...)
|
68 |
+
):
|
69 |
+
if not generator:
|
70 |
+
raise HTTPException(status_code=500, detail="Generator not initialized")
|
71 |
+
|
72 |
+
topic_list = [t.strip() for t in topics.split(',') if t.strip()]
|
73 |
+
if not topic_list:
|
74 |
+
raise HTTPException(status_code=400, detail="At least one topic must be provided")
|
75 |
+
|
76 |
+
# Validate and convert enum values
|
77 |
+
try:
|
78 |
+
difficulty_enum = DifficultyLevel(difficulty.lower())
|
79 |
+
except ValueError:
|
80 |
+
valid_difficulties = [d.value for d in DifficultyLevel]
|
81 |
+
raise HTTPException(status_code=400, detail=f"Invalid difficulty. Must be one of: {valid_difficulties}")
|
82 |
+
|
83 |
+
try:
|
84 |
+
qtype_enum = QuestionType(qtype.lower())
|
85 |
+
except ValueError:
|
86 |
+
valid_types = [q.value for q in QuestionType]
|
87 |
+
raise HTTPException(status_code=400, detail=f"Invalid question type. Must be one of: {valid_types}")
|
88 |
+
|
89 |
+
# Save uploaded PDF to temporary folder
|
90 |
+
filename = file.filename if file.filename else "uploaded_file"
|
91 |
+
file_path = os.path.join(tmp_folder, filename)
|
92 |
+
with open(file_path, "wb") as buffer:
|
93 |
+
shutil.copyfileobj(file.file, buffer)
|
94 |
+
file.file.close()
|
95 |
+
|
96 |
+
try:
|
97 |
+
# Load and index the uploaded document
|
98 |
+
docs, _ = generator.load_documents(tmp_folder)
|
99 |
+
generator.build_vector_database(docs)
|
100 |
+
except Exception as e:
|
101 |
+
raise HTTPException(status_code=500, detail=f"Document processing error: {e}")
|
102 |
+
|
103 |
+
start_time = time.time()
|
104 |
+
try:
|
105 |
+
mcqs = generator.generate_batch(
|
106 |
+
topics=topic_list,
|
107 |
+
question_per_topic=int(n_questions),
|
108 |
+
difficulties=[difficulty_enum],
|
109 |
+
question_types=[qtype_enum]
|
110 |
+
)
|
111 |
+
except Exception as e:
|
112 |
+
raise HTTPException(status_code=500, detail=str(e))
|
113 |
+
end_time = time.time()
|
114 |
+
|
115 |
+
res_dict = [m.to_dict() for m in mcqs]
|
116 |
+
|
117 |
+
responses = [
|
118 |
+
MCQResponse(
|
119 |
+
question=m["question"],
|
120 |
+
options=m["options"],
|
121 |
+
correct_answer=m["correct_answer"],
|
122 |
+
confidence_score=m["confidence_score"]
|
123 |
+
) for m in res_dict
|
124 |
+
]
|
125 |
+
avg_conf = sum(m["confidence_score"] for m in res_dict) / len(mcqs)
|
126 |
+
# Clean up temporary files
|
127 |
+
shutil.rmtree(tmp_folder)
|
128 |
+
os.makedirs(tmp_folder)
|
129 |
+
|
130 |
+
return GenerateResponse(
|
131 |
+
topics=topic_list,
|
132 |
+
generated=responses,
|
133 |
+
avg_confidence=avg_conf,
|
134 |
+
generation_time=end_time - start_time
|
135 |
+
)
|
requirements.txt
ADDED
Binary file (3.18 kB). View file
|
|
tmp/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
This is the temporary directory for storing PDF files
|