Spaces:
Sleeping
Sleeping
using llama3 for option generation
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
|
|
3 |
import spacy
|
4 |
import nltk
|
5 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
@@ -11,7 +12,8 @@ from functools import lru_cache
|
|
11 |
nltk.download('punkt')
|
12 |
nltk.download('stopwords')
|
13 |
nltk.download('brown')
|
14 |
-
from nltk.tokenize import sent_tokenize
|
|
|
15 |
nltk.download('wordnet')
|
16 |
from nltk.corpus import wordnet
|
17 |
import random
|
@@ -30,6 +32,8 @@ import uuid
|
|
30 |
import time
|
31 |
import asyncio
|
32 |
import aiohttp
|
|
|
|
|
33 |
print("***************************************************************")
|
34 |
|
35 |
st.set_page_config(
|
@@ -84,7 +88,7 @@ def load_model(modelname):
|
|
84 |
# Load Spacy Model
|
85 |
@st.cache_resource
|
86 |
def load_nlp_models():
|
87 |
-
nlp = spacy.load("
|
88 |
s2v = sense2vec.Sense2Vec().from_disk('s2v_old')
|
89 |
return nlp, s2v
|
90 |
|
@@ -97,6 +101,13 @@ def load_qa_models():
|
|
97 |
spell = SpellChecker()
|
98 |
return similarity_model, spell
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
with st.sidebar:
|
101 |
select_model = st.selectbox("Select Model", ("T5-large","T5-small"))
|
102 |
if select_model == "T5-large":
|
@@ -106,7 +117,12 @@ elif select_model == "T5-small":
|
|
106 |
nlp, s2v = load_nlp_models()
|
107 |
similarity_model, spell = load_qa_models()
|
108 |
context_model = similarity_model
|
|
|
109 |
model, tokenizer = load_model(modelname)
|
|
|
|
|
|
|
|
|
110 |
# Info Section
|
111 |
def display_info():
|
112 |
st.sidebar.title("Information")
|
@@ -251,7 +267,7 @@ def get_synonyms(word, n=3):
|
|
251 |
return synonyms
|
252 |
return synonyms
|
253 |
|
254 |
-
def
|
255 |
options = [answer]
|
256 |
|
257 |
# Add contextually relevant words using a pre-trained model
|
@@ -292,6 +308,142 @@ def generate_options(answer, context, n=3):
|
|
292 |
|
293 |
return options
|
294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
# Function to map keywords to sentences with customizable context window size
|
296 |
def map_keywords_to_sentences(text, keywords, context_window_size):
|
297 |
sentences = sent_tokenize(text)
|
@@ -331,38 +483,8 @@ async def generate_question_async(context, answer, num_beams):
|
|
331 |
except Exception as e:
|
332 |
raise QuestionGenerationError(f"Error in question generation: {str(e)}")
|
333 |
|
334 |
-
async def generate_options_async(answer, context, n=3):
|
335 |
-
try:
|
336 |
-
options = [answer]
|
337 |
-
|
338 |
-
# Add contextually relevant words using a pre-trained model
|
339 |
-
context_embedding = await asyncio.to_thread(context_model.encode, context)
|
340 |
-
answer_embedding = await asyncio.to_thread(context_model.encode, answer)
|
341 |
-
context_words = [token.text for token in nlp(context) if token.is_alpha and token.text.lower() != answer.lower()]
|
342 |
|
343 |
-
# Compute similarity scores and sort context words
|
344 |
-
similarity_scores = [util.pytorch_cos_sim(await asyncio.to_thread(context_model.encode, word), answer_embedding).item() for word in context_words]
|
345 |
-
sorted_context_words = [word for _, word in sorted(zip(similarity_scores, context_words), reverse=True)]
|
346 |
-
options.extend(sorted_context_words[:n])
|
347 |
|
348 |
-
# Try to get similar words based on sense2vec
|
349 |
-
similar_words = await asyncio.to_thread(get_similar_words_sense2vec, answer, n)
|
350 |
-
options.extend(similar_words)
|
351 |
-
|
352 |
-
# If we don't have enough options, try synonyms
|
353 |
-
if len(options) < n + 1:
|
354 |
-
synonyms = await asyncio.to_thread(get_synonyms, answer, n - len(options) + 1)
|
355 |
-
options.extend(synonyms)
|
356 |
-
|
357 |
-
# Ensure we have the correct number of unique options
|
358 |
-
options = list(dict.fromkeys(options))[:n+1]
|
359 |
-
|
360 |
-
# Shuffle the options
|
361 |
-
random.shuffle(options)
|
362 |
-
|
363 |
-
return options
|
364 |
-
except Exception as e:
|
365 |
-
raise QuestionGenerationError(f"Error in generating options: {str(e)}")
|
366 |
|
367 |
|
368 |
# Function to generate questions using beam search
|
@@ -395,13 +517,16 @@ async def generate_questions_async(text, num_questions, context_window_size, num
|
|
395 |
st.error(f"An unexpected error occurred: {str(e)}")
|
396 |
return []
|
397 |
|
398 |
-
async def process_batch(batch, keywords, context_window_size, num_beams):
|
399 |
questions = []
|
400 |
for text in batch:
|
401 |
keyword_sentence_mapping = map_keywords_to_sentences(text, keywords, context_window_size)
|
402 |
for keyword, context in keyword_sentence_mapping.items():
|
403 |
question = await generate_question_async(context, keyword, num_beams)
|
404 |
-
|
|
|
|
|
|
|
405 |
overall_score, relevance_score, complexity_score, spelling_correctness = assess_question_quality(context, question, keyword)
|
406 |
if overall_score >= 0.5:
|
407 |
questions.append({
|
@@ -477,6 +602,7 @@ def assess_question_quality(context, question, answer):
|
|
477 |
return overall_score, relevance_score, complexity_score, spelling_correctness
|
478 |
|
479 |
def main():
|
|
|
480 |
# Streamlit interface
|
481 |
st.title(":blue[Question Generator System]")
|
482 |
session_id = get_session_id()
|
@@ -498,6 +624,7 @@ def main():
|
|
498 |
num_beams = st.slider("Select number of beams for question generation", min_value=2, max_value=10, value=2)
|
499 |
context_window_size = st.slider("Select context window size (number of sentences before and after)", min_value=1, max_value=5, value=1)
|
500 |
num_questions = st.slider("Select number of questions to generate", min_value=1, max_value=1000, value=5)
|
|
|
501 |
col1, col2 = st.columns(2)
|
502 |
with col1:
|
503 |
extract_all_keywords = st.toggle("Extract Max Keywords",value=False)
|
@@ -518,14 +645,14 @@ def main():
|
|
518 |
if text:
|
519 |
text = clean_text(text)
|
520 |
generate_questions_button = st.button("Generate Questions")
|
521 |
-
st.markdown('<span aria-label="Generate questions button">Above is the generate questions button</span>', unsafe_allow_html=True)
|
522 |
|
523 |
# if generate_questions_button:
|
524 |
if generate_questions_button and text:
|
525 |
start_time = time.time()
|
526 |
with st.spinner("Generating questions..."):
|
527 |
try:
|
528 |
-
state['generated_questions'] = asyncio.run(generate_questions_async(text, num_questions, context_window_size, num_beams, extract_all_keywords))
|
529 |
if not state['generated_questions']:
|
530 |
st.warning("No questions were generated. The text might be too short or lack suitable content.")
|
531 |
else:
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, LlamaForCausalLM
|
4 |
import spacy
|
5 |
import nltk
|
6 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
|
12 |
nltk.download('punkt')
|
13 |
nltk.download('stopwords')
|
14 |
nltk.download('brown')
|
15 |
+
from nltk.tokenize import sent_tokenize, word_tokenize
|
16 |
+
from nltk.tag import pos_tag
|
17 |
nltk.download('wordnet')
|
18 |
from nltk.corpus import wordnet
|
19 |
import random
|
|
|
32 |
import time
|
33 |
import asyncio
|
34 |
import aiohttp
|
35 |
+
import torch
|
36 |
+
from dotenv import load_dotenv
|
37 |
print("***************************************************************")
|
38 |
|
39 |
st.set_page_config(
|
|
|
88 |
# Load Spacy Model
|
89 |
@st.cache_resource
|
90 |
def load_nlp_models():
|
91 |
+
nlp = spacy.load("en_core_web_lg")
|
92 |
s2v = sense2vec.Sense2Vec().from_disk('s2v_old')
|
93 |
return nlp, s2v
|
94 |
|
|
|
101 |
spell = SpellChecker()
|
102 |
return similarity_model, spell
|
103 |
|
104 |
+
@st.cache_resource
|
105 |
+
def load_llm_model():
|
106 |
+
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
|
107 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
108 |
+
model = LlamaForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16, device_map="auto")
|
109 |
+
return tokenizer, model
|
110 |
+
|
111 |
with st.sidebar:
|
112 |
select_model = st.selectbox("Select Model", ("T5-large","T5-small"))
|
113 |
if select_model == "T5-large":
|
|
|
117 |
nlp, s2v = load_nlp_models()
|
118 |
similarity_model, spell = load_qa_models()
|
119 |
context_model = similarity_model
|
120 |
+
sentence_model = similarity_model
|
121 |
model, tokenizer = load_model(modelname)
|
122 |
+
# llm_tokenizer, llm_model = load_llm_model()
|
123 |
+
llm_tokenizer, llm_model = "meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct"
|
124 |
+
pipe = pipeline("text-generation", model=llm_model, tokenizer=llm_tokenizer, max_new_tokens=200)
|
125 |
+
|
126 |
# Info Section
|
127 |
def display_info():
|
128 |
st.sidebar.title("Information")
|
|
|
267 |
return synonyms
|
268 |
return synonyms
|
269 |
|
270 |
+
def get_fallback_options(answer, context, n=3):
|
271 |
options = [answer]
|
272 |
|
273 |
# Add contextually relevant words using a pre-trained model
|
|
|
308 |
|
309 |
return options
|
310 |
|
311 |
+
def get_semantic_similarity(word1, word2):
|
312 |
+
embeddings = sentence_model.encode([word1, word2])
|
313 |
+
return util.pytorch_cos_sim(embeddings[0], embeddings[1]).item()
|
314 |
+
|
315 |
+
def ensure_grammatical_consistency(question, answer, option):
|
316 |
+
question_pos = pos_tag(word_tokenize(question))
|
317 |
+
answer_pos = pos_tag(word_tokenize(answer))
|
318 |
+
option_pos = pos_tag(word_tokenize(option))
|
319 |
+
|
320 |
+
# Check if the answer and option have the same part of speech
|
321 |
+
if answer_pos[-1][1] != option_pos[-1][1]:
|
322 |
+
return False
|
323 |
+
|
324 |
+
# Check if the option fits grammatically in the question
|
325 |
+
question_template = question.replace(answer, "PLACEHOLDER")
|
326 |
+
option_question = question_template.replace("PLACEHOLDER", option)
|
327 |
+
option_question_pos = pos_tag(word_tokenize(option_question))
|
328 |
+
|
329 |
+
return question_pos == option_question_pos
|
330 |
+
|
331 |
+
def get_word_type(word):
|
332 |
+
doc = nlp(word)
|
333 |
+
return doc[0].pos_
|
334 |
+
|
335 |
+
def generate_text_with_llama(prompt):
|
336 |
+
full_prompt = f"""[INST] {prompt} [/INST]"""
|
337 |
+
result = pipe(prompt, temperature=0.7, do_sample=True)[0]['generated_text']
|
338 |
+
# Extract the generated part after the prompt
|
339 |
+
# return result.split('[/INST]')[-1].strip()
|
340 |
+
return result
|
341 |
+
|
342 |
+
async def generate_options_with_llm(answer, context, question, n=4):
|
343 |
+
prompt = f"""Given the following context, question, and correct answer, generate {n-1} incorrect but plausible answer options. The options should be:
|
344 |
+
1. Contextually related to the given context
|
345 |
+
2. Grammatically consistent with the question
|
346 |
+
3. Different from the correct answer
|
347 |
+
4. Not explicitly mentioned in the given context
|
348 |
+
|
349 |
+
Context: {context}
|
350 |
+
Question: {question}
|
351 |
+
Correct Answer: {answer}
|
352 |
+
|
353 |
+
Provide the options in a comma-separated list.
|
354 |
+
"""
|
355 |
+
|
356 |
+
try:
|
357 |
+
response = await asyncio.to_thread(generate_text_with_llama, prompt)
|
358 |
+
options = [option.strip() for option in response.split(',')]
|
359 |
+
options = [option for option in options if option.lower() != answer.lower()]
|
360 |
+
print(f"\n\nLLM Options are: {options}\n\n")
|
361 |
+
return options[:n-1] # Ensure we only return n-1 options
|
362 |
+
except Exception as e:
|
363 |
+
st.error(f"Error generating options with LLM: {e}")
|
364 |
+
return []
|
365 |
+
|
366 |
+
|
367 |
+
async def generate_options_async(answer, context, question, n=4):
|
368 |
+
options = [answer]
|
369 |
+
|
370 |
+
# Generate options using the language model
|
371 |
+
llm_options = await generate_options_with_llm(answer, context, question, n)
|
372 |
+
options.extend(llm_options)
|
373 |
+
|
374 |
+
# If we don't have enough options, fall back to previous methods
|
375 |
+
if len(options) < n:
|
376 |
+
semantic_options = await generate_semantic_options(answer, context, question, n - len(options))
|
377 |
+
options.extend(semantic_options)
|
378 |
+
|
379 |
+
# If we still don't have enough options, use the fallback method
|
380 |
+
while len(options) < n:
|
381 |
+
fallback_options = await get_fallback_options(answer, context)
|
382 |
+
for option in fallback_options:
|
383 |
+
if option not in options and ensure_grammatical_consistency(question, answer, option):
|
384 |
+
options.append(option)
|
385 |
+
if len(options) == n:
|
386 |
+
break
|
387 |
+
|
388 |
+
# Shuffle the options
|
389 |
+
random.shuffle(options)
|
390 |
+
|
391 |
+
return options
|
392 |
+
|
393 |
+
async def generate_semantic_options(answer, context, question, n=4):
|
394 |
+
try:
|
395 |
+
options = [answer]
|
396 |
+
|
397 |
+
# Get context words
|
398 |
+
doc = nlp(context)
|
399 |
+
context_words = [token.text for token in doc if token.is_alpha and token.text.lower() != answer.lower()]
|
400 |
+
|
401 |
+
# Get answer type
|
402 |
+
answer_type = get_word_type(answer)
|
403 |
+
print(answer_type,"\n")
|
404 |
+
|
405 |
+
# Get semantically similar words
|
406 |
+
similar_words = []
|
407 |
+
for word in context_words:
|
408 |
+
if get_word_type(word) == answer_type:
|
409 |
+
similarity = get_semantic_similarity(answer, word)
|
410 |
+
if 0.2 < similarity < 0.8: # Adjust these thresholds as needed
|
411 |
+
similar_words.append((word, similarity))
|
412 |
+
|
413 |
+
# Sort by similarity (descending) and take top n-1
|
414 |
+
similar_words.sort(key=lambda x: x[1], reverse=True)
|
415 |
+
top_similar_words = [word for word, _ in similar_words[:n-1]]
|
416 |
+
|
417 |
+
# Ensure grammatical consistency
|
418 |
+
consistent_options = []
|
419 |
+
for word in top_similar_words:
|
420 |
+
if ensure_grammatical_consistency(question, answer, word):
|
421 |
+
consistent_options.append(word)
|
422 |
+
if len(consistent_options) == n-1:
|
423 |
+
break
|
424 |
+
|
425 |
+
options.extend(consistent_options)
|
426 |
+
|
427 |
+
# If we don't have enough options, fall back to original method
|
428 |
+
while len(options) < n:
|
429 |
+
fallback_options = get_fallback_options(answer, context, 3)
|
430 |
+
for option in fallback_options:
|
431 |
+
if option not in options and ensure_grammatical_consistency(question, answer, option):
|
432 |
+
options.append(option)
|
433 |
+
break
|
434 |
+
|
435 |
+
# Shuffle the options
|
436 |
+
random.shuffle(options)
|
437 |
+
print(options)
|
438 |
+
st.write("All possibel options are: ", options, "\n")
|
439 |
+
return options
|
440 |
+
|
441 |
+
except Exception as e:
|
442 |
+
raise QuestionGenerationError(f"Error in generating options: {str(e)}")
|
443 |
+
|
444 |
+
|
445 |
+
|
446 |
+
|
447 |
# Function to map keywords to sentences with customizable context window size
|
448 |
def map_keywords_to_sentences(text, keywords, context_window_size):
|
449 |
sentences = sent_tokenize(text)
|
|
|
483 |
except Exception as e:
|
484 |
raise QuestionGenerationError(f"Error in question generation: {str(e)}")
|
485 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
|
|
|
|
|
|
|
|
|
487 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
488 |
|
489 |
|
490 |
# Function to generate questions using beam search
|
|
|
517 |
st.error(f"An unexpected error occurred: {str(e)}")
|
518 |
return []
|
519 |
|
520 |
+
async def process_batch(batch, keywords, context_window_size, num_beams, use_llm_options):
|
521 |
questions = []
|
522 |
for text in batch:
|
523 |
keyword_sentence_mapping = map_keywords_to_sentences(text, keywords, context_window_size)
|
524 |
for keyword, context in keyword_sentence_mapping.items():
|
525 |
question = await generate_question_async(context, keyword, num_beams)
|
526 |
+
if use_llm_options:
|
527 |
+
options = await generate_options_async(keyword, context, question)
|
528 |
+
else:
|
529 |
+
options =await generate_semantic_options(keyword, context, question)
|
530 |
overall_score, relevance_score, complexity_score, spelling_correctness = assess_question_quality(context, question, keyword)
|
531 |
if overall_score >= 0.5:
|
532 |
questions.append({
|
|
|
602 |
return overall_score, relevance_score, complexity_score, spelling_correctness
|
603 |
|
604 |
def main():
|
605 |
+
load_dotenv()
|
606 |
# Streamlit interface
|
607 |
st.title(":blue[Question Generator System]")
|
608 |
session_id = get_session_id()
|
|
|
624 |
num_beams = st.slider("Select number of beams for question generation", min_value=2, max_value=10, value=2)
|
625 |
context_window_size = st.slider("Select context window size (number of sentences before and after)", min_value=1, max_value=5, value=1)
|
626 |
num_questions = st.slider("Select number of questions to generate", min_value=1, max_value=1000, value=5)
|
627 |
+
use_llm_for_options = st.toggle("Use AI for Advanced option generation", False)
|
628 |
col1, col2 = st.columns(2)
|
629 |
with col1:
|
630 |
extract_all_keywords = st.toggle("Extract Max Keywords",value=False)
|
|
|
645 |
if text:
|
646 |
text = clean_text(text)
|
647 |
generate_questions_button = st.button("Generate Questions")
|
648 |
+
# st.markdown('<span aria-label="Generate questions button">Above is the generate questions button</span>', unsafe_allow_html=True)
|
649 |
|
650 |
# if generate_questions_button:
|
651 |
if generate_questions_button and text:
|
652 |
start_time = time.time()
|
653 |
with st.spinner("Generating questions..."):
|
654 |
try:
|
655 |
+
state['generated_questions'] = asyncio.run(generate_questions_async(text, num_questions, context_window_size, num_beams, extract_all_keywords, use_llm_for_options))
|
656 |
if not state['generated_questions']:
|
657 |
st.warning("No questions were generated. The text might be too short or lack suitable content.")
|
658 |
else:
|