namberino commited on
Commit
53e5542
·
1 Parent(s): d0a9bfe

Testing old version

Browse files
Files changed (2) hide show
  1. enhanced_rag_mcq.py +21 -113
  2. fastapi_app.py +7 -20
enhanced_rag_mcq.py CHANGED
@@ -11,7 +11,6 @@ import os
11
  import json
12
  import time
13
  import torch
14
- import re
15
  from typing import List, Dict, Any, Optional, Tuple
16
  from dataclasses import dataclass, asdict
17
  from pathlib import Path
@@ -29,6 +28,7 @@ from langchain_core.prompts import PromptTemplate
29
  from langchain_community.vectorstores import FAISS
30
 
31
  from langchain_core.documents import Document
 
32
 
33
  # Transformers imports
34
  from transformers import (
@@ -338,7 +338,7 @@ class EnhancedRAGMCQGenerator:
338
  """Get default configuration"""
339
  return {
340
  "embedding_model": "bkai-foundation-models/vietnamese-bi-encoder",
341
- "llm_model": "Qwen/Qwen2.5-3B-Instruct", # 7B, 1.5B
342
  "chunk_size": 500,
343
  "chunk_overlap": 50,
344
  "retrieval_k": 3,
@@ -383,9 +383,10 @@ class EnhancedRAGMCQGenerator:
383
  # Vietnamese typically has ~0.75 tokens per character
384
  return int(len(text) * 0.75)
385
 
386
- #? Parse Json String
387
  def _extract_json_from_response(self, response: str) -> dict:
388
  """Extract JSON from LLM response with multiple fallback strategies"""
 
389
 
390
  # Strategy 1: Clean response of prompt repetition
391
  clean_response = response
@@ -473,16 +474,16 @@ class EnhancedRAGMCQGenerator:
473
  bnb_4bit_quant_type="nf4"
474
  )
475
 
476
- model = AutoModelForCausalLM.from_pretrained(
477
  self.config["llm_model"],
478
  quantization_config=bnb_config,
479
  low_cpu_mem_usage=True,
480
- device_map="cuda", # Use CUDA if available
481
  token=hf_token
482
  )
483
 
484
- tokenizer = AutoTokenizer.from_pretrained(self.config["llm_model"])
485
- # tokenizer.pad_token = tokenizer.eos_token
486
 
487
  model_pipeline = pipeline(
488
  "text-generation",
@@ -657,43 +658,6 @@ class EnhancedRAGMCQGenerator:
657
  print(f"Raw response: {response[:500]}...")
658
  raise ValueError(f"Failed to parse LLM response: {e}")
659
 
660
- def _batch_invoke(self, prompts: List[str]) -> List[str]:
661
- if not prompts:
662
- return []
663
-
664
- # Try to use transformers pipeline (batch mode)
665
- pl = getattr(self.llm, "pipeline", None)
666
- if pl is not None:
667
- try:
668
- # Call the pipeline with a list. Transformers will return a list of generation outputs.
669
- raw_outputs = pl(prompts)
670
-
671
- responses = []
672
- for out in raw_outputs:
673
- # The pipeline may return either a dict (single result) or a list of dicts (if return_full_text or num_return_sequences was set)
674
- if isinstance(out, list) and out:
675
- text = out[0].get("generated_text", "")
676
- elif isinstance(out, dict):
677
- text = out.get("generated_text", "")
678
- else:
679
- # fallback: coerce to string
680
- text = str(out)
681
- responses.append(text)
682
-
683
- if len(responses) == len(prompts):
684
- return responses
685
- else:
686
- print("⚠️ Batch pipeline returned unexpected shape — falling back")
687
- except Exception as e:
688
- # Batch mode failed. Fall back to sequential invocations.
689
- print(f"⚠️ Batch invoke failed: {e}. Falling back to sequential.")
690
-
691
- # Sequential invocation to preserve behavior
692
- results = []
693
- for p in prompts:
694
- results.append(self.llm.invoke(p))
695
- return results
696
-
697
  def generate_batch(self,
698
  topics: List[str],
699
  question_per_topic: int = 5,
@@ -707,9 +671,8 @@ class EnhancedRAGMCQGenerator:
707
  if question_types is None:
708
  question_types = [QuestionType.DEFINITION, QuestionType.APPLICATION]
709
 
 
710
  total_questions = len(topics) * question_per_topic
711
- prompt_metadatas = [] # stores tuples (topic, difficulty, question_type)
712
- formatted_prompts = []
713
 
714
  print(f"🎯 Generating {total_questions} MCQs...")
715
 
@@ -717,76 +680,21 @@ class EnhancedRAGMCQGenerator:
717
  print(f"📝 Processing topic {i+1}/{len(topics)}: {topic}")
718
 
719
  for j in range(question_per_topic):
720
- difficulty = difficulties[j % len(difficulties)]
721
- question_type = question_types[j % len(question_types)]
722
-
723
- query = topic
724
-
725
- # retrieve context once per prompt
726
- contexts = self.retriever.retrieve_diverse_contexts(query, k=self.config.get("k", 5)) if hasattr(self, "retriever") else []
727
- context_text = "\n\n".join([d.page_content for d in contexts]) if contexts else topic
728
-
729
- # Build prompt with PromptTemplateManager
730
- prompt_template = self.template_manager.get_template(question_type)
731
- prompt_input = {
732
- "context": context_text,
733
- "topic": topic,
734
- "difficulty": difficulty.value if hasattr(difficulty, 'value') else str(difficulty),
735
- "question_type": question_type.value if hasattr(question_type, 'value') else str(question_type)
736
- }
737
- formatted = prompt_template.format(**prompt_input)
738
-
739
- # Length check and fallback
740
- is_safe, token_count = self.check_prompt_length(formatted)
741
- if not is_safe:
742
- truncated = formatted[: self.config.get("max_prompt_chars", 2000)]
743
- formatted = truncated
744
-
745
- prompt_metadatas.append((topic, difficulty, question_type))
746
- formatted_prompts.append(formatted)
747
-
748
- total = len(formatted_prompts)
749
- if total == 0:
750
- return []
751
 
752
- print(f"📦 Sending {total} prompts to the LLM in batch mode (if supported)")
753
- start_t = time.time()
754
- raw_responses = self._batch_invoke(formatted_prompts)
755
- elapsed = time.time() - start_t
756
- print(f"⏱ LLM batch time: {elapsed:.2f}s for {total} prompts")
757
 
758
- # Parse raw responses back into MCQQuestion objects
759
- mcqs = []
760
- for meta, response in zip(prompt_metadatas, raw_responses):
761
- topic, difficulty, question_type = meta
762
- try:
763
- response_data = self._extract_json_from_response(response)
764
-
765
- # Reconstruct MCQQuestion
766
- options = []
767
- for label, text in response_data["options"].items():
768
- is_correct = label == response_data["correct_answer"]
769
- options.append(self.MCQOption(label=label, text=text, is_correct=is_correct))
770
-
771
- mcq = self.MCQQuestion(
772
- question_id=response_data.get("id", None),
773
- topic=topic,
774
- question_text=response_data["question"],
775
- options=options,
776
- explanation=response_data.get("explanation", ""),
777
- difficulty=(difficulty if hasattr(difficulty, 'name') else difficulty),
778
- question_type=(question_type if hasattr(question_type, 'name') else question_type),
779
- confidence_score=response_data.get("confidence_score", 0.0)
780
- )
781
-
782
- if hasattr(self, 'validator'):
783
- mcq = self.validator.calculate_quality_score(mcq)
784
-
785
- mcqs.append(mcq)
786
- except Exception as e:
787
- print(f"❌ Failed parsing response for topic={topic}: {e}")
788
 
789
- print(f"🎉 Generated {len(mcqs)}/{total} MCQs successfully (batched)")
790
  return mcqs
791
 
792
  def export_mcqs(self, mcqs: List[MCQQuestion], output_path: str):
 
11
  import json
12
  import time
13
  import torch
 
14
  from typing import List, Dict, Any, Optional, Tuple
15
  from dataclasses import dataclass, asdict
16
  from pathlib import Path
 
28
  from langchain_community.vectorstores import FAISS
29
 
30
  from langchain_core.documents import Document
31
+ from unsloth import FastLanguageModel
32
 
33
  # Transformers imports
34
  from transformers import (
 
338
  """Get default configuration"""
339
  return {
340
  "embedding_model": "bkai-foundation-models/vietnamese-bi-encoder",
341
+ "llm_model": "unsloth/Qwen2.5-3B", # 7B, 1.5B
342
  "chunk_size": 500,
343
  "chunk_overlap": 50,
344
  "retrieval_k": 3,
 
383
  # Vietnamese typically has ~0.75 tokens per character
384
  return int(len(text) * 0.75)
385
 
386
+ #? Parse Json String
387
  def _extract_json_from_response(self, response: str) -> dict:
388
  """Extract JSON from LLM response with multiple fallback strategies"""
389
+ import re
390
 
391
  # Strategy 1: Clean response of prompt repetition
392
  clean_response = response
 
474
  bnb_4bit_quant_type="nf4"
475
  )
476
 
477
+ model, tokenizer = FastLanguageModel.from_pretrained(
478
  self.config["llm_model"],
479
  quantization_config=bnb_config,
480
  low_cpu_mem_usage=True,
481
+ device_map="auto",
482
  token=hf_token
483
  )
484
 
485
+ # tokenizer = AutoTokenizer.from_pretrained(self.config["llm_model"])
486
+ tokenizer.pad_token = tokenizer.eos_token
487
 
488
  model_pipeline = pipeline(
489
  "text-generation",
 
658
  print(f"Raw response: {response[:500]}...")
659
  raise ValueError(f"Failed to parse LLM response: {e}")
660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
661
  def generate_batch(self,
662
  topics: List[str],
663
  question_per_topic: int = 5,
 
671
  if question_types is None:
672
  question_types = [QuestionType.DEFINITION, QuestionType.APPLICATION]
673
 
674
+ mcqs = []
675
  total_questions = len(topics) * question_per_topic
 
 
676
 
677
  print(f"🎯 Generating {total_questions} MCQs...")
678
 
 
680
  print(f"📝 Processing topic {i+1}/{len(topics)}: {topic}")
681
 
682
  for j in range(question_per_topic):
683
+ try:
684
+ # Cycle through difficulties and question types
685
+ difficulty = difficulties[j % len(difficulties)]
686
+ question_type = question_types[j % len(question_types)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
687
 
688
+ mcq = self.generate_mcq(topic, difficulty, question_type)
689
+ mcqs.append(mcq)
 
 
 
690
 
691
+ print(f" ✅ Generated question {j+1}/{question_per_topic} "
692
+ f"(Quality: {mcq.confidence_score:.1f})")
693
+
694
+ except Exception as e:
695
+ print(f" ❌ Failed to generate question {j+1}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696
 
697
+ print(f"🎉 Generated {len(mcqs)}/{total_questions} MCQs successfully")
698
  return mcqs
699
 
700
  def export_mcqs(self, mcqs: List[MCQQuestion], output_path: str):
fastapi_app.py CHANGED
@@ -9,8 +9,9 @@ from fastapi import FastAPI, Form, HTTPException, UploadFile, File
9
  from contextlib import asynccontextmanager
10
 
11
 
 
12
  generator: Optional[EnhancedRAGMCQGenerator] = None
13
- tmp_folder = "./tmp" #? make sure folder upload here
14
  if not os.path.exists(tmp_folder):
15
  os.makedirs(tmp_folder)
16
 
@@ -38,6 +39,7 @@ class GenerateResponse(BaseModel):
38
  avg_confidence: float
39
  generation_time: float
40
 
 
41
  @asynccontextmanager
42
  async def lifespan(app: FastAPI):
43
  global generator
@@ -57,14 +59,14 @@ app = FastAPI(
57
  lifespan=lifespan
58
  )
59
 
60
- #? cmd: fastapi run app.py
61
  @app.post("/generate/")
62
  async def mcq_gen(
63
  file: UploadFile = File(...),
64
  topics: str = Form(...),
65
  n_questions: str = Form(...),
66
- difficulty: str = Form(...),
67
- qtype: str = Form(...)
68
  ):
69
  if not generator:
70
  raise HTTPException(status_code=500, detail="Generator not initialized")
@@ -73,19 +75,6 @@ async def mcq_gen(
73
  if not topic_list:
74
  raise HTTPException(status_code=400, detail="At least one topic must be provided")
75
 
76
- # Validate and convert enum values
77
- try:
78
- difficulty_enum = DifficultyLevel(difficulty.lower())
79
- except ValueError:
80
- valid_difficulties = [d.value for d in DifficultyLevel]
81
- raise HTTPException(status_code=400, detail=f"Invalid difficulty. Must be one of: {valid_difficulties}")
82
-
83
- try:
84
- qtype_enum = QuestionType(qtype.lower())
85
- except ValueError:
86
- valid_types = [q.value for q in QuestionType]
87
- raise HTTPException(status_code=400, detail=f"Invalid question type. Must be one of: {valid_types}")
88
-
89
  # Save uploaded PDF to temporary folder
90
  filename = file.filename if file.filename else "uploaded_file"
91
  file_path = os.path.join(tmp_folder, filename)
@@ -104,9 +93,7 @@ async def mcq_gen(
104
  try:
105
  mcqs = generator.generate_batch(
106
  topics=topic_list,
107
- question_per_topic=int(n_questions),
108
- difficulties=[difficulty_enum],
109
- question_types=[qtype_enum]
110
  )
111
  except Exception as e:
112
  raise HTTPException(status_code=500, detail=str(e))
 
9
  from contextlib import asynccontextmanager
10
 
11
 
12
+
13
  generator: Optional[EnhancedRAGMCQGenerator] = None
14
+ tmp_folder = "./tmp"
15
  if not os.path.exists(tmp_folder):
16
  os.makedirs(tmp_folder)
17
 
 
39
  avg_confidence: float
40
  generation_time: float
41
 
42
+
43
  @asynccontextmanager
44
  async def lifespan(app: FastAPI):
45
  global generator
 
59
  lifespan=lifespan
60
  )
61
 
62
+ #? cmd: uvicorn app:app --reload --reload-exclude unsloth_compiled_cache
63
  @app.post("/generate/")
64
  async def mcq_gen(
65
  file: UploadFile = File(...),
66
  topics: str = Form(...),
67
  n_questions: str = Form(...),
68
+ difficulty: DifficultyLevel = Form(...),
69
+ qtype: QuestionType = Form(...)
70
  ):
71
  if not generator:
72
  raise HTTPException(status_code=500, detail="Generator not initialized")
 
75
  if not topic_list:
76
  raise HTTPException(status_code=400, detail="At least one topic must be provided")
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  # Save uploaded PDF to temporary folder
79
  filename = file.filename if file.filename else "uploaded_file"
80
  file_path = os.path.join(tmp_folder, filename)
 
93
  try:
94
  mcqs = generator.generate_batch(
95
  topics=topic_list,
96
+ question_per_topic=int(n_questions)
 
 
97
  )
98
  except Exception as e:
99
  raise HTTPException(status_code=500, detail=str(e))