|
import sys |
|
sys.path.append('./') |
|
|
|
import pandas as pd |
|
import numpy as np |
|
|
|
import argparse |
|
import torch |
|
from torch.utils.data import DataLoader |
|
from vq_gan_3d import load_VQGAN, get_single_device, VQGANDataset |
|
from tqdm import tqdm |
|
|
|
DEVICE = get_single_device(cpu=False) |
|
|
|
|
|
def extract_embeddings(vqgan, dataloader): |
|
with torch.no_grad(): |
|
batch_embds = [vqgan.feature_extraction(x.to(DEVICE)) for x in tqdm(dataloader)] |
|
batch_embds = torch.cat(batch_embds, dim=0) |
|
return batch_embds |
|
|
|
|
|
def main(args): |
|
|
|
vqgan = load_VQGAN( |
|
folder='../data/checkpoints/pretrained', |
|
ckpt_filename=args.ckpt_filename |
|
).eval().to(DEVICE) |
|
|
|
|
|
df = pd.read_csv(args.dataset_path) |
|
dataloader = DataLoader( |
|
VQGANDataset( |
|
args.data_dir, |
|
df['3d_grid'].to_list(), |
|
vqgan.config['model']['internal_resolution'] |
|
), |
|
batch_size=args.batch_size, |
|
num_workers=args.num_workers, |
|
shuffle=False, |
|
pin_memory=False, |
|
) |
|
|
|
|
|
print('VQGAN model and data loaded!') |
|
print('\tCheckpoint:', args.ckpt_filename) |
|
print('\tDataset size:', df.shape) |
|
|
|
|
|
embeddings = extract_embeddings(vqgan, dataloader).cpu().numpy() |
|
|
|
|
|
df_embeddings = pd.DataFrame(embeddings) |
|
df_full = pd.concat([df, df_embeddings], axis=1) |
|
|
|
|
|
df_full.to_csv(args.save_dataset_path, index=False) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--dataset_path', type=str, required=True) |
|
parser.add_argument('--save_dataset_path', type=str, required=True) |
|
parser.add_argument('--ckpt_filename', type=str, default='VQGAN_43.pt', required=False) |
|
parser.add_argument('--data_dir', type=str, default='/scratch/vyukio/npy_datasets/qm9_npy/', required=False) |
|
parser.add_argument('--batch_size', type=int, default=1, required=False) |
|
parser.add_argument('--num_workers', type=int, default=0, required=False) |
|
args = parser.parse_args() |
|
main(args) |