philipp-zettl commited on
Commit
4706395
·
verified ·
1 Parent(s): 94bf5e1

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +342 -0
model.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description: Classification models
2
+ from transformers import AutoModel, AutoTokenizer, BatchEncoding, TrainingArguments, Trainer
3
+ from functools import partial
4
+ from huggingface_hub import snapshot_download
5
+ from huggingface_hub.constants import HF_HUB_CACHE
6
+ from accelerate import Accelerator
7
+ from accelerate.utils import find_executable_batch_size as auto_find_batch_size
8
+ from datasets import load_dataset, Dataset
9
+ from torch.utils.data import DataLoader
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ import numpy as np
14
+ import json
15
+ import os
16
+ from tqdm import tqdm
17
+ import pandas as pd
18
+
19
+ import matplotlib.pyplot as plt
20
+ from sklearn.metrics import (
21
+ ConfusionMatrixDisplay,
22
+ accuracy_score,
23
+ classification_report,
24
+ confusion_matrix,
25
+ f1_score,
26
+ recall_score
27
+ )
28
+
29
+ BASE_PATH = os.path.dirname(os.path.abspath(__file__))
30
+
31
+
32
+ class MultiHeadClassification(nn.Module):
33
+ """
34
+ MultiHeadClassification
35
+
36
+ An easy to use multi-head classification model. It takes a backbone model and a dictionary of head configurations.
37
+ It can be used to train multiple classification tasks at once using a single backbone model.
38
+
39
+ Apart from joint training, it also supports training individual heads separately, providing a simple way to freeze
40
+ and unfreeze heads.
41
+
42
+ Example:
43
+ >>> from transformers import AutoModel, AutoTokenizer
44
+ >>> from torch.optim import AdamW
45
+ >>> import torch
46
+ >>> import time
47
+ >>> import torch.nn as nn
48
+ >>>
49
+ >>> # Manually load backbone model to create model
50
+ >>> backbone = AutoModel.from_pretrained('BAAI/bge-m3')
51
+ >>> model = MultiHeadClassification(backbone, {'binary': 2, 'sentiment': 3, 'something': 4}).to('cuda')
52
+ >>> print(model)
53
+ >>> # Load tokenizer for data preprocessing
54
+ >>> tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
55
+ >>> # some training data
56
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt", padding=True, truncation=True)
57
+ >>> optimizer = AdamW(model.parameters(), lr=5e-4)
58
+ >>> samples = tokenizer(["Hello, my dog is cute", "Hello, my dog is cute", "I like turtles"], return_tensors="pt", padding=True, truncation=True).to('cuda')
59
+ >>> labels = {'binary': torch.tensor([0, 0, 1]), 'sentiment': torch.tensor([0, 1, 2]), 'something': torch.tensor([0, 1, 2])}
60
+ >>> model.freeze_backbone()
61
+ >>> model.train(True)
62
+ >>> for i in range(10):
63
+ ... optimizer.zero_grad()
64
+ ... outputs = model(samples)
65
+ ... loss = sum([nn.CrossEntropyLoss()(outputs[name].cpu(), labels[name]) for name in model.heads.keys()])
66
+ ... loss.backward()
67
+ ... optimizer.step()
68
+ ... print(loss.item())
69
+ ... #time.sleep(1)
70
+ ... print(model(samples))
71
+ >>> # Save full model
72
+ >>> model.save('model.pth')
73
+ >>> # Save head only
74
+ >>> model.save_head('binary', 'binary.pth')
75
+ >>> # Load full model
76
+ >>> model = MultiHeadClassification(backbone, {}).to('cuda')
77
+ >>> model.load('model.pth')
78
+ >>> # Load head only
79
+ >>> model = MultiHeadClassification(backbone, {}).to('cuda')
80
+ >>> model.load_head('binary', 'binary.pth')
81
+ >>> # Adding new head
82
+ >>> model.add_head('new_head', 3)
83
+ >>> print(model)
84
+ >>> # extend dataset with data for new head
85
+ >>> labels['new_head'] = torch.tensor([0, 1, 2])
86
+ >>> # Freeze all heads and backbone
87
+ >>> model.freeze_all()
88
+ >>> # Only unfreeze new head
89
+ >>> model.unfreeze_head('new_head')
90
+ >>> model.train(True)
91
+ >>> for i in range(10):
92
+ ... optimizer.zero_grad()
93
+ ... outputs = model(samples)
94
+ ... loss = sum([nn.CrossEntropyLoss()(outputs[name].cpu(), labels[name]) for name in model.heads.keys()])
95
+ ... loss.backward()
96
+ ... optimizer.step()
97
+ ... print(loss.item())
98
+ >>> print(model(samples))
99
+
100
+ Args:
101
+ backbone (transformers.PreTrainedModel): A pretrained transformer model
102
+ head_config (dict): A dictionary with head configurations. The key is the head name and the value is the number
103
+ of classes for that head.
104
+ """
105
+ def __init__(self, backbone, head_config, dropout=0.1, l2_reg=0.01):
106
+ super().__init__()
107
+ self.backbone = backbone
108
+ self.num_heads = len(head_config)
109
+ self.heads = nn.ModuleDict({
110
+ name: nn.Linear(backbone.config.hidden_size, num_classes)
111
+ for name, num_classes in head_config.items()
112
+ })
113
+ self.do = nn.Dropout(dropout)
114
+ self.l2_reg = l2_reg
115
+ self.device = 'cpu'
116
+ self.torch_dtype = torch.float16
117
+ self.head_config = head_config
118
+
119
+ def forward(self, x, head_names=None) -> dict:
120
+ """
121
+ Forward pass of the model.
122
+
123
+ Requires tokenizer output as input. The input should be a dictionary with keys 'input_ids', 'attention_mask'.
124
+
125
+ Args:
126
+ x (dict): Tokenizer output
127
+ head_names (list): (optional) List of head names to return logits for. If None, returns logits for all heads.
128
+
129
+ Returns:
130
+ dict: A dictionary with head names as keys and logits as values
131
+ """
132
+ x = self.backbone(**x, return_dict=True, output_hidden_states=True).last_hidden_state[:, 0, :]
133
+ x = self.do(x)
134
+ if head_names is None:
135
+ return {name: head(x) for name, head in self.heads.items()}
136
+ return {name: head(x) for name, head in self.heads.items() if name in head_names}
137
+
138
+ def get_l2_loss(self):
139
+ """
140
+ Getter for L2 regularization loss
141
+
142
+ Returns:
143
+ torch.Tensor: L2 regularization loss
144
+ """
145
+ l2_loss = torch.tensor(0.).to(self.device)
146
+ for param in self.parameters():
147
+ if param.requires_grad:
148
+ l2_loss += torch.norm(param, 2)
149
+ return (self.l2_reg * l2_loss).to(self.device)
150
+
151
+ def to(self, *args, **kwargs):
152
+ super().to(*args, **kwargs)
153
+ if isinstance(args[0], torch.dtype):
154
+ self.torch_dtype = args[0]
155
+ elif isinstance(args[0], str):
156
+ self.device = args[0]
157
+ return self
158
+
159
+ def load_head(self, head_name, path):
160
+ """
161
+ Load head from a file
162
+
163
+ Args:
164
+ head_name (str): Name of the head
165
+ path (str): Path to the file
166
+
167
+ Returns:
168
+ None
169
+ """
170
+ model = torch.load(path)
171
+ if head_name in self.heads:
172
+ num_classes = model['weight'].shape[0]
173
+ self.heads[head_name].load_state_dict(model)
174
+ self.to(self.torch_dtype).to(self.device)
175
+ self.head_config[head_name] = num_classes
176
+ return
177
+
178
+ assert model['weight'].shape[1] == self.backbone.config.hidden_size
179
+ num_classes = model['weight'].shape[0]
180
+ self.heads[head_name] = nn.Linear(self.backbone.config.hidden_size, num_classes)
181
+ self.heads[head_name].load_state_dict(model)
182
+ self.head_config[head_name] = num_classes
183
+
184
+ self.to(self.torch_dtype).to(self.device)
185
+
186
+ def save_head(self, head_name, path):
187
+ """
188
+ Save head to a file
189
+
190
+ Args:
191
+ head_name (str): Name of the head
192
+ path (str): Path to the file
193
+ """
194
+ torch.save(self.heads[head_name].state_dict(), path)
195
+
196
+ def save(self, path):
197
+ """
198
+ Save the full model to a file
199
+
200
+ Args:
201
+ path (str): Path to the file
202
+ """
203
+ torch.save(self.state_dict(), path)
204
+
205
+ def load(self, path):
206
+ """
207
+ Load the full model from a file
208
+
209
+ Args:
210
+ path (str): Path to the file
211
+ """
212
+ self.load_state_dict(torch.load(path))
213
+ self.to(self.torch_dtype).to(self.device)
214
+
215
+ def save_backbone(self, path):
216
+ """
217
+ Save the backbone to a file
218
+
219
+ Args:
220
+ path (str): Path to the file
221
+ """
222
+ self.backbone.save_pretrained(path)
223
+
224
+ def load_backbone(self, path):
225
+ """
226
+ Load the backbone from a file
227
+
228
+ Args:
229
+ path (str): Path to the file
230
+ """
231
+ self.backbone = AutoModel.from_pretrained(path)
232
+ self.to(self.torch_dtype).to(self.device)
233
+
234
+ def freeze_backbone(self):
235
+ """ Freeze the backbone """
236
+ for param in self.backbone.parameters():
237
+ param.requires_grad = False
238
+
239
+ def unfreeze_backbone(self):
240
+ """ Unfreeze the backbone """
241
+ for param in self.backbone.parameters():
242
+ param.requires_grad = True
243
+
244
+ def freeze_head(self, head_name):
245
+ """
246
+ Freeze a head by name
247
+
248
+ Args:
249
+ head_name (str): Name of the head
250
+ """
251
+ for param in self.heads[head_name].parameters():
252
+ param.requires_grad = False
253
+
254
+ def unfreeze_head(self, head_name):
255
+ """
256
+ Unfreeze a head by name
257
+
258
+ Args:
259
+ head_name (str): Name of the head
260
+ """
261
+ for param in self.heads[head_name].parameters():
262
+ param.requires_grad = True
263
+
264
+ def freeze_all_heads(self):
265
+ """ Freeze all heads """
266
+ for head_name in self.heads.keys():
267
+ self.freeze_head(head_name)
268
+
269
+ def unfreeze_all_heads(self):
270
+ """ Unfreeze all heads """
271
+ for head_name in self.heads.keys():
272
+ self.unfreeze_head(head_name)
273
+
274
+ def freeze_all(self):
275
+ """ Freeze all """
276
+ self.freeze_backbone()
277
+ self.freeze_all_heads()
278
+
279
+ def unfreeze_all(self):
280
+ """ Unfreeze all """
281
+ self.unfreeze_backbone()
282
+ self.unfreeze_all_heads()
283
+
284
+ def add_head(self, head_name, num_classes):
285
+ """
286
+ Add a new head to the model
287
+
288
+ Args:
289
+ head_name (str): Name of the head
290
+ num_classes (int): Number of classes for the head
291
+ """
292
+ self.heads[head_name] = nn.Linear(self.backbone.config.hidden_size, num_classes)
293
+ self.heads[head_name].to(self.torch_dtype).to(self.device)
294
+ self.head_config[head_name] = num_classes
295
+
296
+ def remove_head(self, head_name):
297
+ """
298
+ Remove a head from the model
299
+ """
300
+ if head_name not in self.heads:
301
+ raise ValueError(f'Head {head_name} not found')
302
+ del self.heads[head_name]
303
+ del self.head_config[head_name]
304
+
305
+ @classmethod
306
+ def from_pretrained(cls, model_name, head_config=None, dropout=0.1, l2_reg=0.01):
307
+ """
308
+ Load a pretrained model from Huggingface model hub
309
+
310
+ Args:
311
+ model_name (str): Name of the model
312
+ head_config (dict): Head configuration
313
+ dropout (float): Dropout rate
314
+ l2_reg (float): L2 regularization rate
315
+ """
316
+ if head_config is None:
317
+ head_config = {}
318
+ # check if model exists locally
319
+ hf_cache_dir = HF_HUB_CACHE
320
+ model_path = os.path.join(hf_cache_dir, model_name)
321
+ if os.path.exists(model_path):
322
+ return cls._from_directory(model_path, head_config, dropout, l2_reg)
323
+
324
+ model_path = snapshot_download(repo_id=model_name, cache_dir=hf_cache_dir)
325
+ return cls._from_directory(model_path, head_config, dropout, l2_reg)
326
+
327
+ @classmethod
328
+ def _from_directory(cls, model_path, head_config, dropout=0.1, l2_reg=0.01):
329
+ """
330
+ Load a model from a directory
331
+
332
+ Args:
333
+ model_path (str): Path to the model directory
334
+ head_config (dict): Head configuration
335
+ dropout (float): Dropout rate
336
+ l2_reg (float): L2 regularization rate
337
+ """
338
+ backbone = AutoModel.from_pretrained(os.path.join(model_path, 'pretrained/backbone.pth'))
339
+ instance = cls(backbone, head_config, dropout, l2_reg)
340
+ instance.load(os.path.join(model_path, 'pretrained/model.pth'))
341
+ instance.head_config = {k: v. instance.heads}
342
+ return instance