Spaces:
Sleeping
Sleeping
import sys | |
import torch | |
from tqdm import tqdm as tqdm | |
from .meter import AverageValueMeter | |
class Epoch: | |
def __init__(self, model, loss, metrics, stage_name, device="cpu", verbose=True): | |
self.model = model | |
self.loss = loss | |
self.metrics = metrics | |
self.stage_name = stage_name | |
self.verbose = verbose | |
self.device = device | |
self._to_device() | |
def _to_device(self): | |
self.model.to(self.device) | |
self.loss.to(self.device) | |
for metric in self.metrics: | |
metric.to(self.device) | |
def _format_logs(self, logs): | |
str_logs = ["{} - {:.4}".format(k, v) for k, v in logs.items()] | |
s = ", ".join(str_logs) | |
return s | |
def batch_update(self, x, y): | |
raise NotImplementedError | |
def on_epoch_start(self): | |
pass | |
def run(self, dataloader): | |
self.on_epoch_start() | |
logs = {} | |
loss_meter = AverageValueMeter() | |
metrics_meters = { | |
metric.__name__: AverageValueMeter() for metric in self.metrics | |
} | |
with tqdm( | |
dataloader, | |
desc=self.stage_name, | |
file=sys.stdout, | |
disable=not (self.verbose), | |
) as iterator: | |
for x, y in iterator: | |
x, y = x.to(self.device), y.to(self.device) | |
loss, y_pred = self.batch_update(x, y) | |
# update loss logs | |
loss_value = loss.cpu().detach().numpy() | |
loss_meter.add(loss_value) | |
loss_logs = {self.loss.__name__: loss_meter.mean} | |
logs.update(loss_logs) | |
# update metrics logs | |
for metric_fn in self.metrics: | |
metric_value = metric_fn(y_pred, y).cpu().detach().numpy() | |
metrics_meters[metric_fn.__name__].add(metric_value) | |
metrics_logs = {k: v.mean for k, v in metrics_meters.items()} | |
logs.update(metrics_logs) | |
if self.verbose: | |
s = self._format_logs(logs) | |
iterator.set_postfix_str(s) | |
return logs | |
class TrainEpoch(Epoch): | |
def __init__(self, model, loss, metrics, optimizer, device="cpu", verbose=True): | |
super().__init__( | |
model=model, | |
loss=loss, | |
metrics=metrics, | |
stage_name="train", | |
device=device, | |
verbose=verbose, | |
) | |
self.optimizer = optimizer | |
def on_epoch_start(self): | |
self.model.train() | |
def batch_update(self, x, y): | |
self.optimizer.zero_grad() | |
prediction = self.model.forward(x) | |
loss = self.loss(prediction, y) | |
loss.backward() | |
self.optimizer.step() | |
return loss, prediction | |
class ValidEpoch(Epoch): | |
def __init__(self, model, loss, metrics, device="cpu", verbose=True): | |
super().__init__( | |
model=model, | |
loss=loss, | |
metrics=metrics, | |
stage_name="valid", | |
device=device, | |
verbose=verbose, | |
) | |
def on_epoch_start(self): | |
self.model.eval() | |
def batch_update(self, x, y): | |
with torch.no_grad(): | |
prediction = self.model.forward(x) | |
loss = self.loss(prediction, y) | |
return loss, prediction | |