Spaces:
Runtime error
Runtime error
Michelle Li
commited on
add fine-tune reddit models
Browse files- fine_tune_bert.py +115 -0
- rename_labels.py +13 -0
- test_model.py +21 -0
fine_tune_bert.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Convert Reddit data into a format that can be used by the BERT model
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
|
6 |
+
import torch
|
7 |
+
from transformers import (
|
8 |
+
DistilBertTokenizerFast,
|
9 |
+
DistilBertForSequenceClassification,
|
10 |
+
Trainer,
|
11 |
+
TrainingArguments,
|
12 |
+
)
|
13 |
+
|
14 |
+
# train test split
|
15 |
+
def read_reddit_split(reddit_csv):
|
16 |
+
df = pd.read_csv(reddit_csv)
|
17 |
+
texts = df["body"].tolist()
|
18 |
+
labels = df["Class"].tolist()
|
19 |
+
# 80% train, 10% test, 10% valid
|
20 |
+
train_texts, test_texts, train_labels, test_labels = train_test_split(
|
21 |
+
texts, labels, test_size=0.2, stratify=labels
|
22 |
+
)
|
23 |
+
train_texts, val_texts, train_labels, val_labels = train_test_split(
|
24 |
+
train_texts, train_labels, test_size=1.0 / 8.0, stratify=train_labels
|
25 |
+
)
|
26 |
+
print(f"size of train data is {len(train_texts)}")
|
27 |
+
print(f"size of test data is {len(test_texts)}")
|
28 |
+
print(f"size of valid data is {len(val_texts)}")
|
29 |
+
return train_texts, test_texts, val_texts, train_labels, test_labels, val_labels
|
30 |
+
|
31 |
+
|
32 |
+
# tokenize data
|
33 |
+
def tokenize_data(train_texts, test_texts, val_texts, tokenizer):
|
34 |
+
train_enc = tokenizer(train_texts, truncation=True, padding=True)
|
35 |
+
test_enc = tokenizer(test_texts, truncation=True, padding=True)
|
36 |
+
valid_enc = tokenizer(val_texts, truncation=True, padding=True)
|
37 |
+
return train_enc, test_enc, valid_enc
|
38 |
+
|
39 |
+
|
40 |
+
# convert to Dataset object
|
41 |
+
class RedditDataset(torch.utils.data.Dataset):
|
42 |
+
def __init__(self, encodings, labels):
|
43 |
+
self.encodings = encodings
|
44 |
+
self.labels = labels
|
45 |
+
|
46 |
+
def __getitem__(self, idx):
|
47 |
+
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
48 |
+
item["labels"] = torch.tensor(self.labels[idx])
|
49 |
+
return item
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return len(self.labels)
|
53 |
+
|
54 |
+
|
55 |
+
def compute_metrics(pred):
|
56 |
+
labels = pred.label_ids
|
57 |
+
preds = pred.predictions.argmax(-1)
|
58 |
+
precision, recall, f1, _ = precision_recall_fscore_support(
|
59 |
+
labels, preds, average="binary"
|
60 |
+
)
|
61 |
+
acc = accuracy_score(labels, preds)
|
62 |
+
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
|
63 |
+
|
64 |
+
|
65 |
+
# run on main
|
66 |
+
if __name__ == "__main__":
|
67 |
+
# tokenizer
|
68 |
+
bert_tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
|
69 |
+
# model
|
70 |
+
model = DistilBertForSequenceClassification.from_pretrained(
|
71 |
+
"distilbert-base-uncased"
|
72 |
+
)
|
73 |
+
# read data
|
74 |
+
x_train, x_test, x_valid, y_train, y_test, y_valid = read_reddit_split(
|
75 |
+
"/workspaces/Michelle_Li_NLP_Project/reddit_data/reddit_annotated.csv"
|
76 |
+
)
|
77 |
+
|
78 |
+
# tokenize data
|
79 |
+
train_encodings, test_encodings, valid_encodings = tokenize_data(
|
80 |
+
x_train, x_test, x_valid, bert_tokenizer
|
81 |
+
)
|
82 |
+
|
83 |
+
train_dataset = RedditDataset(train_encodings, y_train)
|
84 |
+
test_dataset = RedditDataset(test_encodings, y_test)
|
85 |
+
val_dataset = RedditDataset(valid_encodings, y_valid)
|
86 |
+
|
87 |
+
# fine-tune BERT model
|
88 |
+
training_args = TrainingArguments(
|
89 |
+
output_dir="./results", # output directory
|
90 |
+
num_train_epochs=3, # total number of training epochs
|
91 |
+
per_device_train_batch_size=16, # batch size per device during training
|
92 |
+
per_device_eval_batch_size=64, # batch size for evaluation
|
93 |
+
warmup_steps=500, # number of warmup steps for learning rate scheduler
|
94 |
+
weight_decay=0.01, # strength of weight decay
|
95 |
+
logging_dir="./logs", # directory for storing logs
|
96 |
+
logging_steps=10,
|
97 |
+
)
|
98 |
+
|
99 |
+
trainer = Trainer(
|
100 |
+
model=model, # the instantiated 🤗 Transformers model to be trained
|
101 |
+
args=training_args, # training arguments, defined above
|
102 |
+
train_dataset=train_dataset, # training dataset
|
103 |
+
eval_dataset=val_dataset, # evaluation dataset
|
104 |
+
compute_metrics=compute_metrics, # compute metrics
|
105 |
+
)
|
106 |
+
|
107 |
+
trainer.train()
|
108 |
+
|
109 |
+
# test model
|
110 |
+
trainer.evaluate(test_dataset)
|
111 |
+
|
112 |
+
# save model
|
113 |
+
trainer.save_model("./models")
|
114 |
+
# save tokenizer
|
115 |
+
bert_tokenizer.save_pretrained("./models")
|
rename_labels.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoConfig, AutoModelForSequenceClassification
|
2 |
+
|
3 |
+
# define the mappings as dictionaries
|
4 |
+
label2id = {"NSFW": 1, "SFW": 0}
|
5 |
+
id2label = {1: "NSFW", 0: "SFW"}
|
6 |
+
# define model checkpoint - can be the same model that you already have on the hub
|
7 |
+
model_ckpt = "michellejieli/NSFW_text_classifier"
|
8 |
+
# define config
|
9 |
+
config = AutoConfig.from_pretrained(model_ckpt, label2id=label2id, id2label=id2label)
|
10 |
+
# load model with config
|
11 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, config=config)
|
12 |
+
# export model
|
13 |
+
model.save_pretrained("./models")
|
test_model.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import pipeline
|
2 |
+
from transformers import DistilBertTokenizerFast
|
3 |
+
|
4 |
+
# classification
|
5 |
+
text = "This was a masterpiece. Not completely faithful to the books, but enthralling from beginning to end. Might be my favorite of the three."
|
6 |
+
text2 = "fucking hell pissing me off"
|
7 |
+
text3 = "I see you’ve set aside this special time to humiliate yourself in public."
|
8 |
+
text4 = "Wow, congratulations! So excited for you!"
|
9 |
+
# hugging-face model
|
10 |
+
classifier = pipeline("sentiment-analysis", model="michellejieli/NSFW_text_classifier")
|
11 |
+
# locally
|
12 |
+
local_classifier = pipeline("sentiment-analysis", model="/workspaces/Michelle_Li_NLP_Project/hugging-face/models")
|
13 |
+
# print results of local model
|
14 |
+
print(local_classifier(text))
|
15 |
+
print(local_classifier(text2))
|
16 |
+
print(local_classifier(text3))
|
17 |
+
print(local_classifier(text4))
|
18 |
+
print(classifier(text))
|
19 |
+
print(classifier(text2))
|
20 |
+
print(classifier(text3))
|
21 |
+
print(classifier(text4))
|