File size: 2,214 Bytes
19a6d73 1eb87a5 d160dc6 1eb87a5 19a6d73 d85fde4 e0956f1 1557411 d160dc6 1557411 19a6d73 d160dc6 1557411 19a6d73 054a11a d85fde4 1557411 d160dc6 1557411 d160dc6 1557411 19a6d73 054a11a 1557411 054a11a 4a3fe77 d160dc6 4a3fe77 054a11a 19a6d73 d160dc6 1557411 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
# utils.py
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
def repeat_vmap(fun, in_axes=None):
if in_axes is None:
in_axes = [0]
for axes in in_axes:
fun = jax.vmap(fun, in_axes=axes)
return fun
def make_grid(target_shape: tuple[int, int]):
"""Gera grid de coordenadas com validação rigorosa"""
h, w = target_shape
# Garantir tamanho mínimo e divisibilidade
h = max(32, h)
w = max(32, w)
h = h if h % 8 == 0 else h + (8 - h % 8)
w = w if w % 8 == 0 else w + (8 - w % 8)
# Espaçamento preciso
y_coords = np.linspace(-0.5 + 1 / (2 * h), 0.5 - 1 / (2 * h), h)
x_coords = np.linspace(-0.5 + 1 / (2 * w), 0.5 - 1 / (2 * w), w)
# Criar grid 4D (1, H, W, 2)
grid = np.stack(np.meshgrid(y_coords, x_coords, indexing='ij'), axis=-1)
return grid[np.newaxis, ...]
def interpolate_grid(coords, grid, order=0):
"""Interpolação segura com verificações em tempo real"""
try:
# Converter e garantir formato 4D
coords = jnp.asarray(coords)
original_shape = coords.shape
# Adicionar dimensões faltantes
while coords.ndim < 4:
coords = coords[jnp.newaxis, ...]
# Validação final
if coords.shape[-1] != 2 or coords.ndim != 4:
raise ValueError(
f"Formato inválido: {original_shape} → {coords.shape}. "
f"Esperado (B, H, W, 2)"
)
# Transformação de coordenadas
coords = coords.transpose((0, 3, 1, 2))
coords = coords.at[:, 0].set(
coords[:, 0] * (grid.shape[-3] - 1) + (grid.shape[-3] - 1) / 2
)
coords = coords.at[:, 1].set(
coords[:, 1] * (grid.shape[-2] - 1) + (grid.shape[-2] - 1) / 2
)
# Interpolação vetorizada
map_coordinates = partial(jax.scipy.ndimage.map_coordinates,
order=order,
mode='nearest')
return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)
except Exception as e:
raise RuntimeError(f"Erro na interpolação: {str(e)}") from e |