Spaces:
Sleeping
Sleeping
from abc import ABC, abstractmethod | |
import logging | |
from os import path | |
import os | |
from threading import Thread | |
from time import sleep, time | |
from typing import Callable, Optional | |
import uuid | |
import torch.multiprocessing as mp | |
import torch | |
from ding.data.storage.file import FileModelStorage | |
from ding.data.storage.storage import Storage | |
from ding.framework import Supervisor | |
from ding.framework.supervisor import ChildType, SendPayload | |
class ModelWorker(): | |
def __init__(self, model: torch.nn.Module) -> None: | |
self._model = model | |
def save(self, storage: Storage) -> Storage: | |
storage.save(self._model.state_dict()) | |
return storage | |
class ModelLoader(Supervisor, ABC): | |
def __init__(self, model: torch.nn.Module) -> None: | |
""" | |
Overview: | |
Save and send models asynchronously and load them synchronously. | |
Arguments: | |
- model (:obj:`torch.nn.Module`): Torch module. | |
""" | |
if next(model.parameters()).is_cuda: | |
super().__init__(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn")) | |
else: | |
super().__init__(type_=ChildType.PROCESS) | |
self._model = model | |
self._send_callback_loop = None | |
self._send_callbacks = {} | |
self._model_worker = ModelWorker(self._model) | |
def start(self): | |
if not self._running: | |
self._model.share_memory() | |
self.register(self._model_worker) | |
self.start_link() | |
self._send_callback_loop = Thread(target=self._loop_send_callback, daemon=True) | |
self._send_callback_loop.start() | |
def shutdown(self, timeout: Optional[float] = None) -> None: | |
super().shutdown(timeout) | |
self._send_callback_loop = None | |
self._send_callbacks = {} | |
def _loop_send_callback(self): | |
while True: | |
payload = self.recv(ignore_err=True) | |
if payload.err: | |
logging.warning("Got error when loading data: {}".format(payload.err)) | |
if payload.req_id in self._send_callbacks: | |
del self._send_callbacks[payload.req_id] | |
else: | |
if payload.req_id in self._send_callbacks: | |
callback = self._send_callbacks.pop(payload.req_id) | |
callback(payload.data) | |
def load(self, storage: Storage) -> object: | |
""" | |
Overview: | |
Load model synchronously. | |
Arguments: | |
- storage (:obj:`Stroage`): The model should be wrapped in a storage object, e.g. FileModelStorage. | |
Returns: | |
- object (:obj:): The loaded model. | |
""" | |
return storage.load() | |
def save(self, callback: Callable) -> Storage: | |
""" | |
Overview: | |
Save model asynchronously. | |
Arguments: | |
- callback (:obj:`Callable`): The callback function after saving model. | |
Returns: | |
- storage (:obj:`Storage`): The storage object is created synchronously, so it can be returned. | |
""" | |
raise NotImplementedError | |
class FileModelLoader(ModelLoader): | |
def __init__(self, model: torch.nn.Module, dirname: str, ttl: int = 20) -> None: | |
""" | |
Overview: | |
Model loader using files as storage media. | |
Arguments: | |
- model (:obj:`torch.nn.Module`): Torch module. | |
- dirname (:obj:`str`): The directory for saving files. | |
- ttl (:obj:`int`): Files will be automatically cleaned after ttl. Note that \ | |
files that do not time out when the process is stopped are not cleaned up \ | |
(to avoid errors when other processes read the file), so you may need to \ | |
clean up the remaining files manually | |
""" | |
super().__init__(model) | |
self._dirname = dirname | |
self._ttl = ttl | |
self._files = [] | |
self._cleanup_thread = None | |
def _start_cleanup(self): | |
""" | |
Overview: | |
Start a cleanup thread to clean up files that are taking up too much time on the disk. | |
""" | |
if self._cleanup_thread is None: | |
self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True) | |
self._cleanup_thread.start() | |
def shutdown(self, timeout: Optional[float] = None) -> None: | |
super().shutdown(timeout) | |
self._cleanup_thread = None | |
def _loop_cleanup(self): | |
while True: | |
if len(self._files) == 0 or time() - self._files[0][0] < self._ttl: | |
sleep(1) | |
continue | |
_, file_path = self._files.pop(0) | |
if path.exists(file_path): | |
os.remove(file_path) | |
def save(self, callback: Callable) -> FileModelStorage: | |
if not self._running: | |
logging.warning("Please start model loader before saving model.") | |
return | |
if not path.exists(self._dirname): | |
os.mkdir(self._dirname) | |
file_path = "model_{}.pth.tar".format(uuid.uuid1()) | |
file_path = path.join(self._dirname, file_path) | |
model_storage = FileModelStorage(file_path) | |
payload = SendPayload(proc_id=0, method="save", args=[model_storage]) | |
self.send(payload) | |
def clean_callback(storage: Storage): | |
self._files.append([time(), file_path]) | |
callback(storage) | |
self._send_callbacks[payload.req_id] = clean_callback | |
self._start_cleanup() | |
return model_storage | |