Spaces:
Running
Running
import glob | |
import json | |
import webdataset as wds | |
def split_dataset(path, n_train, n_val, n_test, label, domain_label): | |
max_file_size = 1000 | |
input_files = glob.glob(path + "/*.tar") | |
src = wds.WebDataset(input_files) | |
train_path_prefix = path + "/train" | |
val_path_prefix = path + "/val" | |
test_path_prefix = path + "/test" | |
def write_split(dataset, prefix, start, end): | |
n_split = end - start | |
output_files = [ | |
f"{prefix}_{i}.tar" for i in range(n_split // max_file_size + 1) | |
] | |
for i, output_file in enumerate(output_files): | |
print(f"Writing {output_file}") | |
with wds.TarWriter(output_file) as dst: | |
for sample in dataset.slice( | |
start + i * max_file_size, | |
min(start + (i + 1) * max_file_size, end), | |
): | |
new_sample = { | |
"__key__": sample["__key__"], | |
"jpg": sample["jpg"], | |
"label.cls": label, | |
"domain_label.cls": domain_label, | |
} | |
dst.write(new_sample) | |
write_split(src, train_path_prefix, 0, n_train) | |
write_split(src, val_path_prefix, n_train, n_train + n_val) | |
write_split( | |
src, | |
test_path_prefix, | |
n_train + n_val, | |
n_train + n_val + n_test, | |
) | |
def calculate_sizes(path): | |
stat_files = glob.glob(path + "/*_stats.json") | |
total = 0 | |
for f in stat_files: | |
with open(f) as stats: | |
total += json.load(stats)["successes"] | |
n_train = int(total * 0.8) | |
n_val = int(total * 0.1) | |
n_test = total - n_train - n_val | |
return n_train, n_val, n_test | |
if __name__ == "__main__": | |
paths = [ | |
"./data/laion400m_data", | |
"./data/genai-images/StableDiffusion", | |
"./data/genai-images/midjourney", | |
"./data/genai-images/dalle2", | |
"./data/genai-images/dalle3", | |
] | |
sizes = [] | |
for p in paths: | |
res = calculate_sizes(p) | |
sizes.append(res) | |
domain_labels = [0, 1, 4, 2, 3] | |
for i, p in enumerate(paths): | |
print(f"{p}: {sizes[i]}") | |
label = 0 if i == 0 else 1 | |
print(label, domain_labels[i]) | |
split_dataset(p, *calculate_sizes(p), label, domain_labels[i]) | |