import json import os import shutil from collections import defaultdict import random from tqdm import tqdm def create_balanced_dataset(source_json, source_img_dir, target_dir, min_samples=50): """ Create a balanced dataset for parts detection by sampling images with different parts. Args: source_json (str): Path to source COCO JSON file source_img_dir (str): Path to source images directory target_dir (str): Path to target directory for balanced dataset min_samples (int): Minimum number of samples per class """ # Create target directories os.makedirs(os.path.join(target_dir, 'images'), exist_ok=True) os.makedirs(os.path.join(target_dir, 'labels'), exist_ok=True) # Load COCO annotations with open(source_json, 'r') as f: coco = json.load(f) # Group images by parts they contain images_by_part = defaultdict(set) image_to_anns = defaultdict(list) for ann in coco['annotations']: img_id = ann['image_id'] cat_id = ann['category_id'] images_by_part[cat_id].add(img_id) image_to_anns[img_id].append(ann) # Find images with balanced representation selected_images = set() for part_images in images_by_part.values(): # Sample min_samples images for each part sample_size = min(min_samples, len(part_images)) selected_images.update(random.sample(list(part_images), sample_size)) # Copy selected images and create labels id_to_filename = {img['id']: img['file_name'] for img in coco['images']} print(f"Creating balanced dataset with {len(selected_images)} images...") for img_id in tqdm(selected_images): # Copy image src_img = os.path.join(source_img_dir, id_to_filename[img_id]) dst_img = os.path.join(target_dir, 'images', id_to_filename[img_id]) shutil.copy2(src_img, dst_img) # Create YOLO label base_name = os.path.splitext(id_to_filename[img_id])[0] label_file = os.path.join(target_dir, 'labels', f"{base_name}.txt") # Convert annotations to YOLO format anns = image_to_anns[img_id] label_lines = [] # Get image dimensions from PIL import Image im = Image.open(src_img) w, h = im.size for ann in anns: cat_id = ann['category_id'] # Convert segmentation to YOLO format for seg in ann['segmentation']: seg_norm = [str(x/w) if i%2==0 else str(x/h) for i,x in enumerate(seg)] label_lines.append(f"{cat_id} {' '.join(seg_norm)}") # Write label file with open(label_file, 'w') as f: f.write('\n'.join(label_lines)) if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) base_dir = os.path.dirname(current_dir) # Process training set create_balanced_dataset( source_json=os.path.join(base_dir, "damage_detection_dataset", "train", "COCO_mul_train_annos.json"), source_img_dir=os.path.join(base_dir, "damage_detection_dataset", "img"), target_dir=os.path.join(base_dir, "data", "parts", "balanced", "train"), min_samples=50 ) # Process validation set create_balanced_dataset( source_json=os.path.join(base_dir, "damage_detection_dataset", "val", "COCO_mul_val_annos.json"), source_img_dir=os.path.join(base_dir, "damage_detection_dataset", "img"), target_dir=os.path.join(base_dir, "data", "parts", "balanced", "val"), min_samples=10 )