Spaces:
Runtime error
Runtime error
import numpy as np | |
import pandas as pd | |
import torch | |
from torch import nn | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import AutoTokenizer, BertModel, BertForSequenceClassification | |
from sklearn import metrics | |
import streamlit as st | |
# Have data for BertClass ready for our tuned model. | |
class BertClass(torch.nn.Module): | |
def __init__(self): | |
super(BertClass, self).__init__() | |
self.l1 = BertModel.from_pretrained(model_path) | |
self.dropout = torch.nn.Dropout(HEAD_DROP_OUT) | |
self.classifier = torch.nn.Linear(768, 6) | |
def forward(self, input_ids, attention_mask, token_type_ids): | |
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) | |
hidden_state = output_1[0] | |
pooler = hidden_state[:, 0] | |
pooler = self.dropout(pooler) | |
output = self.classifier(pooler) | |
return output | |
# Define models to be used | |
bert_path = "bert-base-uncased" | |
bert_tokenizer = AutoTokenizer.from_pretrained(bert_path) | |
bert_model = BertForSequenceClassification.from_pretrained(bert_path, num_labels=6) | |
tuned_model = model = torch.load("pytorch_bert_toxic.bin", map_location=torch.device("cpu")) | |
# Read and format data. | |
tweets_raw = pd.read_csv("test.csv", nrows=20) | |
labels_raw = pd.read_csv("test_labels.csv", nrows=20) | |
label_set = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"] | |
label_vector = labels_raw[label_set].values.tolist() | |
tweet_df = tweets_raw[["comment_text"]] | |
tweet_df["labels"] = label_vector | |
# User selects model for front-end. | |
option = st.selectbox("Select a text analysis model:", ("BERT", "Fine-tuned BERT")) | |
if option == "BERT": | |
tokenizer = bert_tokenizer | |
model = bert_model | |
else: | |
tokenizer = bert_tokenizer | |
model = tuned_model | |
# Dataset for loading tables into DataLoader | |
class ToxicityDataset(Dataset): | |
def __init__(self, dataframe, tokenizer, max_len): | |
self.tokenizer = tokenizer | |
self.data = dataframe | |
self.text = self.data.comment_text | |
self.targets = self.data.labels | |
self.max_len = max_len | |
def __len__(self): | |
return len(self.text) | |
def __getitem__(self, index): | |
text = str(self.text[index]) | |
text = " ".join(text.split()) | |
inputs = self.tokenizer.encode_plus( | |
text, | |
None, | |
add_special_tokens=True, | |
max_length=self.max_len, | |
padding="max_length", | |
truncation=True, | |
return_token_type_ids=True, | |
) | |
ids = inputs["input_ids"] | |
mask = inputs["attention_mask"] | |
token_type_ids = inputs["token_type_ids"] | |
return { | |
"ids": torch.tensor(ids, dtype=torch.long), | |
"mask": torch.tensor(mask, dtype=torch.long), | |
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long), | |
"targets": torch.tensor(self.targets[index], dtype=torch.float), | |
} | |
# Based on user model selection, prepare Dataset and DataLoader | |
MAX_LENGTH = 100 | |
TEST_BATCH_SIZE = 128 | |
infer_dataset = ToxicityDataset(tweet_df, tokenizer, MAX_LENGTH) | |
infer_params = {"batch_size": TEST_BATCH_SIZE, "shuffle": True, "num_workers": 0} | |
infer_loader = DataLoader(test_dataset, **test_params) | |
# Freeze model and input tokens | |
def inference(): | |
model.eval() | |
final_targets = [] | |
final_outputs = [] | |
with torch.no_grad(): | |
for _, data in enumerate(infer_loader, 0): | |
ids = data["ids"].to(device, dtype=torch.long) | |
mask = data["mask"].to(device, dtype=torch.long) | |
token_type_ids = data["token_type_ids"].to(device, dtype=torch.long) | |
targets = data["targets"].to(device, dtype=torch.float) | |
outputs = model(ids, mask, token_type_ids) | |
final_targets.extend(targets.cpu().detach().numpy().tolist()) | |
final_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist()) | |
return final_outputs, final_targets | |
prediction, targets = inference() | |
prediction = np.array(prediction) >= 0.5 | |
targets = np.argmax(targets, axis=1) | |
prediction = np.argmax(prediction, axis=1) | |
accuracy = metrics.accuracy_score(targets, prediction) | |
f1_score_micro = metrics.f1_score(targets, prediction, average="micro") | |
f1_score_macro = metrics.f1_score(targets, prediction, average="macro") | |
st.write(prediction) | |
st.write(f"Accuracy Score = {accuracy}") | |
st.write(f"F1 Score (Micro) = {f1_score_micro}") | |
st.write(f"F1 Score (Macro) = {f1_score_macro}") | |