File size: 7,767 Bytes
c614b0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import cv2
from PIL import Image
from torch.utils import data
from torchvision import transforms
from tqdm import tqdm
from .config import Config
from .image_proc import preproc
from .utils import path_to_image
Image.MAX_IMAGE_PIXELS = None # remove DecompressionBombWarning
config = Config()
_class_labels_TR_sorted = (
"Airplane, Ant, Antenna, Archery, Axe, BabyCarriage, Bag, BalanceBeam, Balcony, Balloon, Basket, BasketballHoop, Beatle, Bed, Bee, Bench, Bicycle, "
"BicycleFrame, BicycleStand, Boat, Bonsai, BoomLift, Bridge, BunkBed, Butterfly, Button, Cable, CableLift, Cage, Camcorder, Cannon, Canoe, Car, "
"CarParkDropArm, Carriage, Cart, Caterpillar, CeilingLamp, Centipede, Chair, Clip, Clock, Clothes, CoatHanger, Comb, ConcretePumpTruck, Crack, Crane, "
"Cup, DentalChair, Desk, DeskChair, Diagram, DishRack, DoorHandle, Dragonfish, Dragonfly, Drum, Earphone, Easel, ElectricIron, Excavator, Eyeglasses, "
"Fan, Fence, Fencing, FerrisWheel, FireExtinguisher, Fishing, Flag, FloorLamp, Forklift, GasStation, Gate, Gear, Goal, Golf, GymEquipment, Hammock, "
"Handcart, Handcraft, Handrail, HangGlider, Harp, Harvester, Headset, Helicopter, Helmet, Hook, HorizontalBar, Hydrovalve, IroningTable, Jewelry, Key, "
"KidsPlayground, Kitchenware, Kite, Knife, Ladder, LaundryRack, Lightning, Lobster, Locust, Machine, MachineGun, MagazineRack, Mantis, Medal, MemorialArchway, "
"Microphone, Missile, MobileHolder, Monitor, Mosquito, Motorcycle, MovingTrolley, Mower, MusicPlayer, MusicStand, ObservationTower, Octopus, OilWell, "
"OlympicLogo, OperatingTable, OutdoorFitnessEquipment, Parachute, Pavilion, Piano, Pipe, PlowHarrow, PoleVault, Punchbag, Rack, Racket, Rifle, Ring, Robot, "
"RockClimbing, Rope, Sailboat, Satellite, Scaffold, Scale, Scissor, Scooter, Sculpture, Seadragon, Seahorse, Seal, SewingMachine, Ship, Shoe, ShoppingCart, "
"ShoppingTrolley, Shower, Shrimp, Signboard, Skateboarding, Skeleton, Skiing, Spade, SpeedBoat, Spider, Spoon, Stair, Stand, Stationary, SteeringWheel, "
"Stethoscope, Stool, Stove, StreetLamp, SweetStand, Swing, Sword, TV, Table, TableChair, TableLamp, TableTennis, Tank, Tapeline, Teapot, Telescope, Tent, "
"TobaccoPipe, Toy, Tractor, TrafficLight, TrafficSign, Trampoline, TransmissionTower, Tree, Tricycle, TrimmerCover, Tripod, Trombone, Truck, Trumpet, Tuba, "
"UAV, Umbrella, UnevenBars, UtilityPole, VacuumCleaner, Violin, Wakesurfing, Watch, WaterTower, WateringPot, Well, WellLid, Wheel, Wheelchair, WindTurbine, Windmill, WineGlass, WireWhisk, Yacht"
)
class_labels_TR_sorted = _class_labels_TR_sorted.split(", ")
class MyData(data.Dataset):
def __init__(self, datasets, image_size, is_train=True):
self.size_train = image_size
self.size_test = image_size
self.keep_size = not config.size
self.data_size = config.size
self.is_train = is_train
self.load_all = config.load_all
self.device = config.device
valid_extensions = [".png", ".jpg", ".PNG", ".JPG", ".JPEG"]
if self.is_train and config.auxiliary_classification:
self.cls_name2id = {
_name: _id for _id, _name in enumerate(class_labels_TR_sorted)
}
self.transform_image = transforms.Compose(
[
transforms.Resize(self.data_size[::-1]),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
][self.load_all or self.keep_size :]
)
self.transform_label = transforms.Compose(
[
transforms.Resize(self.data_size[::-1]),
transforms.ToTensor(),
][self.load_all or self.keep_size :]
)
dataset_root = os.path.join(config.data_root_dir, config.task)
# datasets can be a list of different datasets for training on combined sets.
self.image_paths = []
for dataset in datasets.split("+"):
image_root = os.path.join(dataset_root, dataset, "im")
self.image_paths += [
os.path.join(image_root, p)
for p in os.listdir(image_root)
if any(p.endswith(ext) for ext in valid_extensions)
]
self.label_paths = []
for p in self.image_paths:
for ext in valid_extensions:
## 'im' and 'gt' may need modifying
p_gt = p.replace("/im/", "/gt/")[: -(len(p.split(".")[-1]) + 1)] + ext
file_exists = False
if os.path.exists(p_gt):
self.label_paths.append(p_gt)
file_exists = True
break
if not file_exists:
print("Not exists:", p_gt)
if len(self.label_paths) != len(self.image_paths):
set_image_paths = set(
[os.path.splitext(p.split(os.sep)[-1])[0] for p in self.image_paths]
)
set_label_paths = set(
[os.path.splitext(p.split(os.sep)[-1])[0] for p in self.label_paths]
)
print("Path diff:", set_image_paths - set_label_paths)
raise ValueError(
f"There are different numbers of images ({len(self.label_paths)}) and labels ({len(self.image_paths)})"
)
if self.load_all:
self.images_loaded, self.labels_loaded = [], []
self.class_labels_loaded = []
# for image_path, label_path in zip(self.image_paths, self.label_paths):
for image_path, label_path in tqdm(
zip(self.image_paths, self.label_paths), total=len(self.image_paths)
):
_image = path_to_image(image_path, size=config.size, color_type="rgb")
_label = path_to_image(label_path, size=config.size, color_type="gray")
self.images_loaded.append(_image)
self.labels_loaded.append(_label)
self.class_labels_loaded.append(
self.cls_name2id[label_path.split("/")[-1].split("#")[3]]
if self.is_train and config.auxiliary_classification
else -1
)
def __getitem__(self, index):
if self.load_all:
image = self.images_loaded[index]
label = self.labels_loaded[index]
class_label = (
self.class_labels_loaded[index]
if self.is_train and config.auxiliary_classification
else -1
)
else:
image = path_to_image(
self.image_paths[index], size=config.size, color_type="rgb"
)
label = path_to_image(
self.label_paths[index], size=config.size, color_type="gray"
)
class_label = (
self.cls_name2id[self.label_paths[index].split("/")[-1].split("#")[3]]
if self.is_train and config.auxiliary_classification
else -1
)
# loading image and label
if self.is_train:
image, label = preproc(image, label, preproc_methods=config.preproc_methods)
# else:
# if _label.shape[0] > 2048 or _label.shape[1] > 2048:
# _image = cv2.resize(_image, (2048, 2048), interpolation=cv2.INTER_LINEAR)
# _label = cv2.resize(_label, (2048, 2048), interpolation=cv2.INTER_LINEAR)
image, label = self.transform_image(image), self.transform_label(label)
if self.is_train:
return image, label, class_label
else:
return image, label, self.label_paths[index]
def __len__(self):
return len(self.image_paths)
|