saadob12 commited on
Commit
4be804d
·
1 Parent(s): f2390d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -6
app.py CHANGED
@@ -1,9 +1,79 @@
1
  import streamlit as st
2
- from transformers import pipeline
 
 
3
 
4
- pipe = pipeline('sentiment-analysis')
5
- text = st.text_area('enter text')
 
 
6
 
7
- if text:
8
- out = pipe(text)
9
- st.json(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ import pandas as pd
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
+ class preProcess:
7
+ def __init__(self, filename, titlename):
8
+ self.filename = filename
9
+ self.title = titlename + '\n'
10
 
11
+ def read_data(self):
12
+ df = pd.read_csv(self.filename)
13
+ return df
14
+
15
+
16
+ def check_columns(self, df):
17
+ if (len(df.columns) > 3):
18
+ st.error('File has more than 3 coloumns.')
19
+ return False
20
+ if (len(df.columns) == 0):
21
+ st.error('File has no column.')
22
+ return False
23
+ else:
24
+ return True
25
+
26
+ def format_data(self, df):
27
+ headers = [[] for i in range(0, len(df.columns))]
28
+ for i in range(len(df.columns)):
29
+ headers[i] = list(df[df.columns[i]])
30
+ zipped = list(zip(*headers))
31
+ res = [' '.join(map(str,tups)) for tups in zipped]
32
+ input_format = ' labels ' + ' - '.join(list(df.columns)) + ' values ' + ' , '.join(res)
33
+
34
+ return input_format
35
+
36
+
37
+ def combine_title_data(self,df):
38
+ data = self.format_data(df)
39
+ title_data = ' '.join([self.title,data])
40
+
41
+ return title_data
42
+
43
+ class Model:
44
+ def __init__(self,text,mode):
45
+ self.padding = 'max_length'
46
+ self.truncation = True
47
+ self.prefix = 'C2T: '
48
+ self.device = device = "cuda:0" if torch.cuda.is_available() else "cpu"
49
+ self.text = text
50
+ if mode.lower() == 'simple':
51
+ self.tokenizer = AutoTokenizer.from_pretrained('saadob12/t5_C2T_big')
52
+ self.model = AutoModelForSeq2SeqLM.from_pretrained('saadob12/t5_C2T_big').to(self.device)
53
+ elif mode.lower() == 'analytical':
54
+ self.tokenizer = AutoTokenizer.from_pretrained('saadob12/t5_C2T_autochart')
55
+ self.model = AutoModelForSeq2SeqLM.from_pretrained('saadob12/t5_C2T_autochart').to(self.device)
56
+
57
+ def generate(self):
58
+ tokens = self.tokenizer.encode(self.prefix + self.text, truncation=self.truncation, padding=self.padding, return_tensors='pt').to(self.device)
59
+ generated = self.model.generate(tokens, num_beams=4, max_length=256)
60
+ tgt_text = self.tokenizer.decode(generated[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
61
+ summary = str(tgt_text).strip('[]""')
62
+ return summary
63
+
64
+
65
+
66
+ def main():
67
+ '''
68
+ pre = preProcess('test.csv', 'Comparison between two models')
69
+ contents = pre.read_data()
70
+ check = pre.check_columns(contents)
71
+ if check:
72
+ title_data = pre.combine_title_data(contents)
73
+ print(title_data)
74
+ model = Model(title_data, 'simple')
75
+ summary = model.generate()'''
76
+ uploaded_file = st.file_uploader("Choose a file")
77
+ if __name__ == "__main__":
78
+ main()
79
+