DevBM commited on
Commit
b38adec
·
verified ·
1 Parent(s): f174f61

using llama3 for option generation

Browse files
Files changed (1) hide show
  1. app.py +164 -37
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  from transformers import T5ForConditionalGeneration, T5Tokenizer
 
3
  import spacy
4
  import nltk
5
  from sklearn.feature_extraction.text import TfidfVectorizer
@@ -11,7 +12,8 @@ from functools import lru_cache
11
  nltk.download('punkt')
12
  nltk.download('stopwords')
13
  nltk.download('brown')
14
- from nltk.tokenize import sent_tokenize
 
15
  nltk.download('wordnet')
16
  from nltk.corpus import wordnet
17
  import random
@@ -30,6 +32,8 @@ import uuid
30
  import time
31
  import asyncio
32
  import aiohttp
 
 
33
  print("***************************************************************")
34
 
35
  st.set_page_config(
@@ -84,7 +88,7 @@ def load_model(modelname):
84
  # Load Spacy Model
85
  @st.cache_resource
86
  def load_nlp_models():
87
- nlp = spacy.load("en_core_web_md")
88
  s2v = sense2vec.Sense2Vec().from_disk('s2v_old')
89
  return nlp, s2v
90
 
@@ -97,6 +101,13 @@ def load_qa_models():
97
  spell = SpellChecker()
98
  return similarity_model, spell
99
 
 
 
 
 
 
 
 
100
  with st.sidebar:
101
  select_model = st.selectbox("Select Model", ("T5-large","T5-small"))
102
  if select_model == "T5-large":
@@ -106,7 +117,12 @@ elif select_model == "T5-small":
106
  nlp, s2v = load_nlp_models()
107
  similarity_model, spell = load_qa_models()
108
  context_model = similarity_model
 
109
  model, tokenizer = load_model(modelname)
 
 
 
 
110
  # Info Section
111
  def display_info():
112
  st.sidebar.title("Information")
@@ -251,7 +267,7 @@ def get_synonyms(word, n=3):
251
  return synonyms
252
  return synonyms
253
 
254
- def generate_options(answer, context, n=3):
255
  options = [answer]
256
 
257
  # Add contextually relevant words using a pre-trained model
@@ -292,6 +308,142 @@ def generate_options(answer, context, n=3):
292
 
293
  return options
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  # Function to map keywords to sentences with customizable context window size
296
  def map_keywords_to_sentences(text, keywords, context_window_size):
297
  sentences = sent_tokenize(text)
@@ -331,38 +483,8 @@ async def generate_question_async(context, answer, num_beams):
331
  except Exception as e:
332
  raise QuestionGenerationError(f"Error in question generation: {str(e)}")
333
 
334
- async def generate_options_async(answer, context, n=3):
335
- try:
336
- options = [answer]
337
-
338
- # Add contextually relevant words using a pre-trained model
339
- context_embedding = await asyncio.to_thread(context_model.encode, context)
340
- answer_embedding = await asyncio.to_thread(context_model.encode, answer)
341
- context_words = [token.text for token in nlp(context) if token.is_alpha and token.text.lower() != answer.lower()]
342
 
343
- # Compute similarity scores and sort context words
344
- similarity_scores = [util.pytorch_cos_sim(await asyncio.to_thread(context_model.encode, word), answer_embedding).item() for word in context_words]
345
- sorted_context_words = [word for _, word in sorted(zip(similarity_scores, context_words), reverse=True)]
346
- options.extend(sorted_context_words[:n])
347
 
348
- # Try to get similar words based on sense2vec
349
- similar_words = await asyncio.to_thread(get_similar_words_sense2vec, answer, n)
350
- options.extend(similar_words)
351
-
352
- # If we don't have enough options, try synonyms
353
- if len(options) < n + 1:
354
- synonyms = await asyncio.to_thread(get_synonyms, answer, n - len(options) + 1)
355
- options.extend(synonyms)
356
-
357
- # Ensure we have the correct number of unique options
358
- options = list(dict.fromkeys(options))[:n+1]
359
-
360
- # Shuffle the options
361
- random.shuffle(options)
362
-
363
- return options
364
- except Exception as e:
365
- raise QuestionGenerationError(f"Error in generating options: {str(e)}")
366
 
367
 
368
  # Function to generate questions using beam search
@@ -395,13 +517,16 @@ async def generate_questions_async(text, num_questions, context_window_size, num
395
  st.error(f"An unexpected error occurred: {str(e)}")
396
  return []
397
 
398
- async def process_batch(batch, keywords, context_window_size, num_beams):
399
  questions = []
400
  for text in batch:
401
  keyword_sentence_mapping = map_keywords_to_sentences(text, keywords, context_window_size)
402
  for keyword, context in keyword_sentence_mapping.items():
403
  question = await generate_question_async(context, keyword, num_beams)
404
- options = await generate_options_async(keyword, context)
 
 
 
405
  overall_score, relevance_score, complexity_score, spelling_correctness = assess_question_quality(context, question, keyword)
406
  if overall_score >= 0.5:
407
  questions.append({
@@ -477,6 +602,7 @@ def assess_question_quality(context, question, answer):
477
  return overall_score, relevance_score, complexity_score, spelling_correctness
478
 
479
  def main():
 
480
  # Streamlit interface
481
  st.title(":blue[Question Generator System]")
482
  session_id = get_session_id()
@@ -498,6 +624,7 @@ def main():
498
  num_beams = st.slider("Select number of beams for question generation", min_value=2, max_value=10, value=2)
499
  context_window_size = st.slider("Select context window size (number of sentences before and after)", min_value=1, max_value=5, value=1)
500
  num_questions = st.slider("Select number of questions to generate", min_value=1, max_value=1000, value=5)
 
501
  col1, col2 = st.columns(2)
502
  with col1:
503
  extract_all_keywords = st.toggle("Extract Max Keywords",value=False)
@@ -518,14 +645,14 @@ def main():
518
  if text:
519
  text = clean_text(text)
520
  generate_questions_button = st.button("Generate Questions")
521
- st.markdown('<span aria-label="Generate questions button">Above is the generate questions button</span>', unsafe_allow_html=True)
522
 
523
  # if generate_questions_button:
524
  if generate_questions_button and text:
525
  start_time = time.time()
526
  with st.spinner("Generating questions..."):
527
  try:
528
- state['generated_questions'] = asyncio.run(generate_questions_async(text, num_questions, context_window_size, num_beams, extract_all_keywords))
529
  if not state['generated_questions']:
530
  st.warning("No questions were generated. The text might be too short or lack suitable content.")
531
  else:
 
1
  import streamlit as st
2
  from transformers import T5ForConditionalGeneration, T5Tokenizer
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, LlamaForCausalLM
4
  import spacy
5
  import nltk
6
  from sklearn.feature_extraction.text import TfidfVectorizer
 
12
  nltk.download('punkt')
13
  nltk.download('stopwords')
14
  nltk.download('brown')
15
+ from nltk.tokenize import sent_tokenize, word_tokenize
16
+ from nltk.tag import pos_tag
17
  nltk.download('wordnet')
18
  from nltk.corpus import wordnet
19
  import random
 
32
  import time
33
  import asyncio
34
  import aiohttp
35
+ import torch
36
+ from dotenv import load_dotenv
37
  print("***************************************************************")
38
 
39
  st.set_page_config(
 
88
  # Load Spacy Model
89
  @st.cache_resource
90
  def load_nlp_models():
91
+ nlp = spacy.load("en_core_web_lg")
92
  s2v = sense2vec.Sense2Vec().from_disk('s2v_old')
93
  return nlp, s2v
94
 
 
101
  spell = SpellChecker()
102
  return similarity_model, spell
103
 
104
+ @st.cache_resource
105
+ def load_llm_model():
106
+ model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
107
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
108
+ model = LlamaForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16, device_map="auto")
109
+ return tokenizer, model
110
+
111
  with st.sidebar:
112
  select_model = st.selectbox("Select Model", ("T5-large","T5-small"))
113
  if select_model == "T5-large":
 
117
  nlp, s2v = load_nlp_models()
118
  similarity_model, spell = load_qa_models()
119
  context_model = similarity_model
120
+ sentence_model = similarity_model
121
  model, tokenizer = load_model(modelname)
122
+ # llm_tokenizer, llm_model = load_llm_model()
123
+ llm_tokenizer, llm_model = "meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct"
124
+ pipe = pipeline("text-generation", model=llm_model, tokenizer=llm_tokenizer, max_new_tokens=200)
125
+
126
  # Info Section
127
  def display_info():
128
  st.sidebar.title("Information")
 
267
  return synonyms
268
  return synonyms
269
 
270
+ def get_fallback_options(answer, context, n=3):
271
  options = [answer]
272
 
273
  # Add contextually relevant words using a pre-trained model
 
308
 
309
  return options
310
 
311
+ def get_semantic_similarity(word1, word2):
312
+ embeddings = sentence_model.encode([word1, word2])
313
+ return util.pytorch_cos_sim(embeddings[0], embeddings[1]).item()
314
+
315
+ def ensure_grammatical_consistency(question, answer, option):
316
+ question_pos = pos_tag(word_tokenize(question))
317
+ answer_pos = pos_tag(word_tokenize(answer))
318
+ option_pos = pos_tag(word_tokenize(option))
319
+
320
+ # Check if the answer and option have the same part of speech
321
+ if answer_pos[-1][1] != option_pos[-1][1]:
322
+ return False
323
+
324
+ # Check if the option fits grammatically in the question
325
+ question_template = question.replace(answer, "PLACEHOLDER")
326
+ option_question = question_template.replace("PLACEHOLDER", option)
327
+ option_question_pos = pos_tag(word_tokenize(option_question))
328
+
329
+ return question_pos == option_question_pos
330
+
331
+ def get_word_type(word):
332
+ doc = nlp(word)
333
+ return doc[0].pos_
334
+
335
+ def generate_text_with_llama(prompt):
336
+ full_prompt = f"""[INST] {prompt} [/INST]"""
337
+ result = pipe(prompt, temperature=0.7, do_sample=True)[0]['generated_text']
338
+ # Extract the generated part after the prompt
339
+ # return result.split('[/INST]')[-1].strip()
340
+ return result
341
+
342
+ async def generate_options_with_llm(answer, context, question, n=4):
343
+ prompt = f"""Given the following context, question, and correct answer, generate {n-1} incorrect but plausible answer options. The options should be:
344
+ 1. Contextually related to the given context
345
+ 2. Grammatically consistent with the question
346
+ 3. Different from the correct answer
347
+ 4. Not explicitly mentioned in the given context
348
+
349
+ Context: {context}
350
+ Question: {question}
351
+ Correct Answer: {answer}
352
+
353
+ Provide the options in a comma-separated list.
354
+ """
355
+
356
+ try:
357
+ response = await asyncio.to_thread(generate_text_with_llama, prompt)
358
+ options = [option.strip() for option in response.split(',')]
359
+ options = [option for option in options if option.lower() != answer.lower()]
360
+ print(f"\n\nLLM Options are: {options}\n\n")
361
+ return options[:n-1] # Ensure we only return n-1 options
362
+ except Exception as e:
363
+ st.error(f"Error generating options with LLM: {e}")
364
+ return []
365
+
366
+
367
+ async def generate_options_async(answer, context, question, n=4):
368
+ options = [answer]
369
+
370
+ # Generate options using the language model
371
+ llm_options = await generate_options_with_llm(answer, context, question, n)
372
+ options.extend(llm_options)
373
+
374
+ # If we don't have enough options, fall back to previous methods
375
+ if len(options) < n:
376
+ semantic_options = await generate_semantic_options(answer, context, question, n - len(options))
377
+ options.extend(semantic_options)
378
+
379
+ # If we still don't have enough options, use the fallback method
380
+ while len(options) < n:
381
+ fallback_options = await get_fallback_options(answer, context)
382
+ for option in fallback_options:
383
+ if option not in options and ensure_grammatical_consistency(question, answer, option):
384
+ options.append(option)
385
+ if len(options) == n:
386
+ break
387
+
388
+ # Shuffle the options
389
+ random.shuffle(options)
390
+
391
+ return options
392
+
393
+ async def generate_semantic_options(answer, context, question, n=4):
394
+ try:
395
+ options = [answer]
396
+
397
+ # Get context words
398
+ doc = nlp(context)
399
+ context_words = [token.text for token in doc if token.is_alpha and token.text.lower() != answer.lower()]
400
+
401
+ # Get answer type
402
+ answer_type = get_word_type(answer)
403
+ print(answer_type,"\n")
404
+
405
+ # Get semantically similar words
406
+ similar_words = []
407
+ for word in context_words:
408
+ if get_word_type(word) == answer_type:
409
+ similarity = get_semantic_similarity(answer, word)
410
+ if 0.2 < similarity < 0.8: # Adjust these thresholds as needed
411
+ similar_words.append((word, similarity))
412
+
413
+ # Sort by similarity (descending) and take top n-1
414
+ similar_words.sort(key=lambda x: x[1], reverse=True)
415
+ top_similar_words = [word for word, _ in similar_words[:n-1]]
416
+
417
+ # Ensure grammatical consistency
418
+ consistent_options = []
419
+ for word in top_similar_words:
420
+ if ensure_grammatical_consistency(question, answer, word):
421
+ consistent_options.append(word)
422
+ if len(consistent_options) == n-1:
423
+ break
424
+
425
+ options.extend(consistent_options)
426
+
427
+ # If we don't have enough options, fall back to original method
428
+ while len(options) < n:
429
+ fallback_options = get_fallback_options(answer, context, 3)
430
+ for option in fallback_options:
431
+ if option not in options and ensure_grammatical_consistency(question, answer, option):
432
+ options.append(option)
433
+ break
434
+
435
+ # Shuffle the options
436
+ random.shuffle(options)
437
+ print(options)
438
+ st.write("All possibel options are: ", options, "\n")
439
+ return options
440
+
441
+ except Exception as e:
442
+ raise QuestionGenerationError(f"Error in generating options: {str(e)}")
443
+
444
+
445
+
446
+
447
  # Function to map keywords to sentences with customizable context window size
448
  def map_keywords_to_sentences(text, keywords, context_window_size):
449
  sentences = sent_tokenize(text)
 
483
  except Exception as e:
484
  raise QuestionGenerationError(f"Error in question generation: {str(e)}")
485
 
 
 
 
 
 
 
 
 
486
 
 
 
 
 
487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
 
489
 
490
  # Function to generate questions using beam search
 
517
  st.error(f"An unexpected error occurred: {str(e)}")
518
  return []
519
 
520
+ async def process_batch(batch, keywords, context_window_size, num_beams, use_llm_options):
521
  questions = []
522
  for text in batch:
523
  keyword_sentence_mapping = map_keywords_to_sentences(text, keywords, context_window_size)
524
  for keyword, context in keyword_sentence_mapping.items():
525
  question = await generate_question_async(context, keyword, num_beams)
526
+ if use_llm_options:
527
+ options = await generate_options_async(keyword, context, question)
528
+ else:
529
+ options =await generate_semantic_options(keyword, context, question)
530
  overall_score, relevance_score, complexity_score, spelling_correctness = assess_question_quality(context, question, keyword)
531
  if overall_score >= 0.5:
532
  questions.append({
 
602
  return overall_score, relevance_score, complexity_score, spelling_correctness
603
 
604
  def main():
605
+ load_dotenv()
606
  # Streamlit interface
607
  st.title(":blue[Question Generator System]")
608
  session_id = get_session_id()
 
624
  num_beams = st.slider("Select number of beams for question generation", min_value=2, max_value=10, value=2)
625
  context_window_size = st.slider("Select context window size (number of sentences before and after)", min_value=1, max_value=5, value=1)
626
  num_questions = st.slider("Select number of questions to generate", min_value=1, max_value=1000, value=5)
627
+ use_llm_for_options = st.toggle("Use AI for Advanced option generation", False)
628
  col1, col2 = st.columns(2)
629
  with col1:
630
  extract_all_keywords = st.toggle("Extract Max Keywords",value=False)
 
645
  if text:
646
  text = clean_text(text)
647
  generate_questions_button = st.button("Generate Questions")
648
+ # st.markdown('<span aria-label="Generate questions button">Above is the generate questions button</span>', unsafe_allow_html=True)
649
 
650
  # if generate_questions_button:
651
  if generate_questions_button and text:
652
  start_time = time.time()
653
  with st.spinner("Generating questions..."):
654
  try:
655
+ state['generated_questions'] = asyncio.run(generate_questions_async(text, num_questions, context_window_size, num_beams, extract_all_keywords, use_llm_for_options))
656
  if not state['generated_questions']:
657
  st.warning("No questions were generated. The text might be too short or lack suitable content.")
658
  else: