ua-thesis-absa / app.py
gundruke's picture
Update app.py
c4d3f75
import gradio as gr
import torch
import json
from nltk.corpus import wordnet
from transformers import AutoConfig, AutoTokenizer
from models import BERTLstmCRF
from huggingface_hub import hf_hub_download
import os
import nltk
os.system("python -m nltk.downloader all")
checkpoint = "gundruke/bert-lstm-crf-absa"
config = AutoConfig.from_pretrained(checkpoint)
id2label = config.id2label
tokenizer = AutoTokenizer.from_pretrained("gundruke/bert-lstm-crf-absa")
model = BERTLstmCRF(config)
repo = "gundruke/bert-lstm-crf-absa"
filename = "pytorch_model.bin"
model.load_state_dict(torch.load(hf_hub_download(repo_id=repo, filename=filename),
map_location=torch.device('cpu')))
dictionary_file_path = hf_hub_download(repo_id=repo, filename="dictionary.json")
def tokenize_text(text):
tokens = tokenizer.tokenize(text)
tokenized_text = tokenizer(text)
return tokens, tokenized_text
def convert_to_multilabel(label_list):
multilabel = []
if "B-POS" in label_list or "I-POS" in label_list:
multilabel.append("Positive")
if "B-NEG" in label_list or "I-NEG" in label_list:
multilabel.append("Negative")
if "B-NEU" in label_list or "I-NEU" in label_list:
multilabel.append("Neutral")
return " and ".join(multilabel)
def classify_word(word, dictionary):
synsets = wordnet.synsets(word)
if synsets:
hypernyms = synsets[0].hypernyms() # Get the hypernym of the first synset
if hypernyms:
nltk_result = hypernyms[0].lemmas()[0].name()
else:
nltk_result = "Unknown"
else:
nltk_result = "Unknown"
if word in dictionary:
result = dictionary[word]
elif nltk_result in ['atmosphere', 'drinks', 'food', 'price', 'service']:
result = nltk_result
else:
result = 'other'
return result, nltk_result
def get_outputs(tokenized_text):
input_ids = tokenized_text["input_ids"]
token_type_ids = tokenized_text["token_type_ids"]
attention_mask = tokenized_text["attention_mask"]
inputs = {
'input_ids': torch.tensor([input_ids]),
'token_type_ids': torch.tensor([token_type_ids]),
'attention_mask': torch.tensor([attention_mask])
}
with torch.no_grad():
outputs = model(**inputs)
labels = [id2label.get(i) for i in torch.flatten(outputs[1]).tolist()][1:-1]
return labels
def join_wordpieces(tokens, labels):
joined_tokens = []
for token, label in zip(tokens, labels):
if label == "O":
label = None
if token.startswith("##"):
last_token = joined_tokens[-1][0]
joined_tokens[-1] = (last_token+token[2:], label)
else:
joined_tokens.append((token, label))
return joined_tokens
def get_category(word, dict_file):
with open(dict_file, "r") as file:
dictionary = json.load(file)
r, n = classify_word(word, dictionary)
return r
def text_analysis(text):
tokens, tokenized_text = tokenize_text(text)
labels = get_outputs(tokenized_text)
multilabel = convert_to_multilabel(labels)
token_tuple = join_wordpieces(tokens, labels)
tokenized_text["tokens"] = tokens
categories = []
for tok in token_tuple:
if tok[1]:
categories.append((tok[0], get_category(tok[0], dictionary_file_path)))
else:
categories.append((tok[0], None))
return token_tuple, multilabel, categories
theme = gr.themes.Base()
with gr.Blocks(theme=theme) as demo:
with gr.Column():
input_textbox = gr.Textbox(placeholder="Enter sentence here...")
btn = gr.Button("Submit", variant="primary")
btn.click(fn=text_analysis,
inputs=input_textbox,
outputs=[gr.HighlightedText(label="Token labels"),
gr.Label(label="Multilabel classification"),
gr.HighlightedText(label="Category")],
queue=False)
with gr.Column():
examples=[
["I've been coming here as a child and always come back for the taste."],
["The tea is great and all the sweets are homemade."],
["Strong build which really adds to its durability but poor battery life."],
["We loved the recommendation for the wine, and I think the eggplant parmigiana appetizer should become an entree."],
["chicken pasta was tasty, wine was super nice but waiter was rude."]
]
gr.Examples(examples, input_textbox)
demo.launch(debug=True)