File size: 2,899 Bytes
a578005
 
 
 
88424d7
a578005
88424d7
 
 
 
00b5c9c
 
88424d7
 
 
 
 
a578005
88424d7
00b5c9c
88424d7
 
00b5c9c
88424d7
 
a578005
 
2f36d79
a578005
2f36d79
a578005
 
 
88424d7
 
 
 
a578005
88424d7
 
a578005
2f36d79
 
 
a578005
88424d7
 
00b5c9c
 
a578005
88424d7
2f36d79
88424d7
 
a578005
88424d7
2f36d79
 
 
88424d7
 
 
 
2f36d79
88424d7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import streamlit as st
import pandas as pd
import re

st.set_page_config(page_icon='🍃', page_title='MRC for Legal Document Dataset checker', layout='wide', initial_sidebar_state="collapsed")

# start processing events
def load_data(file_uploader):
    if file_uploader is not None:
        return pd.read_csv(file_uploader)
    else:
        return pd.DataFrame(columns=['context', 'question', 'answer'])
    
def convert_df(df):
    # IMPORTANT: Cache the conversion to prevent computation on every rerun
    return df.to_csv().encode("utf-8")
# end processing events

st.markdown("<h1 style='text-align: center;'>Investigation Legal Dataset checker for Machine Reading Comprehension</h1>", unsafe_allow_html=True)

file = st.file_uploader(label='Upload your file here:', type=['csv'], accept_multiple_files=False, label_visibility='hidden')
df = load_data(file_uploader=file)

if 'df' not in st.session_state:
    st.session_state.df = df

if 'idx' not in st.session_state:
    st.session_state.idx = 0

st.markdown(f"<h3 style='text-align: center;'>Sample {st.session_state.idx + 1}/{len(df)}</h3>", unsafe_allow_html=True)

col_1, col_2, col_3, col_4, col_5, col_6, col_7, col_8, col_9, col_10 = st.columns([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

btn_previous = col_1.button(label=':arrow_backward: Previous sample', use_container_width=True)
btn_next = col_2.button(label='Next sample :arrow_forward:', use_container_width=True)
btn_save = col_3.button(label=':heavy_check_mark: Save change', use_container_width=True)
# txt_goto = col_4.selectbox(label='None', options=[np.array(range(len(df)))], label_visibility='collapsed')

if len(df) != 0:
    index = st.session_state.idx

    txt_context = st.text_area(height=300, label='Your context:', value=st.session_state.df['context'][index])
    txt_question = st.text_area(height=100, label='Your question:', value=st.session_state.df['question'][index])
    txt_answer = st.text_area(height=100, label='Your answer:', value=st.session_state.df['answer'][index])

    if txt_answer.strip() and txt_context.strip():
        highlighted_context = re.sub(re.escape(txt_answer), "<mark>" + txt_answer + "</mark>", txt_context, flags=re.IGNORECASE)
        st.markdown(highlighted_context, unsafe_allow_html=True)


    if btn_next:
        if index < len(df) - 1:
            st.session_state.idx += 1
            st.rerun()

    if btn_save:
        st.session_state.df['context'][index] = txt_context
        st.session_state.df['question'][index] = txt_question
        st.session_state.df['answer'][index] = txt_answer
        csv_file = convert_df(df=st.session_state.df)
        btn_download = col_4.download_button(data=csv_file, label=':arrow_down_small: Download file', use_container_width=True, file_name="large_df.csv", mime="text/csv")

    if btn_previous:
        if index > 0:
            st.session_state.idx -= 1
            st.rerun()