Spaces:
Runtime error
Runtime error
Michelle Li
commited on
add fine-tune reddit models
Browse files- fine_tune_bert.py +115 -0
- rename_labels.py +13 -0
- test_model.py +21 -0
fine_tune_bert.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Convert Reddit data into a format that can be used by the BERT model
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from sklearn.model_selection import train_test_split
|
| 5 |
+
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import (
|
| 8 |
+
DistilBertTokenizerFast,
|
| 9 |
+
DistilBertForSequenceClassification,
|
| 10 |
+
Trainer,
|
| 11 |
+
TrainingArguments,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
# train test split
|
| 15 |
+
def read_reddit_split(reddit_csv):
|
| 16 |
+
df = pd.read_csv(reddit_csv)
|
| 17 |
+
texts = df["body"].tolist()
|
| 18 |
+
labels = df["Class"].tolist()
|
| 19 |
+
# 80% train, 10% test, 10% valid
|
| 20 |
+
train_texts, test_texts, train_labels, test_labels = train_test_split(
|
| 21 |
+
texts, labels, test_size=0.2, stratify=labels
|
| 22 |
+
)
|
| 23 |
+
train_texts, val_texts, train_labels, val_labels = train_test_split(
|
| 24 |
+
train_texts, train_labels, test_size=1.0 / 8.0, stratify=train_labels
|
| 25 |
+
)
|
| 26 |
+
print(f"size of train data is {len(train_texts)}")
|
| 27 |
+
print(f"size of test data is {len(test_texts)}")
|
| 28 |
+
print(f"size of valid data is {len(val_texts)}")
|
| 29 |
+
return train_texts, test_texts, val_texts, train_labels, test_labels, val_labels
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# tokenize data
|
| 33 |
+
def tokenize_data(train_texts, test_texts, val_texts, tokenizer):
|
| 34 |
+
train_enc = tokenizer(train_texts, truncation=True, padding=True)
|
| 35 |
+
test_enc = tokenizer(test_texts, truncation=True, padding=True)
|
| 36 |
+
valid_enc = tokenizer(val_texts, truncation=True, padding=True)
|
| 37 |
+
return train_enc, test_enc, valid_enc
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# convert to Dataset object
|
| 41 |
+
class RedditDataset(torch.utils.data.Dataset):
|
| 42 |
+
def __init__(self, encodings, labels):
|
| 43 |
+
self.encodings = encodings
|
| 44 |
+
self.labels = labels
|
| 45 |
+
|
| 46 |
+
def __getitem__(self, idx):
|
| 47 |
+
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
| 48 |
+
item["labels"] = torch.tensor(self.labels[idx])
|
| 49 |
+
return item
|
| 50 |
+
|
| 51 |
+
def __len__(self):
|
| 52 |
+
return len(self.labels)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def compute_metrics(pred):
|
| 56 |
+
labels = pred.label_ids
|
| 57 |
+
preds = pred.predictions.argmax(-1)
|
| 58 |
+
precision, recall, f1, _ = precision_recall_fscore_support(
|
| 59 |
+
labels, preds, average="binary"
|
| 60 |
+
)
|
| 61 |
+
acc = accuracy_score(labels, preds)
|
| 62 |
+
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# run on main
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
# tokenizer
|
| 68 |
+
bert_tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
|
| 69 |
+
# model
|
| 70 |
+
model = DistilBertForSequenceClassification.from_pretrained(
|
| 71 |
+
"distilbert-base-uncased"
|
| 72 |
+
)
|
| 73 |
+
# read data
|
| 74 |
+
x_train, x_test, x_valid, y_train, y_test, y_valid = read_reddit_split(
|
| 75 |
+
"/workspaces/Michelle_Li_NLP_Project/reddit_data/reddit_annotated.csv"
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# tokenize data
|
| 79 |
+
train_encodings, test_encodings, valid_encodings = tokenize_data(
|
| 80 |
+
x_train, x_test, x_valid, bert_tokenizer
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
train_dataset = RedditDataset(train_encodings, y_train)
|
| 84 |
+
test_dataset = RedditDataset(test_encodings, y_test)
|
| 85 |
+
val_dataset = RedditDataset(valid_encodings, y_valid)
|
| 86 |
+
|
| 87 |
+
# fine-tune BERT model
|
| 88 |
+
training_args = TrainingArguments(
|
| 89 |
+
output_dir="./results", # output directory
|
| 90 |
+
num_train_epochs=3, # total number of training epochs
|
| 91 |
+
per_device_train_batch_size=16, # batch size per device during training
|
| 92 |
+
per_device_eval_batch_size=64, # batch size for evaluation
|
| 93 |
+
warmup_steps=500, # number of warmup steps for learning rate scheduler
|
| 94 |
+
weight_decay=0.01, # strength of weight decay
|
| 95 |
+
logging_dir="./logs", # directory for storing logs
|
| 96 |
+
logging_steps=10,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
trainer = Trainer(
|
| 100 |
+
model=model, # the instantiated 🤗 Transformers model to be trained
|
| 101 |
+
args=training_args, # training arguments, defined above
|
| 102 |
+
train_dataset=train_dataset, # training dataset
|
| 103 |
+
eval_dataset=val_dataset, # evaluation dataset
|
| 104 |
+
compute_metrics=compute_metrics, # compute metrics
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
trainer.train()
|
| 108 |
+
|
| 109 |
+
# test model
|
| 110 |
+
trainer.evaluate(test_dataset)
|
| 111 |
+
|
| 112 |
+
# save model
|
| 113 |
+
trainer.save_model("./models")
|
| 114 |
+
# save tokenizer
|
| 115 |
+
bert_tokenizer.save_pretrained("./models")
|
rename_labels.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoConfig, AutoModelForSequenceClassification
|
| 2 |
+
|
| 3 |
+
# define the mappings as dictionaries
|
| 4 |
+
label2id = {"NSFW": 1, "SFW": 0}
|
| 5 |
+
id2label = {1: "NSFW", 0: "SFW"}
|
| 6 |
+
# define model checkpoint - can be the same model that you already have on the hub
|
| 7 |
+
model_ckpt = "michellejieli/NSFW_text_classifier"
|
| 8 |
+
# define config
|
| 9 |
+
config = AutoConfig.from_pretrained(model_ckpt, label2id=label2id, id2label=id2label)
|
| 10 |
+
# load model with config
|
| 11 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, config=config)
|
| 12 |
+
# export model
|
| 13 |
+
model.save_pretrained("./models")
|
test_model.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import pipeline
|
| 2 |
+
from transformers import DistilBertTokenizerFast
|
| 3 |
+
|
| 4 |
+
# classification
|
| 5 |
+
text = "This was a masterpiece. Not completely faithful to the books, but enthralling from beginning to end. Might be my favorite of the three."
|
| 6 |
+
text2 = "fucking hell pissing me off"
|
| 7 |
+
text3 = "I see you’ve set aside this special time to humiliate yourself in public."
|
| 8 |
+
text4 = "Wow, congratulations! So excited for you!"
|
| 9 |
+
# hugging-face model
|
| 10 |
+
classifier = pipeline("sentiment-analysis", model="michellejieli/NSFW_text_classifier")
|
| 11 |
+
# locally
|
| 12 |
+
local_classifier = pipeline("sentiment-analysis", model="/workspaces/Michelle_Li_NLP_Project/hugging-face/models")
|
| 13 |
+
# print results of local model
|
| 14 |
+
print(local_classifier(text))
|
| 15 |
+
print(local_classifier(text2))
|
| 16 |
+
print(local_classifier(text3))
|
| 17 |
+
print(local_classifier(text4))
|
| 18 |
+
print(classifier(text))
|
| 19 |
+
print(classifier(text2))
|
| 20 |
+
print(classifier(text3))
|
| 21 |
+
print(classifier(text4))
|