LINC-BIT's picture
Upload 1912 files
b84549f verified
import functools
from peewee import fn
from playhouse.shortcuts import model_to_dict
from .model import NdsTrialStats, NdsTrialConfig
def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_spec, dataset,
num_epochs=None, reduction=None, include_intermediates=False):
"""
Query trial stats of NDS given conditions.
Parameters
----------
model_family : str or None
If str, can be one of the model families available in :class:`nni.nas.benchmark.nds.NdsTrialConfig`.
Otherwise a wildcard.
proposer : str or None
If str, can be one of the proposers available in :class:`nni.nas.benchmark.nds.NdsTrialConfig`. Otherwise a wildcard.
generator : str or None
If str, can be one of the generators available in :class:`nni.nas.benchmark.nds.NdsTrialConfig`. Otherwise a wildcard.
model_spec : dict or None
If specified, can be one of the model spec available in :class:`nni.nas.benchmark.nds.NdsTrialConfig`.
Otherwise a wildcard.
cell_spec : dict or None
If specified, can be one of the cell spec available in :class:`nni.nas.benchmark.nds.NdsTrialConfig`.
Otherwise a wildcard.
dataset : str or None
If str, can be one of the datasets available in :class:`nni.nas.benchmark.nds.NdsTrialConfig`. Otherwise a wildcard.
num_epochs : float or None
If int, matching results will be returned. Otherwise a wildcard.
reduction : str or None
If 'none' or None, all trial stats will be returned directly.
If 'mean', fields in trial stats will be averaged given the same trial config.
include_intermediates : boolean
If true, intermediate results will be returned.
Returns
-------
generator of dict
A generator of :class:`nni.nas.benchmark.nds.NdsTrialStats` objects,
where each of them has been converted into a dict.
"""
fields = []
if reduction == 'none':
reduction = None
if reduction == 'mean':
for field_name in NdsTrialStats._meta.sorted_field_names:
if field_name not in ['id', 'config', 'seed']:
fields.append(fn.AVG(getattr(NdsTrialStats, field_name)).alias(field_name))
elif reduction is None:
fields.append(NdsTrialStats)
else:
raise ValueError('Unsupported reduction: \'%s\'' % reduction)
query = NdsTrialStats.select(*fields, NdsTrialConfig).join(NdsTrialConfig)
conditions = []
for field_name in ['model_family', 'proposer', 'generator', 'model_spec', 'cell_spec',
'dataset', 'num_epochs']:
if locals()[field_name] is not None:
conditions.append(getattr(NdsTrialConfig, field_name) == locals()[field_name])
if conditions:
query = query.where(functools.reduce(lambda a, b: a & b, conditions))
if reduction is not None:
query = query.group_by(NdsTrialStats.config)
for trial in query:
if include_intermediates:
data = model_to_dict(trial)
# exclude 'trial' from intermediates as it is already available in data
data['intermediates'] = [
{k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates
]
yield data
else:
yield model_to_dict(trial)