LINC-BIT's picture
Upload 1912 files
b84549f verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import (Any, Iterable, List, Optional)
from .graph import Model, Mutation, ModelStatus
__all__ = ['Sampler', 'Mutator']
Choice = Any
class Sampler:
"""
Handles `Mutator.choice()` calls.
"""
def choice(self, candidates: List[Choice], mutator: 'Mutator', model: Model, index: int) -> Choice:
raise NotImplementedError()
def mutation_start(self, mutator: 'Mutator', model: Model) -> None:
pass
def mutation_end(self, mutator: 'Mutator', model: Model) -> None:
pass
class Mutator:
"""
Mutates graphs in model to generate new model.
`Mutator` class will be used in two places:
1. Inherit `Mutator` to implement graph mutation logic.
2. Use `Mutator` subclass to implement NAS strategy.
In scenario 1, the subclass should implement `Mutator.mutate()` interface with `Mutator.choice()`.
In scenario 2, strategy should use constructor or `Mutator.bind_sampler()` to initialize subclass,
and then use `Mutator.apply()` to mutate model.
For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates.
# Method names are open for discussion.
If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label.
"""
def __init__(self, sampler: Optional[Sampler] = None, label: Optional[str] = None):
self.sampler: Optional[Sampler] = sampler
self.label: Optional[str] = label
self._cur_model: Optional[Model] = None
self._cur_choice_idx: Optional[int] = None
def bind_sampler(self, sampler: Sampler) -> 'Mutator':
"""
Set the sampler which will handle `Mutator.choice` calls.
"""
self.sampler = sampler
return self
def apply(self, model: Model) -> Model:
"""
Apply this mutator on a model.
Returns mutated model.
The model will be copied before mutation and the original model will not be modified.
"""
assert self.sampler is not None
copy = model.fork()
self._cur_model = copy
self._cur_choice_idx = 0
self._cur_samples = []
self.sampler.mutation_start(self, copy)
self.mutate(copy)
self.sampler.mutation_end(self, copy)
copy.history.append(Mutation(self, self._cur_samples, model, copy))
copy.status = ModelStatus.Frozen
self._cur_model = None
self._cur_choice_idx = None
return copy
def dry_run(self, model: Model) -> List[List[Choice]]:
"""
Dry run mutator on a model to collect choice candidates.
If you invoke this method multiple times on same or different models,
it may or may not return identical results, depending on how the subclass implements `Mutator.mutate()`.
"""
sampler_backup = self.sampler
recorder = _RecorderSampler()
self.sampler = recorder
new_model = self.apply(model)
self.sampler = sampler_backup
return recorder.recorded_candidates, new_model
def mutate(self, model: Model) -> None:
"""
Abstract method to be implemented by subclass.
Mutate a model in place.
"""
raise NotImplementedError()
def choice(self, candidates: Iterable[Choice]) -> Choice:
"""
Ask sampler to make a choice.
"""
assert self.sampler is not None and self._cur_model is not None and self._cur_choice_idx is not None
ret = self.sampler.choice(list(candidates), self, self._cur_model, self._cur_choice_idx)
self._cur_samples.append(ret)
self._cur_choice_idx += 1
return ret
class _RecorderSampler(Sampler):
def __init__(self):
self.recorded_candidates: List[List[Choice]] = []
def choice(self, candidates: List[Choice], *args) -> Choice:
self.recorded_candidates.append(candidates)
return candidates[0]