File size: 3,755 Bytes
e57fe3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import tensorflow.compat.v1 as tf
import os 
import shutil
import csv
import pandas as pd
import numpy as np
import IPython
import streamlit as st
import subprocess
from itertools import islice
import random
#from transformers import pipeline
from transformers import TapasTokenizer, TapasForQuestionAnswering

tf.get_logger().setLevel('ERROR')

def install(package):
  subprocess.check_call([sys.executable, "-m", "pip", "install", package])
  
install('torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu102.html')

model_name = 'google/tapas-base-finetuned-wtq'
#model_name =  "table-question-answering"
#model = pipeline(model_name)

model = TapasForQuestionAnswering.from_pretrained(model_name, local_files_only=False)
tokenizer = TapasTokenizer.from_pretrained(model_name)

st.set_option('deprecation.showfileUploaderEncoding', False)

st.title('Query your Table')
st.header('Upload CSV file')

uploaded_file = st.file_uploader("Choose your CSV file",type = 'csv')
placeholder = st.empty()

if uploaded_file is not None:
    data = pd.read_csv(uploaded_file)
    data.replace(',','', regex=True, inplace=True)
    if st.checkbox('Want to see the data?'):
        placeholder.dataframe(data)

st.header('Enter your queries')
input_queries = st.text_input('Type your queries separated by comma(,)',value='')
input_queries = input_queries.split(',')

colors1 = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)]) for i in range(len(input_queries))]
colors2 = ['background-color:'+str(color)+'; color: black' for color in colors1]

def styling_specific_cell(x,tags,colors):
    df_styler = pd.DataFrame('', index=x.index, columns=x.columns)
    for idx,tag in enumerate(tags):
        for r,c in tag:
            df_styler.iloc[r, c] = colors[idx]
    return df_styler
    
if st.button('Predict Answers'):
    with st.spinner('It will take approx a minute'):
        data = data.astype(str)
        inputs = tokenizer(table=table, queries=queries, padding='max_length', return_tensors="pt")
        outputs = model(**inputs)
        #outputs = model(table = data, query = queries)
        predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions( inputs, outputs.logits.detach(), outputs.logits_aggregation.detach())
        
        id2aggregation = {0: "NONE", 1: "SUM", 2: "AVERAGE", 3:"COUNT"}
        aggregation_predictions_string = [id2aggregation[x] for x in predicted_aggregation_indices]
    
        answers = []
        
        for coordinates in predicted_answer_coordinates:
           if len(coordinates) == 1:
             # only a single cell:
             answers.append(table.iat[coordinates[0]])
           else:
             # multiple cells
             cell_values = []
             for coordinate in coordinates:
                cell_values.append(table.iat[coordinate])
             answers.append(", ".join(cell_values))
             
    st.success('Done! Please check below the answers and its cells highlighted in table above')
    
    placeholder.dataframe(data.style.apply(styling_specific_cell,tags=predicted_answer_coordinates,colors=colors2,axis=None))
      
    for query, answer, predicted_agg, c in zip(queries, answers, aggregation_predictions_string, colors1):
        st.write('\n')
        st.markdown('<font color={} size=4>**{}**</font>'.format(c,query), unsafe_allow_html=True)
        st.write('\n')
        
        if predicted_agg == "NONE" or predicted_agg == 'COUNT':
            st.markdown('**>** '+str(answer))
        else:
            if predicted_agg == 'SUM':
                st.markdown('**>** '+str(sum(answer.split(','))))
            else:
                st.markdown('**>** '+str(np.round(np.mean(answer.split(',')),2)))