New logic
Browse files
app.py
CHANGED
@@ -55,18 +55,17 @@ def full_pipeline(image, prompt, scale_factor=2.0):
|
|
55 |
source = np.array(image) / 255.0
|
56 |
target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
|
57 |
|
58 |
-
# Preparar parâmetros para JAX
|
59 |
source_jax = jax.device_put(source, JAX_DEVICE)
|
60 |
t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
|
61 |
|
62 |
-
#
|
63 |
upscaled = model_edsr.apply(
|
64 |
-
params_edsr,
|
65 |
source_jax,
|
66 |
t,
|
67 |
-
target_shape
|
68 |
-
do_ensemble=True
|
69 |
)
|
|
|
70 |
upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
|
71 |
|
72 |
# 2. Gerar Bas-Relief
|
|
|
55 |
source = np.array(image) / 255.0
|
56 |
target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
|
57 |
|
|
|
58 |
source_jax = jax.device_put(source, JAX_DEVICE)
|
59 |
t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
|
60 |
|
61 |
+
# Chamada corrigida sem 'do_ensemble'
|
62 |
upscaled = model_edsr.apply(
|
63 |
+
{'params': params_edsr}, # Estrutura de parâmetros correta
|
64 |
source_jax,
|
65 |
t,
|
66 |
+
target_shape
|
|
|
67 |
)
|
68 |
+
|
69 |
upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
|
70 |
|
71 |
# 2. Gerar Bas-Relief
|