Tonic commited on
Commit
8f09b36
β€’
1 Parent(s): db96f54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -16,8 +16,8 @@ def run(image, src_style, src_prompt, prompts, shared_score_shift, shared_score_
16
  image = image.resize((dim, dim))
17
  x0 = np.array(image)
18
  zts = inversion.ddim_inversion(pipeline, x0, src_prompt, num_inference_steps, 2)
 
19
  prompts.insert(0, src_prompt)
20
-
21
  shared_score_shift = np.log(shared_score_shift)
22
  handler = sa_handler.Handler(pipeline)
23
  sa_args = sa_handler.StyleAlignedArgs(
@@ -25,12 +25,11 @@ def run(image, src_style, src_prompt, prompts, shared_score_shift, shared_score_
25
  adain_queries=True, adain_keys=True, adain_values=False,
26
  shared_score_shift=shared_score_shift, shared_score_scale=shared_score_scale,)
27
  handler.register(sa_args)
28
-
29
  for i in range(1, len(prompts)):
30
  prompts[i] = f'{prompts[i]}, {src_style}.'
31
-
32
- zT, inversion_callback = inversion.make_inversion_callback(zts, offset=5)
33
  g_cpu = torch.Generator(device='cpu')
 
34
  if seed > 0:
35
  g_cpu.manual_seed(seed)
36
  latents = torch.randn(len(prompts), 4, d, d, device='cpu', generator=g_cpu, dtype=pipeline.unet.dtype,).to(device)
 
16
  image = image.resize((dim, dim))
17
  x0 = np.array(image)
18
  zts = inversion.ddim_inversion(pipeline, x0, src_prompt, num_inference_steps, 2)
19
+ offset = min(5, len(zts) - 1)
20
  prompts.insert(0, src_prompt)
 
21
  shared_score_shift = np.log(shared_score_shift)
22
  handler = sa_handler.Handler(pipeline)
23
  sa_args = sa_handler.StyleAlignedArgs(
 
25
  adain_queries=True, adain_keys=True, adain_values=False,
26
  shared_score_shift=shared_score_shift, shared_score_scale=shared_score_scale,)
27
  handler.register(sa_args)
 
28
  for i in range(1, len(prompts)):
29
  prompts[i] = f'{prompts[i]}, {src_style}.'
30
+ zT, inversion_callback = inversion.make_inversion_callback(zts, offset=offset)
 
31
  g_cpu = torch.Generator(device='cpu')
32
+
33
  if seed > 0:
34
  g_cpu.manual_seed(seed)
35
  latents = torch.randn(len(prompts), 4, d, d, device='cpu', generator=g_cpu, dtype=pipeline.unet.dtype,).to(device)