namberino
commited on
Commit
·
53e5542
1
Parent(s):
d0a9bfe
Testing old version
Browse files- enhanced_rag_mcq.py +21 -113
- fastapi_app.py +7 -20
enhanced_rag_mcq.py
CHANGED
@@ -11,7 +11,6 @@ import os
|
|
11 |
import json
|
12 |
import time
|
13 |
import torch
|
14 |
-
import re
|
15 |
from typing import List, Dict, Any, Optional, Tuple
|
16 |
from dataclasses import dataclass, asdict
|
17 |
from pathlib import Path
|
@@ -29,6 +28,7 @@ from langchain_core.prompts import PromptTemplate
|
|
29 |
from langchain_community.vectorstores import FAISS
|
30 |
|
31 |
from langchain_core.documents import Document
|
|
|
32 |
|
33 |
# Transformers imports
|
34 |
from transformers import (
|
@@ -338,7 +338,7 @@ class EnhancedRAGMCQGenerator:
|
|
338 |
"""Get default configuration"""
|
339 |
return {
|
340 |
"embedding_model": "bkai-foundation-models/vietnamese-bi-encoder",
|
341 |
-
"llm_model": "
|
342 |
"chunk_size": 500,
|
343 |
"chunk_overlap": 50,
|
344 |
"retrieval_k": 3,
|
@@ -383,9 +383,10 @@ class EnhancedRAGMCQGenerator:
|
|
383 |
# Vietnamese typically has ~0.75 tokens per character
|
384 |
return int(len(text) * 0.75)
|
385 |
|
386 |
-
|
387 |
def _extract_json_from_response(self, response: str) -> dict:
|
388 |
"""Extract JSON from LLM response with multiple fallback strategies"""
|
|
|
389 |
|
390 |
# Strategy 1: Clean response of prompt repetition
|
391 |
clean_response = response
|
@@ -473,16 +474,16 @@ class EnhancedRAGMCQGenerator:
|
|
473 |
bnb_4bit_quant_type="nf4"
|
474 |
)
|
475 |
|
476 |
-
model =
|
477 |
self.config["llm_model"],
|
478 |
quantization_config=bnb_config,
|
479 |
low_cpu_mem_usage=True,
|
480 |
-
device_map="
|
481 |
token=hf_token
|
482 |
)
|
483 |
|
484 |
-
tokenizer = AutoTokenizer.from_pretrained(self.config["llm_model"])
|
485 |
-
|
486 |
|
487 |
model_pipeline = pipeline(
|
488 |
"text-generation",
|
@@ -657,43 +658,6 @@ class EnhancedRAGMCQGenerator:
|
|
657 |
print(f"Raw response: {response[:500]}...")
|
658 |
raise ValueError(f"Failed to parse LLM response: {e}")
|
659 |
|
660 |
-
def _batch_invoke(self, prompts: List[str]) -> List[str]:
|
661 |
-
if not prompts:
|
662 |
-
return []
|
663 |
-
|
664 |
-
# Try to use transformers pipeline (batch mode)
|
665 |
-
pl = getattr(self.llm, "pipeline", None)
|
666 |
-
if pl is not None:
|
667 |
-
try:
|
668 |
-
# Call the pipeline with a list. Transformers will return a list of generation outputs.
|
669 |
-
raw_outputs = pl(prompts)
|
670 |
-
|
671 |
-
responses = []
|
672 |
-
for out in raw_outputs:
|
673 |
-
# The pipeline may return either a dict (single result) or a list of dicts (if return_full_text or num_return_sequences was set)
|
674 |
-
if isinstance(out, list) and out:
|
675 |
-
text = out[0].get("generated_text", "")
|
676 |
-
elif isinstance(out, dict):
|
677 |
-
text = out.get("generated_text", "")
|
678 |
-
else:
|
679 |
-
# fallback: coerce to string
|
680 |
-
text = str(out)
|
681 |
-
responses.append(text)
|
682 |
-
|
683 |
-
if len(responses) == len(prompts):
|
684 |
-
return responses
|
685 |
-
else:
|
686 |
-
print("⚠️ Batch pipeline returned unexpected shape — falling back")
|
687 |
-
except Exception as e:
|
688 |
-
# Batch mode failed. Fall back to sequential invocations.
|
689 |
-
print(f"⚠️ Batch invoke failed: {e}. Falling back to sequential.")
|
690 |
-
|
691 |
-
# Sequential invocation to preserve behavior
|
692 |
-
results = []
|
693 |
-
for p in prompts:
|
694 |
-
results.append(self.llm.invoke(p))
|
695 |
-
return results
|
696 |
-
|
697 |
def generate_batch(self,
|
698 |
topics: List[str],
|
699 |
question_per_topic: int = 5,
|
@@ -707,9 +671,8 @@ class EnhancedRAGMCQGenerator:
|
|
707 |
if question_types is None:
|
708 |
question_types = [QuestionType.DEFINITION, QuestionType.APPLICATION]
|
709 |
|
|
|
710 |
total_questions = len(topics) * question_per_topic
|
711 |
-
prompt_metadatas = [] # stores tuples (topic, difficulty, question_type)
|
712 |
-
formatted_prompts = []
|
713 |
|
714 |
print(f"🎯 Generating {total_questions} MCQs...")
|
715 |
|
@@ -717,76 +680,21 @@ class EnhancedRAGMCQGenerator:
|
|
717 |
print(f"📝 Processing topic {i+1}/{len(topics)}: {topic}")
|
718 |
|
719 |
for j in range(question_per_topic):
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
# retrieve context once per prompt
|
726 |
-
contexts = self.retriever.retrieve_diverse_contexts(query, k=self.config.get("k", 5)) if hasattr(self, "retriever") else []
|
727 |
-
context_text = "\n\n".join([d.page_content for d in contexts]) if contexts else topic
|
728 |
-
|
729 |
-
# Build prompt with PromptTemplateManager
|
730 |
-
prompt_template = self.template_manager.get_template(question_type)
|
731 |
-
prompt_input = {
|
732 |
-
"context": context_text,
|
733 |
-
"topic": topic,
|
734 |
-
"difficulty": difficulty.value if hasattr(difficulty, 'value') else str(difficulty),
|
735 |
-
"question_type": question_type.value if hasattr(question_type, 'value') else str(question_type)
|
736 |
-
}
|
737 |
-
formatted = prompt_template.format(**prompt_input)
|
738 |
-
|
739 |
-
# Length check and fallback
|
740 |
-
is_safe, token_count = self.check_prompt_length(formatted)
|
741 |
-
if not is_safe:
|
742 |
-
truncated = formatted[: self.config.get("max_prompt_chars", 2000)]
|
743 |
-
formatted = truncated
|
744 |
-
|
745 |
-
prompt_metadatas.append((topic, difficulty, question_type))
|
746 |
-
formatted_prompts.append(formatted)
|
747 |
-
|
748 |
-
total = len(formatted_prompts)
|
749 |
-
if total == 0:
|
750 |
-
return []
|
751 |
|
752 |
-
|
753 |
-
|
754 |
-
raw_responses = self._batch_invoke(formatted_prompts)
|
755 |
-
elapsed = time.time() - start_t
|
756 |
-
print(f"⏱ LLM batch time: {elapsed:.2f}s for {total} prompts")
|
757 |
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
response_data = self._extract_json_from_response(response)
|
764 |
-
|
765 |
-
# Reconstruct MCQQuestion
|
766 |
-
options = []
|
767 |
-
for label, text in response_data["options"].items():
|
768 |
-
is_correct = label == response_data["correct_answer"]
|
769 |
-
options.append(self.MCQOption(label=label, text=text, is_correct=is_correct))
|
770 |
-
|
771 |
-
mcq = self.MCQQuestion(
|
772 |
-
question_id=response_data.get("id", None),
|
773 |
-
topic=topic,
|
774 |
-
question_text=response_data["question"],
|
775 |
-
options=options,
|
776 |
-
explanation=response_data.get("explanation", ""),
|
777 |
-
difficulty=(difficulty if hasattr(difficulty, 'name') else difficulty),
|
778 |
-
question_type=(question_type if hasattr(question_type, 'name') else question_type),
|
779 |
-
confidence_score=response_data.get("confidence_score", 0.0)
|
780 |
-
)
|
781 |
-
|
782 |
-
if hasattr(self, 'validator'):
|
783 |
-
mcq = self.validator.calculate_quality_score(mcq)
|
784 |
-
|
785 |
-
mcqs.append(mcq)
|
786 |
-
except Exception as e:
|
787 |
-
print(f"❌ Failed parsing response for topic={topic}: {e}")
|
788 |
|
789 |
-
print(f"🎉 Generated {len(mcqs)}/{
|
790 |
return mcqs
|
791 |
|
792 |
def export_mcqs(self, mcqs: List[MCQQuestion], output_path: str):
|
|
|
11 |
import json
|
12 |
import time
|
13 |
import torch
|
|
|
14 |
from typing import List, Dict, Any, Optional, Tuple
|
15 |
from dataclasses import dataclass, asdict
|
16 |
from pathlib import Path
|
|
|
28 |
from langchain_community.vectorstores import FAISS
|
29 |
|
30 |
from langchain_core.documents import Document
|
31 |
+
from unsloth import FastLanguageModel
|
32 |
|
33 |
# Transformers imports
|
34 |
from transformers import (
|
|
|
338 |
"""Get default configuration"""
|
339 |
return {
|
340 |
"embedding_model": "bkai-foundation-models/vietnamese-bi-encoder",
|
341 |
+
"llm_model": "unsloth/Qwen2.5-3B", # 7B, 1.5B
|
342 |
"chunk_size": 500,
|
343 |
"chunk_overlap": 50,
|
344 |
"retrieval_k": 3,
|
|
|
383 |
# Vietnamese typically has ~0.75 tokens per character
|
384 |
return int(len(text) * 0.75)
|
385 |
|
386 |
+
#? Parse Json String
|
387 |
def _extract_json_from_response(self, response: str) -> dict:
|
388 |
"""Extract JSON from LLM response with multiple fallback strategies"""
|
389 |
+
import re
|
390 |
|
391 |
# Strategy 1: Clean response of prompt repetition
|
392 |
clean_response = response
|
|
|
474 |
bnb_4bit_quant_type="nf4"
|
475 |
)
|
476 |
|
477 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
478 |
self.config["llm_model"],
|
479 |
quantization_config=bnb_config,
|
480 |
low_cpu_mem_usage=True,
|
481 |
+
device_map="auto",
|
482 |
token=hf_token
|
483 |
)
|
484 |
|
485 |
+
# tokenizer = AutoTokenizer.from_pretrained(self.config["llm_model"])
|
486 |
+
tokenizer.pad_token = tokenizer.eos_token
|
487 |
|
488 |
model_pipeline = pipeline(
|
489 |
"text-generation",
|
|
|
658 |
print(f"Raw response: {response[:500]}...")
|
659 |
raise ValueError(f"Failed to parse LLM response: {e}")
|
660 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
661 |
def generate_batch(self,
|
662 |
topics: List[str],
|
663 |
question_per_topic: int = 5,
|
|
|
671 |
if question_types is None:
|
672 |
question_types = [QuestionType.DEFINITION, QuestionType.APPLICATION]
|
673 |
|
674 |
+
mcqs = []
|
675 |
total_questions = len(topics) * question_per_topic
|
|
|
|
|
676 |
|
677 |
print(f"🎯 Generating {total_questions} MCQs...")
|
678 |
|
|
|
680 |
print(f"📝 Processing topic {i+1}/{len(topics)}: {topic}")
|
681 |
|
682 |
for j in range(question_per_topic):
|
683 |
+
try:
|
684 |
+
# Cycle through difficulties and question types
|
685 |
+
difficulty = difficulties[j % len(difficulties)]
|
686 |
+
question_type = question_types[j % len(question_types)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
687 |
|
688 |
+
mcq = self.generate_mcq(topic, difficulty, question_type)
|
689 |
+
mcqs.append(mcq)
|
|
|
|
|
|
|
690 |
|
691 |
+
print(f" ✅ Generated question {j+1}/{question_per_topic} "
|
692 |
+
f"(Quality: {mcq.confidence_score:.1f})")
|
693 |
+
|
694 |
+
except Exception as e:
|
695 |
+
print(f" ❌ Failed to generate question {j+1}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
696 |
|
697 |
+
print(f"🎉 Generated {len(mcqs)}/{total_questions} MCQs successfully")
|
698 |
return mcqs
|
699 |
|
700 |
def export_mcqs(self, mcqs: List[MCQQuestion], output_path: str):
|
fastapi_app.py
CHANGED
@@ -9,8 +9,9 @@ from fastapi import FastAPI, Form, HTTPException, UploadFile, File
|
|
9 |
from contextlib import asynccontextmanager
|
10 |
|
11 |
|
|
|
12 |
generator: Optional[EnhancedRAGMCQGenerator] = None
|
13 |
-
tmp_folder = "./tmp"
|
14 |
if not os.path.exists(tmp_folder):
|
15 |
os.makedirs(tmp_folder)
|
16 |
|
@@ -38,6 +39,7 @@ class GenerateResponse(BaseModel):
|
|
38 |
avg_confidence: float
|
39 |
generation_time: float
|
40 |
|
|
|
41 |
@asynccontextmanager
|
42 |
async def lifespan(app: FastAPI):
|
43 |
global generator
|
@@ -57,14 +59,14 @@ app = FastAPI(
|
|
57 |
lifespan=lifespan
|
58 |
)
|
59 |
|
60 |
-
#? cmd:
|
61 |
@app.post("/generate/")
|
62 |
async def mcq_gen(
|
63 |
file: UploadFile = File(...),
|
64 |
topics: str = Form(...),
|
65 |
n_questions: str = Form(...),
|
66 |
-
difficulty:
|
67 |
-
qtype:
|
68 |
):
|
69 |
if not generator:
|
70 |
raise HTTPException(status_code=500, detail="Generator not initialized")
|
@@ -73,19 +75,6 @@ async def mcq_gen(
|
|
73 |
if not topic_list:
|
74 |
raise HTTPException(status_code=400, detail="At least one topic must be provided")
|
75 |
|
76 |
-
# Validate and convert enum values
|
77 |
-
try:
|
78 |
-
difficulty_enum = DifficultyLevel(difficulty.lower())
|
79 |
-
except ValueError:
|
80 |
-
valid_difficulties = [d.value for d in DifficultyLevel]
|
81 |
-
raise HTTPException(status_code=400, detail=f"Invalid difficulty. Must be one of: {valid_difficulties}")
|
82 |
-
|
83 |
-
try:
|
84 |
-
qtype_enum = QuestionType(qtype.lower())
|
85 |
-
except ValueError:
|
86 |
-
valid_types = [q.value for q in QuestionType]
|
87 |
-
raise HTTPException(status_code=400, detail=f"Invalid question type. Must be one of: {valid_types}")
|
88 |
-
|
89 |
# Save uploaded PDF to temporary folder
|
90 |
filename = file.filename if file.filename else "uploaded_file"
|
91 |
file_path = os.path.join(tmp_folder, filename)
|
@@ -104,9 +93,7 @@ async def mcq_gen(
|
|
104 |
try:
|
105 |
mcqs = generator.generate_batch(
|
106 |
topics=topic_list,
|
107 |
-
question_per_topic=int(n_questions)
|
108 |
-
difficulties=[difficulty_enum],
|
109 |
-
question_types=[qtype_enum]
|
110 |
)
|
111 |
except Exception as e:
|
112 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
9 |
from contextlib import asynccontextmanager
|
10 |
|
11 |
|
12 |
+
|
13 |
generator: Optional[EnhancedRAGMCQGenerator] = None
|
14 |
+
tmp_folder = "./tmp"
|
15 |
if not os.path.exists(tmp_folder):
|
16 |
os.makedirs(tmp_folder)
|
17 |
|
|
|
39 |
avg_confidence: float
|
40 |
generation_time: float
|
41 |
|
42 |
+
|
43 |
@asynccontextmanager
|
44 |
async def lifespan(app: FastAPI):
|
45 |
global generator
|
|
|
59 |
lifespan=lifespan
|
60 |
)
|
61 |
|
62 |
+
#? cmd: uvicorn app:app --reload --reload-exclude unsloth_compiled_cache
|
63 |
@app.post("/generate/")
|
64 |
async def mcq_gen(
|
65 |
file: UploadFile = File(...),
|
66 |
topics: str = Form(...),
|
67 |
n_questions: str = Form(...),
|
68 |
+
difficulty: DifficultyLevel = Form(...),
|
69 |
+
qtype: QuestionType = Form(...)
|
70 |
):
|
71 |
if not generator:
|
72 |
raise HTTPException(status_code=500, detail="Generator not initialized")
|
|
|
75 |
if not topic_list:
|
76 |
raise HTTPException(status_code=400, detail="At least one topic must be provided")
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
# Save uploaded PDF to temporary folder
|
79 |
filename = file.filename if file.filename else "uploaded_file"
|
80 |
file_path = os.path.join(tmp_folder, filename)
|
|
|
93 |
try:
|
94 |
mcqs = generator.generate_batch(
|
95 |
topics=topic_list,
|
96 |
+
question_per_topic=int(n_questions)
|
|
|
|
|
97 |
)
|
98 |
except Exception as e:
|
99 |
raise HTTPException(status_code=500, detail=str(e))
|