Spaces:
Build error
Build error
File size: 7,659 Bytes
88afd90 3981553 88afd90 11dfa1d 88afd90 b30f704 88afd90 6a956f9 88afd90 6a956f9 88afd90 df17ec4 88afd90 0b9a94b 88afd90 d334f3e d1e00d8 74b1387 88afd90 d334f3e 88afd90 d334f3e 88afd90 d334f3e d1e00d8 d334f3e 11dfa1d 88afd90 74b1387 88afd90 2aae144 814772d 2aae144 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
"""
@author:jishnuprakash
"""
import nltk
nltk.download('stopwords')
import os
import torch
import spacy
import utils as ut
import streamlit as st
import pandas as pd
import pandas as pd
import numpy as np
from spacy import displacy
from tqdm.auto import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer
from pytorch_lightning.metrics.functional import accuracy
st.set_page_config(page_title='Classification - BERT', layout='wide', page_icon=':computer:')
st.set_option('deprecation.showPyplotGlobalUse', False)
#this is the header
st.markdown("<h1 style='text-align: center; color: black;'>Multi-label classification using BERT Transformers</h1>", unsafe_allow_html=True)
st.markdown("<div style='text-align: center''> Author: Jishnu Prakash Kunnanath Poduvattil | Portfolio:<a href='https://jishnuprakash.github.io/'>jishnuprakash.github.io</a> | Source Code: <a href='https://github.com/Jishnuprakash/lexGLUE_jishnuprakash'>Github</a> </div>", unsafe_allow_html=True)
st.text('')
expander = st.expander("View Description")
expander.write("""This is a user interface to view and interact with
results obtained from fine-tuned BERT transformers trained on LEX GLUE: ECTHR_A dataset.
Try inputing a text below and see the model predictions. You can also extract the location
and Date entities from the text using the checkbox.\\
Below, you can do the same on test data.\\
Please find the test data here https://huggingface.co/datasets/lex_glue """)
#Load trained model
@st.cache(allow_output_mutation=True)
def load_model():
trained_model = ut.LexGlueTagger.load_from_checkpoint(ut.check_filename+'.ckpt', num_classes = ut.num_classes)
#Initialise BERT tokenizer
tokenizer = AutoTokenizer.from_pretrained(ut.bert_model)
#Set to Eval and freeze to avoid weight update
trained_model.eval()
trained_model.freeze()
test = load_dataset("lex_glue", "ecthr_a")['test']
test = ut.preprocess_data(pd.DataFrame(test))
#Load Model from Spacy
NER = spacy.load("en_core_web_sm")
return (trained_model, tokenizer, test, NER)
trained_model, tokenizer, test, NER = load_model()
st.header("Try out a text!")
with st.form('model_prediction'):
text = st.text_area("Input Text", test.iloc[0]['text'][20])
text = text[:2000]
n1, n2, n3 = st.columns((0.2,0.4,0.4))
ner_check = n1.checkbox("Extract Location and Date", value=True)
predict = n2.form_submit_button("Predict")
with st.spinner("Predicting..."):
if predict:
encoding = tokenizer.encode_plus(text,
add_special_tokens=True,
max_length=512,
return_token_type_ids=False,
padding="max_length",
return_attention_mask=True,
return_tensors='pt',)
# Predict on text
_, prediction = trained_model(encoding["input_ids"], encoding["attention_mask"])
prediction = list(prediction.flatten().numpy())
final_predictions = [prediction.index(i) for i in prediction if i > ut.threshold]
if len(final_predictions)>0:
for i in final_predictions:
st.write('Violations: '+ ut.lex_classes[i] + ' : ' + str(round(prediction[i]*100, 2)) + ' %')
else:
st.write("Confidence less than 50%, Please try another text.")
if ner_check:
#Perform NER on a single text
n_text = NER(text)
loc = []
date = []
for word in n_text.ents:
print(word.text,word.label_)
if word.label_ == 'DATE':
date.append(word.text)
elif word.label_ == 'GPE':
loc.append(word.text)
loc = list(set(loc))
date = list(set(date))
loc = ["None found"] if len(loc)==0 else loc
date = ["None found"] if len(date)==0 else date
st.write("Location entities: " + ",".join(loc))
st.write("Date entities: " + ",".join(date))
#Display entities
st.write("All Entities-")
ent_html = displacy.render(n_text, style="ent", jupyter=False)
# Display the entity visualization in the browser:
st.markdown(ent_html, unsafe_allow_html=True)
st.header("Predict on test data")
with st.form('model_test_prediction'):
s1, s2, s3 = st.columns((0.2, 0.4, 0.4))
top = s1.number_input("Count",1, len(test), value=10)
ner_check2 = s2.checkbox("Extract Location and Date", value=True)
predict2 = s2.form_submit_button("Predict")
with st.spinner("Predicting on test data"):
if predict2:
test_dataset = ut.LexGlueDataset(test.head(top), tokenizer, max_tokens=512)
# Predict on test data
predictions = []
labels = []
for item in tqdm(test_dataset):
_ , prediction = trained_model(item["input_ids"].unsqueeze(dim=0),
item["attention_mask"].unsqueeze(dim=0))
predictions.append(prediction.flatten())
labels.append(item["labels"].int())
predictions = torch.stack(predictions)
labels = torch.stack(labels)
y_pred = predictions.numpy()
y_true = labels.numpy()
#Filter predictions
upper, lower = 1, 0
y_pred = np.where(y_pred > ut.threshold, upper, lower)
# d1, d2 = st.columns((0.6, 0.4))
#Accuracy
acc = round(float(accuracy(predictions, labels, threshold=ut.threshold))*100, 2)
out = test_dataset.data
out['predictions'] = [[list(i).index(j) for j in i if j==1] for i in y_pred]
out['labels'] = out['labels'].apply(lambda x: [ut.lex_classes[i] for i in x])
out['predictions'] = out['predictions'].apply(lambda x: [ut.lex_classes[i] for i in x])
if ner_check2:
#Perform NER on Test Dataset
out['nlp_text'] = out.text.apply(lambda x: NER(" ".join(x)))
#Extract Entities
out['location'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='GPE']))
out['date'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='DATE']))
st.dataframe(out.drop('nlp_text', axis=1))
else:
st.dataframe(out)
s3.metric(label ='Accuracy',value = acc, delta = '', delta_color = 'inverse')
st.header("Comparison - Model Performance")
st.write("""2 transformer models were finetuned and compared their performance on the test dataset. \\
- Bert uncased model (on Original & preprocessed text) \\
- Legal-Bert model (on Original & preprocessed text)\\
(Preprocessing steps were removal of numbers, symbols, stopwords followed by lemmatisation on tokens.)\\
The best performing model is Legal-BERT on original data. Please see the comparison below.""")
met = pd.read_csv("model_comparison.csv")
a1, a2 = st.columns((0.5,0.5))
a1.subheader("Evaluation Metrics")
a1.dataframe(met[12:].reset_index(drop=True))
a2.subheader("Area under ROC curve")
a2.dataframe(met[:11]) |