Spaces:
Runtime error
Runtime error
File size: 3,069 Bytes
97a6728 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import os
from pathlib import Path
from tops.config import LazyCall as L
import torch
import functools
from dp2.data.datasets.coco_cse import CocoCSE
from dp2.data.build import get_dataloader
from dp2.data.transforms.transforms import CreateEmbedding, Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
from dp2.data.transforms.stylegan2_transform import StyleGANAugmentPipe
from dp2.metrics.torch_metrics import compute_metrics_iteratively
from .utils import final_eval_fn
dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
data_dir = Path(dataset_base_dir, "coco_cse")
data = dict(
imsize=(288, 160),
im_channels=3,
semantic_nc=26,
cse_nc=16,
train=dict(
dataset=L(CocoCSE)(data_dir.joinpath("train"), transform=None, normalize_E=False),
loader=L(get_dataloader)(
shuffle=True, num_workers=6, drop_last=True, prefetch_factor=2,
batch_size="${train.batch_size}",
dataset="${..dataset}",
infinite=True,
gpu_transform=L(torch.nn.Sequential)(*[
L(ToFloat)(),
L(StyleGANAugmentPipe)(
rotate=0.5, rotate_max=.05,
xint=.5, xint_max=0.05,
scale=.5, scale_std=.05,
aniso=0.5, aniso_std=.05,
xfrac=.5, xfrac_std=.05,
brightness=.5, brightness_std=.05,
contrast=.5, contrast_std=.1,
hue=.5, hue_max=.05,
saturation=.5, saturation_std=.5,
imgfilter=.5, imgfilter_std=.1),
L(RandomHorizontalFlip)(p=0.5),
L(CreateEmbedding)(),
L(Resize)(size="${data.imsize}"),
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
L(CreateCondition)(),
])
)
),
val=dict(
dataset=L(CocoCSE)(data_dir.joinpath("val"), transform=None, normalize_E=False),
loader=L(get_dataloader)(
shuffle=False, num_workers=6, drop_last=True, prefetch_factor=2,
batch_size="${train.batch_size}",
dataset="${..dataset}",
infinite=False,
gpu_transform=L(torch.nn.Sequential)(*[
L(ToFloat)(),
L(CreateEmbedding)(),
L(Resize)(size="${data.imsize}"),
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
L(CreateCondition)(),
])
)
),
# Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "coco_cse_val"), include_two_fake=False),
evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "coco_cse_val_final"), include_two_fake=True)
)
|