|
from __future__ import annotations |
|
|
|
from collections import UserDict |
|
from typing import Any, Union |
|
|
|
import torch |
|
from lightning.fabric.utilities import move_data_to_device |
|
|
|
from relik.common.log import get_logger |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
class ModelInputs(UserDict): |
|
"""Model input dictionary wrapper.""" |
|
|
|
def __getattr__(self, item: str): |
|
try: |
|
return self.data[item] |
|
except KeyError: |
|
raise AttributeError(f"`ModelInputs` has no attribute `{item}`") |
|
|
|
def __getitem__(self, item: str) -> Any: |
|
return self.data[item] |
|
|
|
def __getstate__(self): |
|
return {"data": self.data} |
|
|
|
def __setstate__(self, state): |
|
if "data" in state: |
|
self.data = state["data"] |
|
|
|
def keys(self): |
|
"""A set-like object providing a view on D's keys.""" |
|
return self.data.keys() |
|
|
|
def values(self): |
|
"""An object providing a view on D's values.""" |
|
return self.data.values() |
|
|
|
def items(self): |
|
"""A set-like object providing a view on D's items.""" |
|
return self.data.items() |
|
|
|
def to(self, device: Union[str, torch.device]) -> ModelInputs: |
|
""" |
|
Send all tensors values to device. |
|
Args: |
|
device (`str` or `torch.device`): The device to put the tensors on. |
|
Returns: |
|
:class:`tokenizers.ModelInputs`: The same instance of :class:`~tokenizers.ModelInputs` |
|
after modification. |
|
""" |
|
self.data = move_data_to_device(self.data, device) |
|
return self |
|
|