File size: 2,757 Bytes
fa0f216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
from pathlib import Path

import torch
import torch.utils.data

from data.dataset import FidDataset
from generate.writer import Writer


def generate_fid(args):
    if 'iam' in args.target_dataset_path.lower():
        args.num_writers = 339
    elif 'cvl' in args.target_dataset_path.lower():
        args.num_writers = 283
    else:
        raise ValueError

    args.vocab_size = len(args.alphabet)

    dataset_train = FidDataset(base_path=args.target_dataset_path, num_examples=args.num_examples, collator_resolution=args.resolution, mode='train', style_dataset=args.dataset_path)
    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True, drop_last=False,
        collate_fn=dataset_train.collate_fn
    )

    dataset_test = FidDataset(base_path=args.target_dataset_path, num_examples=args.num_examples, collator_resolution=args.resolution, mode='test', style_dataset=args.dataset_path)
    test_loader = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True, drop_last=False,
        collate_fn=dataset_test.collate_fn
    )

    args.output = 'saved_images' if args.output is None else args.output
    args.output = Path(args.output) / 'fid' / args.target_dataset_path.split("/")[-1].replace(".pickle", "").replace("-", "")

    model_folder = args.checkpoint.split("/")[-2] if args.checkpoint.endswith(".pth") else args.checkpoint.split("/")[-1]
    model_tag = model_folder.split("-")[-1] if "-" in model_folder else "vatr"
    model_tag += "_" + args.dataset_path.split("/")[-1].replace(".pickle", "").replace("-", "")

    if not args.all_epochs:
        writer = Writer(args.checkpoint, args, only_generator=True)
        if not args.test_only:
            writer.generate_fid(args.output, train_loader, model_tag=model_tag, split='train', fake_only=args.fake_only, long_tail_only=args.long_tail)
        writer.generate_fid(args.output, test_loader, model_tag=model_tag, split='test', fake_only=args.fake_only, long_tail_only=args.long_tail)
    else:
        epochs = sorted([int(f.split("_")[0]) for f in os.listdir(args.checkpoint) if "_" in f])
        generate_real = True

        for epoch in epochs:
            checkpoint_path = os.path.join(args.checkpoint, f"{str(epoch).zfill(4)}_model.pth")
            writer = Writer(checkpoint_path, args, only_generator=True)
            writer.generate_fid(args.output, test_loader, model_tag=f"{model_tag}_{epoch}", split='test', fake_only=not generate_real, long_tail_only=args.long_tail)
            generate_real = False

    print('Done')