Update app.py
Browse files
app.py
CHANGED
@@ -155,7 +155,7 @@ state = train_state.TrainState.create(
|
|
155 |
)
|
156 |
|
157 |
# Modify the train_step function
|
158 |
-
|
159 |
def compute_loss(params, pixel_values, rng):
|
160 |
# Ensure pixel_values are float32
|
161 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
|
|
155 |
)
|
156 |
|
157 |
# Modify the train_step function
|
158 |
+
def train_step(state, batch, rng):
|
159 |
def compute_loss(params, pixel_values, rng):
|
160 |
# Ensure pixel_values are float32
|
161 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|