poi_Engineering / src /sts /tests /test_engine.py
Juartaurus's picture
Upload folder using huggingface_hub
1865436
# Copyright (c) Facebook, Inc. and its affiliates.
import json
import os
import tempfile
import time
import unittest
import torch
from torch import nn
from detectron2.config import configurable, get_cfg
from detectron2.engine import DefaultTrainer, SimpleTrainer, hooks
from detectron2.modeling.meta_arch import META_ARCH_REGISTRY
from detectron2.utils.events import CommonMetricPrinter, JSONWriter
@META_ARCH_REGISTRY.register()
class _SimpleModel(nn.Module):
@configurable
def __init__(self, sleep_sec=0):
super().__init__()
self.mod = nn.Linear(10, 20)
self.sleep_sec = sleep_sec
@classmethod
def from_config(cls, cfg):
return {}
def forward(self, x):
if self.sleep_sec > 0:
time.sleep(self.sleep_sec)
return {"loss": x.sum() + sum([x.mean() for x in self.parameters()])}
class TestTrainer(unittest.TestCase):
def _data_loader(self, device):
device = torch.device(device)
while True:
yield torch.rand(3, 3).to(device)
def test_simple_trainer(self, device="cpu"):
model = _SimpleModel().to(device=device)
trainer = SimpleTrainer(
model, self._data_loader(device), torch.optim.SGD(model.parameters(), 0.1)
)
trainer.train(0, 10)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_simple_trainer_cuda(self):
self.test_simple_trainer(device="cuda")
def test_writer_hooks(self):
model = _SimpleModel(sleep_sec=0.1)
trainer = SimpleTrainer(
model, self._data_loader("cpu"), torch.optim.SGD(model.parameters(), 0.1)
)
max_iter = 50
with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:
json_file = os.path.join(d, "metrics.json")
writers = [CommonMetricPrinter(max_iter), JSONWriter(json_file)]
trainer.register_hooks(
[hooks.EvalHook(0, lambda: {"metric": 100}), hooks.PeriodicWriter(writers)]
)
with self.assertLogs(writers[0].logger) as logs:
trainer.train(0, max_iter)
with open(json_file, "r") as f:
data = [json.loads(line.strip()) for line in f]
self.assertEqual([x["iteration"] for x in data], [19, 39, 49, 50])
# the eval metric is in the last line with iter 50
self.assertIn("metric", data[-1], "Eval metric must be in last line of JSON!")
# test logged messages from CommonMetricPrinter
self.assertEqual(len(logs.output), 3)
for log, iter in zip(logs.output, [19, 39, 49]):
self.assertIn(f"iter: {iter}", log)
self.assertIn("eta: 0:00:00", logs.output[-1], "Last ETA must be 0!")
@unittest.skipIf(os.environ.get("CI"), "Require COCO data.")
def test_default_trainer(self):
cfg = get_cfg()
cfg.MODEL.META_ARCHITECTURE = "_SimpleModel"
cfg.DATASETS.TRAIN = ("coco_2017_val_100",)
with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:
cfg.OUTPUT_DIR = d
trainer = DefaultTrainer(cfg)
# test property
self.assertIs(trainer.model, trainer._trainer.model)
trainer.model = _SimpleModel()
self.assertIs(trainer.model, trainer._trainer.model)