Spaces:
Sleeping
Sleeping
import os | |
import albumentations | |
import numpy as np | |
from PIL import Image | |
from torch.utils.data import Dataset | |
class HRCWHU(Dataset): | |
METAINFO = dict( | |
classes=('clear sky', 'cloud'), | |
palette=((128, 192, 128), (255, 255, 255)), | |
img_size=(3, 256, 256), # C, H, W | |
ann_size=(256, 256), # C, H, W | |
train_size=120, | |
test_size=30, | |
) | |
def __init__(self, root, phase, all_transform: albumentations.Compose = None, | |
img_transform: albumentations.Compose = None, | |
ann_transform: albumentations.Compose = None, seed: int = 42): | |
self.root = root | |
self.phase = phase | |
self.all_transform = all_transform | |
self.img_transform = img_transform | |
self.ann_transform = ann_transform | |
self.seed = seed | |
self.data = self.load_data() | |
def load_data(self): | |
data_list = [] | |
split = 'train' if self.phase == 'train' else 'test' | |
split_file = os.path.join(self.root, f'{split}.txt') | |
with open(split_file, 'r') as f: | |
for line in f: | |
image_file = line.strip() | |
img_path = os.path.join(self.root, 'img_dir', split, image_file) | |
ann_path = os.path.join(self.root, 'ann_dir', split, image_file) | |
lac_type = image_file.split('_')[0] | |
data_list.append((img_path, ann_path, lac_type)) | |
return data_list | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
img_path, ann_path, lac_type = self.data[idx] | |
img = Image.open(img_path) | |
ann = Image.open(ann_path) | |
img = np.array(img) | |
ann = np.array(ann) | |
if self.all_transform: | |
albumention = self.all_transform(image=img, mask=ann) | |
img = albumention['image'] | |
ann = albumention['mask'] | |
if self.img_transform: | |
img = self.img_transform(image=img)['image'] | |
if self.ann_transform: | |
ann = self.ann_transform(image=img)['image'] | |
# if self.img_transform is not None: | |
# img = self.img_transform(img) | |
# if self.ann_transform is not None: | |
# ann = self.ann_transform(ann) | |
# if self.all_transform is not None: | |
# # 对img和ann实现相同的随机变换操作 | |
# # seed_everything(self.seed, workers=True) | |
# # random.seed(self.seed) | |
# # img= self.all_transform(img) | |
# # seed_everything(self.seed, workers=True) | |
# # random.seed(self.seed) | |
# # ann= self.all_transform(ann) | |
# merge = torch.cat((img, ann), dim=0) | |
# merge = self.all_transform(merge) | |
# img = merge[:-1] | |
# ann = merge[-1] | |
return { | |
'img': img, | |
'ann': np.int64(ann), | |
'img_path': img_path, | |
'ann_path': ann_path, | |
'lac_type': lac_type, | |
} | |
if __name__ == '__main__': | |
import torchvision.transforms as transforms | |
import torch | |
# all_transform = transforms.Compose([ | |
# transforms.RandomCrop((256, 256)), | |
# ]) | |
all_transform = transforms.RandomCrop((256, 256)) | |
# img_transform = transforms.Compose([ | |
# transforms.ToTensor(), | |
# ]) | |
img_transform = transforms.ToTensor() | |
# ann_transform = transforms.Compose([ | |
# transforms.PILToTensor(), | |
# ]) | |
ann_transform = transforms.PILToTensor() | |
train_dataset = HRCWHU(root='data/hrcwhu', phase='train', all_transform=all_transform, img_transform=img_transform, | |
ann_transform=ann_transform) | |
test_dataset = HRCWHU(root='data/hrcwhu', phase='test', all_transform=all_transform, img_transform=img_transform, | |
ann_transform=ann_transform) | |
assert len(train_dataset) == train_dataset.METAINFO['train_size'] | |
assert len(test_dataset) == test_dataset.METAINFO['test_size'] | |
train_sample = train_dataset[0] | |
test_sample = test_dataset[0] | |
assert train_sample['img'].shape == test_sample['img'].shape == train_dataset.METAINFO['img_size'] | |
assert train_sample['ann'].shape == test_sample['ann'].shape == train_dataset.METAINFO['ann_size'] | |
import matplotlib.pyplot as plt | |
fig, axs = plt.subplots(1, 2, figsize=(10, 5)) | |
for train_sample in train_dataset: | |
axs[0].imshow(train_sample['img'].permute(1, 2, 0)) | |
axs[0].set_title('Image') | |
axs[1].imshow(torch.tensor(train_dataset.METAINFO['palette'])[train_sample['ann']]) | |
axs[1].set_title('Annotation') | |
plt.suptitle(f'Land Cover Type: {train_sample["lac_type"].capitalize()}', y=0.8) | |
plt.tight_layout() | |
plt.savefig('HRCWHU_sample.png', bbox_inches="tight") | |
# break | |