Spaces:
Sleeping
Sleeping
#loading tfidf dataset | |
import pandas as pd | |
newsdf_sample = pd.read_excel("complete_tfidf_25.xlsx",engine="openpyxl") | |
print("file size",len(newsdf_sample)) | |
#preprocessing for better tokenization (needed for tfidf) | |
import nltk | |
nltk.download('stopwords') | |
nltk.download('wordnet') | |
nltk.download('omw-1.4') | |
from nltk.corpus import stopwords | |
stopwords_list = stopwords.words('english') | |
stopwords_list | |
def process_row(row): | |
import re | |
from textblob import Word | |
from string import punctuation | |
from nltk.stem.snowball import SnowballStemmer | |
#Mail address | |
row = re.sub('(\S+@\S+)(com|\s+com)', ' ', row) | |
#Username | |
row = re.sub('(\S+@\S+)', ' ', row) | |
# print('username',len(row.split())) | |
#punctuation | |
punctuation = punctuation + '\n' + 'ββ,ββ-β' + '0123456789' +"\t" | |
row = ''.join(word for word in row if word not in punctuation) | |
# print('punctuation',len(row.split())) | |
# print('punctuation',row) | |
#Lower case | |
row = row.lower() | |
# print('lower',len(row.split())) | |
#Stopwords | |
stop = stopwords_list | |
row = ' '.join(word for word in row.split() if word not in stop ) | |
# print('stop',len(row.split())) | |
# print('stop',row) | |
# Lemma | |
row = " ".join([Word(word).lemmatize() for word in row.split()]) | |
# print('lemma',len(row.split())) | |
# print('lemma',row) | |
#Stemming | |
stemmer = SnowballStemmer(language='english') | |
row = " ".join([stemmer.stem(word) for word in row.split()]) | |
# print('stem',len(row.split())) | |
# print('stem',row) | |
#Extra whitespace | |
row = re.sub('\s{1,}', ' ', row) | |
# print('extra white',len(row.split())) | |
row = " ".join([word for word in row.split() if len(word) > 2]) | |
return row | |
import pickle | |
kmeans_tfidf = pickle.load( open( "kmeans_tfidf_25_complete.p", "rb" ) ) | |
vectorizer = pickle.load(open("tfidf_vectorizer_complete.p","rb")) | |
import matplotlib.pyplot as plt | |
from wordcloud import WordCloud | |
dictt_cluster_words={} | |
for i in range(0,25): | |
# print(i) | |
temp_df = newsdf_sample[newsdf_sample.exp25==i] | |
text_list= temp_df["tfidf_cleaned"].values | |
text_list = [element for element in text_list if str(element) != "nan"] | |
single_text = " ".join(text_list) | |
wordcloud = WordCloud(width = 1000, height = 500, max_words=1000).generate(single_text) | |
dictt_cluster_words[i] = wordcloud.words_ | |
#summarization model | |
from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
from transformers import pipeline | |
import torch | |
model_name = 'google/pegasus-cnn_dailymail' | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tokenizer = PegasusTokenizer.from_pretrained(model_name) | |
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device) | |
def return_summary(text): | |
src_text =[text] | |
batch = tokenizer(src_text, truncation=True, padding="longest", return_tensors="pt").to(device) | |
translated = model.generate(**batch) | |
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True) | |
tgt_text= tgt_text[0].replace("<n>"," ") | |
return tgt_text | |
############ | |
def return_squad_answer(question, relevant_text): | |
qa_pipeline = pipeline( | |
"question-answering", | |
model="mvonwyl/distilbert-base-uncased-finetuned-squad2",#csarron/bert-base-uncased-squad-v1", | |
tokenizer="mvonwyl/distilbert-base-uncased-finetuned-squad2",#csarron/bert-base-uncased-squad-v1" | |
) | |
predictions = qa_pipeline({ | |
'context': relevant_text, | |
'question': question | |
}) | |
print(predictions) | |
return predictions["answer"] | |
#keyword based cluster selection would be better | |
#document selection based on tfidf vector | |
import numpy as np | |
import math | |
def l2_norm(a): | |
return math.sqrt(np.dot(a,a)) | |
def cosine_similarity(a,b): | |
return abs(np.dot(a,b)/ (l2_norm(a) * l2_norm(b))) | |
def return_selected_cluster(ques): | |
ques_clean = process_row(ques) | |
count_tokens = len(ques_clean.split()) | |
cluster_selected =-1 | |
cluster_score =0 | |
for clus_id in dictt_cluster_words: | |
score_temp=0 | |
matched_token=0 | |
for word in ques_clean.split(): | |
dictt_temp = dictt_cluster_words[clus_id] | |
if word in dictt_temp: | |
matched_token+=1 | |
score_temp+=dictt_temp[word] | |
score_temp*= (matched_token/count_tokens) | |
if score_temp>cluster_score: | |
cluster_selected = clus_id | |
cluster_score = score_temp | |
return cluster_selected | |
def get_summary_answer(Question): | |
print("question: ", Question) | |
cluster_selected = return_selected_cluster(Question) | |
temp_df = newsdf_sample[newsdf_sample.exp25==cluster_selected] | |
tfidf_ques = vectorizer.transform([process_row(Question)]).todense() | |
cosine_score = [] | |
for sent in temp_df["tfidf_cleaned"].values: | |
val = vectorizer.transform([sent]).todense() | |
# print(np.array(tfidf_ques)[0], np.array(val)[0]) | |
cos_score = cosine_similarity(np.array(tfidf_ques)[0],np.array(val)[0]) | |
cosine_score.append(cos_score) | |
temp_df["cos_score"] = cosine_score | |
temp_df = temp_df.sort_values(by=['cos_score'], ascending=False) | |
relevant_docs = temp_df["cleaned_doc"][:20] | |
relevant_text = " ".join(relevant_docs) | |
print("relevant_text", relevant_text) | |
# print("summary - ",return_summary(relevant_text)) | |
# print("squad answer- ",return_squad_answer(ques, relevant_text)) | |
summary = return_summary(relevant_text) | |
squad_answer = return_squad_answer(Question, relevant_text) | |
relevant_text = " ".join(relevant_text.split()[:min(250,len(relevant_text.split()))]) | |
return relevant_text, summary, squad_answer | |
import gradio as gr | |
iface = gr.Interface(fn = get_summary_answer, | |
inputs = gr.Textbox(type="text", label="Type your question"), | |
# outputs = ["text", "text"], | |
outputs = [ | |
gr.Textbox(type="text", value=1, label="Relevant text"), | |
gr.Textbox(type="text", value=2, label="Answer from Generative Model"), | |
gr.Textbox(type="text", value=3, label="Answer from SQuAD model"), | |
], | |
title = "20NewsGroup_QA", | |
description ="Returns answer from 20NewsGroup dataset") | |
iface.launch(inline = False, debug = True) |