LINC-BIT's picture
Upload 1912 files
b84549f verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import collections
from typing import Dict, Any, List
from ..graph import Model
from ..mutator import Mutator, Sampler
class _FixedSampler(Sampler):
def __init__(self, sample):
self.sample = sample
def choice(self, candidates, mutator, model, index):
return self.sample[(mutator, index)]
def dry_run_for_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any, List[Any]]:
search_space = collections.OrderedDict()
for mutator in mutators:
recorded_candidates, model = mutator.dry_run(model)
for i, candidates in enumerate(recorded_candidates):
search_space[(mutator, i)] = candidates
return search_space
def get_targeted_model(base_model: Model, mutators: List[Mutator], sample: dict) -> Model:
sampler = _FixedSampler(sample)
model = base_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
return model