MuhammadHanif commited on
Commit
d9edb33
·
1 Parent(s): a6df2dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -23,8 +23,8 @@ def infer(prompts, negative_prompts):
23
  rng = create_key(0)
24
  rng = jax.random.split(rng, jax.device_count())
25
 
26
- prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
27
- negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
28
 
29
  p_params = replicate(params)
30
  prompt_ids = shard(prompt_ids)
 
23
  rng = create_key(0)
24
  rng = jax.random.split(rng, jax.device_count())
25
 
26
+ prompt_ids = pipe.prepare_inputs([prompts] * num_samples)
27
+ negative_prompt_ids = pipe.prepare_inputs([negative_prompts] * num_samples)
28
 
29
  p_params = replicate(params)
30
  prompt_ids = shard(prompt_ids)