LINC-BIT's picture
Upload 1912 files
b84549f verified
import os
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase
from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps
db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench101.db'), autoconnect=True)
class Nb101TrialConfig(Model):
"""
Trial config for NAS-Bench-101.
Attributes
----------
arch : dict
A dict with keys ``op1``, ``op2``, ... and ``input1``, ``input2``, ... Vertices are
enumerate from 0. Since node 0 is input node, it is skipped in this dict. Each ``op``
is one of :const:`nni.nas.benchmark.nasbench101.CONV3X3_BN_RELU`,
:const:`nni.nas.benchmark.nasbench101.CONV1X1_BN_RELU`, and :const:`nni.nas.benchmark.nasbench101.MAXPOOL3X3`.
Each ``input`` is a list of previous nodes. For example ``input5`` can be ``[0, 1, 3]``.
num_vertices : int
Number of vertices (nodes) in one cell. Should be less than or equal to 7 in default setup.
hash : str
Graph-invariant MD5 string for this architecture.
num_epochs : int
Number of epochs planned for this trial. Should be one of 4, 12, 36, 108 in default setup.
"""
arch = JSONField(json_dumps=json_dumps, index=True)
num_vertices = IntegerField(index=True)
hash = CharField(max_length=64, index=True)
num_epochs = IntegerField(index=True)
class Meta:
database = db
class Nb101TrialStats(Model):
"""
Computation statistics for NAS-Bench-101. Each corresponds to one trial.
Each config has multiple trials with different random seeds, but unfortunately seed for each trial is unavailable.
NAS-Bench-101 trains and evaluates on CIFAR-10 by default. The original training set is divided into
40k training images and 10k validation images, and the original validation set is used for test only.
Attributes
----------
config : Nb101TrialConfig
Setup for this trial data.
train_acc : float
Final accuracy on training data, ranging from 0 to 100.
valid_acc : float
Final accuracy on validation data, ranging from 0 to 100.
test_acc : float
Final accuracy on test data, ranging from 0 to 100.
parameters : float
Number of trainable parameters in million.
training_time : float
Duration of training in seconds.
"""
config = ForeignKeyField(Nb101TrialConfig, backref='trial_stats', index=True)
train_acc = FloatField()
valid_acc = FloatField()
test_acc = FloatField()
parameters = FloatField()
training_time = FloatField()
class Meta:
database = db
class Nb101IntermediateStats(Model):
"""
Intermediate statistics for NAS-Bench-101.
Attributes
----------
trial : Nb101TrialStats
The exact trial where the intermediate result is produced.
current_epoch : int
Elapsed epochs when evaluation is done.
train_acc : float
Intermediate accuracy on training data, ranging from 0 to 100.
valid_acc : float
Intermediate accuracy on validation data, ranging from 0 to 100.
test_acc : float
Intermediate accuracy on test data, ranging from 0 to 100.
training_time : float
Time elapsed in seconds.
"""
trial = ForeignKeyField(Nb101TrialStats, backref='intermediates', index=True)
current_epoch = IntegerField(index=True)
train_acc = FloatField()
valid_acc = FloatField()
test_acc = FloatField()
training_time = FloatField()
class Meta:
database = db