# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Union import warnings from .base import ConfigBase, PathLike from .common import TrainingServiceConfig from . import util __all__ = ['RemoteConfig', 'RemoteMachineConfig'] @dataclass(init=False) class RemoteMachineConfig(ConfigBase): host: str port: int = 22 user: str password: Optional[str] = None ssh_key_file: PathLike = None #'~/.ssh/id_rsa' ssh_passphrase: Optional[str] = None use_active_gpu: bool = False max_trial_number_per_gpu: int = 1 gpu_indices: Union[List[int], str, int, None] = None python_path: Optional[str] = None _canonical_rules = { 'ssh_key_file': util.canonical_path, 'gpu_indices': util.canonical_gpu_indices } _validation_rules = { 'port': lambda value: 0 < value < 65536, 'max_trial_number_per_gpu': lambda value: value > 0, 'gpu_indices': lambda value: all(idx >= 0 for idx in value) and len(value) == len(set(value)) } def validate(self): super().validate() if self.password is None and not Path(self.ssh_key_file).is_file(): raise ValueError(f'Password is not provided and cannot find SSH key file "{self.ssh_key_file}"') if self.password: warnings.warn('Password will be exposed through web UI in plain text. We recommend to use SSH key file.') @dataclass(init=False) class RemoteConfig(TrainingServiceConfig): platform: str = 'remote' reuse_mode: bool = True machine_list: List[RemoteMachineConfig] def __init__(self, **kwargs): kwargs = util.case_insensitive(kwargs) kwargs['machinelist'] = util.load_config(RemoteMachineConfig, kwargs.get('machinelist')) super().__init__(**kwargs) _canonical_rules = { 'machine_list': lambda value: [config.canonical() for config in value] } _validation_rules = { 'platform': lambda value: (value == 'remote', 'cannot be modified') }