Spaces:
Runtime error
Runtime error
import os | |
import json | |
import random | |
import streamlit as st | |
from transformers import TextClassificationPipeline, pipeline | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DistilBertTokenizerFast, DistilBertForSequenceClassification | |
# We'll be using Torch this time around | |
import torch | |
import torch.nn.functional as F | |
emotion_model_names = ( | |
"cardiffnlp/twitter-roberta-base-sentiment", | |
"finiteautomata/beto-sentiment-analysis", | |
"bhadresh-savani/distilbert-base-uncased-emotion", | |
"siebert/sentiment-roberta-large-english" | |
) | |
class ModelImplementation(object): | |
def __init__( | |
self, | |
transformer_model_name, | |
model_transformer, | |
tokenizer_model_name, | |
tokenizer_func, | |
pipeline_func, | |
parser_func, | |
classifier_args={}, | |
placeholders=[""] | |
): | |
self.transformer_model_name = transformer_model_name | |
self.tokenizer_model_name = tokenizer_model_name | |
self.placeholders = placeholders | |
self.model = model_transformer.from_pretrained(self.transformer_model_name) | |
self.tokenizer = tokenizer_func.from_pretrained(self.tokenizer_model_name) | |
self.classifier = pipeline_func(model=self.model, tokenizer=self.tokenizer, padding=True, truncation=True, **classifier_args) | |
self.parser = parser_func | |
self.history = [] | |
def predict(self, val): | |
result = self.classifier(val) | |
return self.parser(self, result) | |
def ParseEmotionOutput(self, result): | |
label = result[0]['label'] | |
score = result[0]['score'] | |
if self.transformer_model_name == "cardiffnlp/twitter-roberta-base-sentiment": | |
if label == "LABEL_0": | |
label = "Negative" | |
elif label == "LABEL_2": | |
label = "Positive" | |
else: | |
label = "Neutral" | |
return label, score | |
def ParsePatentOutput(self, result): | |
return result | |
def emotion_model_change(): | |
st.session_state.emotion_model = ModelImplementation( | |
st.session_state.emotion_model_name, | |
AutoModelForSequenceClassification, | |
st.session_state.emotion_model_name, | |
AutoTokenizer, | |
pipeline, | |
ParseEmotionOutput, | |
classifier_args={ "task" : "sentiment-analysis" }, | |
placeholders=["@AmericanAir just landed - 3hours Late Flight - and now we need to wait TWENTY MORE MINUTES for a gate! I have patience but none for incompetence."] | |
) | |
if "page" not in st.session_state: | |
st.session_state.page = "home" | |
if "emotion_model_name" not in st.session_state: | |
st.session_state.emotion_model_name = "cardiffnlp/twitter-roberta-base-sentiment" | |
emotion_model_change() | |
if "patent_data" not in st.session_state: | |
f = open('./data/val.json') | |
valData = json.load(f) | |
f.close() | |
patent_data = {} | |
for num, label, abstract, claim in zip(valData["patent_numbers"],valData["labels"], valData["abstracts"], valData["claims"]): | |
patent_data[num] = {"patent_number":num,"label":label,"abstract":abstract,"claim":claim} | |
st.session_state.patent_data = patent_data | |
st.session_state.patent_num = list(patent_data.keys())[0] | |
st.session_state.weight = 0.5 | |
st.session_state.patent_abstract_model = ModelImplementation( | |
'./models/uspto_abstracts', | |
DistilBertForSequenceClassification, | |
'distilbert-base-uncased', | |
DistilBertTokenizerFast, | |
TextClassificationPipeline, | |
ParsePatentOutput, | |
classifier_args={"return_all_scores":True}, | |
) | |
print("Patent abstracts model initialized") | |
st.session_state.patent_claim_model = ModelImplementation( | |
'./models/uspto_claims', | |
DistilBertForSequenceClassification, | |
'distilbert-base-uncased', | |
DistilBertTokenizerFast, | |
TextClassificationPipeline, | |
ParsePatentOutput, | |
classifier_args={"return_all_scores":True}, | |
) | |
print("Patent claims model initialized") | |
# Title | |
st.title("CSGY-6613 Project") | |
# Subtitle | |
st.markdown("_**Ryan Kim (rk2546)**_") | |
st.markdown("---") | |
def PageToHome(): | |
st.session_state.page = "home" | |
def PageToEmotion(): | |
st.session_state.page = "emotion" | |
def PageToPatent(): | |
st.session_state.page = "patent" | |
with st.sidebar: | |
st.subheader("Toolbox") | |
home_selected = st.button("Home", on_click=PageToHome) | |
emotion_selected = st.button( | |
"Emotion Analysis [Milestone #2]", | |
on_click=PageToEmotion | |
) | |
patent_selected = st.button( | |
"Patent Prediction [Milestone #3]", | |
on_click=PageToPatent | |
) | |
if st.session_state.page == "emotion": | |
st.subheader("Sentiment Analysis") | |
if "emotion_model" not in st.session_state: | |
st.write("Loading model...") | |
else: | |
model_option = st.selectbox( | |
"What sentiment analysis model do you want to use? NOTE: Lag may occur when loading a new model!", | |
emotion_model_names, | |
on_change=emotion_model_change, | |
key="emotion_model_name" | |
) | |
form = st.form(key='sentiment-analysis-form') | |
text_input = form.text_area( | |
"Enter some text for sentiment analysis! If you just want to test it out without entering anything, just press the \"Submit\" button and the model will look at the placeholder.", | |
placeholder=st.session_state.emotion_model.placeholders[0] | |
) | |
submit = form.form_submit_button('Submit') | |
if submit: | |
if text_input is None or len(text_input.strip()) == 0: | |
to_eval = st.session_state.emotion_model.placeholders[0] | |
else: | |
to_eval = text_input.strip() | |
st.write("You entered:") | |
st.markdown("> {}".format(to_eval)) | |
st.write("Using the NLP model:") | |
st.markdown("> {}".format(st.session_state.emotion_model_name)) | |
label, score = st.session_state.emotion_model.predict(to_eval) | |
st.markdown("#### Result:") | |
st.markdown("**{}**: {}".format(label,score)) | |
elif st.session_state.page == "patent": | |
st.subheader("USPTO Patent Evaluation") | |
st.markdown("Below are two inputs - one for an **ABSTRACT** and another for a list of **CLAIMS**. Enter both and select the \"Submit\" button to evaluate the patenteability of your idea.") | |
patent_select_list = list(st.session_state.patent_data.keys()) | |
patent_index_option = st.selectbox( | |
"Want to pre-populate with an existing patent? Select the index number of below.", | |
patent_select_list, | |
key="patent_num", | |
) | |
print(patent_index_option) | |
if "patent_abstract_model" not in st.session_state or "patent_claim_model" not in st.session_state: | |
st.write("Loading models...") | |
else: | |
with st.form(key='patent-form'): | |
col1, col2 = st.columns(2) | |
with col1: | |
abstract_input = st.text_area( | |
"Enter the abstract of the patent below", | |
placeholder=st.session_state.patent_data[st.session_state.patent_num]["abstract"], | |
height=400 | |
) | |
with col2: | |
claim_input = st.text_area( | |
"Enter the claims of the patent below", | |
placeholder=st.session_state.patent_data[st.session_state.patent_num]["claim"], | |
height=400 | |
) | |
weight_val = st.slider( | |
"How much do the abstract and claims weight when aggregating a total softmax score?", | |
min_value=-1.0, | |
max_value=1.0, | |
value=0.5, | |
) | |
submit = st.form_submit_button('Submit') | |
if submit: | |
is_custom = False | |
if abstract_input is None or len(abstract_input.strip()) == 0: | |
abstract_to_eval = st.session_state.patent_data[st.session_state.patent_num]["abstract"].strip() | |
else: | |
abstract_to_eval = abstract_input.strip() | |
is_custom = True | |
if claim_input is None or len(claim_input.strip()) == 0: | |
claim_to_eval = st.session_state.patent_data[st.session_state.patent_num]["claim"].strip() | |
else: | |
claim_to_eval = claim_input.strip() | |
is_custom = True | |
#tokenized_claim = st.session_state.patent_claim_model.tokenizer.encode(claim_to_eval, padding=True, truncation=True, max_length=512, add_special_tokens = True) | |
#untokenized_claim = st.session_state.patent_claim_model.tokenizer.decode(tokenized_claim) | |
#claim_to_eval2 = untokenized_claim.replace("[CLS]","") | |
#claim_to_eval2 = claim_to_eval2.replace("[SEP]","") | |
#print(claim_to_eval2) | |
abstract_response = st.session_state.patent_abstract_model.predict(abstract_to_eval) | |
claim_response = st.session_state.patent_claim_model.predict(claim_to_eval) | |
print(abstract_response[0]) | |
print(claim_response[0]) | |
print(weight_val) | |
claim_weight = (1+weight_val)/2 | |
abstract_weight = 1-claim_weight | |
aggregate_score = [ | |
{'label':'REJECTED','score':abstract_response[0][0]['score']*abstract_weight + claim_response[0][0]['score']*claim_weight}, | |
{'label':'ACCEPTED','score':abstract_response[0][1]['score']*abstract_weight + claim_response[0][1]['score']*claim_weight} | |
] | |
aggregate_score_sorted = sorted(aggregate_score, key=lambda d: d['score'], reverse=True) | |
print(aggregate_score_sorted) | |
print(f'Original Rating: {st.session_state.patent_data[st.session_state.patent_num]["label"]}') | |
st.markdown("---") | |
answerCol1, answerCol2 = st.columns(2) | |
with answerCol1: | |
st.markdown("### Abstract Ratings") | |
st.markdown(""" | |
> **Reject**: {} | |
> **Accept**: {} | |
""".format(abstract_response[0][0]["score"], abstract_response[0][1]["score"])) | |
with answerCol2: | |
st.markdown("### Claims Ratings") | |
st.markdown(""" | |
> **Reject**: {} | |
> **Accept**: {} | |
""".format(claim_response[0][0]["score"], claim_response[0][1]["score"])) | |
st.markdown(f'### Final Rating: **{aggregate_score_sorted[0]["label"]}**') | |
st.markdown(""" | |
> **Reject**: {} | |
> **Accept**: {} | |
""".format(aggregate_score[0]['score'], aggregate_score[1]['score'])) | |
#if not is_custom: | |
# st.markdown('**Original Score:**') | |
# st.markdown(st.session_state.patent_data[st.session_state.patent_num]["label"]) | |
else: | |
st.write("To get started, access the sidebar on the left (click the arrow in the top-left corner of the screen) and select a tool.") | |
st.write("") | |
st.write("") |