Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Image classification configuration definition.""" | |
import dataclasses | |
import os | |
from typing import List, Optional, Tuple, Union, Sequence | |
from official.core import config_definitions as cfg | |
from official.core import exp_factory | |
from official.modeling import hyperparams | |
from official.modeling import optimization | |
from official.vision.configs import common | |
from official.vision.configs import backbones | |
class DataConfig(cfg.DataConfig): | |
"""Input config for training.""" | |
input_path: Union[Sequence[str], str, hyperparams.Config] = '' | |
weights: Optional[hyperparams.base_config.Config] = None | |
global_batch_size: int = 0 | |
is_training: bool = True | |
dtype: str = 'float32' | |
shuffle_buffer_size: int = 10000 | |
cycle_length: int = 10 | |
is_multilabel: bool = False | |
aug_rand_hflip: bool = True | |
aug_crop: Optional[bool] = True | |
crop_area_range: Optional[Tuple[float, float]] = (0.08, 1.0) | |
aug_type: Optional[ | |
common.Augmentation] = None # Choose from AutoAugment and RandAugment. | |
three_augment: bool = False | |
color_jitter: float = 0. | |
random_erasing: Optional[common.RandomErasing] = None | |
file_type: str = 'tfrecord' | |
image_field_key: str = 'image/encoded' | |
label_field_key: str = 'image/class/label' | |
decode_jpeg_only: bool = True | |
mixup_and_cutmix: Optional[common.MixupAndCutmix] = None | |
decoder: Optional[common.DataDecoder] = dataclasses.field( | |
default_factory=common.DataDecoder | |
) | |
# Keep for backward compatibility. | |
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'. | |
randaug_magnitude: Optional[int] = 10 | |
# Determines ratio between the side of the cropped image and the short side of | |
# the original image. | |
center_crop_fraction: Optional[float] = 0.875 | |
# Interpolation method for resizing image in Parser for both training and eval | |
tf_resize_method: str = 'bilinear' | |
# Repeat augmentation puts multiple augmentations of the same image in a batch | |
# https://arxiv.org/abs/1902.05509 | |
repeated_augment: Optional[int] = None | |
class ImageClassificationModel(hyperparams.Config): | |
"""The model config.""" | |
num_classes: int = 0 | |
input_size: List[int] = dataclasses.field(default_factory=list) | |
backbone: backbones.Backbone = dataclasses.field( | |
default_factory=lambda: backbones.Backbone( # pylint: disable=g-long-lambda | |
type='resnet', resnet=backbones.ResNet() | |
) | |
) | |
dropout_rate: float = 0.0 | |
norm_activation: common.NormActivation = dataclasses.field( | |
default_factory=lambda: common.NormActivation(use_sync_bn=False) | |
) | |
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification | |
add_head_batch_norm: bool = False | |
kernel_initializer: str = 'random_uniform' | |
# Whether to output softmax results instead of logits. | |
output_softmax: bool = False | |
class Losses(hyperparams.Config): | |
loss_weight: float = 1.0 | |
one_hot: bool = True | |
label_smoothing: float = 0.0 | |
l2_weight_decay: float = 0.0 | |
soft_labels: bool = False | |
# Converts multi-class classification to multi-label classification. Weights | |
# each object class equally in the loss function, ignoring their size. | |
use_binary_cross_entropy: bool = False | |
class Evaluation(hyperparams.Config): | |
top_k: int = 5 | |
precision_and_recall_thresholds: Optional[List[float]] = None | |
report_per_class_precision_and_recall: bool = False | |
class ImageClassificationTask(cfg.TaskConfig): | |
"""The task config.""" | |
model: ImageClassificationModel = dataclasses.field( | |
default_factory=ImageClassificationModel | |
) | |
train_data: DataConfig = dataclasses.field( | |
default_factory=lambda: DataConfig(is_training=True) | |
) | |
validation_data: DataConfig = dataclasses.field( | |
default_factory=lambda: DataConfig(is_training=False) | |
) | |
losses: Losses = dataclasses.field(default_factory=Losses) | |
evaluation: Evaluation = dataclasses.field(default_factory=Evaluation) | |
train_input_partition_dims: Optional[List[int]] = dataclasses.field( | |
default_factory=list) | |
eval_input_partition_dims: Optional[List[int]] = dataclasses.field( | |
default_factory=list) | |
init_checkpoint: Optional[str] = None | |
init_checkpoint_modules: str = 'all' # all or backbone | |
model_output_keys: Optional[List[int]] = dataclasses.field( | |
default_factory=list) | |
freeze_backbone: bool = False | |
def image_classification() -> cfg.ExperimentConfig: | |
"""Image classification general.""" | |
return cfg.ExperimentConfig( | |
task=ImageClassificationTask(), | |
trainer=cfg.TrainerConfig(), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None' | |
]) | |
IMAGENET_TRAIN_EXAMPLES = 1281167 | |
IMAGENET_VAL_EXAMPLES = 50000 | |
IMAGENET_INPUT_PATH_BASE = 'imagenet-2012-tfrecord' | |
def image_classification_imagenet() -> cfg.ExperimentConfig: | |
"""Image classification on imagenet with resnet.""" | |
train_batch_size = 4096 | |
eval_batch_size = 4096 | |
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size | |
config = cfg.ExperimentConfig( | |
runtime=cfg.RuntimeConfig(enable_xla=True), | |
task=ImageClassificationTask( | |
model=ImageClassificationModel( | |
num_classes=1001, | |
input_size=[224, 224, 3], | |
backbone=backbones.Backbone( | |
type='resnet', resnet=backbones.ResNet(model_id=50)), | |
norm_activation=common.NormActivation( | |
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)), | |
losses=Losses(l2_weight_decay=1e-4), | |
train_data=DataConfig( | |
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'), | |
is_training=True, | |
global_batch_size=train_batch_size), | |
validation_data=DataConfig( | |
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'), | |
is_training=False, | |
global_batch_size=eval_batch_size)), | |
trainer=cfg.TrainerConfig( | |
steps_per_loop=steps_per_epoch, | |
summary_interval=steps_per_epoch, | |
checkpoint_interval=steps_per_epoch, | |
train_steps=90 * steps_per_epoch, | |
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size, | |
validation_interval=steps_per_epoch, | |
optimizer_config=optimization.OptimizationConfig({ | |
'optimizer': { | |
'type': 'sgd', | |
'sgd': { | |
'momentum': 0.9 | |
} | |
}, | |
'learning_rate': { | |
'type': 'stepwise', | |
'stepwise': { | |
'boundaries': [ | |
30 * steps_per_epoch, 60 * steps_per_epoch, | |
80 * steps_per_epoch | |
], | |
'values': [ | |
0.1 * train_batch_size / 256, | |
0.01 * train_batch_size / 256, | |
0.001 * train_batch_size / 256, | |
0.0001 * train_batch_size / 256, | |
] | |
} | |
}, | |
'warmup': { | |
'type': 'linear', | |
'linear': { | |
'warmup_steps': 5 * steps_per_epoch, | |
'warmup_learning_rate': 0 | |
} | |
} | |
})), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None' | |
]) | |
return config | |
def image_classification_imagenet_resnetrs() -> cfg.ExperimentConfig: | |
"""Image classification on imagenet with resnet-rs.""" | |
train_batch_size = 4096 | |
eval_batch_size = 4096 | |
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size | |
config = cfg.ExperimentConfig( | |
task=ImageClassificationTask( | |
model=ImageClassificationModel( | |
num_classes=1001, | |
input_size=[160, 160, 3], | |
backbone=backbones.Backbone( | |
type='resnet', | |
resnet=backbones.ResNet( | |
model_id=50, | |
stem_type='v1', | |
resnetd_shortcut=True, | |
replace_stem_max_pool=True, | |
se_ratio=0.25, | |
stochastic_depth_drop_rate=0.0)), | |
dropout_rate=0.25, | |
norm_activation=common.NormActivation( | |
norm_momentum=0.0, | |
norm_epsilon=1e-5, | |
use_sync_bn=False, | |
activation='swish')), | |
losses=Losses(l2_weight_decay=4e-5, label_smoothing=0.1), | |
train_data=DataConfig( | |
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'), | |
is_training=True, | |
global_batch_size=train_batch_size, | |
aug_type=common.Augmentation( | |
type='randaug', randaug=common.RandAugment(magnitude=10))), | |
validation_data=DataConfig( | |
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'), | |
is_training=False, | |
global_batch_size=eval_batch_size)), | |
trainer=cfg.TrainerConfig( | |
steps_per_loop=steps_per_epoch, | |
summary_interval=steps_per_epoch, | |
checkpoint_interval=steps_per_epoch, | |
train_steps=350 * steps_per_epoch, | |
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size, | |
validation_interval=steps_per_epoch, | |
optimizer_config=optimization.OptimizationConfig({ | |
'optimizer': { | |
'type': 'sgd', | |
'sgd': { | |
'momentum': 0.9 | |
} | |
}, | |
'ema': { | |
'average_decay': 0.9999, | |
'trainable_weights_only': False, | |
}, | |
'learning_rate': { | |
'type': 'cosine', | |
'cosine': { | |
'initial_learning_rate': 1.6, | |
'decay_steps': 350 * steps_per_epoch | |
} | |
}, | |
'warmup': { | |
'type': 'linear', | |
'linear': { | |
'warmup_steps': 5 * steps_per_epoch, | |
'warmup_learning_rate': 0 | |
} | |
} | |
})), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None' | |
]) | |
return config | |
def image_classification_imagenet_revnet() -> cfg.ExperimentConfig: | |
"""Returns a revnet config for image classification on imagenet.""" | |
train_batch_size = 4096 | |
eval_batch_size = 4096 | |
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size | |
config = cfg.ExperimentConfig( | |
task=ImageClassificationTask( | |
model=ImageClassificationModel( | |
num_classes=1001, | |
input_size=[224, 224, 3], | |
backbone=backbones.Backbone( | |
type='revnet', revnet=backbones.RevNet(model_id=56)), | |
norm_activation=common.NormActivation( | |
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False), | |
add_head_batch_norm=True), | |
losses=Losses(l2_weight_decay=1e-4), | |
train_data=DataConfig( | |
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'), | |
is_training=True, | |
global_batch_size=train_batch_size), | |
validation_data=DataConfig( | |
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'), | |
is_training=False, | |
global_batch_size=eval_batch_size)), | |
trainer=cfg.TrainerConfig( | |
steps_per_loop=steps_per_epoch, | |
summary_interval=steps_per_epoch, | |
checkpoint_interval=steps_per_epoch, | |
train_steps=90 * steps_per_epoch, | |
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size, | |
validation_interval=steps_per_epoch, | |
optimizer_config=optimization.OptimizationConfig({ | |
'optimizer': { | |
'type': 'sgd', | |
'sgd': { | |
'momentum': 0.9 | |
} | |
}, | |
'learning_rate': { | |
'type': 'stepwise', | |
'stepwise': { | |
'boundaries': [ | |
30 * steps_per_epoch, 60 * steps_per_epoch, | |
80 * steps_per_epoch | |
], | |
'values': [0.8, 0.08, 0.008, 0.0008] | |
} | |
}, | |
'warmup': { | |
'type': 'linear', | |
'linear': { | |
'warmup_steps': 5 * steps_per_epoch, | |
'warmup_learning_rate': 0 | |
} | |
} | |
})), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None' | |
]) | |
return config | |
def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig: | |
"""Image classification on imagenet with mobilenet.""" | |
train_batch_size = 4096 | |
eval_batch_size = 4096 | |
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size | |
config = cfg.ExperimentConfig( | |
task=ImageClassificationTask( | |
model=ImageClassificationModel( | |
num_classes=1001, | |
dropout_rate=0.2, | |
input_size=[224, 224, 3], | |
backbone=backbones.Backbone( | |
type='mobilenet', | |
mobilenet=backbones.MobileNet( | |
model_id='MobileNetV2', filter_size_scale=1.0)), | |
norm_activation=common.NormActivation( | |
norm_momentum=0.997, norm_epsilon=1e-3, use_sync_bn=False)), | |
losses=Losses(l2_weight_decay=1e-5, label_smoothing=0.1), | |
train_data=DataConfig( | |
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'), | |
is_training=True, | |
global_batch_size=train_batch_size), | |
validation_data=DataConfig( | |
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'), | |
is_training=False, | |
global_batch_size=eval_batch_size)), | |
trainer=cfg.TrainerConfig( | |
steps_per_loop=steps_per_epoch, | |
summary_interval=steps_per_epoch, | |
checkpoint_interval=steps_per_epoch, | |
train_steps=500 * steps_per_epoch, | |
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size, | |
validation_interval=steps_per_epoch, | |
optimizer_config=optimization.OptimizationConfig({ | |
'optimizer': { | |
'type': 'rmsprop', | |
'rmsprop': { | |
'rho': 0.9, | |
'momentum': 0.9, | |
'epsilon': 0.002, | |
} | |
}, | |
'learning_rate': { | |
'type': 'exponential', | |
'exponential': { | |
'initial_learning_rate': | |
0.008 * (train_batch_size // 128), | |
'decay_steps': | |
int(2.5 * steps_per_epoch), | |
'decay_rate': | |
0.98, | |
'staircase': | |
True | |
} | |
}, | |
'warmup': { | |
'type': 'linear', | |
'linear': { | |
'warmup_steps': 5 * steps_per_epoch, | |
'warmup_learning_rate': 0 | |
} | |
}, | |
})), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None' | |
]) | |
return config | |
def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig: | |
"""Image classification on imagenet with vision transformer.""" | |
train_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32 | |
eval_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32 | |
label_smoothing = 0.1 | |
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size | |
config = cfg.ExperimentConfig( | |
task=ImageClassificationTask( | |
model=ImageClassificationModel( | |
num_classes=1001, | |
input_size=[224, 224, 3], | |
kernel_initializer='zeros', | |
backbone=backbones.Backbone( | |
type='vit', | |
vit=backbones.VisionTransformer( | |
model_name='vit-b16', | |
representation_size=768, | |
init_stochastic_depth_rate=0.1, | |
original_init=False, | |
transformer=backbones.Transformer( | |
dropout_rate=0.0, attention_dropout_rate=0.0)))), | |
losses=Losses( | |
l2_weight_decay=0.0, | |
label_smoothing=label_smoothing, | |
one_hot=False, | |
soft_labels=True), | |
train_data=DataConfig( | |
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'), | |
is_training=True, | |
global_batch_size=train_batch_size, | |
aug_type=common.Augmentation( | |
type='randaug', | |
randaug=common.RandAugment( | |
magnitude=9, exclude_ops=['Cutout'])), | |
mixup_and_cutmix=common.MixupAndCutmix( | |
label_smoothing=label_smoothing)), | |
validation_data=DataConfig( | |
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'), | |
is_training=False, | |
global_batch_size=eval_batch_size)), | |
trainer=cfg.TrainerConfig( | |
steps_per_loop=steps_per_epoch, | |
summary_interval=steps_per_epoch, | |
checkpoint_interval=steps_per_epoch, | |
train_steps=300 * steps_per_epoch, | |
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size, | |
validation_interval=steps_per_epoch, | |
optimizer_config=optimization.OptimizationConfig({ | |
'optimizer': { | |
'type': 'adamw', | |
'adamw': { | |
'weight_decay_rate': 0.05, | |
'include_in_weight_decay': r'.*(kernel|weight):0$', | |
'gradient_clip_norm': 0.0 | |
} | |
}, | |
'learning_rate': { | |
'type': 'cosine', | |
'cosine': { | |
'initial_learning_rate': 0.0005 * train_batch_size / 512, | |
'decay_steps': 300 * steps_per_epoch, | |
} | |
}, | |
'warmup': { | |
'type': 'linear', | |
'linear': { | |
'warmup_steps': 5 * steps_per_epoch, | |
'warmup_learning_rate': 0 | |
} | |
} | |
})), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None' | |
]) | |
return config | |
def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig: | |
"""Image classification on imagenet with vision transformer.""" | |
train_batch_size = 4096 | |
eval_batch_size = 4096 | |
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size | |
config = cfg.ExperimentConfig( | |
task=ImageClassificationTask( | |
model=ImageClassificationModel( | |
num_classes=1001, | |
input_size=[224, 224, 3], | |
kernel_initializer='zeros', | |
backbone=backbones.Backbone( | |
type='vit', | |
vit=backbones.VisionTransformer( | |
model_name='vit-b16', representation_size=768))), | |
losses=Losses(l2_weight_decay=0.0), | |
train_data=DataConfig( | |
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'), | |
is_training=True, | |
global_batch_size=train_batch_size), | |
validation_data=DataConfig( | |
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'), | |
is_training=False, | |
global_batch_size=eval_batch_size)), | |
trainer=cfg.TrainerConfig( | |
steps_per_loop=steps_per_epoch, | |
summary_interval=steps_per_epoch, | |
checkpoint_interval=steps_per_epoch, | |
train_steps=300 * steps_per_epoch, | |
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size, | |
validation_interval=steps_per_epoch, | |
optimizer_config=optimization.OptimizationConfig({ | |
'optimizer': { | |
'type': 'adamw', | |
'adamw': { | |
'weight_decay_rate': 0.3, | |
'include_in_weight_decay': r'.*(kernel|weight):0$', | |
'gradient_clip_norm': 0.0 | |
} | |
}, | |
'learning_rate': { | |
'type': 'cosine', | |
'cosine': { | |
'initial_learning_rate': 0.003 * train_batch_size / 4096, | |
'decay_steps': 300 * steps_per_epoch, | |
} | |
}, | |
'warmup': { | |
'type': 'linear', | |
'linear': { | |
'warmup_steps': 10000, | |
'warmup_learning_rate': 0 | |
} | |
} | |
})), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None' | |
]) | |
return config | |
def image_classification_imagenet_vit_finetune() -> cfg.ExperimentConfig: | |
"""Image classification on imagenet with vision transformer.""" | |
train_batch_size = 512 | |
eval_batch_size = 512 | |
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size | |
config = cfg.ExperimentConfig( | |
task=ImageClassificationTask( | |
model=ImageClassificationModel( | |
num_classes=1001, | |
input_size=[384, 384, 3], | |
backbone=backbones.Backbone( | |
type='vit', | |
vit=backbones.VisionTransformer(model_name='vit-b16'))), | |
losses=Losses(l2_weight_decay=0.0), | |
train_data=DataConfig( | |
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'), | |
is_training=True, | |
global_batch_size=train_batch_size), | |
validation_data=DataConfig( | |
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'), | |
is_training=False, | |
global_batch_size=eval_batch_size)), | |
trainer=cfg.TrainerConfig( | |
steps_per_loop=steps_per_epoch, | |
summary_interval=steps_per_epoch, | |
checkpoint_interval=steps_per_epoch, | |
train_steps=20000, | |
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size, | |
validation_interval=steps_per_epoch, | |
optimizer_config=optimization.OptimizationConfig({ | |
'optimizer': { | |
'type': 'sgd', | |
'sgd': { | |
'momentum': 0.9, | |
'global_clipnorm': 1.0, | |
} | |
}, | |
'learning_rate': { | |
'type': 'cosine', | |
'cosine': { | |
'initial_learning_rate': 0.003, | |
'decay_steps': 20000, | |
} | |
} | |
})), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None' | |
]) | |
return config | |