Spaces:
Sleeping
Sleeping
from typing import Callable | |
import torch | |
from .time_helper_base import TimeWrapper | |
def get_cuda_time_wrapper() -> Callable[[], 'TimeWrapper']: | |
""" | |
Overview: | |
Return the ``TimeWrapperCuda`` class, this wrapper aims to ensure compatibility in no cuda device | |
Returns: | |
- TimeWrapperCuda(:obj:`class`): See ``TimeWrapperCuda`` class | |
.. note:: | |
Must use ``torch.cuda.synchronize()``, reference: <https://blog.csdn.net/u013548568/article/details/81368019> | |
""" | |
# TODO find a way to autodoc the class within method | |
class TimeWrapperCuda(TimeWrapper): | |
""" | |
Overview: | |
A class method that inherit from ``TimeWrapper`` class | |
Notes: | |
Must use torch.cuda.synchronize(), reference: \ | |
<https://blog.csdn.net/u013548568/article/details/81368019> | |
Interfaces: | |
``start_time``, ``end_time`` | |
""" | |
# cls variable is initialized on loading this class | |
start_record = torch.cuda.Event(enable_timing=True) | |
end_record = torch.cuda.Event(enable_timing=True) | |
# overwrite | |
def start_time(cls): | |
""" | |
Overview: | |
Implement and overide the ``start_time`` method in ``TimeWrapper`` class | |
""" | |
torch.cuda.synchronize() | |
cls.start = cls.start_record.record() | |
# overwrite | |
def end_time(cls): | |
""" | |
Overview: | |
Implement and overide the end_time method in ``TimeWrapper`` class | |
Returns: | |
- time(:obj:`float`): The time between ``start_time`` and ``end_time`` | |
""" | |
cls.end = cls.end_record.record() | |
torch.cuda.synchronize() | |
return cls.start_record.elapsed_time(cls.end_record) / 1000 | |
return TimeWrapperCuda | |