File size: 4,545 Bytes
491e087 ac7f046 344f958 491e087 5344f77 491e087 344f958 491e087 5344f77 491e087 2704216 491e087 5344f77 2704216 491e087 8196f31 491e087 8196f31 491e087 8196f31 491e087 8196f31 491e087 5344f77 8196f31 a887cae 5344f77 8196f31 491e087 344f958 491e087 2704216 344f958 a887cae 2704216 ac7f046 344f958 2704216 344f958 2704216 344f958 2704216 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import streamlit as st
import pandas as pd
from plms.language_model import TransformersQG
import time
import os
import numpy as np
st.set_page_config(page_icon='🧪', page_title='ViQAG for Vietnamese Education', layout='centered', initial_sidebar_state="collapsed")
with open(r"./static/styles.css") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
st.markdown(f"""
<div class=logo_area>
<img src="./app/static/AlphaEdu_logo_trans.png"/>
</div>
""", unsafe_allow_html=True)
st.markdown("<h1 style='text-align: center;'>AlphaEdu</h1>", unsafe_allow_html=True)
# =====================================================================================================
if 'output' not in st.session_state:
st.session_state.output = ''
def file_selector(folder_path=r'./Resources/'):
filenames = os.listdir(folder_path)
return filenames
filenames = file_selector()
def load_grades(file_name, folder_path=r'./Resources/'):
file_path = f"{folder_path}{file_name}"
df = pd.read_csv(file_path)
list_grades = df['grade'].drop_duplicates().values
return list_grades, df
def load_chapters(df, grade_name):
df_raw = df[df['grade'] == grade_name]
list_chapters = df_raw['chapter'].drop_duplicates().values
return list_chapters, df
def load_lessons(df, grade_name, chapter_name):
df_raw = df[(df['grade'] == grade_name) & (df['chapter'] == chapter_name)]
return df_raw['lesson'].drop_duplicates().values
def load_context(df, grade_name, chapter_name, lesson_name):
context = df[(df['grade'] == grade_name) & (df['chapter'] == chapter_name) & (df['lesson'] == lesson_name)]['context'].values
return len(context), context
def generateQA(context, model_path = 'shnl/vit5-vinewsqa-qg-ae'):
unique_qa_pairs = set()
model = TransformersQG(model=model_path, max_length=512)
output = model.generate_qa(context)
qa_pairs = ''
for item in output:
question, answer = item
if (question, answer) not in unique_qa_pairs:
qa_pairs += f'question: {question} \nanswer: {answer} [SEP] '
unique_qa_pairs.add((question, answer))
qa = '\n\n'.join(qa_pairs.split(' [SEP] '))
return qa
# =====================================================================================================
col_1, col_2 = st.sidebar.columns(spec=[1, 1])
subject = col_1.selectbox(label='Select your subject:', options=filenames, label_visibility='visible')
list_grades, df = load_grades(file_name=subject)
grade = col_2.selectbox(label='Select your grade:', options=list_grades, label_visibility='visible')
list_chapters, df = load_chapters(df=df, grade_name=grade)
chapter = st.sidebar.selectbox(label='Select your chapter:', options=list_chapters, label_visibility='visible')
lesson_names = load_lessons(df=df, grade_name=grade, chapter_name=chapter)
lesson = st.sidebar.selectbox(label='Lesson:', options=lesson_names, label_visibility='visible')
total_paragraph, context_values = load_context(df=df, grade_name=grade, chapter_name=chapter, lesson_name=lesson)
col_12, col_22 = st.sidebar.columns(spec=[4, 6])
paragraph_idx = col_12.selectbox(label='Paragraph:', options=list(np.arange(1, total_paragraph + 1)), label_visibility='visible')
paragraph = st.text_area(label='Paragraph content', label_visibility='visible', height=200, value=context_values[paragraph_idx - 1])
col_22.selectbox(label='QAG model:', options=['ViT5-ViNewsQA'], label_visibility='visible')
btn_show_answer = st.sidebar.toggle(label='Show answers', disabled=False)
col_14, col_24, col_34, col_44, col_54 = st.columns(spec=[1, 1, 1, 1, 1])
btn_generate = col_34.button(label='Generate', use_container_width=True)
if btn_generate == True:
with st.spinner(text='Generating QA pairs from the selected paragraph. Please wait ...'):
st.session_state.output = generateQA(context=paragraph)
if btn_show_answer:
if st.session_state.output != '':
st.markdown("<h8 style='text-align: left; font-weight: normal'>Generated QA pairs:</h8>", unsafe_allow_html=True)
st.code(body=st.session_state.output, language='latex')
else:
pass
else:
if st.session_state.output != '':
st.markdown("<h8 style='text-align: left; font-weight: normal'>Generated QA pairs:</h8>", unsafe_allow_html=True)
output_no_answer = st.session_state.output.split(' [SEP] ')[0].split(', answer: ')[0].replace('question: ', '')
st.code(body=output_no_answer, language='latex')
else:
pass |