ironjr commited on
Commit
7ec6cab
·
verified ·
1 Parent(s): a98f79f

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +1 -1
model.py CHANGED
@@ -1125,7 +1125,7 @@ class StreamMultiDiffusion(nn.Module):
1125
  model_pred = self.unet(
1126
  x_t_latent_plus_uc.to(self.dtype), # (B, 4, h, w)
1127
  t_list, # (B,)
1128
- encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
1129
  return_dict=False,
1130
  )[0] # (B, 4, h, w)
1131
  print('222222222222222', model_pred.dtype)
 
1125
  model_pred = self.unet(
1126
  x_t_latent_plus_uc.to(self.dtype), # (B, 4, h, w)
1127
  t_list, # (B,)
1128
+ encoder_hidden_states=self.prompt_embeds.float(), # (B, 77, 768)
1129
  return_dict=False,
1130
  )[0] # (B, 4, h, w)
1131
  print('222222222222222', model_pred.dtype)