Spaces:
Runtime error
Runtime error
Commit
·
d9edb33
1
Parent(s):
a6df2dc
Update app.py
Browse files
app.py
CHANGED
@@ -23,8 +23,8 @@ def infer(prompts, negative_prompts):
|
|
23 |
rng = create_key(0)
|
24 |
rng = jax.random.split(rng, jax.device_count())
|
25 |
|
26 |
-
prompt_ids = pipe.
|
27 |
-
negative_prompt_ids = pipe.
|
28 |
|
29 |
p_params = replicate(params)
|
30 |
prompt_ids = shard(prompt_ids)
|
|
|
23 |
rng = create_key(0)
|
24 |
rng = jax.random.split(rng, jax.device_count())
|
25 |
|
26 |
+
prompt_ids = pipe.prepare_inputs([prompts] * num_samples)
|
27 |
+
negative_prompt_ids = pipe.prepare_inputs([negative_prompts] * num_samples)
|
28 |
|
29 |
p_params = replicate(params)
|
30 |
prompt_ids = shard(prompt_ids)
|