Vipul-Chauhan commited on
Commit
14c4173
β€’
1 Parent(s): 798a280

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -0
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #loading tfidf dataset
2
+ import pandas as pd
3
+ newsdf_sample = pd.read_excel("/content/drive/MyDrive/Colab Notebooks/quantiphi/200_sample_each_20newsgroup_4k_tfidf.xlsx",engine="openpyxl")
4
+
5
+ #preprocessing for better tokenization (needed for tfidf)
6
+ import nltk
7
+ nltk.download('stopwords')
8
+ nltk.download('wordnet')
9
+ nltk.download('omw-1.4')
10
+ from nltk.corpus import stopwords
11
+ stopwords_list = stopwords.words('english')
12
+ stopwords_list
13
+
14
+ def process_row(row):
15
+ import re
16
+ from textblob import Word
17
+ from string import punctuation
18
+ from nltk.stem.snowball import SnowballStemmer
19
+
20
+
21
+ #Mail address
22
+ row = re.sub('(\S+@\S+)(com|\s+com)', ' ', row)
23
+
24
+ #Username
25
+ row = re.sub('(\S+@\S+)', ' ', row)
26
+ # print('username',len(row.split()))
27
+
28
+ #punctuation
29
+ punctuation = punctuation + '\n' + 'β€”β€œ,β€β€˜-’' + '0123456789' +"\t"
30
+ row = ''.join(word for word in row if word not in punctuation)
31
+ # print('punctuation',len(row.split()))
32
+ # print('punctuation',row)
33
+
34
+ #Lower case
35
+ row = row.lower()
36
+ # print('lower',len(row.split()))
37
+
38
+ #Stopwords
39
+ stop = stopwords_list
40
+ row = ' '.join(word for word in row.split() if word not in stop )
41
+ # print('stop',len(row.split()))
42
+ # print('stop',row)
43
+
44
+ # Lemma
45
+ row = " ".join([Word(word).lemmatize() for word in row.split()])
46
+ # print('lemma',len(row.split()))
47
+ # print('lemma',row)
48
+
49
+ #Stemming
50
+ stemmer = SnowballStemmer(language='english')
51
+ row = " ".join([stemmer.stem(word) for word in row.split()])
52
+ # print('stem',len(row.split()))
53
+ # print('stem',row)
54
+
55
+ #Extra whitespace
56
+ row = re.sub('\s{1,}', ' ', row)
57
+ # print('extra white',len(row.split()))
58
+
59
+ row = " ".join([word for word in row.split() if len(word) > 2])
60
+
61
+ return row
62
+
63
+ import pickle
64
+ kmeans_tfidf = pickle.load( open( "/content/drive/MyDrive/Colab Notebooks/quantiphi/kmeans_tfidf_20.p", "rb" ) )
65
+ vectorizer = pickle.load(open("/content/drive/MyDrive/Colab Notebooks/quantiphi/tfidf_vectorizer.p","rb"))
66
+
67
+ import matplotlib.pyplot as plt
68
+ from wordcloud import WordCloud
69
+
70
+
71
+ dictt_cluster_words={}
72
+
73
+ for i in range(0,20):
74
+ # print(i)
75
+ temp_df = newsdf_sample[newsdf_sample.exp1==i]
76
+ text_list= temp_df["tfidf_cleaned"].values
77
+ text_list = [element for element in text_list if str(element) != "nan"]
78
+ single_text = " ".join(text_list)
79
+ wordcloud = WordCloud(width = 1000, height = 500).generate(single_text)
80
+ dictt_cluster_words[i] = wordcloud.words_
81
+
82
+
83
+ #summarization model
84
+
85
+
86
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer
87
+ from transformers import pipeline
88
+ import torch
89
+
90
+
91
+
92
+ model_name = 'google/pegasus-cnn_dailymail'
93
+ device = "cuda" if torch.cuda.is_available() else "cpu"
94
+ tokenizer = PegasusTokenizer.from_pretrained(model_name)
95
+ model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
96
+
97
+ def return_summary(text):
98
+ src_text =[text]
99
+ batch = tokenizer(src_text, truncation=True, padding="longest", return_tensors="pt").to(device)
100
+ translated = model.generate(**batch)
101
+ tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
102
+ tgt_text= tgt_text[0].replace("<n>"," ")
103
+ return tgt_text
104
+
105
+ ############
106
+
107
+
108
+
109
+ def return_squad_answer(question, relevant_text):
110
+
111
+ qa_pipeline = pipeline(
112
+ "question-answering",
113
+ model="mvonwyl/distilbert-base-uncased-finetuned-squad2",#csarron/bert-base-uncased-squad-v1",
114
+ tokenizer="mvonwyl/distilbert-base-uncased-finetuned-squad2",#csarron/bert-base-uncased-squad-v1"
115
+ )
116
+
117
+ predictions = qa_pipeline({
118
+ 'context': relevant_text,
119
+ 'question': question
120
+ })
121
+
122
+ print(predictions)
123
+ return predictions["answer"]
124
+
125
+ #keyword based cluster selection would be better
126
+ #document selection based on tfidf vector
127
+
128
+ import numpy as np
129
+ import math
130
+ def l2_norm(a):
131
+ return math.sqrt(np.dot(a,a))
132
+
133
+ def cosine_similarity(a,b):
134
+ return abs(np.dot(a,b)/ (l2_norm(a) * l2_norm(b)))
135
+
136
+ def return_selected_cluster(ques):
137
+ ques_clean = process_row(ques)
138
+ cluster_selected =-1
139
+ cluster_score =0
140
+ for clus_id in dictt_cluster_words:
141
+ score_temp=0
142
+ for word in ques_clean.split():
143
+ dictt_temp = dictt_cluster_words[clus_id]
144
+ if word in dictt_temp:
145
+ score_temp+=dictt_temp[word]
146
+ if score_temp>cluster_score:
147
+ cluster_selected = clus_id
148
+ cluster_score = score_temp
149
+ return cluster_selected
150
+
151
+
152
+ def get_summary_answer(Question):
153
+ print("question: ", Question)
154
+ cluster_selected = return_selected_cluster(Question)
155
+
156
+ temp_df = newsdf_sample[newsdf_sample.exp1==cluster_selected]
157
+ tfidf_ques = vectorizer.transform([process_row(ques)]).todense()
158
+ cosine_score = []
159
+ for sent in temp_df["tfidf_cleaned"].values:
160
+ val = vectorizer.transform([sent]).todense()
161
+ # print(np.array(tfidf_ques)[0], np.array(val)[0])
162
+ cos_score = cosine_similarity(np.array(tfidf_ques)[0],np.array(val)[0])
163
+ cosine_score.append(cos_score)
164
+
165
+ temp_df["cos_score"] = cosine_score
166
+ temp_df = temp_df.sort_values(by=['cos_score'], ascending=False)
167
+
168
+ relevant_docs = temp_df["cleaned_doc"][:20]
169
+ relevant_text = " ".join(relevant_docs)
170
+ print("relevant_text", relevant_text)
171
+
172
+ # print("summary - ",return_summary(relevant_text))
173
+ # print("squad answer- ",return_squad_answer(ques, relevant_text))
174
+
175
+ summary = return_summary(relevant_text)
176
+ squad_answer = return_squad_answer(Question, relevant_text)
177
+
178
+ return summary, squad_answer
179
+
180
+
181
+ import gradio as gr
182
+ iface = gr.Interface(fn = get_summary_answer,
183
+ inputs = gr.Textbox(type="text", label="Type your question"),
184
+ # outputs = ["text", "text"],
185
+ outputs = [
186
+ gr.Textbox(type="text", value=1, label="Answer from Generative Model"),
187
+ gr.Textbox(type="text", value=2, label="Answer from SQuAD model"),
188
+ ],
189
+ title = "20NewsGroup_QA",
190
+ description ="Returns answer from 20NewsGroup dataset")
191
+ iface.launch(inline = False)