File size: 2,214 Bytes
19a6d73
1eb87a5
d160dc6
1eb87a5
19a6d73
d85fde4
 
e0956f1
 
 
 
 
 
 
 
1557411
 
 
d160dc6
1557411
 
 
 
 
 
 
19a6d73
 
d160dc6
1557411
19a6d73
054a11a
d85fde4
 
 
1557411
d160dc6
1557411
d160dc6
1557411
 
 
19a6d73
 
 
 
 
054a11a
1557411
 
054a11a
4a3fe77
d160dc6
4a3fe77
054a11a
 
 
 
 
 
 
 
19a6d73
 
 
 
d160dc6
 
1557411
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# utils.py
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial


def repeat_vmap(fun, in_axes=None):
    if in_axes is None:
        in_axes = [0]
    for axes in in_axes:
        fun = jax.vmap(fun, in_axes=axes)
    return fun


def make_grid(target_shape: tuple[int, int]):
    """Gera grid de coordenadas com validação rigorosa"""
    h, w = target_shape

    # Garantir tamanho mínimo e divisibilidade
    h = max(32, h)
    w = max(32, w)
    h = h if h % 8 == 0 else h + (8 - h % 8)
    w = w if w % 8 == 0 else w + (8 - w % 8)

    # Espaçamento preciso
    y_coords = np.linspace(-0.5 + 1 / (2 * h), 0.5 - 1 / (2 * h), h)
    x_coords = np.linspace(-0.5 + 1 / (2 * w), 0.5 - 1 / (2 * w), w)

    # Criar grid 4D (1, H, W, 2)
    grid = np.stack(np.meshgrid(y_coords, x_coords, indexing='ij'), axis=-1)
    return grid[np.newaxis, ...]


def interpolate_grid(coords, grid, order=0):
    """Interpolação segura com verificações em tempo real"""
    try:
        # Converter e garantir formato 4D
        coords = jnp.asarray(coords)
        original_shape = coords.shape

        # Adicionar dimensões faltantes
        while coords.ndim < 4:
            coords = coords[jnp.newaxis, ...]

        # Validação final
        if coords.shape[-1] != 2 or coords.ndim != 4:
            raise ValueError(
                f"Formato inválido: {original_shape}{coords.shape}. "
                f"Esperado (B, H, W, 2)"
            )

        # Transformação de coordenadas
        coords = coords.transpose((0, 3, 1, 2))
        coords = coords.at[:, 0].set(
            coords[:, 0] * (grid.shape[-3] - 1) + (grid.shape[-3] - 1) / 2
        )
        coords = coords.at[:, 1].set(
            coords[:, 1] * (grid.shape[-2] - 1) + (grid.shape[-2] - 1) / 2
        )

        # Interpolação vetorizada
        map_coordinates = partial(jax.scipy.ndimage.map_coordinates,
                                  order=order,
                                  mode='nearest')
        return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)

    except Exception as e:
        raise RuntimeError(f"Erro na interpolação: {str(e)}") from e