|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
from copy import deepcopy |
|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
import torch |
|
from fairseq.models.ema import EMA |
|
|
|
|
|
class DummyModule(torch.nn.Module): |
|
def __init__(self) -> None: |
|
"""LightningModule for testing purposes |
|
|
|
Args: |
|
epoch_min_loss_override (int, optional): Pass in an epoch that will be set to the minimum |
|
validation loss for testing purposes (zero based). If None this is ignored. Defaults to None. |
|
""" |
|
super().__init__() |
|
self.layer = torch.nn.Linear(in_features=32, out_features=2) |
|
self.another_layer = torch.nn.Linear(in_features=2, out_features=2) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.layer(x) |
|
return self.another_layer(x) |
|
|
|
|
|
@dataclass |
|
class EMAConfig(object): |
|
ema_decay: float = 0.99 |
|
ema_start_update: int = 0 |
|
ema_fp32: bool = False |
|
ema_seed_model: Optional[str] = None |
|
|
|
|
|
class TestEMAGPU(unittest.TestCase): |
|
def assertTorchAllClose(self, x, y, atol=1e-8, rtol=1e-5, msg=None): |
|
diff = x.float() - y.float() |
|
diff_norm = torch.norm(diff) |
|
other_norm = torch.norm(y.float()) |
|
|
|
if msg is None: |
|
msg = "|input - other| > {} + {} * |other|".format( |
|
atol, rtol |
|
) |
|
|
|
self.assertLessEqual( |
|
diff_norm, |
|
atol + rtol * other_norm, |
|
msg=msg, |
|
) |
|
|
|
def test_ema(self): |
|
model = DummyModule() |
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) |
|
state = deepcopy(model.state_dict()) |
|
config = EMAConfig() |
|
ema = EMA(model, config) |
|
|
|
|
|
ema._set_decay(config.ema_decay) |
|
self.assertEqual(ema.get_decay(), config.ema_decay) |
|
|
|
|
|
self.assertEqual(ema.get_model(), ema.model) |
|
|
|
|
|
self.assertEqual(len(ema.fp32_params), 0) |
|
|
|
|
|
x = torch.randn(32) |
|
y = model(x) |
|
loss = y.sum() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
ema.step(model) |
|
|
|
ema_state_dict = ema.get_model().state_dict() |
|
|
|
for key, param in model.state_dict().items(): |
|
prev_param = state[key] |
|
ema_param = ema_state_dict[key] |
|
|
|
if "version" in key: |
|
|
|
continue |
|
self.assertTorchAllClose( |
|
ema_param, |
|
config.ema_decay * prev_param + (1 - config.ema_decay) * param, |
|
) |
|
|
|
|
|
self.assertEqual(len(ema.fp32_params), 0) |
|
|
|
|
|
model2 = DummyModule() |
|
ema.reverse(model2) |
|
|
|
for key, param in model2.state_dict().items(): |
|
ema_param = ema_state_dict[key] |
|
self.assertTrue( |
|
torch.allclose(ema_param, param) |
|
) |
|
|
|
def test_ema_fp32(self): |
|
model = DummyModule().half() |
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) |
|
state = deepcopy(model.state_dict()) |
|
config = EMAConfig(ema_fp32=True) |
|
ema = EMA(model, config) |
|
|
|
x = torch.randn(32) |
|
y = model(x.half()) |
|
loss = y.sum() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
ema.step(model) |
|
|
|
for key, param in model.state_dict().items(): |
|
prev_param = state[key] |
|
ema_param = ema.get_model().state_dict()[key] |
|
|
|
if "version" in key: |
|
|
|
continue |
|
self.assertIn(key, ema.fp32_params) |
|
|
|
|
|
|
|
self.assertLessEqual( |
|
torch.norm( |
|
ema_param.float() - |
|
(config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float() |
|
), |
|
torch.norm( |
|
ema_param.float() - |
|
(config.ema_decay * prev_param + (1 - config.ema_decay) * param).float() |
|
), |
|
) |
|
self.assertTorchAllClose( |
|
ema_param, |
|
(config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half(), |
|
) |
|
|
|
def test_ema_fp16(self): |
|
model = DummyModule().half() |
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) |
|
state = deepcopy(model.state_dict()) |
|
config = EMAConfig(ema_fp32=False) |
|
ema = EMA(model, config) |
|
|
|
|
|
self.assertEqual(len(ema.fp32_params), 0) |
|
|
|
x = torch.randn(32) |
|
y = model(x.half()) |
|
loss = y.sum() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
ema.step(model) |
|
|
|
for key, param in model.state_dict().items(): |
|
prev_param = state[key] |
|
ema_param = ema.get_model().state_dict()[key] |
|
|
|
if "version" in key: |
|
|
|
continue |
|
|
|
|
|
|
|
self.assertLessEqual( |
|
torch.norm( |
|
ema_param.float() - |
|
(config.ema_decay * prev_param + (1 - config.ema_decay) * param).float() |
|
), |
|
torch.norm( |
|
ema_param.float() - |
|
(config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float() |
|
), |
|
) |
|
self.assertTorchAllClose( |
|
ema_param, |
|
config.ema_decay * prev_param + (1 - config.ema_decay) * param, |
|
) |
|
|
|
|
|
self.assertEqual(len(ema.fp32_params), 0) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|