LightChen2333 commited on
Commit
82144dc
·
1 Parent(s): 05f3947

Delete model_manager.py

Browse files
Files changed (1) hide show
  1. model_manager.py +0 -324
model_manager.py DELETED
@@ -1,324 +0,0 @@
1
- '''
2
- Author: Qiguang Chen
3
- Date: 2023-01-11 10:39:26
4
- LastEditors: Qiguang Chen
5
- LastEditTime: 2023-02-08 00:42:56
6
- Description: manage all process of model training and prediction.
7
-
8
- '''
9
- import os
10
- import random
11
-
12
- import numpy as np
13
- import torch
14
- from tqdm import tqdm
15
-
16
-
17
- from common import utils
18
- from common.loader import DataFactory
19
- from common.logger import Logger
20
- from common.metric import Evaluator
21
- from common.tokenizer import get_tokenizer, get_tokenizer_class, load_embedding
22
- from common.utils import InputData, instantiate
23
- from common.utils import OutputData
24
- from common.config import Config
25
- import dill
26
-
27
-
28
- class ModelManager(object):
29
- def __init__(self, config: Config):
30
- """create model manager by config
31
-
32
- Args:
33
- config (Config): configuration to manage all process in OpenSLU
34
- """
35
- # init config
36
- self.config = config
37
- self.__set_seed(self.config.base.get("seed"))
38
- self.device = self.config.base.get("device")
39
-
40
- # enable accelerator
41
- if "accelerator" in self.config and self.config["accelerator"].get("use_accelerator"):
42
- from accelerate import Accelerator
43
- self.accelerator = Accelerator(log_with="wandb")
44
- else:
45
- self.accelerator = None
46
- if self.config.base.get("train"):
47
- self.tokenizer = get_tokenizer(
48
- self.config.tokenizer.get("_tokenizer_name_"))
49
- self.logger = Logger(
50
- "wandb", self.config.base["name"], start_time=config.start_time, accelerator=self.accelerator)
51
-
52
- # init dataloader & load data
53
- if self.config.base.get("save_dir"):
54
- self.model_save_dir = self.config.base["save_dir"]
55
- else:
56
- if not os.path.exists("save/"):
57
- os.mkdir("save/")
58
- self.model_save_dir = "save/" + config.start_time
59
- if not os.path.exists(self.model_save_dir):
60
- os.mkdir(self.model_save_dir)
61
- batch_size = self.config.base["batch_size"]
62
- df = DataFactory(tokenizer=self.tokenizer,
63
- use_multi_intent=self.config.base.get("multi_intent"),
64
- to_lower_case=self.config.base.get("_to_lower_case_"))
65
- train_dataset = df.load_dataset(self.config.dataset, split="train")
66
-
67
- # update label and vocabulary
68
- df.update_label_names(train_dataset)
69
- df.update_vocabulary(train_dataset)
70
-
71
- # init tokenizer config and dataloaders
72
- tokenizer_config = {key: self.config.tokenizer[key]
73
- for key in self.config.tokenizer if key[0] != "_" and key[-1] != "_"}
74
- self.train_dataloader = df.get_data_loader(train_dataset,
75
- batch_size,
76
- shuffle=True,
77
- device=self.device,
78
- enable_label=True,
79
- align_mode=self.config.tokenizer.get(
80
- "_align_mode_"),
81
- label2tensor=True,
82
- **tokenizer_config)
83
- dev_dataset = df.load_dataset(
84
- self.config.dataset, split="validation")
85
- self.dev_dataloader = df.get_data_loader(dev_dataset,
86
- batch_size,
87
- shuffle=False,
88
- device=self.device,
89
- enable_label=True,
90
- align_mode=self.config.tokenizer.get(
91
- "_align_mode_"),
92
- label2tensor=False,
93
- **tokenizer_config)
94
- df.update_vocabulary(dev_dataset)
95
- # add intent label num and slot label num to config
96
- if int(self.config.get_intent_label_num()) == 0 or int(self.config.get_slot_label_num()) == 0:
97
- self.intent_list = df.intent_label_list
98
- self.intent_dict = df.intent_label_dict
99
- self.config.set_intent_label_num(len(self.intent_list))
100
- self.slot_list = df.slot_label_list
101
- self.slot_dict = df.slot_label_dict
102
- self.config.set_slot_label_num(len(self.slot_list))
103
- self.config.set_vocab_size(self.tokenizer.vocab_size)
104
-
105
- # autoload embedding for non-pretrained encoder
106
- if self.config["model"]["encoder"].get("embedding") and self.config["model"]["encoder"]["embedding"].get(
107
- "load_embedding_name"):
108
- self.config["model"]["encoder"]["embedding"]["embedding_matrix"] = load_embedding(self.tokenizer,
109
- self.config["model"][
110
- "encoder"][
111
- "embedding"].get(
112
- "load_embedding_name"))
113
- # fill template in config
114
- self.config.autoload_template()
115
- # save config
116
- self.logger.set_config(self.config)
117
-
118
- self.model = None
119
- self.optimizer = None
120
- self.total_step = None
121
- self.lr_scheduler = None
122
- if self.config.tokenizer.get("_tokenizer_name_") == "word_tokenizer":
123
- self.tokenizer.save(os.path.join(self.model_save_dir, "tokenizer.json"))
124
- utils.save_json(os.path.join(
125
- self.model_save_dir, "label.json"), {"intent": self.intent_list,"slot": self.slot_list})
126
- if self.config.base.get("test"):
127
- self.test_dataset = df.load_dataset(
128
- self.config.dataset, split="test")
129
- self.test_dataloader = df.get_data_loader(self.test_dataset,
130
- batch_size,
131
- shuffle=False,
132
- device=self.device,
133
- enable_label=True,
134
- align_mode=self.config.tokenizer.get(
135
- "_align_mode_"),
136
- label2tensor=False,
137
- **tokenizer_config)
138
-
139
- def init_model(self, model):
140
- """init model, optimizer, lr_scheduler
141
-
142
- Args:
143
- model (Any): pytorch model
144
- """
145
- self.model = model
146
- self.model.to(self.device)
147
- if self.config.base.get("train"):
148
- self.optimizer = instantiate(
149
- self.config["optimizer"])(self.model.parameters())
150
- self.total_step = int(self.config.base.get(
151
- "epoch_num")) * len(self.train_dataloader)
152
- self.lr_scheduler = instantiate(self.config["scheduler"])(
153
- optimizer=self.optimizer,
154
- num_training_steps=self.total_step
155
- )
156
- if self.accelerator is not None:
157
- self.model, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare(
158
- self.model, self.optimizer, self.train_dataloader, self.lr_scheduler)
159
- if self.config.base.get("load_dir_path"):
160
- self.accelerator.load_state(self.config.base.get("load_dir_path"))
161
- # self.dev_dataloader = self.accelerator.prepare(self.dev_dataloader)
162
-
163
- def eval(self, step: int, best_metric: float) -> float:
164
- """ evaluation models.
165
-
166
- Args:
167
- step (int): which step the model has trained in
168
- best_metric (float): last best metric value to judge whether to test or save model
169
-
170
- Returns:
171
- float: updated best metric value
172
- """
173
- # TODO: save dev
174
- _, res = self.__evaluate(self.model, self.dev_dataloader)
175
- self.logger.log_metric(res, metric_split="dev", step=step)
176
- if res[self.config.base.get("best_key")] > best_metric:
177
- best_metric = res[self.config.base.get("best_key")]
178
- outputs, test_res = self.__evaluate(
179
- self.model, self.test_dataloader)
180
- if not os.path.exists(self.model_save_dir):
181
- os.mkdir(self.model_save_dir)
182
- if self.accelerator is None:
183
- torch.save(self.model, os.path.join(
184
- self.model_save_dir, "model.pkl"))
185
- torch.save(self.optimizer, os.path.join(
186
- self.model_save_dir, "optimizer.pkl"))
187
- torch.save(self.lr_scheduler, os.path.join(
188
- self.model_save_dir, "lr_scheduler.pkl"), pickle_module=dill)
189
- torch.save(step, os.path.join(
190
- self.model_save_dir, "step.pkl"))
191
- else:
192
- self.accelerator.wait_for_everyone()
193
- unwrapped_model = self.accelerator.unwrap_model(self.model)
194
- self.accelerator.save(unwrapped_model.state_dict(
195
- ), os.path.join(self.model_save_dir, "model.pkl"))
196
- self.accelerator.save_state(output_dir=self.model_save_dir)
197
- outputs.save(self.model_save_dir, self.test_dataset)
198
- self.logger.log_metric(test_res, metric_split="test", step=step)
199
- return best_metric
200
-
201
- def train(self) -> float:
202
- """ train models.
203
-
204
- Returns:
205
- float: updated best metric value
206
- """
207
- step = 0
208
- best_metric = 0
209
- progress_bar = tqdm(range(self.total_step))
210
- for _ in range(int(self.config.base.get("epoch_num"))):
211
- for data in self.train_dataloader:
212
- if step == 0:
213
- self.logger.info(data.get_item(
214
- 0, tokenizer=self.tokenizer, intent_map=self.intent_list, slot_map=self.slot_list))
215
- output = self.model(data)
216
- if self.accelerator is not None and hasattr(self.model, "module"):
217
- loss, intent_loss, slot_loss = self.model.module.compute_loss(
218
- pred=output, target=data)
219
- else:
220
- loss, intent_loss, slot_loss = self.model.compute_loss(
221
- pred=output, target=data)
222
- self.logger.log_loss(loss, "Loss", step=step)
223
- self.logger.log_loss(intent_loss, "Intent Loss", step=step)
224
- self.logger.log_loss(slot_loss, "Slot Loss", step=step)
225
- self.optimizer.zero_grad()
226
-
227
- if self.accelerator is not None:
228
- self.accelerator.backward(loss)
229
- else:
230
- loss.backward()
231
- self.optimizer.step()
232
- self.lr_scheduler.step()
233
- if not self.config.base.get("eval_by_epoch") and step % self.config.base.get(
234
- "eval_step") == 0 and step != 0:
235
- best_metric = self.eval(step, best_metric)
236
- step += 1
237
- progress_bar.update(1)
238
- if self.config.base.get("eval_by_epoch"):
239
- best_metric = self.eval(step, best_metric)
240
- self.logger.finish()
241
- return best_metric
242
-
243
- def __set_seed(self, seed_value: int):
244
- """Manually set random seeds.
245
-
246
- Args:
247
- seed_value (int): random seed
248
- """
249
- random.seed(seed_value)
250
- np.random.seed(seed_value)
251
- torch.manual_seed(seed_value)
252
- torch.random.manual_seed(seed_value)
253
- os.environ['PYTHONHASHSEED'] = str(seed_value)
254
- if torch.cuda.is_available():
255
- torch.cuda.manual_seed(seed_value)
256
- torch.cuda.manual_seed_all(seed_value)
257
- torch.backends.cudnn.deterministic = True
258
- torch.backends.cudnn.benchmark = True
259
- return
260
-
261
- def __evaluate(self, model, dataloader):
262
- model.eval()
263
- inps = InputData()
264
- outputs = OutputData()
265
- for data in dataloader:
266
- torch.cuda.empty_cache()
267
- output = model(data)
268
- if self.accelerator is not None and hasattr(self.model, "module"):
269
- decode_output = model.module.decode(output, data)
270
- else:
271
- decode_output = model.decode(output, data)
272
-
273
- decode_output.map_output(slot_map=self.slot_list,
274
- intent_map=self.intent_list)
275
- data, decode_output = utils.remove_slot_ignore_index(
276
- data, decode_output, ignore_index="#")
277
-
278
- inps.merge_input_data(data)
279
- outputs.merge_output_data(decode_output)
280
- if "metric" in self.config:
281
- res = Evaluator.compute_all_metric(
282
- inps, outputs, intent_label_map=self.intent_dict, metric_list=self.config.metric)
283
- else:
284
- res = Evaluator.compute_all_metric(
285
- inps, outputs, intent_label_map=self.intent_dict)
286
- model.train()
287
- return outputs, res
288
-
289
- def load(self):
290
-
291
- self.model = torch.load(os.path.join(self.config.base["model_dir"], "model.pkl"),map_location=self.config.base["device"])
292
- if self.config.tokenizer["_tokenizer_name_"] == "word_tokenizer":
293
- self.tokenizer = get_tokenizer_class(self.config.tokenizer["_tokenizer_name_"]).from_file(
294
- os.path.join(self.config.base["model_dir"], "tokenizer.json"))
295
- else:
296
- self.tokenizer = get_tokenizer(self.config.tokenizer["_tokenizer_name_"])
297
- self.model.to(self.device)
298
- label = utils.load_json(os.path.join(self.config.base["model_dir"], "label.json"))
299
- self.intent_list = label["intent"]
300
- self.slot_list = label["slot"]
301
- self.data_factory=DataFactory(tokenizer=self.tokenizer,
302
- use_multi_intent=self.config.base.get("multi_intent"),
303
- to_lower_case=self.config.tokenizer.get("_to_lower_case_"))
304
-
305
- def predict(self, text_data):
306
- self.model.eval()
307
- tokenizer_config = {key: self.config.tokenizer[key]
308
- for key in self.config.tokenizer if key[0] != "_" and key[-1] != "_"}
309
- align_mode = self.config.tokenizer.get("_align_mode_")
310
- inputs = self.data_factory.batch_fn(batch=[{"text": text_data.split(" ")}],
311
- device=self.device,
312
- config=tokenizer_config,
313
- enable_label=False,
314
- align_mode= align_mode if align_mode is not None else "general",
315
- label2tensor=False)
316
- output = self.model(inputs)
317
- decode_output = self.model.decode(output, inputs)
318
- decode_output.map_output(slot_map=self.slot_list,
319
- intent_map=self.intent_list)
320
- if self.config.base.get("multi_intent"):
321
- intent = decode_output.intent_ids[0]
322
- else:
323
- intent = [decode_output.intent_ids[0]]
324
- return {"intent": intent, "slot": decode_output.slot_ids[0], "text": self.tokenizer.decode(inputs.input_ids[0])}