vshirasuna's picture
Move code to 3dgrid_vqgan folder
a4c759f
import sys
sys.path.append('./')
import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader
from vq_gan_3d.model.vqgan_DDP import load_VQGAN, VQGAN
from vq_gan_3d.metrics import ImageMetrics
from dataset.default import DEFAULTDataset
from tqdm import tqdm
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def infer_epoch_from_filename(filename):
try:
epoch = int(filename.split('_')[-1].split('.')[0])
except:
epoch = None
return epoch
def main():
# settings
model_folder = './checkpoints'
model_filename = 'VQGAN_43.pt'
data_folder = '../data/sample_data_schema'
save_results_folder = './results'
epoch = infer_epoch_from_filename(model_filename)
print(model_filename)
if not os.path.exists(save_results_folder):
os.mkdir(save_results_folder)
if epoch is None:
raise Exception("Cannot find the epoch number from filename. Please, provide an epoch manually.")
# load model
vqgan = load_VQGAN(folder=model_folder, filename=model_filename).to(DEVICE)
vqgan.eval()
# load data
eval_dataset = DEFAULTDataset(
root_dir=data_folder,
internal_resolution=vqgan.config['model']['internal_resolution']
)
dataloader = DataLoader(
eval_dataset,
shuffle=False,
batch_size=1,
pin_memory=True
)
# decoder quantative evaluation
print('Starting evaluation...')
dfs = []
for idx, data in enumerate(tqdm(dataloader)):
x = data['data'].to(DEVICE)
with torch.no_grad():
print(f'{idx}/{len(dataloader)}')
# forward
encds = vqgan.encode(x)
x_recon = vqgan.decode(encds)
torch.cuda.empty_cache()
# image metrics
eval_metrics = ImageMetrics(x.squeeze(), x_recon.squeeze(), device=DEVICE)
metrics_dict = eval_metrics.get_metrics()
metrics_dict = {k: [v] for k, v in metrics_dict.items()}
dfs.append(pd.DataFrame(metrics_dict))
torch.cuda.empty_cache()
# save results
df = pd.concat(dfs, axis=0)
df.to_csv(os.path.join(save_results_folder, f'./vqgan_image_evaluation_epoch={epoch}.csv'), index=False)
print(df.mean())
if __name__ == '__main__':
main()