MohamedRashad commited on
Commit
8f812c4
·
1 Parent(s): 2f3fed1

Update CUDA device usage in app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -74,7 +74,7 @@ def encode_cropped_prompt_77tokens(txt: str):
74
  padding="max_length",
75
  max_length=tokenizer.model_max_length,
76
  truncation=True,
77
- return_tensors="pt").input_ids.to(device=text_encoder.device)
78
  text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
79
  return text_cond
80
 
@@ -117,15 +117,15 @@ def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed,
117
  rng = torch.Generator(device="cuda").manual_seed(int(seed))
118
 
119
  fg = resize_and_center_crop(input_fg, image_width, image_height)
120
- concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
121
  concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
122
 
123
  conds = encode_cropped_prompt_77tokens(prompt)
124
  unconds = encode_cropped_prompt_77tokens(n_prompt)
125
 
126
- fs = torch.tensor(input_undo_steps).to(device=unet.device, dtype=torch.long)
127
  initial_latents = torch.zeros_like(concat_conds)
128
- concat_conds = concat_conds.to(device=unet.device, dtype=unet.dtype)
129
  latents = k_sampler(
130
  initial_latent=initial_latents,
131
  strength=1.0,
@@ -169,13 +169,13 @@ def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=
169
  positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
170
  negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
171
 
172
- input_frames = input_frames.to(device=video_pipe.image_encoder.device, dtype=video_pipe.image_encoder.dtype)
173
  positive_image_cond = video_pipe.encode_clip_vision(input_frames)
174
  positive_image_cond = video_pipe.image_projection(positive_image_cond)
175
  negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
176
  negative_image_cond = video_pipe.image_projection(negative_image_cond)
177
 
178
- input_frames = input_frames.to(device=video_pipe.vae.device, dtype=video_pipe.vae.dtype)
179
  input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
180
  first_frame = input_frame_latents[:, :, 0]
181
  last_frame = input_frame_latents[:, :, 1]
 
74
  padding="max_length",
75
  max_length=tokenizer.model_max_length,
76
  truncation=True,
77
+ return_tensors="pt").input_ids.to(device="cuda")
78
  text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
79
  return text_cond
80
 
 
117
  rng = torch.Generator(device="cuda").manual_seed(int(seed))
118
 
119
  fg = resize_and_center_crop(input_fg, image_width, image_height)
120
+ concat_conds = numpy2pytorch([fg]).to(device="cuda", dtype=vae.dtype)
121
  concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
122
 
123
  conds = encode_cropped_prompt_77tokens(prompt)
124
  unconds = encode_cropped_prompt_77tokens(n_prompt)
125
 
126
+ fs = torch.tensor(input_undo_steps).to(device="cuda", dtype=torch.long)
127
  initial_latents = torch.zeros_like(concat_conds)
128
+ concat_conds = concat_conds.to(device="cuda", dtype=unet.dtype)
129
  latents = k_sampler(
130
  initial_latent=initial_latents,
131
  strength=1.0,
 
169
  positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
170
  negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
171
 
172
+ input_frames = input_frames.to(device="cuda", dtype=video_pipe.image_encoder.dtype)
173
  positive_image_cond = video_pipe.encode_clip_vision(input_frames)
174
  positive_image_cond = video_pipe.image_projection(positive_image_cond)
175
  negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
176
  negative_image_cond = video_pipe.image_projection(negative_image_cond)
177
 
178
+ input_frames = input_frames.to(device="cuda", dtype=video_pipe.vae.dtype)
179
  input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
180
  first_frame = input_frame_latents[:, :, 0]
181
  last_frame = input_frame_latents[:, :, 1]