namberino
commited on
Commit
·
d405869
1
Parent(s):
f69f5bd
Revert
Browse files- enhanced_rag_mcq.py +113 -21
- fastapi_app.py +20 -7
- requirements.txt +0 -0
enhanced_rag_mcq.py
CHANGED
|
@@ -11,6 +11,7 @@ import os
|
|
| 11 |
import json
|
| 12 |
import time
|
| 13 |
import torch
|
|
|
|
| 14 |
from typing import List, Dict, Any, Optional, Tuple
|
| 15 |
from dataclasses import dataclass, asdict
|
| 16 |
from pathlib import Path
|
|
@@ -28,7 +29,6 @@ from langchain_core.prompts import PromptTemplate
|
|
| 28 |
from langchain_community.vectorstores import FAISS
|
| 29 |
|
| 30 |
from langchain_core.documents import Document
|
| 31 |
-
from unsloth import FastLanguageModel
|
| 32 |
|
| 33 |
# Transformers imports
|
| 34 |
from transformers import (
|
|
@@ -338,7 +338,7 @@ class EnhancedRAGMCQGenerator:
|
|
| 338 |
"""Get default configuration"""
|
| 339 |
return {
|
| 340 |
"embedding_model": "bkai-foundation-models/vietnamese-bi-encoder",
|
| 341 |
-
"llm_model": "
|
| 342 |
"chunk_size": 500,
|
| 343 |
"chunk_overlap": 50,
|
| 344 |
"retrieval_k": 3,
|
|
@@ -383,10 +383,9 @@ class EnhancedRAGMCQGenerator:
|
|
| 383 |
# Vietnamese typically has ~0.75 tokens per character
|
| 384 |
return int(len(text) * 0.75)
|
| 385 |
|
| 386 |
-
|
| 387 |
def _extract_json_from_response(self, response: str) -> dict:
|
| 388 |
"""Extract JSON from LLM response with multiple fallback strategies"""
|
| 389 |
-
import re
|
| 390 |
|
| 391 |
# Strategy 1: Clean response of prompt repetition
|
| 392 |
clean_response = response
|
|
@@ -474,16 +473,16 @@ class EnhancedRAGMCQGenerator:
|
|
| 474 |
bnb_4bit_quant_type="nf4"
|
| 475 |
)
|
| 476 |
|
| 477 |
-
model
|
| 478 |
self.config["llm_model"],
|
| 479 |
quantization_config=bnb_config,
|
| 480 |
low_cpu_mem_usage=True,
|
| 481 |
-
device_map="
|
| 482 |
token=hf_token
|
| 483 |
)
|
| 484 |
|
| 485 |
-
|
| 486 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 487 |
|
| 488 |
model_pipeline = pipeline(
|
| 489 |
"text-generation",
|
|
@@ -658,6 +657,43 @@ class EnhancedRAGMCQGenerator:
|
|
| 658 |
print(f"Raw response: {response[:500]}...")
|
| 659 |
raise ValueError(f"Failed to parse LLM response: {e}")
|
| 660 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
def generate_batch(self,
|
| 662 |
topics: List[str],
|
| 663 |
question_per_topic: int = 5,
|
|
@@ -671,8 +707,9 @@ class EnhancedRAGMCQGenerator:
|
|
| 671 |
if question_types is None:
|
| 672 |
question_types = [QuestionType.DEFINITION, QuestionType.APPLICATION]
|
| 673 |
|
| 674 |
-
mcqs = []
|
| 675 |
total_questions = len(topics) * question_per_topic
|
|
|
|
|
|
|
| 676 |
|
| 677 |
print(f"🎯 Generating {total_questions} MCQs...")
|
| 678 |
|
|
@@ -680,21 +717,76 @@ class EnhancedRAGMCQGenerator:
|
|
| 680 |
print(f"📝 Processing topic {i+1}/{len(topics)}: {topic}")
|
| 681 |
|
| 682 |
for j in range(question_per_topic):
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
|
| 691 |
-
|
| 692 |
-
|
|
|
|
|
|
|
|
|
|
| 693 |
|
| 694 |
-
|
| 695 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
|
| 697 |
-
print(f"🎉 Generated {len(mcqs)}/{
|
| 698 |
return mcqs
|
| 699 |
|
| 700 |
def export_mcqs(self, mcqs: List[MCQQuestion], output_path: str):
|
|
|
|
| 11 |
import json
|
| 12 |
import time
|
| 13 |
import torch
|
| 14 |
+
import re
|
| 15 |
from typing import List, Dict, Any, Optional, Tuple
|
| 16 |
from dataclasses import dataclass, asdict
|
| 17 |
from pathlib import Path
|
|
|
|
| 29 |
from langchain_community.vectorstores import FAISS
|
| 30 |
|
| 31 |
from langchain_core.documents import Document
|
|
|
|
| 32 |
|
| 33 |
# Transformers imports
|
| 34 |
from transformers import (
|
|
|
|
| 338 |
"""Get default configuration"""
|
| 339 |
return {
|
| 340 |
"embedding_model": "bkai-foundation-models/vietnamese-bi-encoder",
|
| 341 |
+
"llm_model": "Qwen/Qwen2.5-3B-Instruct", # 7B, 1.5B
|
| 342 |
"chunk_size": 500,
|
| 343 |
"chunk_overlap": 50,
|
| 344 |
"retrieval_k": 3,
|
|
|
|
| 383 |
# Vietnamese typically has ~0.75 tokens per character
|
| 384 |
return int(len(text) * 0.75)
|
| 385 |
|
| 386 |
+
#? Parse Json String
|
| 387 |
def _extract_json_from_response(self, response: str) -> dict:
|
| 388 |
"""Extract JSON from LLM response with multiple fallback strategies"""
|
|
|
|
| 389 |
|
| 390 |
# Strategy 1: Clean response of prompt repetition
|
| 391 |
clean_response = response
|
|
|
|
| 473 |
bnb_4bit_quant_type="nf4"
|
| 474 |
)
|
| 475 |
|
| 476 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 477 |
self.config["llm_model"],
|
| 478 |
quantization_config=bnb_config,
|
| 479 |
low_cpu_mem_usage=True,
|
| 480 |
+
device_map="cuda", # Use CUDA if available
|
| 481 |
token=hf_token
|
| 482 |
)
|
| 483 |
|
| 484 |
+
tokenizer = AutoTokenizer.from_pretrained(self.config["llm_model"])
|
| 485 |
+
# tokenizer.pad_token = tokenizer.eos_token
|
| 486 |
|
| 487 |
model_pipeline = pipeline(
|
| 488 |
"text-generation",
|
|
|
|
| 657 |
print(f"Raw response: {response[:500]}...")
|
| 658 |
raise ValueError(f"Failed to parse LLM response: {e}")
|
| 659 |
|
| 660 |
+
def _batch_invoke(self, prompts: List[str]) -> List[str]:
|
| 661 |
+
if not prompts:
|
| 662 |
+
return []
|
| 663 |
+
|
| 664 |
+
# Try to use transformers pipeline (batch mode)
|
| 665 |
+
pl = getattr(self.llm, "pipeline", None)
|
| 666 |
+
if pl is not None:
|
| 667 |
+
try:
|
| 668 |
+
# Call the pipeline with a list. Transformers will return a list of generation outputs.
|
| 669 |
+
raw_outputs = pl(prompts)
|
| 670 |
+
|
| 671 |
+
responses = []
|
| 672 |
+
for out in raw_outputs:
|
| 673 |
+
# The pipeline may return either a dict (single result) or a list of dicts (if return_full_text or num_return_sequences was set)
|
| 674 |
+
if isinstance(out, list) and out:
|
| 675 |
+
text = out[0].get("generated_text", "")
|
| 676 |
+
elif isinstance(out, dict):
|
| 677 |
+
text = out.get("generated_text", "")
|
| 678 |
+
else:
|
| 679 |
+
# fallback: coerce to string
|
| 680 |
+
text = str(out)
|
| 681 |
+
responses.append(text)
|
| 682 |
+
|
| 683 |
+
if len(responses) == len(prompts):
|
| 684 |
+
return responses
|
| 685 |
+
else:
|
| 686 |
+
print("⚠️ Batch pipeline returned unexpected shape — falling back")
|
| 687 |
+
except Exception as e:
|
| 688 |
+
# Batch mode failed. Fall back to sequential invocations.
|
| 689 |
+
print(f"⚠️ Batch invoke failed: {e}. Falling back to sequential.")
|
| 690 |
+
|
| 691 |
+
# Sequential invocation to preserve behavior
|
| 692 |
+
results = []
|
| 693 |
+
for p in prompts:
|
| 694 |
+
results.append(self.llm.invoke(p))
|
| 695 |
+
return results
|
| 696 |
+
|
| 697 |
def generate_batch(self,
|
| 698 |
topics: List[str],
|
| 699 |
question_per_topic: int = 5,
|
|
|
|
| 707 |
if question_types is None:
|
| 708 |
question_types = [QuestionType.DEFINITION, QuestionType.APPLICATION]
|
| 709 |
|
|
|
|
| 710 |
total_questions = len(topics) * question_per_topic
|
| 711 |
+
prompt_metadatas = [] # stores tuples (topic, difficulty, question_type)
|
| 712 |
+
formatted_prompts = []
|
| 713 |
|
| 714 |
print(f"🎯 Generating {total_questions} MCQs...")
|
| 715 |
|
|
|
|
| 717 |
print(f"📝 Processing topic {i+1}/{len(topics)}: {topic}")
|
| 718 |
|
| 719 |
for j in range(question_per_topic):
|
| 720 |
+
difficulty = difficulties[j % len(difficulties)]
|
| 721 |
+
question_type = question_types[j % len(question_types)]
|
| 722 |
+
|
| 723 |
+
query = topic
|
| 724 |
+
|
| 725 |
+
# retrieve context once per prompt
|
| 726 |
+
contexts = self.retriever.retrieve_diverse_contexts(query, k=self.config.get("k", 5)) if hasattr(self, "retriever") else []
|
| 727 |
+
context_text = "\n\n".join([d.page_content for d in contexts]) if contexts else topic
|
| 728 |
+
|
| 729 |
+
# Build prompt with PromptTemplateManager
|
| 730 |
+
prompt_template = self.template_manager.get_template(question_type)
|
| 731 |
+
prompt_input = {
|
| 732 |
+
"context": context_text,
|
| 733 |
+
"topic": topic,
|
| 734 |
+
"difficulty": difficulty.value if hasattr(difficulty, 'value') else str(difficulty),
|
| 735 |
+
"question_type": question_type.value if hasattr(question_type, 'value') else str(question_type)
|
| 736 |
+
}
|
| 737 |
+
formatted = prompt_template.format(**prompt_input)
|
| 738 |
+
|
| 739 |
+
# Length check and fallback
|
| 740 |
+
is_safe, token_count = self.check_prompt_length(formatted)
|
| 741 |
+
if not is_safe:
|
| 742 |
+
truncated = formatted[: self.config.get("max_prompt_chars", 2000)]
|
| 743 |
+
formatted = truncated
|
| 744 |
+
|
| 745 |
+
prompt_metadatas.append((topic, difficulty, question_type))
|
| 746 |
+
formatted_prompts.append(formatted)
|
| 747 |
+
|
| 748 |
+
total = len(formatted_prompts)
|
| 749 |
+
if total == 0:
|
| 750 |
+
return []
|
| 751 |
|
| 752 |
+
print(f"📦 Sending {total} prompts to the LLM in batch mode (if supported)")
|
| 753 |
+
start_t = time.time()
|
| 754 |
+
raw_responses = self._batch_invoke(formatted_prompts)
|
| 755 |
+
elapsed = time.time() - start_t
|
| 756 |
+
print(f"⏱ LLM batch time: {elapsed:.2f}s for {total} prompts")
|
| 757 |
|
| 758 |
+
# Parse raw responses back into MCQQuestion objects
|
| 759 |
+
mcqs = []
|
| 760 |
+
for meta, response in zip(prompt_metadatas, raw_responses):
|
| 761 |
+
topic, difficulty, question_type = meta
|
| 762 |
+
try:
|
| 763 |
+
response_data = self._extract_json_from_response(response)
|
| 764 |
+
|
| 765 |
+
# Reconstruct MCQQuestion
|
| 766 |
+
options = []
|
| 767 |
+
for label, text in response_data["options"].items():
|
| 768 |
+
is_correct = label == response_data["correct_answer"]
|
| 769 |
+
options.append(self.MCQOption(label=label, text=text, is_correct=is_correct))
|
| 770 |
+
|
| 771 |
+
mcq = self.MCQQuestion(
|
| 772 |
+
question_id=response_data.get("id", None),
|
| 773 |
+
topic=topic,
|
| 774 |
+
question_text=response_data["question"],
|
| 775 |
+
options=options,
|
| 776 |
+
explanation=response_data.get("explanation", ""),
|
| 777 |
+
difficulty=(difficulty if hasattr(difficulty, 'name') else difficulty),
|
| 778 |
+
question_type=(question_type if hasattr(question_type, 'name') else question_type),
|
| 779 |
+
confidence_score=response_data.get("confidence_score", 0.0)
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
if hasattr(self, 'validator'):
|
| 783 |
+
mcq = self.validator.calculate_quality_score(mcq)
|
| 784 |
+
|
| 785 |
+
mcqs.append(mcq)
|
| 786 |
+
except Exception as e:
|
| 787 |
+
print(f"❌ Failed parsing response for topic={topic}: {e}")
|
| 788 |
|
| 789 |
+
print(f"🎉 Generated {len(mcqs)}/{total} MCQs successfully (batched)")
|
| 790 |
return mcqs
|
| 791 |
|
| 792 |
def export_mcqs(self, mcqs: List[MCQQuestion], output_path: str):
|
fastapi_app.py
CHANGED
|
@@ -9,9 +9,8 @@ from fastapi import FastAPI, Form, HTTPException, UploadFile, File
|
|
| 9 |
from contextlib import asynccontextmanager
|
| 10 |
|
| 11 |
|
| 12 |
-
|
| 13 |
generator: Optional[EnhancedRAGMCQGenerator] = None
|
| 14 |
-
tmp_folder = "./tmp"
|
| 15 |
if not os.path.exists(tmp_folder):
|
| 16 |
os.makedirs(tmp_folder)
|
| 17 |
|
|
@@ -39,7 +38,6 @@ class GenerateResponse(BaseModel):
|
|
| 39 |
avg_confidence: float
|
| 40 |
generation_time: float
|
| 41 |
|
| 42 |
-
|
| 43 |
@asynccontextmanager
|
| 44 |
async def lifespan(app: FastAPI):
|
| 45 |
global generator
|
|
@@ -59,14 +57,14 @@ app = FastAPI(
|
|
| 59 |
lifespan=lifespan
|
| 60 |
)
|
| 61 |
|
| 62 |
-
#? cmd:
|
| 63 |
@app.post("/generate/")
|
| 64 |
async def mcq_gen(
|
| 65 |
file: UploadFile = File(...),
|
| 66 |
topics: str = Form(...),
|
| 67 |
n_questions: str = Form(...),
|
| 68 |
-
difficulty:
|
| 69 |
-
qtype:
|
| 70 |
):
|
| 71 |
if not generator:
|
| 72 |
raise HTTPException(status_code=500, detail="Generator not initialized")
|
|
@@ -75,6 +73,19 @@ async def mcq_gen(
|
|
| 75 |
if not topic_list:
|
| 76 |
raise HTTPException(status_code=400, detail="At least one topic must be provided")
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
# Save uploaded PDF to temporary folder
|
| 79 |
filename = file.filename if file.filename else "uploaded_file"
|
| 80 |
file_path = os.path.join(tmp_folder, filename)
|
|
@@ -93,7 +104,9 @@ async def mcq_gen(
|
|
| 93 |
try:
|
| 94 |
mcqs = generator.generate_batch(
|
| 95 |
topics=topic_list,
|
| 96 |
-
question_per_topic=int(n_questions)
|
|
|
|
|
|
|
| 97 |
)
|
| 98 |
except Exception as e:
|
| 99 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
| 9 |
from contextlib import asynccontextmanager
|
| 10 |
|
| 11 |
|
|
|
|
| 12 |
generator: Optional[EnhancedRAGMCQGenerator] = None
|
| 13 |
+
tmp_folder = "./tmp" #? make sure folder upload here
|
| 14 |
if not os.path.exists(tmp_folder):
|
| 15 |
os.makedirs(tmp_folder)
|
| 16 |
|
|
|
|
| 38 |
avg_confidence: float
|
| 39 |
generation_time: float
|
| 40 |
|
|
|
|
| 41 |
@asynccontextmanager
|
| 42 |
async def lifespan(app: FastAPI):
|
| 43 |
global generator
|
|
|
|
| 57 |
lifespan=lifespan
|
| 58 |
)
|
| 59 |
|
| 60 |
+
#? cmd: fastapi run app.py
|
| 61 |
@app.post("/generate/")
|
| 62 |
async def mcq_gen(
|
| 63 |
file: UploadFile = File(...),
|
| 64 |
topics: str = Form(...),
|
| 65 |
n_questions: str = Form(...),
|
| 66 |
+
difficulty: str = Form(...),
|
| 67 |
+
qtype: str = Form(...)
|
| 68 |
):
|
| 69 |
if not generator:
|
| 70 |
raise HTTPException(status_code=500, detail="Generator not initialized")
|
|
|
|
| 73 |
if not topic_list:
|
| 74 |
raise HTTPException(status_code=400, detail="At least one topic must be provided")
|
| 75 |
|
| 76 |
+
# Validate and convert enum values
|
| 77 |
+
try:
|
| 78 |
+
difficulty_enum = DifficultyLevel(difficulty.lower())
|
| 79 |
+
except ValueError:
|
| 80 |
+
valid_difficulties = [d.value for d in DifficultyLevel]
|
| 81 |
+
raise HTTPException(status_code=400, detail=f"Invalid difficulty. Must be one of: {valid_difficulties}")
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
qtype_enum = QuestionType(qtype.lower())
|
| 85 |
+
except ValueError:
|
| 86 |
+
valid_types = [q.value for q in QuestionType]
|
| 87 |
+
raise HTTPException(status_code=400, detail=f"Invalid question type. Must be one of: {valid_types}")
|
| 88 |
+
|
| 89 |
# Save uploaded PDF to temporary folder
|
| 90 |
filename = file.filename if file.filename else "uploaded_file"
|
| 91 |
file_path = os.path.join(tmp_folder, filename)
|
|
|
|
| 104 |
try:
|
| 105 |
mcqs = generator.generate_batch(
|
| 106 |
topics=topic_list,
|
| 107 |
+
question_per_topic=int(n_questions),
|
| 108 |
+
difficulties=[difficulty_enum],
|
| 109 |
+
question_types=[qtype_enum]
|
| 110 |
)
|
| 111 |
except Exception as e:
|
| 112 |
raise HTTPException(status_code=500, detail=str(e))
|
requirements.txt
CHANGED
|
Binary files a/requirements.txt and b/requirements.txt differ
|
|
|