File size: 5,687 Bytes
d1ca2ad
 
 
e4382ce
d1ca2ad
 
e4382ce
d1ca2ad
 
e4382ce
d1ca2ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c08fd3
 
 
 
 
d1ca2ad
e4382ce
d1ca2ad
 
 
 
e4382ce
 
 
d1ca2ad
e4382ce
d1ca2ad
 
e4382ce
d1ca2ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4382ce
d1ca2ad
 
 
 
 
 
 
 
 
 
 
e4382ce
d1ca2ad
e4382ce
 
d1ca2ad
e4382ce
d1ca2ad
 
 
 
9bd0ed5
e4382ce
d1ca2ad
 
e4382ce
 
 
 
 
 
7c6e400
 
 
 
 
 
 
 
 
e4382ce
 
 
 
 
 
 
 
 
 
9bd0ed5
e4382ce
d1ca2ad
 
e4382ce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import streamlit as st
from st_aggrid import AgGrid
import pandas as pd
from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer

# Set the page layout for Streamlit
st.set_page_config(layout="wide")

# CSS styling
style = '''
    <style>
        body {background-color: #F5F5F5; color: #000000;}
        header {visibility: hidden;}
        div.block-container {padding-top:4rem;}
        section[data-testid="stSidebar"] div:first-child {
        padding-top: 0;
    }
     .font {                                          
    text-align:center;
    font-family:sans-serif;font-size: 1.25rem;}
    </style>
'''
st.markdown(style, unsafe_allow_html=True)

st.markdown('<p style="font-family:sans-serif;font-size: 1.5rem;text-align: right;"> HertogAI Table Q&A using TAPAS and Model Language</p>', unsafe_allow_html=True)
st.markdown('<p style="font-family:sans-serif;font-size: 0.7rem;text-align: right;"> This code is based on Jordan Skinner. I enhanced his work using Language Model T5</p>', unsafe_allow_html=True)
st.markdown("<p style='font-family:sans-serif;font-size: 0.6rem;text-align: right;'>Pre-trained TAPAS model runs on max 64 rows and 32 columns data. Make sure the file data doesn't exceed these dimensions.</p>", unsafe_allow_html=True)



# Initialize TAPAS pipeline
tqa = pipeline(task="table-question-answering", 
              model="google/tapas-large-finetuned-wtq",
              device="cpu")

# Initialize T5 tokenizer and model for text generation
t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")

# File uploader in the sidebar
file_name = st.sidebar.file_uploader("Upload file:", type=['csv', 'xlsx'])

# File processing and question answering
if file_name is None:
    st.markdown('<p class="font">Please upload an excel or csv file </p>', unsafe_allow_html=True)
else:
    try:
        # Check file type and handle reading accordingly
        if file_name.name.endswith('.csv'):
            df = pd.read_csv(file_name, sep=';', encoding='ISO-8859-1')  # Adjust encoding if needed
        elif file_name.name.endswith('.xlsx'):
            df = pd.read_excel(file_name, engine='openpyxl')  # Use openpyxl to read .xlsx files
        else:
            st.error("Unsupported file type")
            df = None

        # Continue with further processing if df is loaded
        if df is not None:
            numeric_columns = df.select_dtypes(include=['object']).columns
            for col in numeric_columns:
                df[col] = pd.to_numeric(df[col], errors='ignore')

            st.write("Original Data:")
            st.write(df)

            # Create a copy for numerical operations
            df_numeric = df.copy()
            df = df.astype(str)

            # Display the first 5 rows of the dataframe in an editable grid
            grid_response = AgGrid(
                df.head(5),
                columns_auto_size_mode='FIT_CONTENTS',
                editable=True, 
                height=300, 
                width='100%',
            )
            
    except Exception as e:
        st.error(f"Error reading file: {str(e)}")

    # User input for the question
    question = st.text_input('Type your question')

    # Process the answer using TAPAS and T5
    with st.spinner():
        if st.button('Answer'):
            try:
                # Get the raw answer from TAPAS
                raw_answer = tqa(table=df, query=question, truncation=True)
                
                st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>",
                           unsafe_allow_html=True)
                st.success(raw_answer)
                
                # Extract relevant information from the TAPAS result
                answer = raw_answer['answer']
                aggregator = raw_answer.get('aggregator', '')
                coordinates = raw_answer.get('coordinates', [])
                cells = raw_answer.get('cells', [])
                
                # Construct a base sentence replacing 'SUM' with the query term
                base_sentence = f"The {question.lower()} of the selected data is {answer}."
                if coordinates and cells:
                    rows_info = [f"Row {coordinate[0] + 1}, Column '{df.columns[coordinate[1]]}' with value {cell}" 
                                 for coordinate, cell in zip(coordinates, cells)]
                    rows_description = " and ".join(rows_info)
                    base_sentence += f" This includes the following data: {rows_description}."

                # Generate a fluent response using the T5 model, rephrasing the base sentence
                input_text = f"Given the question: '{question}', generate a more human-readable response: {base_sentence}"

                # Tokenize the input and generate a fluent response using T5
                inputs = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
                summary_ids = t5_model.generate(inputs, max_length=150, num_beams=4, early_stopping=True)

                # Decode the generated text
                generated_text = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True)

                # Display the final generated response
                st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response with LLM: </p>", unsafe_allow_html=True)
                st.success(generated_text)

            except Exception as e:
                st.warning("Please retype your question and make sure to use the column name and cell value correctly.")