chirmy commited on
Commit
863fd3a
·
verified ·
1 Parent(s): 39fc4a0

Upload 07_data_augmentation.py

Browse files
Files changed (1) hide show
  1. 07_data_augmentation.py +171 -0
07_data_augmentation.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ # import pandas as pd
3
+ # from PIL import Image, ImageOps
4
+ # import numpy as np
5
+ # from tqdm import tqdm
6
+ # from multiprocessing import Pool, cpu_count
7
+
8
+ # # 读取CSV文件
9
+ # csv_path = '/data/cjm/FungiCLEF2024/Dataset/06_new_train_valmetadata.csv'
10
+ # data = pd.read_csv(csv_path)
11
+
12
+ # # 设置根目录
13
+ # root_dir = '/data/cjm/FungiCLEF2024/Dataset/DF20_21_300'
14
+
15
+ # # 过滤poisonous为1的数据
16
+ # poisonous_data = data[data['poisonous'] == 1]
17
+
18
+ # # 创建保存增强数据的DataFrame,并包含原始数据
19
+ # new_data = data.copy()
20
+
21
+ # # 定义数据增强函数
22
+ # def augment_image(args):
23
+ # row, root_dir = args
24
+ # image_path = row['image_path']
25
+ # full_path = os.path.join(root_dir, image_path)
26
+ # augmented_rows = []
27
+
28
+ # if os.path.exists(full_path):
29
+ # image = Image.open(full_path)
30
+ # w, h = image.size
31
+
32
+ # # 定义旋转和翻转操作
33
+ # transformations = {
34
+ # 'r90': image.rotate(90, expand=True),
35
+ # 'r180': image.rotate(180, expand=True),
36
+ # 'r270': image.rotate(270, expand=True),
37
+ # 'fh': ImageOps.mirror(image),
38
+ # 'fv': ImageOps.flip(image),
39
+ # }
40
+
41
+ # for suffix, img in transformations.items():
42
+ # # 裁剪图片以去除旋转后的黑边
43
+ # if suffix in ['r90', 'r270']:
44
+ # img = img.crop((0, 0, h, w))
45
+
46
+ # new_image_path = os.path.splitext(image_path)[0] + f'_{suffix}.JPG'
47
+ # new_full_path = os.path.join(root_dir, new_image_path)
48
+ # img.save(new_full_path)
49
+
50
+ # new_row = row.copy()
51
+ # new_row['image_path'] = new_image_path
52
+ # augmented_rows.append(new_row)
53
+
54
+ # return augmented_rows
55
+
56
+ # # 准备多进程处理
57
+ # num_processes = cpu_count()
58
+ # pool = Pool(processes=num_processes)
59
+
60
+ # # 使用tqdm显示进度
61
+ # augmented_data = []
62
+ # for augmented_rows in tqdm(pool.imap_unordered(augment_image, [(row, root_dir) for _, row in poisonous_data.iterrows()]), total=len(poisonous_data)):
63
+ # augmented_data.extend(augmented_rows)
64
+
65
+ # # 关闭进程池
66
+ # pool.close()
67
+ # pool.join()
68
+
69
+ # # 将增强后的数据添加到new_data中
70
+ # new_data = new_data.append(augmented_data, ignore_index=True)
71
+
72
+ # # 将数据保存到新的CSV文件中
73
+ # new_csv_path = '/data/cjm/FungiCLEF2024/Dataset/07_new_train_valmetadata.csv'
74
+ # new_data.to_csv(new_csv_path, index=False)
75
+
76
+
77
+ import os
78
+ import pandas as pd
79
+ from PIL import Image, ImageOps
80
+ import numpy as np
81
+ from tqdm import tqdm
82
+ from multiprocessing import Pool, cpu_count
83
+ import random
84
+
85
+ # 读取CSV文件
86
+ csv_path = '/data/cjm/FungiCLEF2024/Dataset/06_new_train_valmetadata.csv'
87
+ data = pd.read_csv(csv_path)
88
+
89
+ # 设置根目录
90
+ root_dir = '/data/cjm/FungiCLEF2024/Dataset/DF20_21_300'
91
+
92
+ # 过滤poisonous为1的数据
93
+ poisonous_data = data[data['poisonous'] == 1]
94
+
95
+ # 创建保存增强数据的DataFrame,并包含原始数据
96
+ new_data = data.copy()
97
+
98
+ # 定义数据增强函数
99
+ def augment_image(args):
100
+ row, root_dir = args
101
+ image_path = row['image_path']
102
+ full_path = os.path.join(root_dir, image_path)
103
+ augmented_rows = []
104
+
105
+ if os.path.exists(full_path):
106
+ image = Image.open(full_path)
107
+ w, h = image.size
108
+
109
+ # 定义旋转和翻转操作
110
+ transformations = {
111
+ 'r90': image.rotate(90, expand=True),
112
+ 'r180': image.rotate(180, expand=True),
113
+ 'r270': image.rotate(270, expand=True),
114
+ 'fh': ImageOps.mirror(image),
115
+ 'fv': ImageOps.flip(image),
116
+ }
117
+
118
+ # 添加随机裁剪操作
119
+ for i in range(4):
120
+ rand = random.uniform(0.7, 0.8)
121
+ new_w = int(w * rand)
122
+ new_h = int(h * rand)
123
+ left = random.randint(0, w - new_w)
124
+ top = random.randint(0, h - new_h)
125
+ right = left + new_w
126
+ bottom = top + new_h
127
+ cropped_image = image.crop((left, top, right, bottom))
128
+ # cropped_image = cropped_image.resize((w, h)) # 调整回原始尺寸
129
+
130
+ new_image_path = os.path.splitext(image_path)[0] + f'_crop{rand}.JPG'
131
+ new_full_path = os.path.join(root_dir, new_image_path)
132
+ cropped_image.save(new_full_path)
133
+
134
+ new_row = row.copy()
135
+ new_row['image_path'] = new_image_path
136
+ augmented_rows.append(new_row)
137
+
138
+ for suffix, img in transformations.items():
139
+ # 裁剪图片以去除旋转后的黑边
140
+ if suffix in ['r90', 'r270']:
141
+ img = img.crop((0, 0, h, w))
142
+
143
+ new_image_path = os.path.splitext(image_path)[0] + f'_{suffix}.JPG'
144
+ new_full_path = os.path.join(root_dir, new_image_path)
145
+ img.save(new_full_path)
146
+
147
+ new_row = row.copy()
148
+ new_row['image_path'] = new_image_path
149
+ augmented_rows.append(new_row)
150
+
151
+ return augmented_rows
152
+
153
+ # 准备多进程处理
154
+ num_processes = cpu_count()
155
+ pool = Pool(processes=num_processes)
156
+
157
+ # 使用tqdm显示进度
158
+ augmented_data = []
159
+ for augmented_rows in tqdm(pool.imap_unordered(augment_image, [(row, root_dir) for _, row in poisonous_data.iterrows()]), total=len(poisonous_data)):
160
+ augmented_data.extend(augmented_rows)
161
+
162
+ # 关闭进程池
163
+ pool.close()
164
+ pool.join()
165
+
166
+ # 将增强后的数据添加到new_data中
167
+ new_data = new_data.append(augmented_data, ignore_index=True)
168
+
169
+ # 将数据保存到新的CSV文件中
170
+ new_csv_path = '/data/cjm/FungiCLEF2024/Dataset/07_new_train_valmetadata.csv'
171
+ new_data.to_csv(new_csv_path, index=False)