File size: 3,644 Bytes
f01c86d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import json
import os
import shutil
from collections import defaultdict
import random
from tqdm import tqdm

def create_balanced_dataset(source_json, source_img_dir, target_dir, min_samples=50):
    """
    Create a balanced dataset for parts detection by sampling images with different parts.
    
    Args:
        source_json (str): Path to source COCO JSON file
        source_img_dir (str): Path to source images directory
        target_dir (str): Path to target directory for balanced dataset
        min_samples (int): Minimum number of samples per class
    """
    # Create target directories
    os.makedirs(os.path.join(target_dir, 'images'), exist_ok=True)
    os.makedirs(os.path.join(target_dir, 'labels'), exist_ok=True)
    
    # Load COCO annotations
    with open(source_json, 'r') as f:
        coco = json.load(f)
    
    # Group images by parts they contain
    images_by_part = defaultdict(set)
    image_to_anns = defaultdict(list)
    
    for ann in coco['annotations']:
        img_id = ann['image_id']
        cat_id = ann['category_id']
        images_by_part[cat_id].add(img_id)
        image_to_anns[img_id].append(ann)
    
    # Find images with balanced representation
    selected_images = set()
    for part_images in images_by_part.values():
        # Sample min_samples images for each part
        sample_size = min(min_samples, len(part_images))
        selected_images.update(random.sample(list(part_images), sample_size))
    
    # Copy selected images and create labels
    id_to_filename = {img['id']: img['file_name'] for img in coco['images']}
    
    print(f"Creating balanced dataset with {len(selected_images)} images...")
    for img_id in tqdm(selected_images):
        # Copy image
        src_img = os.path.join(source_img_dir, id_to_filename[img_id])
        dst_img = os.path.join(target_dir, 'images', id_to_filename[img_id])
        shutil.copy2(src_img, dst_img)
        
        # Create YOLO label
        base_name = os.path.splitext(id_to_filename[img_id])[0]
        label_file = os.path.join(target_dir, 'labels', f"{base_name}.txt")
        
        # Convert annotations to YOLO format
        anns = image_to_anns[img_id]
        label_lines = []
        
        # Get image dimensions
        from PIL import Image
        im = Image.open(src_img)
        w, h = im.size
        
        for ann in anns:
            cat_id = ann['category_id']
            # Convert segmentation to YOLO format
            for seg in ann['segmentation']:
                seg_norm = [str(x/w) if i%2==0 else str(x/h) for i,x in enumerate(seg)]
                label_lines.append(f"{cat_id} {' '.join(seg_norm)}")
        
        # Write label file
        with open(label_file, 'w') as f:
            f.write('\n'.join(label_lines))

if __name__ == "__main__":
    current_dir = os.path.dirname(os.path.abspath(__file__))
    base_dir = os.path.dirname(current_dir)
    
    # Process training set
    create_balanced_dataset(
        source_json=os.path.join(base_dir, "damage_detection_dataset", "train", "COCO_mul_train_annos.json"),
        source_img_dir=os.path.join(base_dir, "damage_detection_dataset", "img"),
        target_dir=os.path.join(base_dir, "data", "parts", "balanced", "train"),
        min_samples=50
    )
    
    # Process validation set
    create_balanced_dataset(
        source_json=os.path.join(base_dir, "damage_detection_dataset", "val", "COCO_mul_val_annos.json"),
        source_img_dir=os.path.join(base_dir, "damage_detection_dataset", "img"),
        target_dir=os.path.join(base_dir, "data", "parts", "balanced", "val"),
        min_samples=10
    )