Spaces:
Running
Running
import os | |
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model | |
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase | |
from nni.nas.benchmarks.constants import DATABASE_DIR | |
from nni.nas.benchmarks.utils import json_dumps | |
db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench101.db'), autoconnect=True) | |
class Nb101TrialConfig(Model): | |
""" | |
Trial config for NAS-Bench-101. | |
Attributes | |
---------- | |
arch : dict | |
A dict with keys ``op1``, ``op2``, ... and ``input1``, ``input2``, ... Vertices are | |
enumerate from 0. Since node 0 is input node, it is skipped in this dict. Each ``op`` | |
is one of :const:`nni.nas.benchmark.nasbench101.CONV3X3_BN_RELU`, | |
:const:`nni.nas.benchmark.nasbench101.CONV1X1_BN_RELU`, and :const:`nni.nas.benchmark.nasbench101.MAXPOOL3X3`. | |
Each ``input`` is a list of previous nodes. For example ``input5`` can be ``[0, 1, 3]``. | |
num_vertices : int | |
Number of vertices (nodes) in one cell. Should be less than or equal to 7 in default setup. | |
hash : str | |
Graph-invariant MD5 string for this architecture. | |
num_epochs : int | |
Number of epochs planned for this trial. Should be one of 4, 12, 36, 108 in default setup. | |
""" | |
arch = JSONField(json_dumps=json_dumps, index=True) | |
num_vertices = IntegerField(index=True) | |
hash = CharField(max_length=64, index=True) | |
num_epochs = IntegerField(index=True) | |
class Meta: | |
database = db | |
class Nb101TrialStats(Model): | |
""" | |
Computation statistics for NAS-Bench-101. Each corresponds to one trial. | |
Each config has multiple trials with different random seeds, but unfortunately seed for each trial is unavailable. | |
NAS-Bench-101 trains and evaluates on CIFAR-10 by default. The original training set is divided into | |
40k training images and 10k validation images, and the original validation set is used for test only. | |
Attributes | |
---------- | |
config : Nb101TrialConfig | |
Setup for this trial data. | |
train_acc : float | |
Final accuracy on training data, ranging from 0 to 100. | |
valid_acc : float | |
Final accuracy on validation data, ranging from 0 to 100. | |
test_acc : float | |
Final accuracy on test data, ranging from 0 to 100. | |
parameters : float | |
Number of trainable parameters in million. | |
training_time : float | |
Duration of training in seconds. | |
""" | |
config = ForeignKeyField(Nb101TrialConfig, backref='trial_stats', index=True) | |
train_acc = FloatField() | |
valid_acc = FloatField() | |
test_acc = FloatField() | |
parameters = FloatField() | |
training_time = FloatField() | |
class Meta: | |
database = db | |
class Nb101IntermediateStats(Model): | |
""" | |
Intermediate statistics for NAS-Bench-101. | |
Attributes | |
---------- | |
trial : Nb101TrialStats | |
The exact trial where the intermediate result is produced. | |
current_epoch : int | |
Elapsed epochs when evaluation is done. | |
train_acc : float | |
Intermediate accuracy on training data, ranging from 0 to 100. | |
valid_acc : float | |
Intermediate accuracy on validation data, ranging from 0 to 100. | |
test_acc : float | |
Intermediate accuracy on test data, ranging from 0 to 100. | |
training_time : float | |
Time elapsed in seconds. | |
""" | |
trial = ForeignKeyField(Nb101TrialStats, backref='intermediates', index=True) | |
current_epoch = IntegerField(index=True) | |
train_acc = FloatField() | |
valid_acc = FloatField() | |
test_acc = FloatField() | |
training_time = FloatField() | |
class Meta: | |
database = db | |