Spaces:
Runtime error
Runtime error
import os | |
from pathlib import Path | |
from tops.config import LazyCall as L | |
import torch | |
import functools | |
from dp2.data.datasets.coco_cse import CocoCSE | |
from dp2.data.build import get_dataloader | |
from dp2.data.transforms.transforms import CreateEmbedding, Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip | |
from dp2.data.transforms.stylegan2_transform import StyleGANAugmentPipe | |
from dp2.metrics.torch_metrics import compute_metrics_iteratively | |
from .utils import final_eval_fn | |
dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data" | |
metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache" | |
data_dir = Path(dataset_base_dir, "coco_cse") | |
data = dict( | |
imsize=(288, 160), | |
im_channels=3, | |
semantic_nc=26, | |
cse_nc=16, | |
train=dict( | |
dataset=L(CocoCSE)(data_dir.joinpath("train"), transform=None, normalize_E=False), | |
loader=L(get_dataloader)( | |
shuffle=True, num_workers=6, drop_last=True, prefetch_factor=2, | |
batch_size="${train.batch_size}", | |
dataset="${..dataset}", | |
infinite=True, | |
gpu_transform=L(torch.nn.Sequential)(*[ | |
L(ToFloat)(), | |
L(StyleGANAugmentPipe)( | |
rotate=0.5, rotate_max=.05, | |
xint=.5, xint_max=0.05, | |
scale=.5, scale_std=.05, | |
aniso=0.5, aniso_std=.05, | |
xfrac=.5, xfrac_std=.05, | |
brightness=.5, brightness_std=.05, | |
contrast=.5, contrast_std=.1, | |
hue=.5, hue_max=.05, | |
saturation=.5, saturation_std=.5, | |
imgfilter=.5, imgfilter_std=.1), | |
L(RandomHorizontalFlip)(p=0.5), | |
L(CreateEmbedding)(), | |
L(Resize)(size="${data.imsize}"), | |
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True), | |
L(CreateCondition)(), | |
]) | |
) | |
), | |
val=dict( | |
dataset=L(CocoCSE)(data_dir.joinpath("val"), transform=None, normalize_E=False), | |
loader=L(get_dataloader)( | |
shuffle=False, num_workers=6, drop_last=True, prefetch_factor=2, | |
batch_size="${train.batch_size}", | |
dataset="${..dataset}", | |
infinite=False, | |
gpu_transform=L(torch.nn.Sequential)(*[ | |
L(ToFloat)(), | |
L(CreateEmbedding)(), | |
L(Resize)(size="${data.imsize}"), | |
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True), | |
L(CreateCondition)(), | |
]) | |
) | |
), | |
# Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP. | |
train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "coco_cse_val"), include_two_fake=False), | |
evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "coco_cse_val_final"), include_two_fake=True) | |
) | |