Spaces:
Build error
Build error
import os | |
import csv | |
import numpy as np | |
img_folder = '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/' | |
def get_captions_by_split(): | |
video_len = 4 | |
descriptions_original = np.load(os.path.join(img_folder, 'descriptions.npy'), allow_pickle=True, | |
encoding='latin1').item() | |
followings = np.load(os.path.join(img_folder, 'following_cache4.npy')) | |
train_ids, val_ids, test_ids = np.load(os.path.join(img_folder, 'train_seen_unseen_ids.npy'), allow_pickle=True) | |
filenames = ['descriptions_train.csv', 'descriptions_val.csv', 'descriptions_test.csv'] | |
for ids, filename in zip([train_ids, val_ids, test_ids], filenames): | |
im_ids = [] | |
for src_img_id in ids: | |
tgt_img_paths = [str(followings[src_img_id][i])[2:-1] for i in range(video_len)] | |
tgt_img_ids = [str(tgt_img_path).replace(img_folder, '').replace('.png', '') for tgt_img_path in | |
tgt_img_paths] | |
im_ids.extend(tgt_img_ids) | |
# captions = [descriptions_original[tgt_img_id] for tgt_img_id in tgt_img_ids] | |
im_ids = list(set(im_ids)) | |
im_ids.sort() | |
# captions = [descriptions_original[i] for i in im_ids] | |
with open(os.path.join(img_folder, filename), 'w') as csvfile: | |
# creating a csv writer object | |
csvwriter = csv.writer(csvfile) | |
for i in im_ids: | |
csvwriter.writerow([i, descriptions_original[i][0]]) | |
get_captions_by_split() |