File size: 2,122 Bytes
9123ba9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import sys
sys.path.append('./')
import pandas as pd
import numpy as np
import argparse
import torch
from torch.utils.data import DataLoader
from vq_gan_3d import load_VQGAN, get_single_device, VQGANDataset
from tqdm import tqdm
DEVICE = get_single_device(cpu=False)
def extract_embeddings(vqgan, dataloader):
with torch.no_grad():
batch_embds = [vqgan.feature_extraction(x.to(DEVICE)) for x in tqdm(dataloader)]
batch_embds = torch.cat(batch_embds, dim=0)
return batch_embds
def main(args):
# load model
vqgan = load_VQGAN(
folder='../data/checkpoints/pretrained',
ckpt_filename=args.ckpt_filename
).eval().to(DEVICE)
# load data
df = pd.read_csv(args.dataset_path)
dataloader = DataLoader(
VQGANDataset(
args.data_dir,
df['3d_grid'].to_list(),
vqgan.config['model']['internal_resolution']
),
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=False,
pin_memory=False,
)
# debug
print('VQGAN model and data loaded!')
print('\tCheckpoint:', args.ckpt_filename)
print('\tDataset size:', df.shape)
# extract vqgan embeddings
embeddings = extract_embeddings(vqgan, dataloader).cpu().numpy()
# concat embeddings with dataset
df_embeddings = pd.DataFrame(embeddings)
df_full = pd.concat([df, df_embeddings], axis=1)
# save to disk
df_full.to_csv(args.save_dataset_path, index=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', type=str, required=True)
parser.add_argument('--save_dataset_path', type=str, required=True)
parser.add_argument('--ckpt_filename', type=str, default='VQGAN_43.pt', required=False)
parser.add_argument('--data_dir', type=str, default='/scratch/vyukio/npy_datasets/qm9_npy/', required=False)
parser.add_argument('--batch_size', type=int, default=1, required=False)
parser.add_argument('--num_workers', type=int, default=0, required=False)
args = parser.parse_args()
main(args) |