jishnuprakash commited on
Commit
2281681
·
1 Parent(s): 88afd90
Files changed (1) hide show
  1. home.py +0 -152
home.py DELETED
@@ -1,152 +0,0 @@
1
- """
2
- @author:jishnuprakash
3
- """
4
- import os
5
- import torch
6
- import spacy
7
- import utils as ut
8
- import streamlit as st
9
- import pandas as pd
10
- import plotly.express as px
11
- import plotly.graph_objects as go
12
- import pandas as pd
13
- import numpy as np
14
- import matplotlib.pyplot as plt
15
- import seaborn as sns
16
- from nltk import word_tokenize
17
- from nltk.probability import FreqDist
18
- from matplotlib import pyplot as plt
19
- from nltk.corpus import stopwords
20
- from tqdm.auto import tqdm
21
- from datasets import load_dataset
22
- from transformers import AutoTokenizer, AutoModel
23
- from pytorch_lightning.metrics.functional import accuracy, f1, auroc
24
- from sklearn.metrics import classification_report
25
-
26
-
27
- st.set_page_config(page_title='NLP Challenge- JP', layout='wide', page_icon=':computer:')
28
- st.set_option('deprecation.showPyplotGlobalUse', False)
29
-
30
- #this is the header
31
- st.markdown("<h1 style='text-align: center; color: black;'>NLP Challenge - HM Land Registry</h1>", unsafe_allow_html=True)
32
- st.markdown("<h3 style='text-align: center; color: grey;'>Multi-label classification using BERT Transformers</h3>", unsafe_allow_html=True)
33
- st.markdown("<div style='text-align: center''> Submission by: Jishnu Prakash Kunnanath Poduvattil | Portfolio:<a href='https://jishnuprakash.github.io/'>jishnuprakash.github.io</a> | Source Code: <a href='https://github.com/Jishnuprakash/lexGLUE_jishnuprakash'>Github</a> </div>", unsafe_allow_html=True)
34
- st.text('')
35
- expander = st.expander("View Description")
36
- expander.write("""This is minimal user interface implemetation to view and interact with
37
- results obtained from fine-tuned BERT transformers trained on LEX GLUE: ECTHR_A dataset.
38
- Try inputing a text below and see the model predictions. You can also extract the location
39
- and Date entities from the text using the checkbox.\\
40
- Below, you can do the same on test data. """)
41
-
42
-
43
- #Load trained model
44
- @st.cache(allow_output_mutation=True)
45
- def load_model():
46
- trained_model = ut.LexGlueTagger.load_from_checkpoint(ut.check_filename+'.ckpt', num_classes = ut.num_classes)
47
- #Initialise BERT tokenizer
48
- tokenizer = AutoTokenizer.from_pretrained(ut.bert_model)
49
- #Set to Eval and freeze to avoid weight update
50
- trained_model.eval()
51
- trained_model.freeze()
52
- test = load_dataset("lex_glue", "ecthr_a")['test']
53
- test = ut.preprocess_data(pd.DataFrame(test))
54
- #Load Model from Spacy
55
- NER = spacy.load("en_core_web_sm")
56
- return (trained_model, tokenizer, test, NER)
57
-
58
- trained_model, tokenizer, test, NER = load_model()
59
-
60
- st.header("Try out a text!")
61
- with st.form('model_prediction'):
62
- text = st.text_area("Input Text", " ".join(test.iloc[0]['text'])[:1525])
63
- n1, n2, n3 = st.columns((0.13,0.3,0.4))
64
- ner_check = n1.checkbox("Extract Location and Date", value=True)
65
- predict = n2.form_submit_button("Predict")
66
- with st.spinner("Predicting..."):
67
- if predict:
68
- encoding = tokenizer.encode_plus(text,
69
- add_special_tokens=True,
70
- max_length=512,
71
- return_token_type_ids=False,
72
- padding="max_length",
73
- return_attention_mask=True,
74
- return_tensors='pt',)
75
- # Predict on text
76
- _, prediction = trained_model(encoding["input_ids"], encoding["attention_mask"])
77
- prediction = list(prediction.flatten().numpy())
78
-
79
- final_predictions = [prediction.index(i) for i in prediction if i > ut.threshold]
80
- if len(final_predictions)>0:
81
- for i in final_predictions:
82
- st.write('Violations: '+ ut.lex_classes[i] + ' : ' + str(round(prediction[i]*100, 2)) + ' %')
83
- else:
84
- st.write("Confidence less than 50%, Please try another text.")
85
-
86
- if ner_check:
87
- #Perform NER on a single text
88
- n_text = NER(text)
89
- loc = ''
90
- date = ''
91
- for word in n_text.ents:
92
- print(word.text,word.label_)
93
- if word.label_ == 'DATE':
94
- date += word.text + ', '
95
- elif word.label_ == 'GPE':
96
- loc += word.text + ', '
97
- loc = "None found" if len(loc)<1 else loc
98
- date = "None found" if len(date)<1 else date
99
- st.write("Location entities: " + loc)
100
- st.write("Date entities: " + date)
101
-
102
- st.header("Predict on test data")
103
- with st.form('model_test_prediction'):
104
- s1, s2, s3 = st.columns((0.1, 0.3, 0.6))
105
- top = s1.number_input("Count",1, len(test), value=10)
106
- ner_check2 = s2.checkbox("Extract Location and Date", value=True)
107
- predict2 = s2.form_submit_button("Predict")
108
- with st.spinner("Predicting on test data"):
109
- if predict2:
110
- test_dataset = ut.LexGlueDataset(test.head(top), tokenizer, max_tokens=512)
111
-
112
- # Predict on test data
113
- predictions = []
114
- labels = []
115
-
116
- for item in tqdm(test_dataset):
117
- _ , prediction = trained_model(item["input_ids"].unsqueeze(dim=0),
118
- item["attention_mask"].unsqueeze(dim=0))
119
- predictions.append(prediction.flatten())
120
- labels.append(item["labels"].int())
121
-
122
- predictions = torch.stack(predictions)
123
- labels = torch.stack(labels)
124
-
125
- y_pred = predictions.numpy()
126
- y_true = labels.numpy()
127
-
128
- #Filter predictions
129
- upper, lower = 1, 0
130
- y_pred = np.where(y_pred > ut.threshold, upper, lower)
131
- # d1, d2 = st.columns((0.6, 0.4))
132
-
133
- #Accuracy
134
- acc = round(float(accuracy(predictions, labels, threshold=ut.threshold))*100, 2)
135
-
136
- out = test_dataset.data
137
- out['predictions'] = [[list(i).index(j) for j in i if j==1] for i in y_pred]
138
- out['labels'] = out['labels'].apply(lambda x: [ut.lex_classes[i] for i in x])
139
- out['predictions'] = out['predictions'].apply(lambda x: [ut.lex_classes[i] for i in x])
140
-
141
- if ner_check2:
142
- #Perform NER on Test Dataset
143
- out['nlp_text'] = out.text.apply(lambda x: NER(" ".join(x)))
144
-
145
- #Extract Entities
146
- out['location'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='GPE']))
147
- out['date'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='DATE']))
148
-
149
- st.dataframe(out.drop('nlp_text', axis=1))
150
- else:
151
- st.dataframe(out)
152
- s3.metric(label ='Accuracy',value = acc, delta = '', delta_color = 'inverse')