File size: 1,012 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import collections
from typing import Dict, Any, List
from ..graph import Model
from ..mutator import Mutator, Sampler


class _FixedSampler(Sampler):
    def __init__(self, sample):
        self.sample = sample

    def choice(self, candidates, mutator, model, index):
        return self.sample[(mutator, index)]


def dry_run_for_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any, List[Any]]:
    search_space = collections.OrderedDict()
    for mutator in mutators:
        recorded_candidates, model = mutator.dry_run(model)
        for i, candidates in enumerate(recorded_candidates):
            search_space[(mutator, i)] = candidates
    return search_space


def get_targeted_model(base_model: Model, mutators: List[Mutator], sample: dict) -> Model:
    sampler = _FixedSampler(sample)
    model = base_model
    for mutator in mutators:
        model = mutator.bind_sampler(sampler).apply(model)
    return model