|
import json |
|
import os |
|
import shutil |
|
from collections import defaultdict |
|
import random |
|
from tqdm import tqdm |
|
|
|
def create_balanced_dataset(source_json, source_img_dir, target_dir, min_samples=50): |
|
""" |
|
Create a balanced dataset for parts detection by sampling images with different parts. |
|
|
|
Args: |
|
source_json (str): Path to source COCO JSON file |
|
source_img_dir (str): Path to source images directory |
|
target_dir (str): Path to target directory for balanced dataset |
|
min_samples (int): Minimum number of samples per class |
|
""" |
|
|
|
os.makedirs(os.path.join(target_dir, 'images'), exist_ok=True) |
|
os.makedirs(os.path.join(target_dir, 'labels'), exist_ok=True) |
|
|
|
|
|
with open(source_json, 'r') as f: |
|
coco = json.load(f) |
|
|
|
|
|
images_by_part = defaultdict(set) |
|
image_to_anns = defaultdict(list) |
|
|
|
for ann in coco['annotations']: |
|
img_id = ann['image_id'] |
|
cat_id = ann['category_id'] |
|
images_by_part[cat_id].add(img_id) |
|
image_to_anns[img_id].append(ann) |
|
|
|
|
|
selected_images = set() |
|
for part_images in images_by_part.values(): |
|
|
|
sample_size = min(min_samples, len(part_images)) |
|
selected_images.update(random.sample(list(part_images), sample_size)) |
|
|
|
|
|
id_to_filename = {img['id']: img['file_name'] for img in coco['images']} |
|
|
|
print(f"Creating balanced dataset with {len(selected_images)} images...") |
|
for img_id in tqdm(selected_images): |
|
|
|
src_img = os.path.join(source_img_dir, id_to_filename[img_id]) |
|
dst_img = os.path.join(target_dir, 'images', id_to_filename[img_id]) |
|
shutil.copy2(src_img, dst_img) |
|
|
|
|
|
base_name = os.path.splitext(id_to_filename[img_id])[0] |
|
label_file = os.path.join(target_dir, 'labels', f"{base_name}.txt") |
|
|
|
|
|
anns = image_to_anns[img_id] |
|
label_lines = [] |
|
|
|
|
|
from PIL import Image |
|
im = Image.open(src_img) |
|
w, h = im.size |
|
|
|
for ann in anns: |
|
cat_id = ann['category_id'] |
|
|
|
for seg in ann['segmentation']: |
|
seg_norm = [str(x/w) if i%2==0 else str(x/h) for i,x in enumerate(seg)] |
|
label_lines.append(f"{cat_id} {' '.join(seg_norm)}") |
|
|
|
|
|
with open(label_file, 'w') as f: |
|
f.write('\n'.join(label_lines)) |
|
|
|
if __name__ == "__main__": |
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
base_dir = os.path.dirname(current_dir) |
|
|
|
|
|
create_balanced_dataset( |
|
source_json=os.path.join(base_dir, "damage_detection_dataset", "train", "COCO_mul_train_annos.json"), |
|
source_img_dir=os.path.join(base_dir, "damage_detection_dataset", "img"), |
|
target_dir=os.path.join(base_dir, "data", "parts", "balanced", "train"), |
|
min_samples=50 |
|
) |
|
|
|
|
|
create_balanced_dataset( |
|
source_json=os.path.join(base_dir, "damage_detection_dataset", "val", "COCO_mul_val_annos.json"), |
|
source_img_dir=os.path.join(base_dir, "damage_detection_dataset", "img"), |
|
target_dir=os.path.join(base_dir, "data", "parts", "balanced", "val"), |
|
min_samples=10 |
|
) |