ds1david commited on
Commit
42a2e7b
·
1 Parent(s): 0652978
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -25,7 +25,9 @@ def load_thera_model(repo_id, filename):
25
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
26
  with open(model_path, 'rb') as fh:
27
  check = pickle.load(fh)
28
- params, backbone, size = check['model'], check['backbone'], check['size']
 
 
29
  model = build_thera(3, backbone, size)
30
  return model, params
31
 
@@ -58,9 +60,9 @@ def full_pipeline(image, prompt, scale_factor=2.0):
58
  source_jax = jax.device_put(source, JAX_DEVICE)
59
  t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
60
 
61
- # Chamada corrigida sem 'do_ensemble'
62
  upscaled = model_edsr.apply(
63
- {'params': params_edsr}, # Estrutura de parâmetros correta
64
  source_jax,
65
  t,
66
  target_shape
@@ -125,4 +127,4 @@ with gr.Blocks(title="Super Res + Bas-Relief") as app:
125
  )
126
 
127
  if __name__ == "__main__":
128
- app.launch()
 
25
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
26
  with open(model_path, 'rb') as fh:
27
  check = pickle.load(fh)
28
+ # Ajustar a estrutura dos parâmetros
29
+ params = check['model']['params'] # Acessar os parâmetros corretamente
30
+ backbone, size = check['backbone'], check['size']
31
  model = build_thera(3, backbone, size)
32
  return model, params
33
 
 
60
  source_jax = jax.device_put(source, JAX_DEVICE)
61
  t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
62
 
63
+ # Chamada corrigida com estrutura de parâmetros adequada
64
  upscaled = model_edsr.apply(
65
+ params_edsr, # Parâmetros estruturados corretamente
66
  source_jax,
67
  t,
68
  target_shape
 
127
  )
128
 
129
  if __name__ == "__main__":
130
+ app.launch(share=False) # Ativando link público