File size: 552 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from dataclasses import dataclass

from .common import TrainingServiceConfig

__all__ = ['AmlConfig']

@dataclass(init=False)
class AmlConfig(TrainingServiceConfig):
    platform: str = 'aml'
    subscription_id: str
    resource_group: str
    workspace_name: str
    compute_target: str
    docker_image: str = 'msranni/nni:latest'
    max_trial_number_per_gpu: int = 1

    _validation_rules = {
        'platform': lambda value: (value == 'aml', 'cannot be modified')
    }