File size: 3,573 Bytes
a02c6d7
 
 
 
 
1eb87a5
a02c6d7
1eb87a5
 
 
a02c6d7
1eb87a5
a02c6d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1eb87a5
a02c6d7
 
 
1eb87a5
a02c6d7
 
 
 
 
 
1eb87a5
a02c6d7
 
1eb87a5
a02c6d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1eb87a5
a02c6d7
 
1eb87a5
 
a02c6d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python

from argparse import ArgumentParser, Namespace
import pickle

import jax
from jax import jit
import jax.numpy as jnp
import numpy as np
from PIL import Image

from model import build_thera
from utils import make_grid, interpolate_grid

MEAN = jnp.array([.4488, .4371, .4040])
VAR = jnp.array([.25, .25, .25])
PATCH_SIZE = 256


def process_single(source, apply_encoder, apply_decoder, params, target_shape):
    t = jnp.float32((target_shape[0] / source.shape[1])**-2)[None]
    coords_nearest = jnp.asarray(make_grid(target_shape)[None])
    source_up = interpolate_grid(coords_nearest, source[None])
    source = jax.nn.standardize(source, mean=MEAN, variance=VAR)[None]

    encoding = apply_encoder(params, source)
    coords = jnp.asarray(make_grid(source_up.shape[1:3])[None])  # global sampling coords
    out = jnp.full_like(source_up, jnp.nan, dtype=jnp.float32)

    for h_min in range(0, coords.shape[1], PATCH_SIZE):
        h_max = min(h_min + PATCH_SIZE, coords.shape[1])
        for w_min in range(0, coords.shape[2], PATCH_SIZE):
            # apply decoder with one patch of coordinates
            w_max = min(w_min + PATCH_SIZE, coords.shape[2])
            coords_patch = coords[:, h_min:h_max, w_min:w_max]
            out_patch = apply_decoder(params, encoding, coords_patch, t)
            out = out.at[:, h_min:h_max, w_min:w_max].set(out_patch)

    out = out * jnp.sqrt(VAR)[None, None, None] + MEAN[None, None, None]
    out += source_up
    return out


def process(source, model, params, target_shape, do_ensemble=True):
    apply_encoder = jit(model.apply_encoder)
    apply_decoder = jit(model.apply_decoder)

    outs = []
    for i_rot in range(4 if do_ensemble else 1):
        source_ = jnp.rot90(source, k=i_rot, axes=(-3, -2))
        target_shape_ = tuple(reversed(target_shape)) if i_rot % 2 else target_shape
        out = process_single(source_, apply_encoder, apply_decoder, params, target_shape_)
        outs.append(jnp.rot90(out, k=i_rot, axes=(-2, -3)))

    out = jnp.stack(outs).mean(0).clip(0., 1.)
    return jnp.rint(out[0] * 255).astype(jnp.uint8)


def main(args: Namespace):
    source = np.asarray(Image.open(args.in_file)) / 255.

    if args.scale is not None:
        if args.size is not None:
            raise ValueError('Cannot specify both size and scale')
        target_shape = (
            round(source.shape[0] * args.scale),
            round(source.shape[1] * args.scale),
        )
    elif args.size is not None:
        target_shape = args.size
    else:
        raise ValueError('Must specify either size or scale')

    with open(args.checkpoint, 'rb') as fh:
        check = pickle.load(fh)
        params, backbone, size = check['model'], check['backbone'], check['size']

    model = build_thera(3, backbone, size)

    out = process(source, model, params, target_shape, not args.no_ensemble)

    Image.fromarray(np.asarray(out)).save(args.out_file)


def parse_args() -> Namespace:
    parser = ArgumentParser()
    parser.add_argument('in_file')
    parser.add_argument('out_file')
    parser.add_argument('--scale', type=float, help='Scale factor for super-resolution')
    parser.add_argument('--size', type=int, nargs=2,
                        help='Target size (h, w), mutually exclusive with --scale')
    parser.add_argument('--checkpoint', help='Path to checkpoint file')
    parser.add_argument('--no-ensemble', action='store_true', help='Disable geo-ensemble')
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    main(args)