Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -129,30 +129,39 @@ def process(
|
|
129 |
shape = (1, 4, height // 8, width // 8)
|
130 |
x_T = torch.randn(shape, device=device, dtype=torch.float32)
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
if not tile_diffusion and not tile_vae:
|
133 |
samples = sampler.sample_ccsr(
|
134 |
steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
|
135 |
-
positive_prompt=
|
136 |
cfg_scale=cfg_scale,
|
137 |
color_fix_type="adain" if use_color_fix else "none"
|
138 |
)
|
139 |
else:
|
140 |
if tile_vae:
|
141 |
-
#
|
142 |
-
# model._init_tiled_vae(encoder_tile_size=vae_encoder_tile_size, decoder_tile_size=vae_decoder_tile_size)
|
143 |
pass
|
144 |
if tile_diffusion:
|
145 |
samples = sampler.sample_with_tile_ccsr(
|
146 |
tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride,
|
147 |
steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
|
148 |
-
positive_prompt=
|
149 |
cfg_scale=cfg_scale,
|
150 |
color_fix_type="adain" if use_color_fix else "none"
|
151 |
)
|
152 |
else:
|
153 |
samples = sampler.sample_ccsr(
|
154 |
steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
|
155 |
-
positive_prompt=
|
156 |
cfg_scale=cfg_scale,
|
157 |
color_fix_type="adain" if use_color_fix else "none"
|
158 |
)
|
|
|
129 |
shape = (1, 4, height // 8, width // 8)
|
130 |
x_T = torch.randn(shape, device=device, dtype=torch.float32)
|
131 |
|
132 |
+
# Modify the get_learned_conditioning method to handle the attention mask issue
|
133 |
+
def modified_get_learned_conditioning(model, prompt):
|
134 |
+
tokens = model.cond_stage_model.tokenizer.encode(prompt)
|
135 |
+
tokens = torch.LongTensor(tokens).to(model.device).unsqueeze(0)
|
136 |
+
encoder_hidden_states = model.cond_stage_model.transformer(input_ids=tokens).last_hidden_state
|
137 |
+
return encoder_hidden_states
|
138 |
+
|
139 |
+
cond = modified_get_learned_conditioning(model, positive_prompt)
|
140 |
+
uncond = modified_get_learned_conditioning(model, negative_prompt)
|
141 |
+
|
142 |
if not tile_diffusion and not tile_vae:
|
143 |
samples = sampler.sample_ccsr(
|
144 |
steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
|
145 |
+
positive_prompt=cond, negative_prompt=uncond, x_T=x_T,
|
146 |
cfg_scale=cfg_scale,
|
147 |
color_fix_type="adain" if use_color_fix else "none"
|
148 |
)
|
149 |
else:
|
150 |
if tile_vae:
|
151 |
+
# Note: Tiled VAE is not implemented in this version
|
|
|
152 |
pass
|
153 |
if tile_diffusion:
|
154 |
samples = sampler.sample_with_tile_ccsr(
|
155 |
tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride,
|
156 |
steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
|
157 |
+
positive_prompt=cond, negative_prompt=uncond, x_T=x_T,
|
158 |
cfg_scale=cfg_scale,
|
159 |
color_fix_type="adain" if use_color_fix else "none"
|
160 |
)
|
161 |
else:
|
162 |
samples = sampler.sample_ccsr(
|
163 |
steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
|
164 |
+
positive_prompt=cond, negative_prompt=uncond, x_T=x_T,
|
165 |
cfg_scale=cfg_scale,
|
166 |
color_fix_type="adain" if use_color_fix else "none"
|
167 |
)
|