File size: 3,824 Bytes
560a1b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright 2023 Adobe Research. All rights reserved.
# To view a copy of the license, visit LICENSE.md.

import json
from pathlib import Path

import torch
from torch.utils.data import DataLoader
import torchvision
from PIL import Image, ImageDraw, ImageFont
from expansion_utils import consts


def max_num_not_in_list(max_num, lst):
    for i in range(max_num, 0, -1):
        if i not in lst:
            return i


def process_config(expansion_cfg_path):
    with Path(expansion_cfg_path).open() as fp:
        expansion_cfg = json.load(fp)
    assert ("tasks" in expansion_cfg.keys()) and ("tasks_losses" in expansion_cfg.keys())

    curr_max_dim = consts.LATENT_DIM - 1

    used_dims = [x.get("dimension") for x in expansion_cfg["tasks"]]

    for dim in used_dims:
        if dim is not None and used_dims.count(dim) > 1:
            raise ValueError(f"Config tries to repurpose the same dim {dim} more than once, unsupported...")

    for task in expansion_cfg["tasks"]:
        if task.get("dimension") is None:
            curr_max_dim = max_num_not_in_list(curr_max_dim, used_dims)

            if curr_max_dim is None:
                raise ValueError("No available dimension was found")

            task["dimension"] = curr_max_dim
            used_dims.append(curr_max_dim)

    print(f"Parsed config successfuly! Repurposing {len(used_dims)} dims, good luck!")

    return expansion_cfg


def label_image(img: torch.Tensor, label: str = None):
    batch_size = img.shape[0]
    img = torchvision.utils.make_grid(img, batch_size) # concat over W
    if label is not None:
        H, W = img.shape[-2:]
        W = W // (4 * batch_size) 
        H, W = W, H # will be rotated, so H is W
        font = ImageFont.truetype("DejaVuSans.ttf", 60) # TODO: use different font sizes for different resolutions.
        label_img = Image.new('RGB', (W ,H), color='white') 
        draw = ImageDraw.Draw(label_img)
        w, h = draw.textsize(label, font=font)
        draw.text(((W - w) / 2, (H - h) / 2), label, font=font, fill=(0, 0, 0))
        label_img = torchvision.transforms.functional.pil_to_tensor(label_img.rotate(90, expand=True))
        label_img = label_img.to(torch.float32) / 127.5 - 1
        img = torch.cat([label_img, img], dim=-1)
    return img

def save_batched_images(images: torch.Tensor, output_path: Path, labels: list = None, max_row_in_img=5):
    num_rows = images.shape[0]
    
    if labels is not None:
        if num_rows != len(labels):
            raise ValueError('Number of labels should match number of batches')
    else:
        labels = [None] * num_rows

    images = [label_image(image, label) for image, label in zip(images, labels)]
    images = torch.stack(images)

    batched_iter = DataLoader(images, batch_size=max_row_in_img)
    for batch_idx, images_slice in enumerate(batched_iter):
        save_images(
            images_slice,
            output_path.with_name(f"{output_path.stem}_batch_{batch_idx}"),
            1,
        )


def save_images(frames: torch.Tensor, output_path: Path, nrow=None, size=None, separate=False):
    parent_dir = output_path.parent
    parent_dir.mkdir(exist_ok=True, parents=True)

    if size:
        frames = torch.nn.functional.interpolate(frames, size)

    if separate:
        base_name = output_path.stem
        for i, frame in enumerate(frames):
            torchvision.utils.save_image(
                frame,
                output_path.with_name(f"{i:05d}_{base_name}.jpg"),
                nrow=len(frame),
                normalize=True,
                range=(-1, 1),
            )
    else:
        torchvision.utils.save_image(
            frames,
            output_path.with_suffix(".jpg"),
            nrow=nrow if nrow else len(frames),
            normalize=True,
            range=(-1, 1),
        )