Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -16,8 +16,8 @@ def run(image, src_style, src_prompt, prompts, shared_score_shift, shared_score_
|
|
16 |
image = image.resize((dim, dim))
|
17 |
x0 = np.array(image)
|
18 |
zts = inversion.ddim_inversion(pipeline, x0, src_prompt, num_inference_steps, 2)
|
|
|
19 |
prompts.insert(0, src_prompt)
|
20 |
-
|
21 |
shared_score_shift = np.log(shared_score_shift)
|
22 |
handler = sa_handler.Handler(pipeline)
|
23 |
sa_args = sa_handler.StyleAlignedArgs(
|
@@ -25,12 +25,11 @@ def run(image, src_style, src_prompt, prompts, shared_score_shift, shared_score_
|
|
25 |
adain_queries=True, adain_keys=True, adain_values=False,
|
26 |
shared_score_shift=shared_score_shift, shared_score_scale=shared_score_scale,)
|
27 |
handler.register(sa_args)
|
28 |
-
|
29 |
for i in range(1, len(prompts)):
|
30 |
prompts[i] = f'{prompts[i]}, {src_style}.'
|
31 |
-
|
32 |
-
zT, inversion_callback = inversion.make_inversion_callback(zts, offset=5)
|
33 |
g_cpu = torch.Generator(device='cpu')
|
|
|
34 |
if seed > 0:
|
35 |
g_cpu.manual_seed(seed)
|
36 |
latents = torch.randn(len(prompts), 4, d, d, device='cpu', generator=g_cpu, dtype=pipeline.unet.dtype,).to(device)
|
|
|
16 |
image = image.resize((dim, dim))
|
17 |
x0 = np.array(image)
|
18 |
zts = inversion.ddim_inversion(pipeline, x0, src_prompt, num_inference_steps, 2)
|
19 |
+
offset = min(5, len(zts) - 1)
|
20 |
prompts.insert(0, src_prompt)
|
|
|
21 |
shared_score_shift = np.log(shared_score_shift)
|
22 |
handler = sa_handler.Handler(pipeline)
|
23 |
sa_args = sa_handler.StyleAlignedArgs(
|
|
|
25 |
adain_queries=True, adain_keys=True, adain_values=False,
|
26 |
shared_score_shift=shared_score_shift, shared_score_scale=shared_score_scale,)
|
27 |
handler.register(sa_args)
|
|
|
28 |
for i in range(1, len(prompts)):
|
29 |
prompts[i] = f'{prompts[i]}, {src_style}.'
|
30 |
+
zT, inversion_callback = inversion.make_inversion_callback(zts, offset=offset)
|
|
|
31 |
g_cpu = torch.Generator(device='cpu')
|
32 |
+
|
33 |
if seed > 0:
|
34 |
g_cpu.manual_seed(seed)
|
35 |
latents = torch.randn(len(prompts), 4, d, d, device='cpu', generator=g_cpu, dtype=pipeline.unet.dtype,).to(device)
|