fixing bugs
Browse files
utils.py
CHANGED
@@ -5,6 +5,14 @@ import numpy as np
|
|
5 |
from functools import partial
|
6 |
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
def make_grid(patch_size: int | tuple[int, int]):
|
9 |
"""Gera grid de coordenadas com validação robusta"""
|
10 |
if isinstance(patch_size, int):
|
|
|
5 |
from functools import partial
|
6 |
|
7 |
|
8 |
+
def repeat_vmap(fun, in_axes=None):
|
9 |
+
if in_axes is None:
|
10 |
+
in_axes = [0]
|
11 |
+
for axes in in_axes:
|
12 |
+
fun = jax.vmap(fun, in_axes=axes)
|
13 |
+
return fun
|
14 |
+
|
15 |
+
|
16 |
def make_grid(patch_size: int | tuple[int, int]):
|
17 |
"""Gera grid de coordenadas com validação robusta"""
|
18 |
if isinstance(patch_size, int):
|