Spaces:
Runtime error
Runtime error
import numpy as np | |
import pandas as pd | |
import torch | |
from torch import nn | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import AutoTokenizer, BertModel | |
from sklearn import metrics | |
import streamlit as st | |
# Define constants. Enable CUDA if available. | |
MAX_LENGTH = 100 | |
INFER_BATCH_SIZE = 128 | |
HEAD_DROP_OUT = 0.4 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
bert_path = "bert-base-uncased" | |
tokenizer = AutoTokenizer.from_pretrained(bert_path) | |
# Read and format data. | |
tweets_raw = pd.read_csv("test.csv", nrows=20) | |
labels_raw = pd.read_csv("test_labels.csv", nrows=20) | |
label_set = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"] | |
label_vector = labels_raw[label_set].values.tolist() | |
tweet_df = tweets_raw[["comment_text"]] | |
tweet_df["labels"] = label_vector | |
# Dataset for loading tables into DataLoader | |
class ToxicityDataset(Dataset): | |
def __init__(self, dataframe, tokenizer, max_len): | |
self.tokenizer = tokenizer | |
self.data = dataframe | |
self.text = self.data.comment_text | |
self.targets = self.data.labels | |
self.max_len = max_len | |
def __len__(self): | |
return len(self.text) | |
def __getitem__(self, index): | |
text = str(self.text[index]) | |
text = " ".join(text.split()) | |
inputs = self.tokenizer.encode_plus( | |
text, | |
None, | |
add_special_tokens=True, | |
max_length=self.max_len, | |
padding="max_length", | |
truncation=True, | |
return_token_type_ids=True, | |
) | |
ids = inputs["input_ids"] | |
mask = inputs["attention_mask"] | |
token_type_ids = inputs["token_type_ids"] | |
return { | |
"ids": torch.tensor(ids, dtype=torch.long), | |
"mask": torch.tensor(mask, dtype=torch.long), | |
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long), | |
"targets": torch.tensor(self.targets[index], dtype=torch.float), | |
} | |
# Based on user model selection, prepare Dataset and DataLoader | |
infer_dataset = ToxicityDataset(tweet_df, tokenizer, MAX_LENGTH) | |
infer_params = {"batch_size": INFER_BATCH_SIZE, "shuffle": False} | |
infer_loader = DataLoader(infer_dataset, **infer_params) | |
class BertClass(torch.nn.Module): | |
def __init__(self): | |
super(BertClass, self).__init__() | |
self.l1 = BertModel.from_pretrained(bert_path) | |
self.dropout = torch.nn.Dropout(HEAD_DROP_OUT) | |
self.classifier = torch.nn.Linear(768, 6) | |
# return_dict must equal False for Huggingface Transformers v4+ | |
def forward(self, input_ids, attention_mask, token_type_ids): | |
output_1 = self.l1( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
return_dict=False, | |
) | |
hidden_state = output_1[0] | |
pooler = hidden_state[:, 0] | |
pooler = self.dropout(pooler) | |
output = self.classifier(pooler) | |
return output | |
class PretrainedBertClass(torch.nn.Module): | |
def __init__(self): | |
super(PretrainedBertClass, self).__init__() | |
self.l1 = BertModel.from_pretrained(bert_path) | |
self.l2 = torch.nn.Dropout(HEAD_DROP_OUT) | |
self.l3 = torch.nn.Linear(768, 6) | |
def forward(self, ids, mask, token_type_ids): | |
_, output_1= self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids, return_dict=False) | |
output_2 = self.l2(output_1) | |
output = self.l3(output_2) | |
return output | |
# User selects model for front-end. | |
option = st.selectbox("Select a text analysis model:", ("BERT", "Fine-tuned BERT")) | |
if option == "BERT": | |
model = PretrainedBertClass() | |
else: | |
model = torch.load("pytorch_bert_toxic.bin", map_location=torch.device("cpu")) | |
# Freeze model and input tokens | |
def inference(): | |
model.eval() | |
final_targets = [] | |
final_outputs = [] | |
with torch.no_grad(): | |
for _, data in enumerate(infer_loader, 0): | |
ids = data["ids"].to(device, dtype=torch.long) | |
mask = data["mask"].to(device, dtype=torch.long) | |
token_type_ids = data["token_type_ids"].to(device, dtype=torch.long) | |
targets = data["targets"].to(device, dtype=torch.float) | |
outputs = model(ids, mask, token_type_ids) | |
final_targets.extend(targets.cpu().detach().numpy().tolist()) | |
final_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist()) | |
return final_outputs, final_targets | |
prediction, targets = inference() | |
st.write("before argmax") | |
st.write(prediction) | |
st.write(targets) | |