|
import argparse |
|
import csv |
|
import os |
|
|
|
from torchvision.datasets import ImageNet |
|
|
|
|
|
def get_filelist(file_path): |
|
Filelist = [] |
|
for home, dirs, files in os.walk(file_path): |
|
for filename in files: |
|
Filelist.append(os.path.join(home, filename)) |
|
return Filelist |
|
|
|
|
|
def split_by_capital(name): |
|
|
|
new_name = "" |
|
for i in range(len(name)): |
|
if name[i].isupper() and i != 0: |
|
new_name += " " |
|
new_name += name[i] |
|
return new_name |
|
|
|
|
|
def process_imagenet(root, split): |
|
root = os.path.expanduser(root) |
|
data = ImageNet(root, split=split) |
|
samples = [(path, data.classes[label][0]) for path, label in data.samples] |
|
output = f"imagenet_{split}.csv" |
|
|
|
with open(output, "w") as f: |
|
writer = csv.writer(f) |
|
writer.writerows(samples) |
|
|
|
print(f"Saved {len(samples)} samples to {output}.") |
|
|
|
|
|
def process_ucf101(root, split): |
|
root = os.path.expanduser(root) |
|
video_lists = get_filelist(os.path.join(root, split)) |
|
classes = [x.split("/")[-2] for x in video_lists] |
|
classes = [split_by_capital(x) for x in classes] |
|
samples = list(zip(video_lists, classes)) |
|
output = f"ucf101_{split}.csv" |
|
|
|
with open(output, "w") as f: |
|
writer = csv.writer(f) |
|
writer.writerows(samples) |
|
|
|
print(f"Saved {len(samples)} samples to {output}.") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("dataset", type=str, choices=["imagenet", "ucf101"]) |
|
parser.add_argument("root", type=str) |
|
parser.add_argument("--split", type=str, default="train") |
|
args = parser.parse_args() |
|
|
|
if args.dataset == "imagenet": |
|
process_imagenet(args.root, args.split) |
|
elif args.dataset == "ucf101": |
|
process_ucf101(args.root, args.split) |
|
else: |
|
raise ValueError("Invalid dataset") |
|
|