ds1david commited on
Commit
0652978
·
1 Parent(s): a7111d1
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -55,18 +55,17 @@ def full_pipeline(image, prompt, scale_factor=2.0):
55
  source = np.array(image) / 255.0
56
  target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
57
 
58
- # Preparar parâmetros para JAX
59
  source_jax = jax.device_put(source, JAX_DEVICE)
60
  t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
61
 
62
- # Processar com Thera
63
  upscaled = model_edsr.apply(
64
- params_edsr,
65
  source_jax,
66
  t,
67
- target_shape,
68
- do_ensemble=True
69
  )
 
70
  upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
71
 
72
  # 2. Gerar Bas-Relief
 
55
  source = np.array(image) / 255.0
56
  target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
57
 
 
58
  source_jax = jax.device_put(source, JAX_DEVICE)
59
  t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
60
 
61
+ # Chamada corrigida sem 'do_ensemble'
62
  upscaled = model_edsr.apply(
63
+ {'params': params_edsr}, # Estrutura de parâmetros correta
64
  source_jax,
65
  t,
66
+ target_shape
 
67
  )
68
+
69
  upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
70
 
71
  # 2. Gerar Bas-Relief