|
|
|
|
|
|
|
from typing import (Any, Iterable, List, Optional) |
|
|
|
from .graph import Model, Mutation, ModelStatus |
|
|
|
|
|
__all__ = ['Sampler', 'Mutator'] |
|
|
|
|
|
Choice = Any |
|
|
|
|
|
class Sampler: |
|
""" |
|
Handles `Mutator.choice()` calls. |
|
""" |
|
|
|
def choice(self, candidates: List[Choice], mutator: 'Mutator', model: Model, index: int) -> Choice: |
|
raise NotImplementedError() |
|
|
|
def mutation_start(self, mutator: 'Mutator', model: Model) -> None: |
|
pass |
|
|
|
def mutation_end(self, mutator: 'Mutator', model: Model) -> None: |
|
pass |
|
|
|
|
|
class Mutator: |
|
""" |
|
Mutates graphs in model to generate new model. |
|
`Mutator` class will be used in two places: |
|
|
|
1. Inherit `Mutator` to implement graph mutation logic. |
|
2. Use `Mutator` subclass to implement NAS strategy. |
|
|
|
In scenario 1, the subclass should implement `Mutator.mutate()` interface with `Mutator.choice()`. |
|
In scenario 2, strategy should use constructor or `Mutator.bind_sampler()` to initialize subclass, |
|
and then use `Mutator.apply()` to mutate model. |
|
For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates. |
|
# Method names are open for discussion. |
|
|
|
If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label. |
|
""" |
|
|
|
def __init__(self, sampler: Optional[Sampler] = None, label: Optional[str] = None): |
|
self.sampler: Optional[Sampler] = sampler |
|
self.label: Optional[str] = label |
|
self._cur_model: Optional[Model] = None |
|
self._cur_choice_idx: Optional[int] = None |
|
|
|
def bind_sampler(self, sampler: Sampler) -> 'Mutator': |
|
""" |
|
Set the sampler which will handle `Mutator.choice` calls. |
|
""" |
|
self.sampler = sampler |
|
return self |
|
|
|
def apply(self, model: Model) -> Model: |
|
""" |
|
Apply this mutator on a model. |
|
Returns mutated model. |
|
The model will be copied before mutation and the original model will not be modified. |
|
""" |
|
assert self.sampler is not None |
|
copy = model.fork() |
|
self._cur_model = copy |
|
self._cur_choice_idx = 0 |
|
self._cur_samples = [] |
|
self.sampler.mutation_start(self, copy) |
|
self.mutate(copy) |
|
self.sampler.mutation_end(self, copy) |
|
copy.history.append(Mutation(self, self._cur_samples, model, copy)) |
|
copy.status = ModelStatus.Frozen |
|
self._cur_model = None |
|
self._cur_choice_idx = None |
|
return copy |
|
|
|
def dry_run(self, model: Model) -> List[List[Choice]]: |
|
""" |
|
Dry run mutator on a model to collect choice candidates. |
|
If you invoke this method multiple times on same or different models, |
|
it may or may not return identical results, depending on how the subclass implements `Mutator.mutate()`. |
|
""" |
|
sampler_backup = self.sampler |
|
recorder = _RecorderSampler() |
|
self.sampler = recorder |
|
new_model = self.apply(model) |
|
self.sampler = sampler_backup |
|
return recorder.recorded_candidates, new_model |
|
|
|
def mutate(self, model: Model) -> None: |
|
""" |
|
Abstract method to be implemented by subclass. |
|
Mutate a model in place. |
|
""" |
|
raise NotImplementedError() |
|
|
|
def choice(self, candidates: Iterable[Choice]) -> Choice: |
|
""" |
|
Ask sampler to make a choice. |
|
""" |
|
assert self.sampler is not None and self._cur_model is not None and self._cur_choice_idx is not None |
|
ret = self.sampler.choice(list(candidates), self, self._cur_model, self._cur_choice_idx) |
|
self._cur_samples.append(ret) |
|
self._cur_choice_idx += 1 |
|
return ret |
|
|
|
|
|
class _RecorderSampler(Sampler): |
|
def __init__(self): |
|
self.recorded_candidates: List[List[Choice]] = [] |
|
|
|
def choice(self, candidates: List[Choice], *args) -> Choice: |
|
self.recorded_candidates.append(candidates) |
|
return candidates[0] |
|
|