File size: 1,707 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
"""
The abstract Task
Authors
* Leo 2022
"""
import abc
from collections import defaultdict
from typing import List
import torch
__all__ = ["Task"]
class Task(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def get_state(self):
# self.model will be separately saved, do not save self.model.state_dict() here
return {}
def set_state(self, state: dict):
pass
def parse_cached_results(self, cached_results: List[dict]):
keys = list(cached_results[0].keys())
dol = defaultdict(list)
for d in cached_results:
assert sorted(keys) == sorted(list(d.keys()))
for k, v in d.items():
if isinstance(v, (tuple, list)):
dol[k].extend(v)
else:
dol[k].append(v)
return dict(dol)
@abc.abstractmethod
def predict(self):
raise NotImplementedError
def forward(self, mode: str, *args, **kwargs):
return getattr(self, f"{mode}_step")(*args, **kwargs)
def reduction(self, mode: str, *args, **kwargs):
return getattr(self, f"{mode}_reduction")(*args, **kwargs)
@abc.abstractmethod
def train_step(self):
raise NotImplementedError
@abc.abstractmethod
def valid_step(self):
raise NotImplementedError
@abc.abstractmethod
def test_step(self):
raise NotImplementedError
@abc.abstractmethod
def train_reduction(self):
raise NotImplementedError
@abc.abstractmethod
def valid_reduction(self):
raise NotImplementedError
@abc.abstractmethod
def test_reduction(self):
raise NotImplementedError
|