Spaces:
Sleeping
Sleeping
DishaKushwah
commited on
Commit
·
cac9aa5
1
Parent(s):
c11256e
Created using Colab
Browse files- mcq_generator.ipynb +12 -80
mcq_generator.ipynb
CHANGED
@@ -5,7 +5,7 @@
|
|
5 |
"colab": {
|
6 |
"provenance": [],
|
7 |
"gpuType": "T4",
|
8 |
-
"authorship_tag": "
|
9 |
"include_colab_link": true
|
10 |
},
|
11 |
"kernelspec": {
|
@@ -79,12 +79,7 @@
|
|
79 |
" self.nlp = None\n",
|
80 |
"\n",
|
81 |
" # Load fill-mask pipeline for generating distractors\n",
|
82 |
-
" self.fill_mask = pipeline(\n",
|
83 |
-
" \"fill-mask\",\n",
|
84 |
-
" model=\"roberta-large\",\n",
|
85 |
-
" tokenizer=\"roberta-large\",\n",
|
86 |
-
" device=0 if torch.cuda.is_available() else -1\n",
|
87 |
-
" )\n",
|
88 |
"\n",
|
89 |
" # Download NLTK data\n",
|
90 |
" try:\n",
|
@@ -99,17 +94,11 @@
|
|
99 |
" return {\"entities\": [], \"noun_chunks\": [], \"sentences\": []}\n",
|
100 |
"\n",
|
101 |
" doc = self.nlp(text)\n",
|
102 |
-
"\n",
|
103 |
" # Extract named entities\n",
|
104 |
" entities = []\n",
|
105 |
" for ent in doc.ents:\n",
|
106 |
" if ent.label_ in ['PERSON', 'ORG', 'GPE', 'DATE', 'EVENT', 'WORK_OF_ART', 'CARDINAL', 'ORDINAL']:\n",
|
107 |
-
" entities.append({\n",
|
108 |
-
" 'text': ent.text,\n",
|
109 |
-
" 'label': ent.label_,\n",
|
110 |
-
" 'start': ent.start_char,\n",
|
111 |
-
" 'end': ent.end_char\n",
|
112 |
-
" })\n",
|
113 |
"\n",
|
114 |
" # Extract noun chunks\n",
|
115 |
" noun_chunks = [chunk.text for chunk in doc.noun_chunks if len(chunk.text.split()) <= 4]\n",
|
@@ -117,36 +106,17 @@
|
|
117 |
" # Extract sentences\n",
|
118 |
" sentences = [sent.text.strip() for sent in doc.sents if len(sent.text.split()) > 5]\n",
|
119 |
"\n",
|
120 |
-
" return {\n",
|
121 |
-
" \"entities\": entities,\n",
|
122 |
-
" \"noun_chunks\": noun_chunks,\n",
|
123 |
-
" \"sentences\": sentences\n",
|
124 |
-
" }\n",
|
125 |
"\n",
|
126 |
" def generate_question_from_context(self, context: str, answer_text: str) -> str:\n",
|
127 |
" \"\"\"Generate a question given context and answer.\"\"\"\n",
|
128 |
" # Highlight the answer in the context for T5\n",
|
129 |
" highlighted_context = context.replace(answer_text, f\"<hl>{answer_text}<hl>\")\n",
|
130 |
" input_text = f\"generate question: {highlighted_context}\"\n",
|
131 |
-
"\n",
|
132 |
-
" inputs = self.qg_tokenizer.encode_plus(\n",
|
133 |
-
" input_text,\n",
|
134 |
-
" max_length=512,\n",
|
135 |
-
" truncation=True,\n",
|
136 |
-
" padding=True,\n",
|
137 |
-
" return_tensors=\"pt\"\n",
|
138 |
-
" ).to(self.device)\n",
|
139 |
"\n",
|
140 |
" with torch.no_grad():\n",
|
141 |
-
" outputs = self.qg_model.generate(\n",
|
142 |
-
" inputs[\"input_ids\"],\n",
|
143 |
-
" attention_mask=inputs[\"attention_mask\"],\n",
|
144 |
-
" max_length=64,\n",
|
145 |
-
" num_beams=4,\n",
|
146 |
-
" temperature=0.8,\n",
|
147 |
-
" do_sample=True,\n",
|
148 |
-
" early_stopping=True\n",
|
149 |
-
" )\n",
|
150 |
"\n",
|
151 |
" question = self.qg_tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
|
152 |
" return question\n",
|
@@ -186,9 +156,7 @@
|
|
186 |
"\n",
|
187 |
" # Find similar entities\n",
|
188 |
" for ent in doc.ents:\n",
|
189 |
-
" if (ent.label_ == answer_label and
|
190 |
-
" ent.text != correct_answer and\n",
|
191 |
-
" ent.text not in distractors):\n",
|
192 |
" distractors.append(ent.text)\n",
|
193 |
" if len(distractors) >= num_distractors:\n",
|
194 |
" break\n",
|
@@ -226,7 +194,6 @@
|
|
226 |
" if d.lower() not in seen and d.lower() != correct_answer.lower():\n",
|
227 |
" seen.add(d.lower())\n",
|
228 |
" unique_distractors.append(d)\n",
|
229 |
-
"\n",
|
230 |
" return unique_distractors[:num_distractors]\n",
|
231 |
"\n",
|
232 |
" def validate_mcq_quality(self, question: str, correct_answer: str, distractors: List[str], context: str) -> Dict:\n",
|
@@ -242,7 +209,6 @@
|
|
242 |
" correct_embedding = self.sentence_model.encode([correct_answer])\n",
|
243 |
" predicted_embedding = self.sentence_model.encode([predicted_answer])\n",
|
244 |
" similarity = cosine_similarity(correct_embedding, predicted_embedding)[0][0]\n",
|
245 |
-
"\n",
|
246 |
" is_answerable = similarity > similarity_threshold or correct_answer.lower() in predicted_answer.lower()\n",
|
247 |
"\n",
|
248 |
" except:\n",
|
@@ -265,13 +231,7 @@
|
|
265 |
" distractor_quality = \"poor\"\n",
|
266 |
" avg_distractor_similarity = 0.0\n",
|
267 |
"\n",
|
268 |
-
" return {\n",
|
269 |
-
" \"is_answerable\": is_answerable,\n",
|
270 |
-
" \"confidence\": confidence,\n",
|
271 |
-
" \"answer_similarity\": similarity,\n",
|
272 |
-
" \"distractor_quality\": distractor_quality,\n",
|
273 |
-
" \"avg_distractor_similarity\": avg_distractor_similarity\n",
|
274 |
-
" }\n",
|
275 |
"\n",
|
276 |
" def generate_mcq(self, context: str, num_questions: int = 5) -> List[Dict]:\n",
|
277 |
" \"\"\"Generate multiple choice questions from context.\"\"\"\n",
|
@@ -304,12 +264,7 @@
|
|
304 |
"\n",
|
305 |
" mcq = {\n",
|
306 |
" \"question\": question,\n",
|
307 |
-
" \"options\": {\n",
|
308 |
-
" \"A\": options[0],\n",
|
309 |
-
" \"B\": options[1],\n",
|
310 |
-
" \"C\": options[2] if len(options) > 2 else \"None of the above\",\n",
|
311 |
-
" \"D\": options[3] if len(options) > 3 else \"All of the above\"\n",
|
312 |
-
" },\n",
|
313 |
" \"correct_answer\": correct_option,\n",
|
314 |
" \"correct_text\": correct_answer,\n",
|
315 |
" \"entity_type\": entity[\"label\"],\n",
|
@@ -337,12 +292,7 @@
|
|
337 |
"\n",
|
338 |
" mcq = {\n",
|
339 |
" \"question\": question,\n",
|
340 |
-
" \"options\": {\n",
|
341 |
-
" \"A\": options[0],\n",
|
342 |
-
" \"B\": options[1],\n",
|
343 |
-
" \"C\": options[2] if len(options) > 2 else \"None of the above\",\n",
|
344 |
-
" \"D\": options[3] if len(options) > 3 else \"All of the above\"\n",
|
345 |
-
" },\n",
|
346 |
" \"correct_answer\": correct_option,\n",
|
347 |
" \"correct_text\": chunk,\n",
|
348 |
" \"entity_type\": \"NOUN_CHUNK\",\n",
|
@@ -358,36 +308,18 @@
|
|
358 |
"def main():\n",
|
359 |
" \"\"\"Main function to demonstrate the MCQ generator.\"\"\"\n",
|
360 |
" generator = MultipleChoiceQuestionGenerator()\n",
|
361 |
-
"\n",
|
362 |
-
" # Sample context\n",
|
363 |
-
" sample_context = \"\"\"\n",
|
364 |
-
" The Renaissance was a period of cultural, artistic, political and economic rebirth following the Middle Ages.\n",
|
365 |
-
" It began in Italy in the 14th century and spread throughout Europe. Leonardo da Vinci, born in 1452, was one\n",
|
366 |
-
" of the most famous Renaissance artists and inventors. He created masterpieces like the Mona Lisa and The Last Supper.\n",
|
367 |
-
" Michelangelo, another renowned artist, painted the ceiling of the Sistine Chapel between 1508 and 1512.\n",
|
368 |
-
" The Renaissance emphasized humanism, scientific inquiry, and artistic innovation. The printing press,\n",
|
369 |
-
" invented by Johannes Gutenberg around 1440, helped spread Renaissance ideas across Europe.\n",
|
370 |
-
" This period lasted approximately 300 years, from the 14th to the 17th century.\n",
|
371 |
-
" \"\"\"\n",
|
372 |
-
"\n",
|
373 |
" print(\"Multiple Choice Question Generator\")\n",
|
374 |
"\n",
|
375 |
" # Get user input\n",
|
376 |
-
" user_context = input(\"Enter your context
|
377 |
-
" if not user_context:\n",
|
378 |
-
" user_context = sample_context\n",
|
379 |
-
" print(\"Using sample context about the Renaissance...\")\n",
|
380 |
-
"\n",
|
381 |
" try:\n",
|
382 |
" num_questions = int(input(\"Number of MCQs to generate (default 5): \") or \"5\")\n",
|
383 |
" except ValueError:\n",
|
384 |
" num_questions = 5\n",
|
385 |
-
"\n",
|
386 |
" print(f\"\\nGenerating {num_questions} multiple choice questions...\")\n",
|
387 |
"\n",
|
388 |
" # Generate MCQs\n",
|
389 |
" mcqs = generator.generate_mcq(user_context, num_questions)\n",
|
390 |
-
"\n",
|
391 |
" # Display results\n",
|
392 |
" if mcqs:\n",
|
393 |
" for i, mcq in enumerate(mcqs, 1):\n",
|
@@ -413,7 +345,7 @@
|
|
413 |
"id": "o1ic84jCGc-u",
|
414 |
"outputId": "82f5601f-0e8f-4ca7-d9a4-b514a58df793"
|
415 |
},
|
416 |
-
"execution_count":
|
417 |
"outputs": [
|
418 |
{
|
419 |
"output_type": "stream",
|
|
|
5 |
"colab": {
|
6 |
"provenance": [],
|
7 |
"gpuType": "T4",
|
8 |
+
"authorship_tag": "ABX9TyNlLgN36uc2PRyXWiLUUS03",
|
9 |
"include_colab_link": true
|
10 |
},
|
11 |
"kernelspec": {
|
|
|
79 |
" self.nlp = None\n",
|
80 |
"\n",
|
81 |
" # Load fill-mask pipeline for generating distractors\n",
|
82 |
+
" self.fill_mask = pipeline(\"fill-mask\",model=\"roberta-large\",tokenizer=\"roberta-large\",device=0 if torch.cuda.is_available() else -1)\n",
|
|
|
|
|
|
|
|
|
|
|
83 |
"\n",
|
84 |
" # Download NLTK data\n",
|
85 |
" try:\n",
|
|
|
94 |
" return {\"entities\": [], \"noun_chunks\": [], \"sentences\": []}\n",
|
95 |
"\n",
|
96 |
" doc = self.nlp(text)\n",
|
|
|
97 |
" # Extract named entities\n",
|
98 |
" entities = []\n",
|
99 |
" for ent in doc.ents:\n",
|
100 |
" if ent.label_ in ['PERSON', 'ORG', 'GPE', 'DATE', 'EVENT', 'WORK_OF_ART', 'CARDINAL', 'ORDINAL']:\n",
|
101 |
+
" entities.append({'text': ent.text,'label': ent.label_,'start': ent.start_char,'end': ent.end_char})\n",
|
|
|
|
|
|
|
|
|
|
|
102 |
"\n",
|
103 |
" # Extract noun chunks\n",
|
104 |
" noun_chunks = [chunk.text for chunk in doc.noun_chunks if len(chunk.text.split()) <= 4]\n",
|
|
|
106 |
" # Extract sentences\n",
|
107 |
" sentences = [sent.text.strip() for sent in doc.sents if len(sent.text.split()) > 5]\n",
|
108 |
"\n",
|
109 |
+
" return {\"entities\": entities,\"noun_chunks\": noun_chunks,\"sentences\": sentences}\n",
|
|
|
|
|
|
|
|
|
110 |
"\n",
|
111 |
" def generate_question_from_context(self, context: str, answer_text: str) -> str:\n",
|
112 |
" \"\"\"Generate a question given context and answer.\"\"\"\n",
|
113 |
" # Highlight the answer in the context for T5\n",
|
114 |
" highlighted_context = context.replace(answer_text, f\"<hl>{answer_text}<hl>\")\n",
|
115 |
" input_text = f\"generate question: {highlighted_context}\"\n",
|
116 |
+
" inputs = self.qg_tokenizer.encode_plus(input_text,max_length=512,truncation=True,padding=True,return_tensors=\"pt\").to(self.device)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
"\n",
|
118 |
" with torch.no_grad():\n",
|
119 |
+
" outputs = self.qg_model.generate(inputs[\"input_ids\"],attention_mask=inputs[\"attention_mask\"],max_length=64,num_beams=4,temperature=0.8,do_sample=True,early_stopping=True)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
"\n",
|
121 |
" question = self.qg_tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
|
122 |
" return question\n",
|
|
|
156 |
"\n",
|
157 |
" # Find similar entities\n",
|
158 |
" for ent in doc.ents:\n",
|
159 |
+
" if (ent.label_ == answer_label and ent.text != correct_answer and ent.text not in distractors):\n",
|
|
|
|
|
160 |
" distractors.append(ent.text)\n",
|
161 |
" if len(distractors) >= num_distractors:\n",
|
162 |
" break\n",
|
|
|
194 |
" if d.lower() not in seen and d.lower() != correct_answer.lower():\n",
|
195 |
" seen.add(d.lower())\n",
|
196 |
" unique_distractors.append(d)\n",
|
|
|
197 |
" return unique_distractors[:num_distractors]\n",
|
198 |
"\n",
|
199 |
" def validate_mcq_quality(self, question: str, correct_answer: str, distractors: List[str], context: str) -> Dict:\n",
|
|
|
209 |
" correct_embedding = self.sentence_model.encode([correct_answer])\n",
|
210 |
" predicted_embedding = self.sentence_model.encode([predicted_answer])\n",
|
211 |
" similarity = cosine_similarity(correct_embedding, predicted_embedding)[0][0]\n",
|
|
|
212 |
" is_answerable = similarity > similarity_threshold or correct_answer.lower() in predicted_answer.lower()\n",
|
213 |
"\n",
|
214 |
" except:\n",
|
|
|
231 |
" distractor_quality = \"poor\"\n",
|
232 |
" avg_distractor_similarity = 0.0\n",
|
233 |
"\n",
|
234 |
+
" return {\"is_answerable\": is_answerable,\"confidence\": confidence,\"answer_similarity\": similarity,\"distractor_quality\": distractor_quality,\"avg_distractor_similarity\": avg_distractor_similarity }\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
"\n",
|
236 |
" def generate_mcq(self, context: str, num_questions: int = 5) -> List[Dict]:\n",
|
237 |
" \"\"\"Generate multiple choice questions from context.\"\"\"\n",
|
|
|
264 |
"\n",
|
265 |
" mcq = {\n",
|
266 |
" \"question\": question,\n",
|
267 |
+
" \"options\": {\"A\": options[0],\"B\": options[1],\"C\": options[2] if len(options) > 2 else \"None of the above\",\"D\": options[3] if len(options) > 3 else \"All of the above\"},\n",
|
|
|
|
|
|
|
|
|
|
|
268 |
" \"correct_answer\": correct_option,\n",
|
269 |
" \"correct_text\": correct_answer,\n",
|
270 |
" \"entity_type\": entity[\"label\"],\n",
|
|
|
292 |
"\n",
|
293 |
" mcq = {\n",
|
294 |
" \"question\": question,\n",
|
295 |
+
" \"options\": {\"A\": options[0],\"B\": options[1],\"C\": options[2] if len(options) > 2 else \"None of the above\",\"D\": options[3] if len(options) > 3 else \"All of the above\"},\n",
|
|
|
|
|
|
|
|
|
|
|
296 |
" \"correct_answer\": correct_option,\n",
|
297 |
" \"correct_text\": chunk,\n",
|
298 |
" \"entity_type\": \"NOUN_CHUNK\",\n",
|
|
|
308 |
"def main():\n",
|
309 |
" \"\"\"Main function to demonstrate the MCQ generator.\"\"\"\n",
|
310 |
" generator = MultipleChoiceQuestionGenerator()\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
" print(\"Multiple Choice Question Generator\")\n",
|
312 |
"\n",
|
313 |
" # Get user input\n",
|
314 |
+
" user_context = input(\"Enter your context: \").strip()\n",
|
|
|
|
|
|
|
|
|
315 |
" try:\n",
|
316 |
" num_questions = int(input(\"Number of MCQs to generate (default 5): \") or \"5\")\n",
|
317 |
" except ValueError:\n",
|
318 |
" num_questions = 5\n",
|
|
|
319 |
" print(f\"\\nGenerating {num_questions} multiple choice questions...\")\n",
|
320 |
"\n",
|
321 |
" # Generate MCQs\n",
|
322 |
" mcqs = generator.generate_mcq(user_context, num_questions)\n",
|
|
|
323 |
" # Display results\n",
|
324 |
" if mcqs:\n",
|
325 |
" for i, mcq in enumerate(mcqs, 1):\n",
|
|
|
345 |
"id": "o1ic84jCGc-u",
|
346 |
"outputId": "82f5601f-0e8f-4ca7-d9a4-b514a58df793"
|
347 |
},
|
348 |
+
"execution_count": null,
|
349 |
"outputs": [
|
350 |
{
|
351 |
"output_type": "stream",
|