Spaces:
Running
Running
import gradio as gr | |
import torch | |
import json | |
from nltk.corpus import wordnet | |
from transformers import AutoConfig, AutoTokenizer | |
from models import BERTLstmCRF | |
from huggingface_hub import hf_hub_download | |
import os | |
import nltk | |
os.system("python -m nltk.downloader all") | |
checkpoint = "gundruke/bert-lstm-crf-absa" | |
config = AutoConfig.from_pretrained(checkpoint) | |
id2label = config.id2label | |
tokenizer = AutoTokenizer.from_pretrained("gundruke/bert-lstm-crf-absa") | |
model = BERTLstmCRF(config) | |
repo = "gundruke/bert-lstm-crf-absa" | |
filename = "pytorch_model.bin" | |
model.load_state_dict(torch.load(hf_hub_download(repo_id=repo, filename=filename), | |
map_location=torch.device('cpu'))) | |
dictionary_file_path = hf_hub_download(repo_id=repo, filename="dictionary.json") | |
def tokenize_text(text): | |
tokens = tokenizer.tokenize(text) | |
tokenized_text = tokenizer(text) | |
return tokens, tokenized_text | |
def convert_to_multilabel(label_list): | |
multilabel = [] | |
if "B-POS" in label_list or "I-POS" in label_list: | |
multilabel.append("Positive") | |
if "B-NEG" in label_list or "I-NEG" in label_list: | |
multilabel.append("Negative") | |
if "B-NEU" in label_list or "I-NEU" in label_list: | |
multilabel.append("Neutral") | |
return " and ".join(multilabel) | |
def classify_word(word, dictionary): | |
synsets = wordnet.synsets(word) | |
if synsets: | |
hypernyms = synsets[0].hypernyms() # Get the hypernym of the first synset | |
if hypernyms: | |
nltk_result = hypernyms[0].lemmas()[0].name() | |
else: | |
nltk_result = "Unknown" | |
else: | |
nltk_result = "Unknown" | |
if word in dictionary: | |
result = dictionary[word] | |
elif nltk_result in ['atmosphere', 'drinks', 'food', 'price', 'service']: | |
result = nltk_result | |
else: | |
result = 'other' | |
return result, nltk_result | |
def get_outputs(tokenized_text): | |
input_ids = tokenized_text["input_ids"] | |
token_type_ids = tokenized_text["token_type_ids"] | |
attention_mask = tokenized_text["attention_mask"] | |
inputs = { | |
'input_ids': torch.tensor([input_ids]), | |
'token_type_ids': torch.tensor([token_type_ids]), | |
'attention_mask': torch.tensor([attention_mask]) | |
} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
labels = [id2label.get(i) for i in torch.flatten(outputs[1]).tolist()][1:-1] | |
return labels | |
def join_wordpieces(tokens, labels): | |
joined_tokens = [] | |
for token, label in zip(tokens, labels): | |
if label == "O": | |
label = None | |
if token.startswith("##"): | |
last_token = joined_tokens[-1][0] | |
joined_tokens[-1] = (last_token+token[2:], label) | |
else: | |
joined_tokens.append((token, label)) | |
return joined_tokens | |
def get_category(word, dict_file): | |
with open(dict_file, "r") as file: | |
dictionary = json.load(file) | |
r, n = classify_word(word, dictionary) | |
return r | |
def text_analysis(text): | |
tokens, tokenized_text = tokenize_text(text) | |
labels = get_outputs(tokenized_text) | |
multilabel = convert_to_multilabel(labels) | |
token_tuple = join_wordpieces(tokens, labels) | |
tokenized_text["tokens"] = tokens | |
categories = [] | |
for tok in token_tuple: | |
if tok[1]: | |
categories.append((tok[0], get_category(tok[0], dictionary_file_path))) | |
else: | |
categories.append((tok[0], None)) | |
return token_tuple, multilabel, categories | |
theme = gr.themes.Base() | |
with gr.Blocks(theme=theme) as demo: | |
with gr.Column(): | |
input_textbox = gr.Textbox(placeholder="Enter sentence here...") | |
btn = gr.Button("Submit", variant="primary") | |
btn.click(fn=text_analysis, | |
inputs=input_textbox, | |
outputs=[gr.HighlightedText(label="Token labels"), | |
gr.Label(label="Multilabel classification"), | |
gr.HighlightedText(label="Category")], | |
queue=False) | |
with gr.Column(): | |
examples=[ | |
["I've been coming here as a child and always come back for the taste."], | |
["The tea is great and all the sweets are homemade."], | |
["Strong build which really adds to its durability but poor battery life."], | |
["We loved the recommendation for the wine, and I think the eggplant parmigiana appetizer should become an entree."], | |
["chicken pasta was tasty, wine was super nice but waiter was rude."] | |
] | |
gr.Examples(examples, input_textbox) | |
demo.launch(debug=True) | |