File size: 2,026 Bytes
19a6d73
1eb87a5
d160dc6
1eb87a5
19a6d73
d85fde4
 
 
19a6d73
d85fde4
19a6d73
054a11a
19a6d73
d160dc6
19a6d73
 
 
d160dc6
19a6d73
 
054a11a
d85fde4
 
 
19a6d73
d160dc6
19a6d73
d160dc6
19a6d73
 
 
 
 
 
 
054a11a
19a6d73
054a11a
4a3fe77
d160dc6
4a3fe77
054a11a
 
 
 
 
 
 
 
19a6d73
 
 
 
d160dc6
 
054a11a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# utils.py
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial


def make_grid(patch_size: int | tuple[int, int]):
    """Gera grid de coordenadas com validação robusta"""
    if isinstance(patch_size, int):
        h = w = max(16, patch_size)  # Novo mínimo seguro
    else:
        h, w = (max(16, ps) for ps in patch_size)  # 16x16 mínimo

    # Cálculo preciso das coordenadas
    y_coords = np.linspace(-0.5 + 1 / (2 * h), 0.5 - 1 / (2 * h), h)
    x_coords = np.linspace(-0.5 + 1 / (2 * w), 0.5 - 1 / (2 * w), w)

    # Grid com dimensões (1, H, W, 2)
    grid = np.stack(np.meshgrid(y_coords, x_coords, indexing='ij'), axis=-1)
    return grid[np.newaxis, ...]


def interpolate_grid(coords, grid, order=0):
    """Interpolação com tratamento completo de dimensões"""
    try:
        # Converter e garantir 4D
        coords = jnp.asarray(coords)
        if coords.ndim == 1:  # Caso de erro reportado
            coords = coords.reshape(1, 1, 1, -1)
        while coords.ndim < 4:
            coords = coords[jnp.newaxis, ...]

        # Validação final
        if coords.shape[-1] != 2 or coords.ndim != 4:
            raise ValueError(
                f"Dimensões inválidas: {coords.shape}. Formato esperado: (B, H, W, 2)"
            )

        # Transformação de coordenadas
        coords = coords.transpose((0, 3, 1, 2))
        coords = coords.at[:, 0].set(
            coords[:, 0] * (grid.shape[-3] - 1) + (grid.shape[-3] - 1) / 2
        )
        coords = coords.at[:, 1].set(
            coords[:, 1] * (grid.shape[-2] - 1) + (grid.shape[-2] - 1) / 2
        )

        # Interpolação vetorizada
        map_coordinates = partial(jax.scipy.ndimage.map_coordinates,
                                  order=order,
                                  mode='nearest')
        return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)

    except Exception as e:
        raise RuntimeError(f"Erro de interpolação: {str(e)}") from e