Michelle Li commited on
Commit
cfafbd3
·
unverified ·
1 Parent(s): 1b4e06d

add fine-tune reddit models

Browse files
Files changed (3) hide show
  1. fine_tune_bert.py +115 -0
  2. rename_labels.py +13 -0
  3. test_model.py +21 -0
fine_tune_bert.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Convert Reddit data into a format that can be used by the BERT model
2
+
3
+ import pandas as pd
4
+ from sklearn.model_selection import train_test_split
5
+ from sklearn.metrics import precision_recall_fscore_support, accuracy_score
6
+ import torch
7
+ from transformers import (
8
+ DistilBertTokenizerFast,
9
+ DistilBertForSequenceClassification,
10
+ Trainer,
11
+ TrainingArguments,
12
+ )
13
+
14
+ # train test split
15
+ def read_reddit_split(reddit_csv):
16
+ df = pd.read_csv(reddit_csv)
17
+ texts = df["body"].tolist()
18
+ labels = df["Class"].tolist()
19
+ # 80% train, 10% test, 10% valid
20
+ train_texts, test_texts, train_labels, test_labels = train_test_split(
21
+ texts, labels, test_size=0.2, stratify=labels
22
+ )
23
+ train_texts, val_texts, train_labels, val_labels = train_test_split(
24
+ train_texts, train_labels, test_size=1.0 / 8.0, stratify=train_labels
25
+ )
26
+ print(f"size of train data is {len(train_texts)}")
27
+ print(f"size of test data is {len(test_texts)}")
28
+ print(f"size of valid data is {len(val_texts)}")
29
+ return train_texts, test_texts, val_texts, train_labels, test_labels, val_labels
30
+
31
+
32
+ # tokenize data
33
+ def tokenize_data(train_texts, test_texts, val_texts, tokenizer):
34
+ train_enc = tokenizer(train_texts, truncation=True, padding=True)
35
+ test_enc = tokenizer(test_texts, truncation=True, padding=True)
36
+ valid_enc = tokenizer(val_texts, truncation=True, padding=True)
37
+ return train_enc, test_enc, valid_enc
38
+
39
+
40
+ # convert to Dataset object
41
+ class RedditDataset(torch.utils.data.Dataset):
42
+ def __init__(self, encodings, labels):
43
+ self.encodings = encodings
44
+ self.labels = labels
45
+
46
+ def __getitem__(self, idx):
47
+ item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
48
+ item["labels"] = torch.tensor(self.labels[idx])
49
+ return item
50
+
51
+ def __len__(self):
52
+ return len(self.labels)
53
+
54
+
55
+ def compute_metrics(pred):
56
+ labels = pred.label_ids
57
+ preds = pred.predictions.argmax(-1)
58
+ precision, recall, f1, _ = precision_recall_fscore_support(
59
+ labels, preds, average="binary"
60
+ )
61
+ acc = accuracy_score(labels, preds)
62
+ return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
63
+
64
+
65
+ # run on main
66
+ if __name__ == "__main__":
67
+ # tokenizer
68
+ bert_tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
69
+ # model
70
+ model = DistilBertForSequenceClassification.from_pretrained(
71
+ "distilbert-base-uncased"
72
+ )
73
+ # read data
74
+ x_train, x_test, x_valid, y_train, y_test, y_valid = read_reddit_split(
75
+ "/workspaces/Michelle_Li_NLP_Project/reddit_data/reddit_annotated.csv"
76
+ )
77
+
78
+ # tokenize data
79
+ train_encodings, test_encodings, valid_encodings = tokenize_data(
80
+ x_train, x_test, x_valid, bert_tokenizer
81
+ )
82
+
83
+ train_dataset = RedditDataset(train_encodings, y_train)
84
+ test_dataset = RedditDataset(test_encodings, y_test)
85
+ val_dataset = RedditDataset(valid_encodings, y_valid)
86
+
87
+ # fine-tune BERT model
88
+ training_args = TrainingArguments(
89
+ output_dir="./results", # output directory
90
+ num_train_epochs=3, # total number of training epochs
91
+ per_device_train_batch_size=16, # batch size per device during training
92
+ per_device_eval_batch_size=64, # batch size for evaluation
93
+ warmup_steps=500, # number of warmup steps for learning rate scheduler
94
+ weight_decay=0.01, # strength of weight decay
95
+ logging_dir="./logs", # directory for storing logs
96
+ logging_steps=10,
97
+ )
98
+
99
+ trainer = Trainer(
100
+ model=model, # the instantiated 🤗 Transformers model to be trained
101
+ args=training_args, # training arguments, defined above
102
+ train_dataset=train_dataset, # training dataset
103
+ eval_dataset=val_dataset, # evaluation dataset
104
+ compute_metrics=compute_metrics, # compute metrics
105
+ )
106
+
107
+ trainer.train()
108
+
109
+ # test model
110
+ trainer.evaluate(test_dataset)
111
+
112
+ # save model
113
+ trainer.save_model("./models")
114
+ # save tokenizer
115
+ bert_tokenizer.save_pretrained("./models")
rename_labels.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModelForSequenceClassification
2
+
3
+ # define the mappings as dictionaries
4
+ label2id = {"NSFW": 1, "SFW": 0}
5
+ id2label = {1: "NSFW", 0: "SFW"}
6
+ # define model checkpoint - can be the same model that you already have on the hub
7
+ model_ckpt = "michellejieli/NSFW_text_classifier"
8
+ # define config
9
+ config = AutoConfig.from_pretrained(model_ckpt, label2id=label2id, id2label=id2label)
10
+ # load model with config
11
+ model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, config=config)
12
+ # export model
13
+ model.save_pretrained("./models")
test_model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ from transformers import DistilBertTokenizerFast
3
+
4
+ # classification
5
+ text = "This was a masterpiece. Not completely faithful to the books, but enthralling from beginning to end. Might be my favorite of the three."
6
+ text2 = "fucking hell pissing me off"
7
+ text3 = "I see you’ve set aside this special time to humiliate yourself in public."
8
+ text4 = "Wow, congratulations! So excited for you!"
9
+ # hugging-face model
10
+ classifier = pipeline("sentiment-analysis", model="michellejieli/NSFW_text_classifier")
11
+ # locally
12
+ local_classifier = pipeline("sentiment-analysis", model="/workspaces/Michelle_Li_NLP_Project/hugging-face/models")
13
+ # print results of local model
14
+ print(local_classifier(text))
15
+ print(local_classifier(text2))
16
+ print(local_classifier(text3))
17
+ print(local_classifier(text4))
18
+ print(classifier(text))
19
+ print(classifier(text2))
20
+ print(classifier(text3))
21
+ print(classifier(text4))