File size: 4,713 Bytes
97beac5
4be804d
 
70f757d
4be804d
97beac5
4be804d
 
 
 
aff4412
4be804d
 
 
 
 
 
c344f79
4be804d
 
 
 
 
 
 
 
 
 
 
 
86f9681
4be804d
 
f590933
86f9681
 
 
 
4be804d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1dd4f4
 
4be804d
 
 
 
 
 
a154d32
3e77d90
 
a154d32
3e77d90
a154d32
3e77d90
a154d32
3e77d90
a154d32
3e77d90
a154d32
3e77d90
a154d32
3e77d90
a154d32
3e77d90
a154d32
3e77d90
a154d32
3e77d90
b7b149a
a154d32
 
 
4be804d
 
dcfd954
 
2b6b5fd
26e951a
 
9e7a428
0314595
 
 
 
 
 
 
 
c344f79
 
0314595
0c7ed46
 
0314595
 
 
 
 
0c7ed46
 
0314595
 
a02613f
8cd9729
 
529aa95
4be804d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import streamlit as st
import torch
import pandas as pd
from io import StringIO
from transformers import AutoTokenizer,  AutoModelForSeq2SeqLM

class preProcess:
    def __init__(self, filename, titlename):
      self.filename = filename
      self.title = titlename + '\n'

    def read_data(self):
      df = pd.read_csv(self.filename)
      return df


    def check_columns(self, df):
      if (len(df.columns) > 4):
        st.error('File has more than 3 coloumns.')
        return False
      if (len(df.columns) == 0):
        st.error('File has no column.')
        return False
      else:
        return True

    def format_data(self, df):
        headers = [[] for i in range(0, len(df.columns))]
        for i in range(len(df.columns)):
            headers[i] = list(df[df.columns[i]])
        
        zipped = list(zip(*headers))
        res = [' '.join(map(str,tups)) for tups in zipped]
        if len(df.columns) < 3:
          input_format = ' x-y values ' + ' - '.join(list(df.columns)) +  ' values '  + ' , '.join(res)

        else:
          input_format = ' labels ' + ' - '.join(list(df.columns)) +  ' values '  + ' , '.join(res)

        return input_format


    def combine_title_data(self,df):
      data = self.format_data(df)
      title_data = ' '.join([self.title,data])
      
      return title_data
      
class Model:
    def __init__(self,text,mode):
      self.padding = 'max_length'
      self.truncation = True
      self.prefix = 'C2T: '
      self.device = device = "cuda:0" if torch.cuda.is_available() else "cpu"
      self.text = text
      if mode.lower() == 'simple':
        self.tokenizer = AutoTokenizer.from_pretrained('saadob12/t5_C2T_big')
        self.model = AutoModelForSeq2SeqLM.from_pretrained('saadob12/t5_C2T_big').to(self.device)
      elif mode.lower() == 'analytical':
        self.tokenizer = AutoTokenizer.from_pretrained('saadob12/t5_autochart_2')
        self.model = AutoModelForSeq2SeqLM.from_pretrained('saadob12/t5_autochart_2').to(self.device)

    def generate(self):
      tokens = self.tokenizer.encode(self.prefix + self.text,  truncation=self.truncation, padding=self.padding, return_tensors='pt').to(self.device)
      generated = self.model.generate(tokens, num_beams=4, max_length=256)
      tgt_text = self.tokenizer.decode(generated[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
      summary = str(tgt_text).strip('[]""')
      
      if 'barchart' in summary:
        summary.replace('barchart','statistic')
      elif 'bar graph' in summary:
        summary.replace('bar graph','statistic')
      elif 'bar plot' in summary:
        summary.replace('bar plot','statistic')
      elif 'scatter plot' in summary:
        summary.replace('scatter plot','statistic')
      elif 'scatter graph' in summary:
        summary.replace('scatter graph','statistic')
      elif 'scatterchart' in summary:
        summary.replace('scatter chart','statistic')
      elif 'line plot' in summary:
        summary.replace('line plot','statistic')
      elif 'line graph' in summary:
        summary.replace('line graph','statistic')
      elif 'linechart' in summary:
        summary.replace('linechart','statistic')
      
      if 'graph' in summary:
        summary.replace('graph','statistic')
        
        
        
      return summary
      
st.title('Chart and Data Summarization')
st.write('This application generates a summary of a datafile (.csv) (or the underlying data of a chart). Right now, it only generates summaries of files with maximum of four columns. If the file contains more than four columns, the app will throw an error.')
mode = st.selectbox('What kind of summary do you want?',
     ('Simple', 'Analytical'))
st.write('You selected: ' + mode + ' summary.') 
title = st.text_input('Add appropriate Title of the .csv file', 'State minimum wage rates in the United States as of January 1 , 2020')
st.write('Title of the file is: ' + title) 
uploaded_file = st.file_uploader("Upload only .csv file")
if uploaded_file is not None and mode is not None and title is not None:
  st.write('Preprocessing file...')
  p = preProcess(uploaded_file, title)
  contents = p.read_data()
  check = p.check_columns(contents)
  if check:
    st.write('Your file contents:\n')
    st.write(contents)
    title_data = p.combine_title_data(contents)
    st.write('Linearized input format of the data file:\n ')
    st.markdown('**'+ title_data + '**')
    
    st.write('Loading model...')
    model = Model(title_data, mode)
    st.write('Model loading done!\nGenerating Summary...')
    summary = model.generate()
    st.write('Generated Summary:\n')
    st.markdown('**'+ summary + '**')