namberino
commited on
Commit
Β·
ffd37d2
1
Parent(s):
2f5ff1a
Testing
Browse files- enhanced_rag_mcq.py +13 -106
- fastapi_app.py +0 -12
enhanced_rag_mcq.py
CHANGED
@@ -657,43 +657,6 @@ class EnhancedRAGMCQGenerator:
|
|
657 |
print(f"Raw response: {response[:500]}...")
|
658 |
raise ValueError(f"Failed to parse LLM response: {e}")
|
659 |
|
660 |
-
def _batch_invoke(self, prompts: List[str]) -> List[str]:
|
661 |
-
if not prompts:
|
662 |
-
return []
|
663 |
-
|
664 |
-
# Try to use transformers pipeline (batch mode)
|
665 |
-
pl = getattr(self.llm, "pipeline", None)
|
666 |
-
if pl is not None:
|
667 |
-
try:
|
668 |
-
# Call the pipeline with a list. Transformers will return a list of generation outputs.
|
669 |
-
raw_outputs = pl(prompts)
|
670 |
-
|
671 |
-
responses = []
|
672 |
-
for out in raw_outputs:
|
673 |
-
# The pipeline may return either a dict (single result) or a list of dicts (if return_full_text or num_return_sequences was set)
|
674 |
-
if isinstance(out, list) and out:
|
675 |
-
text = out[0].get("generated_text", "")
|
676 |
-
elif isinstance(out, dict):
|
677 |
-
text = out.get("generated_text", "")
|
678 |
-
else:
|
679 |
-
# fallback: coerce to string
|
680 |
-
text = str(out)
|
681 |
-
responses.append(text)
|
682 |
-
|
683 |
-
if len(responses) == len(prompts):
|
684 |
-
return responses
|
685 |
-
else:
|
686 |
-
print("β οΈ Batch pipeline returned unexpected shape β falling back")
|
687 |
-
except Exception as e:
|
688 |
-
# Batch mode failed. Fall back to sequential invocations.
|
689 |
-
print(f"β οΈ Batch invoke failed: {e}. Falling back to sequential.")
|
690 |
-
|
691 |
-
# Sequential invocation to preserve behavior
|
692 |
-
results = []
|
693 |
-
for p in prompts:
|
694 |
-
results.append(self.llm.invoke(p))
|
695 |
-
return results
|
696 |
-
|
697 |
def generate_batch(self,
|
698 |
topics: List[str],
|
699 |
question_per_topic: int = 5,
|
@@ -707,9 +670,8 @@ class EnhancedRAGMCQGenerator:
|
|
707 |
if question_types is None:
|
708 |
question_types = [QuestionType.DEFINITION, QuestionType.APPLICATION]
|
709 |
|
|
|
710 |
total_questions = len(topics) * question_per_topic
|
711 |
-
prompt_metadatas = [] # stores tuples (topic, difficulty, question_type)
|
712 |
-
formatted_prompts = []
|
713 |
|
714 |
print(f"π― Generating {total_questions} MCQs...")
|
715 |
|
@@ -717,76 +679,21 @@ class EnhancedRAGMCQGenerator:
|
|
717 |
print(f"π Processing topic {i+1}/{len(topics)}: {topic}")
|
718 |
|
719 |
for j in range(question_per_topic):
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
# retrieve context once per prompt
|
726 |
-
contexts = self.retriever.retrieve_diverse_contexts(query, k=self.config.get("k", 5)) if hasattr(self, "retriever") else []
|
727 |
-
context_text = "\n\n".join([d.page_content for d in contexts]) if contexts else topic
|
728 |
-
|
729 |
-
# Build prompt with PromptTemplateManager
|
730 |
-
prompt_template = self.template_manager.get_template(question_type)
|
731 |
-
prompt_input = {
|
732 |
-
"context": context_text,
|
733 |
-
"topic": topic,
|
734 |
-
"difficulty": difficulty.value if hasattr(difficulty, 'value') else str(difficulty),
|
735 |
-
"question_type": question_type.value if hasattr(question_type, 'value') else str(question_type)
|
736 |
-
}
|
737 |
-
formatted = prompt_template.format(**prompt_input)
|
738 |
-
|
739 |
-
# Length check and fallback
|
740 |
-
is_safe, token_count = self.check_prompt_length(formatted)
|
741 |
-
if not is_safe:
|
742 |
-
truncated = formatted[: self.config.get("max_prompt_chars", 2000)]
|
743 |
-
formatted = truncated
|
744 |
-
|
745 |
-
prompt_metadatas.append((topic, difficulty, question_type))
|
746 |
-
formatted_prompts.append(formatted)
|
747 |
-
|
748 |
-
total = len(formatted_prompts)
|
749 |
-
if total == 0:
|
750 |
-
return []
|
751 |
|
752 |
-
|
753 |
-
|
754 |
-
raw_responses = self._batch_invoke(formatted_prompts)
|
755 |
-
elapsed = time.time() - start_t
|
756 |
-
print(f"β± LLM batch time: {elapsed:.2f}s for {total} prompts")
|
757 |
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
response_data = self._extract_json_from_response(response)
|
764 |
-
|
765 |
-
# Reconstruct MCQQuestion
|
766 |
-
options = []
|
767 |
-
for label, text in response_data["options"].items():
|
768 |
-
is_correct = label == response_data["correct_answer"]
|
769 |
-
options.append(self.MCQOption(label=label, text=text, is_correct=is_correct))
|
770 |
-
|
771 |
-
mcq = self.MCQQuestion(
|
772 |
-
question_id=response_data.get("id", None),
|
773 |
-
topic=topic,
|
774 |
-
question_text=response_data["question"],
|
775 |
-
options=options,
|
776 |
-
explanation=response_data.get("explanation", ""),
|
777 |
-
difficulty=(difficulty if hasattr(difficulty, 'name') else difficulty),
|
778 |
-
question_type=(question_type if hasattr(question_type, 'name') else question_type),
|
779 |
-
confidence_score=response_data.get("confidence_score", 0.0)
|
780 |
-
)
|
781 |
-
|
782 |
-
if hasattr(self, 'validator'):
|
783 |
-
mcq = self.validator.calculate_quality_score(mcq)
|
784 |
-
|
785 |
-
mcqs.append(mcq)
|
786 |
-
except Exception as e:
|
787 |
-
print(f"β Failed parsing response for topic={topic}: {e}")
|
788 |
|
789 |
-
print(f"π Generated {len(mcqs)}/{
|
790 |
return mcqs
|
791 |
|
792 |
def export_mcqs(self, mcqs: List[MCQQuestion], output_path: str):
|
|
|
657 |
print(f"Raw response: {response[:500]}...")
|
658 |
raise ValueError(f"Failed to parse LLM response: {e}")
|
659 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
660 |
def generate_batch(self,
|
661 |
topics: List[str],
|
662 |
question_per_topic: int = 5,
|
|
|
670 |
if question_types is None:
|
671 |
question_types = [QuestionType.DEFINITION, QuestionType.APPLICATION]
|
672 |
|
673 |
+
mcqs = []
|
674 |
total_questions = len(topics) * question_per_topic
|
|
|
|
|
675 |
|
676 |
print(f"π― Generating {total_questions} MCQs...")
|
677 |
|
|
|
679 |
print(f"π Processing topic {i+1}/{len(topics)}: {topic}")
|
680 |
|
681 |
for j in range(question_per_topic):
|
682 |
+
try:
|
683 |
+
# Cycle through difficulties and question types
|
684 |
+
difficulty = difficulties[j % len(difficulties)]
|
685 |
+
question_type = question_types[j % len(question_types)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
686 |
|
687 |
+
mcq = self.generate_mcq(topic, difficulty, question_type)
|
688 |
+
mcqs.append(mcq)
|
|
|
|
|
|
|
689 |
|
690 |
+
print(f" β
Generated question {j+1}/{question_per_topic} "
|
691 |
+
f"(Quality: {mcq.confidence_score:.1f})")
|
692 |
+
|
693 |
+
except Exception as e:
|
694 |
+
print(f" β Failed to generate question {j+1}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
695 |
|
696 |
+
print(f"π Generated {len(mcqs)}/{total_questions} MCQs successfully")
|
697 |
return mcqs
|
698 |
|
699 |
def export_mcqs(self, mcqs: List[MCQQuestion], output_path: str):
|
fastapi_app.py
CHANGED
@@ -68,14 +68,10 @@ async def mcq_gen(
|
|
68 |
):
|
69 |
if not generator:
|
70 |
raise HTTPException(status_code=500, detail="Generator not initialized")
|
71 |
-
|
72 |
-
print("generator initialized")
|
73 |
|
74 |
topic_list = [t.strip() for t in topics.split(',') if t.strip()]
|
75 |
if not topic_list:
|
76 |
raise HTTPException(status_code=400, detail="At least one topic must be provided")
|
77 |
-
|
78 |
-
print("topic validated")
|
79 |
|
80 |
# Validate and convert enum values
|
81 |
try:
|
@@ -84,16 +80,12 @@ async def mcq_gen(
|
|
84 |
valid_difficulties = [d.value for d in DifficultyLevel]
|
85 |
raise HTTPException(status_code=400, detail=f"Invalid difficulty. Must be one of: {valid_difficulties}")
|
86 |
|
87 |
-
print("difficulty validated")
|
88 |
-
|
89 |
try:
|
90 |
qtype_enum = QuestionType(qtype.lower())
|
91 |
except ValueError:
|
92 |
valid_types = [q.value for q in QuestionType]
|
93 |
raise HTTPException(status_code=400, detail=f"Invalid question type. Must be one of: {valid_types}")
|
94 |
|
95 |
-
print("question type validated")
|
96 |
-
|
97 |
# Save uploaded PDF to temporary folder
|
98 |
filename = file.filename if file.filename else "uploaded_file"
|
99 |
file_path = os.path.join(tmp_folder, filename)
|
@@ -101,8 +93,6 @@ async def mcq_gen(
|
|
101 |
shutil.copyfileobj(file.file, buffer)
|
102 |
file.file.close()
|
103 |
|
104 |
-
print(f"file read and written to {file_path}")
|
105 |
-
|
106 |
try:
|
107 |
# Load and index the uploaded document
|
108 |
docs, _ = generator.load_documents(tmp_folder)
|
@@ -110,8 +100,6 @@ async def mcq_gen(
|
|
110 |
except Exception as e:
|
111 |
raise HTTPException(status_code=500, detail=f"Document processing error: {e}")
|
112 |
|
113 |
-
print("loaded and indexed the document")
|
114 |
-
|
115 |
start_time = time.time()
|
116 |
try:
|
117 |
mcqs = generator.generate_batch(
|
|
|
68 |
):
|
69 |
if not generator:
|
70 |
raise HTTPException(status_code=500, detail="Generator not initialized")
|
|
|
|
|
71 |
|
72 |
topic_list = [t.strip() for t in topics.split(',') if t.strip()]
|
73 |
if not topic_list:
|
74 |
raise HTTPException(status_code=400, detail="At least one topic must be provided")
|
|
|
|
|
75 |
|
76 |
# Validate and convert enum values
|
77 |
try:
|
|
|
80 |
valid_difficulties = [d.value for d in DifficultyLevel]
|
81 |
raise HTTPException(status_code=400, detail=f"Invalid difficulty. Must be one of: {valid_difficulties}")
|
82 |
|
|
|
|
|
83 |
try:
|
84 |
qtype_enum = QuestionType(qtype.lower())
|
85 |
except ValueError:
|
86 |
valid_types = [q.value for q in QuestionType]
|
87 |
raise HTTPException(status_code=400, detail=f"Invalid question type. Must be one of: {valid_types}")
|
88 |
|
|
|
|
|
89 |
# Save uploaded PDF to temporary folder
|
90 |
filename = file.filename if file.filename else "uploaded_file"
|
91 |
file_path = os.path.join(tmp_folder, filename)
|
|
|
93 |
shutil.copyfileobj(file.file, buffer)
|
94 |
file.file.close()
|
95 |
|
|
|
|
|
96 |
try:
|
97 |
# Load and index the uploaded document
|
98 |
docs, _ = generator.load_documents(tmp_folder)
|
|
|
100 |
except Exception as e:
|
101 |
raise HTTPException(status_code=500, detail=f"Document processing error: {e}")
|
102 |
|
|
|
|
|
103 |
start_time = time.time()
|
104 |
try:
|
105 |
mcqs = generator.generate_batch(
|