DishaKushwah commited on
Commit
0d20fc5
·
1 Parent(s): 4ef93f3

Create short_answer_generator.py

Browse files
Files changed (1) hide show
  1. short_answer_generator.py +97 -0
short_answer_generator.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
4
+
5
+ class QuestionGenerator:
6
+ def __init__(self, model_name='deepset/roberta-base-squad2'):
7
+ """
8
+ Initialize question generation system
9
+ """
10
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu' # Detect and set device
11
+ self.model = AutoModelForQuestionAnswering.from_pretrained(model_name) # Load model and tokenizer
12
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+
14
+ # Create QA pipeline
15
+ self.qa_pipeline = pipeline('question-answering', model=self.model, tokenizer=self.tokenizer,device=0 if self.device == 'cuda' else -1)
16
+
17
+ # Question templates
18
+ self.question_templates = ["What is the main idea of","Who is responsible for","When did this occur","Where does this take place","Why is this important","How does this work","What are the key features of","Explain the significance of","What is the purpose of","Describe the process of"]
19
+
20
+ def generate_questions(self, context, num_questions=3, difficulty='medium'):
21
+ """
22
+ Generate multiple questions from context
23
+ """
24
+ generated_questions = []
25
+ attempts = 0
26
+ max_attempts = num_questions * 10
27
+
28
+ while len(generated_questions) < num_questions and attempts < max_attempts:
29
+ try:
30
+ template = random.choice(self.question_templates) # Select random template
31
+ words = context.split() # Create question
32
+ start_index = random.randint(0, max(0, len(words) - 5))
33
+ full_question = f"{template} {' '.join(words[start_index:start_index+5])}?"
34
+ result = self.qa_pipeline(question=full_question, context=context) # Get answer
35
+
36
+ # Validate result
37
+ if (result['answer'] and len(result['answer']) > 3 and result['score'] > 0.5 and not any(q['answer'] == result['answer'] for q in generated_questions)):
38
+ generated_questions.append({'question': full_question,'answer': result['answer'],'confidence': result['score']})
39
+ attempts += 1
40
+
41
+ except Exception as e:
42
+ print(f"Question generation error: {e}")
43
+ attempts += 1
44
+ return generated_questions
45
+
46
+ def display_questions(self, questions):
47
+ """
48
+ Display generated questions
49
+ """
50
+ print("\n--- Generated Questions ---")
51
+ for idx, q in enumerate(questions, 1):
52
+ print(f"Q{idx}: {q['question']}")
53
+ print(f"Expected keyword: {q['answer']} \n")
54
+
55
+ def get_user_input():
56
+ """
57
+ Get user input for question generation
58
+ """
59
+ print("\n--- Interactive Question Generator ---")
60
+ print("\n>> Enter the context for question generation: ") # Context input
61
+ context = input().strip()
62
+
63
+ # Number of questions
64
+ while True:
65
+ try:
66
+ num_questions = int(input("\n>> How many questions do you want? (1-10): "))
67
+ if 1 <= num_questions <= 10:
68
+ break
69
+ else:
70
+ print("Please enter a number between 1 and 10.")
71
+ except ValueError:
72
+ print("Invalid input. Please enter a number.")
73
+ return context, num_questions
74
+
75
+ def main():
76
+ # Initialize generator
77
+ generator = QuestionGenerator()
78
+
79
+ while True:
80
+ try:
81
+ context, num_questions = get_user_input() # Get user input
82
+ questions = generator.generate_questions(context, num_questions=num_questions) # Generate questions
83
+ if questions: # Display questions
84
+ generator.display_questions(questions)
85
+ else:
86
+ print("Could not generate questions. Please try a different context.")
87
+
88
+ # Continue option
89
+ continue_choice = input("Generate more questions? (yes/no): ").lower()
90
+ if continue_choice not in ['yes', 'y']:
91
+ break
92
+ except Exception as e:
93
+ print(f"An error occurred: {e}")
94
+ print("Thank you for using the Question Generator!")
95
+
96
+ if __name__ == "__main__":
97
+ main()