Spaces:
Runtime error
Runtime error
File size: 6,315 Bytes
8044721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
#!/usr/bin/env python3
# coding=utf-8
from utility.loading_bar import LoadingBar
import time
import torch
class Log:
def __init__(self, dataset, model, optimizer, args, directory, log_each: int, initial_epoch=-1, log_wandb=True):
self.dataset = dataset
self.model = model
self.args = args
self.optimizer = optimizer
self.loading_bar = LoadingBar(length=27)
self.best_f1_score = 0.0
self.log_each = log_each
self.epoch = initial_epoch
self.log_wandb = log_wandb
if self.log_wandb:
globals()["wandb"] = __import__("wandb") # ugly way to not require wandb if not needed
self.directory = directory
self.evaluation_results = f"{directory}/results_{{0}}_{{1}}.json"
self.full_evaluation_results = f"{directory}/full_results_{{0}}_{{1}}.json"
self.best_full_evaluation_results = f"{directory}/best_full_results_{{0}}_{{1}}.json"
self.result_history = {epoch: {} for epoch in range(args.epochs)}
self.best_checkpoint_filename = f"{self.directory}/best_checkpoint.h5"
self.last_checkpoint_filename = f"{self.directory}/last_checkpoint.h5"
self.step = 0
self.total_batch_size = 0
self.flushed = True
def train(self, len_dataset: int) -> None:
self.flush()
self.epoch += 1
if self.epoch == 0:
self._print_header()
self.is_train = True
self._reset(len_dataset)
def eval(self, len_dataset: int) -> None:
self.flush()
self.is_train = False
self._reset(len_dataset)
def __call__(self, batch_size, losses, grad_norm: float = None, learning_rates: float = None,) -> None:
if self.is_train:
self._train_step(batch_size, losses, grad_norm, learning_rates)
else:
self._eval_step(batch_size, losses)
self.flushed = False
def flush(self) -> None:
if self.flushed:
return
self.flushed = True
if self.is_train:
print(f"\rβ{self.epoch:12d} β{self._time():>12} β", end="", flush=True)
else:
if self.losses is not None and self.log_wandb:
dictionary = {f"validation/{key}": value / self.step for key, value in self.losses.items()}
dictionary["epoch"] = self.epoch
wandb.log(dictionary)
self.losses = None
# self._save_model(save_as_best=False, performance=None)
def log_evaluation(self, scores, mode, epoch):
f1_score = scores["sentiment_tuple/f1"]
if self.log_wandb:
scores = {f"{mode}/{k}": v for k, v in scores.items()}
wandb.log({
"epoch": epoch,
**scores
})
if mode == "validation" and f1_score > self.best_f1_score:
if self.log_wandb:
wandb.run.summary["best sentiment tuple f1 score"] = f1_score
self.best_f1_score = f1_score
self._save_model(save_as_best=True, f1_score=f1_score)
def _save_model(self, save_as_best: bool, f1_score: float):
if not self.args.save_checkpoints:
return
state = {
"epoch": self.epoch,
"dataset": self.dataset.state_dict(),
"f1_score": f1_score,
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"args": self.args.state_dict(),
}
filename = self.best_checkpoint_filename if save_as_best else self.last_checkpoint_filename
torch.save(state, filename)
if self.log_wandb:
wandb.save(filename)
def _train_step(self, batch_size, losses, grad_norm: float, learning_rates) -> None:
self.total_batch_size += batch_size
self.step += 1
if self.losses is None:
self.losses = losses
else:
for key, values in losses.items():
if key not in self.losses:
self.losses[key] = losses[key]
continue
self.losses[key] += losses[key]
if self.step % self.log_each == 0:
progress = self.total_batch_size / self.len_dataset
print(f"\rβ{self.epoch:12d} β{self._time():>12} {self.loading_bar(progress)}", end="", flush=True)
if self.log_wandb:
dictionary = {f"train/{key}" if not key.startswith("weight/") else key: value / self.log_each for key, value in self.losses.items()}
dictionary["epoch"] = self.epoch
dictionary["learning_rate/encoder"] = learning_rates[0]
dictionary["learning_rate/decoder"] = learning_rates[-2]
dictionary["learning_rate/grad_norm"] = learning_rates[-1]
dictionary["gradient norm"] = grad_norm
wandb.log(dictionary)
self.losses = None
def _eval_step(self, batch_size, losses) -> None:
self.step += 1
if self.losses is None:
self.losses = losses
else:
for key, values in losses.items():
if key not in self.losses:
self.losses[key] = losses[key]
continue
self.losses[key] += losses[key]
def _reset(self, len_dataset: int) -> None:
self.start_time = time.time()
self.step = 0
self.total_batch_size = 0
self.len_dataset = len_dataset
self.losses = None
def _time(self) -> str:
time_seconds = int(time.time() - self.start_time)
return f"{time_seconds // 60:02d}:{time_seconds % 60:02d} min"
def _print_header(self) -> None:
print(f"ββββββββββββββββ³ββββΈSβΊβΈEβΊβΈMβΊβΈAβΊβΈNβΊβΈTβΊβΈIβΊβΈSβΊβΈKβΊβββββββββββββββ")
print(f"β β β· β")
print(f"β epoch β elapsed β progress bar β")
print(f"β ββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββ¨")
|