Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. | |
import json | |
import os | |
import tempfile | |
import time | |
import unittest | |
import torch | |
from torch import nn | |
from detectron2.config import configurable, get_cfg | |
from detectron2.engine import DefaultTrainer, SimpleTrainer, hooks | |
from detectron2.modeling.meta_arch import META_ARCH_REGISTRY | |
from detectron2.utils.events import CommonMetricPrinter, JSONWriter | |
class _SimpleModel(nn.Module): | |
def __init__(self, sleep_sec=0): | |
super().__init__() | |
self.mod = nn.Linear(10, 20) | |
self.sleep_sec = sleep_sec | |
def from_config(cls, cfg): | |
return {} | |
def forward(self, x): | |
if self.sleep_sec > 0: | |
time.sleep(self.sleep_sec) | |
return {"loss": x.sum() + sum([x.mean() for x in self.parameters()])} | |
class TestTrainer(unittest.TestCase): | |
def _data_loader(self, device): | |
device = torch.device(device) | |
while True: | |
yield torch.rand(3, 3).to(device) | |
def test_simple_trainer(self, device="cpu"): | |
model = _SimpleModel().to(device=device) | |
trainer = SimpleTrainer( | |
model, self._data_loader(device), torch.optim.SGD(model.parameters(), 0.1) | |
) | |
trainer.train(0, 10) | |
def test_simple_trainer_cuda(self): | |
self.test_simple_trainer(device="cuda") | |
def test_writer_hooks(self): | |
model = _SimpleModel(sleep_sec=0.1) | |
trainer = SimpleTrainer( | |
model, self._data_loader("cpu"), torch.optim.SGD(model.parameters(), 0.1) | |
) | |
max_iter = 50 | |
with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: | |
json_file = os.path.join(d, "metrics.json") | |
writers = [CommonMetricPrinter(max_iter), JSONWriter(json_file)] | |
trainer.register_hooks( | |
[hooks.EvalHook(0, lambda: {"metric": 100}), hooks.PeriodicWriter(writers)] | |
) | |
with self.assertLogs(writers[0].logger) as logs: | |
trainer.train(0, max_iter) | |
with open(json_file, "r") as f: | |
data = [json.loads(line.strip()) for line in f] | |
self.assertEqual([x["iteration"] for x in data], [19, 39, 49, 50]) | |
# the eval metric is in the last line with iter 50 | |
self.assertIn("metric", data[-1], "Eval metric must be in last line of JSON!") | |
# test logged messages from CommonMetricPrinter | |
self.assertEqual(len(logs.output), 3) | |
for log, iter in zip(logs.output, [19, 39, 49]): | |
self.assertIn(f"iter: {iter}", log) | |
self.assertIn("eta: 0:00:00", logs.output[-1], "Last ETA must be 0!") | |
def test_default_trainer(self): | |
cfg = get_cfg() | |
cfg.MODEL.META_ARCHITECTURE = "_SimpleModel" | |
cfg.DATASETS.TRAIN = ("coco_2017_val_100",) | |
with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: | |
cfg.OUTPUT_DIR = d | |
trainer = DefaultTrainer(cfg) | |
# test property | |
self.assertIs(trainer.model, trainer._trainer.model) | |
trainer.model = _SimpleModel() | |
self.assertIs(trainer.model, trainer._trainer.model) | |