Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
from copy import deepcopy | |
from pathlib import Path | |
from typing import Any, Dict, List | |
# from logger import logger | |
import numpy as np | |
# import torch | |
# import torch.utils.data as torchdata | |
# import torchvision.transforms as tvf | |
from omegaconf import DictConfig, OmegaConf | |
import pytorch_lightning as pl | |
from dataset.UAV.dataset import UavMapPair | |
# from torch.utils.data import Dataset, DataLoader | |
# from torchvision import transforms | |
from torch.utils.data import Dataset, ConcatDataset | |
from torch.utils.data import Dataset, DataLoader, random_split | |
import torchvision.transforms as tvf | |
# 自定义数据模块类,继承自pl.LightningDataModule | |
class UavMapDatasetModule(pl.LightningDataModule): | |
def __init__(self, cfg: Dict[str, Any]): | |
super().__init__() | |
# default_cfg = OmegaConf.create(self.default_cfg) | |
# OmegaConf.set_struct(default_cfg, True) # cannot add new keys | |
# self.cfg = OmegaConf.merge(default_cfg, cfg) | |
self.cfg=cfg | |
# self.transform = tvf.Compose([ | |
# tvf.ToTensor(), | |
# tvf.Resize(self.cfg.image_size), | |
# tvf.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
# ]) | |
tfs = [] | |
tfs.append(tvf.ToTensor()) | |
tfs.append(tvf.Resize(self.cfg.image_size)) | |
self.val_tfs = tvf.Compose(tfs) | |
# transforms.Resize(self.cfg.image_size), | |
if cfg.augmentation.image.apply: | |
args = OmegaConf.masked_copy( | |
cfg.augmentation.image, ["brightness", "contrast", "saturation", "hue"] | |
) | |
tfs.append(tvf.ColorJitter(**args)) | |
self.train_tfs = tvf.Compose(tfs) | |
# self.train_tfs=self.transform | |
# self.val_tfs = self.transform | |
self.init() | |
def init(self): | |
self.train_dataset = ConcatDataset([ | |
UavMapPair(root=Path(self.cfg.root),city=city,training=True,transform=self.train_tfs) | |
for city in self.cfg.train_citys | |
]) | |
self.val_dataset = ConcatDataset([ | |
UavMapPair(root=Path(self.cfg.root),city=city,training=False,transform=self.val_tfs) | |
for city in self.cfg.val_citys | |
]) | |
# self.val_datasets = { | |
# city:UavMapPair(root=Path(self.cfg.root),city=city,transform=self.val_tfs) | |
# for city in self.cfg.val_citys | |
# } | |
# logger.info("train data len:{},val data len:{}".format(len(self.train_dataset),len(self.val_dataset))) | |
# # 定义分割比例 | |
# train_ratio = 0.8 # 训练集比例 | |
# # 计算分割的样本数量 | |
# train_size = int(len(self.dataset) * train_ratio) | |
# val_size = len(self.dataset) - train_size | |
# self.train_dataset, self.val_dataset = random_split(self.dataset, [train_size, val_size]) | |
def train_dataloader(self): | |
train_loader = DataLoader(self.train_dataset, | |
batch_size=self.cfg.train.batch_size, | |
num_workers=self.cfg.train.num_workers, | |
shuffle=True,pin_memory = True) | |
return train_loader | |
def val_dataloader(self): | |
val_loader = DataLoader(self.val_dataset, | |
batch_size=self.cfg.val.batch_size, | |
num_workers=self.cfg.val.num_workers, | |
shuffle=True,pin_memory = True) | |
# | |
# my_dict = {k: v for k, v in self.val_datasets} | |
# val_loaders={city: DataLoader(dataset, | |
# batch_size=self.cfg.val.batch_size, | |
# num_workers=self.cfg.val.num_workers, | |
# shuffle=False,pin_memory = True) for city, dataset in self.val_datasets.items()} | |
return val_loader | |