|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from pathlib import Path, PurePosixPath |
|
from typing import Any, Dict, Optional |
|
|
|
from .base import PathLike |
|
from .common import TrainingServiceConfig |
|
from . import util |
|
|
|
__all__ = ['OpenpaiConfig'] |
|
|
|
@dataclass(init=False) |
|
class OpenpaiConfig(TrainingServiceConfig): |
|
platform: str = 'openpai' |
|
host: str |
|
username: str |
|
token: str |
|
trial_cpu_number: int |
|
trial_memory_size: str |
|
storage_config_name: str |
|
docker_image: str = 'msranni/nni:latest' |
|
local_storage_mount_point: PathLike |
|
container_storage_mount_point: str |
|
reuse_mode: bool = True |
|
|
|
openpai_config: Optional[Dict[str, Any]] = None |
|
openpai_config_file: Optional[PathLike] = None |
|
|
|
_canonical_rules = { |
|
'host': lambda value: 'https://' + value if '://' not in value else value, |
|
'local_storage_mount_point': util.canonical_path, |
|
'openpai_config_file': util.canonical_path |
|
} |
|
|
|
_validation_rules = { |
|
'platform': lambda value: (value == 'openpai', 'cannot be modified'), |
|
'local_storage_mount_point': lambda value: Path(value).is_dir(), |
|
'container_storage_mount_point': lambda value: (PurePosixPath(value).is_absolute(), 'is not absolute'), |
|
'openpai_config_file': lambda value: Path(value).is_file() |
|
} |
|
|
|
def validate(self) -> None: |
|
super().validate() |
|
if self.openpai_config is not None and self.openpai_config_file is not None: |
|
raise ValueError('openpai_config and openpai_config_file can only be set one') |
|
|