File size: 2,004 Bytes
d90acf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from omegaconf import OmegaConf
import numpy as np
from scipy import ndimage
import torch.nn as nn
from skimage.transform import resize


def load_conf(config_path):
    conf = OmegaConf.load(config_path)
    conf.data.tokens_length = conf.common.tokens_length
    conf.data.processor_names = conf.model.encoders.model_names
    conf.data.dataset.seed = conf.common.seed
    conf.data.dataset.image_size = conf.common.image_size

    conf.trainer.trainer_params.max_steps = conf.common.train_steps
    conf.scheduler.params.total_steps = conf.common.train_steps
    conf.logger.tensorboard.name = conf.common.experiment_name

    conf.model.encoders.context_dim = conf.model.unet_params.context_dim
    return conf


def freeze(model):
    for p in model.parameters():
        p.requires_grad = False
    return model

def unfreeze(model):
    for p in model.parameters():
        p.requires_grad = True
    return model

def zero_module(module):
    for p in module.parameters():
        nn.init.zeros_(p)
    return module

def resize_mask_for_diffusion(mask):
    reduce_factor = max(1, (mask.size / 1024**2)**0.5)
    resized_mask = resize(
        mask,
        (
            (round(mask.shape[0] / reduce_factor) // 64) * 64,
            (round(mask.shape[1] / reduce_factor) // 64) * 64
        ),
        preserve_range=True,
        anti_aliasing=False
    )

    return resized_mask

def resize_image_for_diffusion(image):
    reduce_factor = max(1, (image.size[0] * image.size[1] / 1024**2)**0.5)
    image = image.resize((
        (round(image.size[0] / reduce_factor) // 64) * 64, (round(image.size[1] / reduce_factor) // 64) * 64
    ))

    return image

def prepare_mask(mask):
    ker = np.array([[1, 1,  1, 1, 1],
        [1, 5,  5, 5, 1],
        [1, 5, 44, 5, 1],
        [1, 5,  5, 5, 1],
        [1, 1,  1, 1, 1]]) / 100
    out = ndimage.convolve(mask, ker)
    out = ndimage.convolve(out, ker)
    out = ndimage.convolve(out, ker)

    mask = (out > 0).astype(int)
    return mask