Spaces:
vilarin
/
Running on Zero

vilarin commited on
Commit
ccc0607
·
verified ·
1 Parent(s): d79f1a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -96,7 +96,7 @@ class ModelWrapper:
96
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
97
  current_timesteps = current_timesteps.to(torch.float16)
98
  print(f'current_timestpes: {current_timesteps.dtype}')
99
- eval_images = self.model(noise.to(torch.float16), current_timesteps, prompt_embed.to(torch.float16), added_cond_kwargs=unet_added_conditions).sample
100
  print(type(eval_images))
101
 
102
  eval_images = get_x0_from_noise(noise, eval_images, alphas_cumprod, current_timesteps).to(self.DTYPE)
@@ -123,7 +123,7 @@ class ModelWrapper:
123
 
124
  add_time_ids = self.build_condition_input(height, width).repeat(num_images, 1)
125
 
126
- noise = torch.randn(num_images, 4, height // self.vae_downsample_ratio, width // self.vae_downsample_ratio, generator=generator).to(device="cuda", dtype=torch.float16)
127
 
128
  prompt_inputs = self._encode_prompt(prompt)
129
 
@@ -132,13 +132,13 @@ class ModelWrapper:
132
  prompt_embeds, pooled_prompt_embeds = self.text_encoder(prompt_inputs)
133
 
134
  batch_prompt_embeds, batch_pooled_prompt_embeds = (
135
- prompt_embeds.repeat(num_images, 1, 1),
136
- pooled_prompt_embeds.repeat(num_images, 1, 1)
137
  )
138
 
139
  unet_added_conditions = {
140
  "time_ids": add_time_ids,
141
- "text_embeds": batch_pooled_prompt_embeds.squeeze(1).to(torch.float16)
142
  }
143
 
144
 
 
96
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
97
  current_timesteps = current_timesteps.to(torch.float16)
98
  print(f'current_timestpes: {current_timesteps.dtype}')
99
+ eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
100
  print(type(eval_images))
101
 
102
  eval_images = get_x0_from_noise(noise, eval_images, alphas_cumprod, current_timesteps).to(self.DTYPE)
 
123
 
124
  add_time_ids = self.build_condition_input(height, width).repeat(num_images, 1)
125
 
126
+ noise = torch.randn(num_images, 4, height // self.vae_downsample_ratio, width // self.vae_downsample_ratio, generator=generator).to(device="cuda").to(torch.float16)
127
 
128
  prompt_inputs = self._encode_prompt(prompt)
129
 
 
132
  prompt_embeds, pooled_prompt_embeds = self.text_encoder(prompt_inputs)
133
 
134
  batch_prompt_embeds, batch_pooled_prompt_embeds = (
135
+ prompt_embeds.repeat(num_images, 1, 1).to(torch.float16),
136
+ pooled_prompt_embeds.repeat(num_images, 1, 1).to(torch.float16)
137
  )
138
 
139
  unet_added_conditions = {
140
  "time_ids": add_time_ids,
141
+ "text_embeds": batch_pooled_prompt_embeds.squeeze(1)
142
  }
143
 
144