uruguayai commited on
Commit
e9745d9
·
verified ·
1 Parent(s): c8658d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -136,7 +136,7 @@ def train_step(state, batch, rng):
136
  )
137
 
138
  # Predict noise
139
- model_output = state.apply_fn(
140
  {'params': params["unet"]},
141
  noisy_latents,
142
  timesteps,
@@ -160,7 +160,7 @@ learning_rate = 1e-5
160
  optimizer = optax.adam(learning_rate)
161
  float32_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
162
  state = train_state.TrainState.create(
163
- apply_fn=unet.__call__,
164
  params=float32_params,
165
  tx=optimizer,
166
  )
 
136
  )
137
 
138
  # Predict noise
139
+ model_output = unet.apply(
140
  {'params': params["unet"]},
141
  noisy_latents,
142
  timesteps,
 
160
  optimizer = optax.adam(learning_rate)
161
  float32_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
162
  state = train_state.TrainState.create(
163
+ apply_fn=unet.apply,
164
  params=float32_params,
165
  tx=optimizer,
166
  )