|
import os |
|
from glob import glob |
|
from typing import List, Literal |
|
import shutil |
|
from PIL import Image |
|
import json |
|
import numpy as np |
|
from rich.progress import track |
|
import cv2 |
|
from vegseg.datasets import GrassDataset |
|
from sklearn.model_selection import train_test_split |
|
import argparse |
|
|
|
|
|
def give_color_to_mask(mask: np.ndarray, palette: List[int]) -> Image.Image: |
|
""" |
|
Convert mask to color image |
|
Args: |
|
mask (np.ndarray): numpy array of shape (H, W) |
|
palette (List[int]): list of RGB values |
|
return: |
|
color_mask (Image.Image): PIL Image of shape (H, W) |
|
""" |
|
im = Image.fromarray(mask).convert("P") |
|
im.putpalette(palette) |
|
|
|
return im |
|
|
|
|
|
def get_mask_by_json(filename: str) -> np.ndarray: |
|
""" |
|
Convert json to mask |
|
Args: |
|
filename (str): path to json file |
|
return: |
|
mask (np.ndarray): numpy array of shape (H, W) |
|
""" |
|
|
|
json_file = json.load(open(filename)) |
|
img_height = json_file["imageHeight"] |
|
img_width = json_file["imageWidth"] |
|
mask = np.zeros((img_height, img_width), dtype="int8") |
|
for shape in json_file["shapes"]: |
|
label = int(shape["label"]) |
|
label -= 1 |
|
label = max(label, 0) |
|
points = np.array(shape["points"]).astype(np.int32) |
|
cv2.fillPoly(mask, [points], label) |
|
return mask |
|
|
|
|
|
def json_to_image(json_path, image_path): |
|
""" |
|
Convert json to image |
|
Args: |
|
json_path (str): path to json file |
|
image_path (str): path to save image |
|
return: None |
|
""" |
|
mask = get_mask_by_json(json_path) |
|
palette_list = GrassDataset.METAINFO["palette"] |
|
palette = [] |
|
for palette_item in palette_list: |
|
palette.extend(palette_item) |
|
color_mask = give_color_to_mask(mask, palette) |
|
color_mask.save(image_path) |
|
|
|
|
|
def create_dataset( |
|
image_paths: List[str], |
|
ann_paths: List[str], |
|
phase: Literal["train", "val"], |
|
output_dir: str, |
|
): |
|
""" |
|
Args: |
|
image_paths (List[str]): list of image paths |
|
ann_paths (List[str]): list of annotation paths |
|
phase (Literal["train", "val"]): train or val |
|
output_dir (str): path to save dataset |
|
Return: |
|
None |
|
""" |
|
for image_path, ann_path in track( |
|
zip(image_paths, ann_paths), |
|
description=f"{phase} dataset", |
|
total=len(image_paths), |
|
): |
|
ann_save_path = os.path.join( |
|
output_dir, |
|
"ann_dir", |
|
phase, |
|
os.path.basename(ann_path).replace(".json", ".png"), |
|
) |
|
|
|
|
|
new_image_path = os.path.join( |
|
output_dir, "img_dir", phase, os.path.basename(image_path) |
|
) |
|
shutil.copy(image_path, new_image_path) |
|
|
|
|
|
json_to_image(ann_path, ann_save_path) |
|
|
|
|
|
def split_dataset( |
|
root_path: str, |
|
output_path: str, |
|
split_ratio: float = 0.8, |
|
shuffle: bool = True, |
|
seed: int = 42, |
|
) -> None: |
|
""" |
|
Split a dataset into train, test, and validation sets. |
|
|
|
Args: |
|
root_path (str): Path to the dataset. The dataset should be organized as follows: |
|
dataset_path/ |
|
image1.tif |
|
image2.tif |
|
... |
|
imageN.tif |
|
label1.tif |
|
label2.tif |
|
... |
|
labelN.tif |
|
output_path (str): Path to the output directory where the split dataset will be saved. |
|
split_ratio (float, optional): Ratio of the dataset to be used for training. Defaults to 0.8. |
|
seed (int, optional): Seed for the random number generator. Defaults to 42. |
|
""" |
|
image_paths = glob(os.path.join(root_path, "*.tif")) |
|
ann_paths = [filename.replace("tif", "json") for filename in image_paths] |
|
assert len(image_paths) == len( |
|
ann_paths |
|
), "Number of images and annotations do not match" |
|
print(f"images: {len(image_paths)}, annotations: {len(ann_paths)}") |
|
|
|
image_train, image_test, ann_train, ann_test = train_test_split( |
|
image_paths, |
|
ann_paths, |
|
train_size=split_ratio, |
|
random_state=seed, |
|
shuffle=shuffle, |
|
) |
|
print(f"train: {len(image_train)}, test: {len(image_test)}") |
|
|
|
os.makedirs(os.path.join(output_path, "img_dir", "train"), exist_ok=True) |
|
os.makedirs(os.path.join(output_path, "img_dir", "val"), exist_ok=True) |
|
os.makedirs(os.path.join(output_path, "ann_dir", "train"), exist_ok=True) |
|
os.makedirs(os.path.join(output_path, "ann_dir", "val"), exist_ok=True) |
|
|
|
create_dataset(image_train, ann_train, "train", output_path) |
|
create_dataset(image_test, ann_test, "val", output_path) |
|
|
|
|
|
def main(): |
|
args = argparse.ArgumentParser() |
|
args.add_argument("--root", type=str, default="data/raw_data") |
|
args.add_argument("--output", type=str, default="data/grass") |
|
args.add_argument("--split_ratio", type=float, default=0.8) |
|
args.add_argument("--seed", type=int, default=42) |
|
args.add_argument("--shuffle", type=bool, default=True) |
|
args = args.parse_args() |
|
|
|
root: str = args.root |
|
output_path: str = args.output |
|
split_ratio: float = args.split_ratio |
|
seed: int = args.seed |
|
shuffle: bool = args.shuffle |
|
|
|
split_dataset( |
|
root_path=root, |
|
output_path=output_path, |
|
split_ratio=split_ratio, |
|
shuffle=shuffle, |
|
seed=seed, |
|
) |
|
|
|
print("数据集划分完成") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
main() |
|
|