File size: 2,100 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union
import warnings

from .base import ConfigBase, PathLike
from .common import TrainingServiceConfig
from . import util

__all__ = ['RemoteConfig', 'RemoteMachineConfig']

@dataclass(init=False)
class RemoteMachineConfig(ConfigBase):
    host: str
    port: int = 22
    user: str
    password: Optional[str] = None
    ssh_key_file: PathLike = None #'~/.ssh/id_rsa'
    ssh_passphrase: Optional[str] = None
    use_active_gpu: bool = False
    max_trial_number_per_gpu: int = 1
    gpu_indices: Union[List[int], str, int, None] = None
    python_path: Optional[str] = None

    _canonical_rules = {
        'ssh_key_file': util.canonical_path,
        'gpu_indices': util.canonical_gpu_indices
    }

    _validation_rules = {
        'port': lambda value: 0 < value < 65536,
        'max_trial_number_per_gpu': lambda value: value > 0,
        'gpu_indices': lambda value: all(idx >= 0 for idx in value) and len(value) == len(set(value))
    }

    def validate(self):
        super().validate()
        if self.password is None and not Path(self.ssh_key_file).is_file():
            raise ValueError(f'Password is not provided and cannot find SSH key file "{self.ssh_key_file}"')
        if self.password:
            warnings.warn('Password will be exposed through web UI in plain text. We recommend to use SSH key file.')

@dataclass(init=False)
class RemoteConfig(TrainingServiceConfig):
    platform: str = 'remote'
    reuse_mode: bool = True
    machine_list: List[RemoteMachineConfig]

    def __init__(self, **kwargs):
        kwargs = util.case_insensitive(kwargs)
        kwargs['machinelist'] = util.load_config(RemoteMachineConfig, kwargs.get('machinelist'))
        super().__init__(**kwargs)

    _canonical_rules = {
        'machine_list': lambda value: [config.canonical() for config in value]
    }

    _validation_rules = {
        'platform': lambda value: (value == 'remote', 'cannot be modified')
    }