MohamedRashad commited on
Commit
2f3fed1
·
1 Parent(s): 453686a

chore: Update CUDA device usage in app.py

Browse files
Files changed (2) hide show
  1. app.py +7 -23
  2. memory_management.py +0 -67
app.py CHANGED
@@ -12,7 +12,6 @@ import gradio as gr
12
  import numpy as np
13
  import torch
14
  import wd14tagger
15
- import memory_management
16
  import uuid
17
 
18
  from PIL import Image
@@ -37,9 +36,9 @@ class ModifiedUNet(UNet2DConditionModel):
37
 
38
  model_name = 'lllyasviel/paints_undo_single_frame'
39
  tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
40
- text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(torch.float16)
41
- vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16) # bfloat16 vae
42
- unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16)
43
 
44
  unet.set_attn_processor(AttnProcessor2_0())
45
  vae.set_attn_processor(AttnProcessor2_0())
@@ -47,12 +46,7 @@ vae.set_attn_processor(AttnProcessor2_0())
47
  video_pipe = LatentVideoDiffusionPipeline.from_pretrained(
48
  'lllyasviel/paints_undo_multi_frame',
49
  fp16=True
50
- )
51
-
52
- memory_management.unload_all_models([
53
- video_pipe.unet, video_pipe.vae, video_pipe.text_encoder, video_pipe.image_projection, video_pipe.image_encoder,
54
- unet, vae, text_encoder
55
- ])
56
 
57
  k_sampler = KDiffusionSampler(
58
  unet=unet,
@@ -76,7 +70,6 @@ def find_best_bucket(h, w, options):
76
 
77
  @torch.inference_mode()
78
  def encode_cropped_prompt_77tokens(txt: str):
79
- memory_management.load_models_to_gpu(text_encoder)
80
  cond_ids = tokenizer(txt,
81
  padding="max_length",
82
  max_length=tokenizer.model_max_length,
@@ -111,28 +104,25 @@ def resize_without_crop(image, target_width, target_height):
111
 
112
 
113
  @torch.inference_mode()
114
- @spaces.GPU(duration=360)
115
  def interrogator_process(x):
116
  image_description = wd14tagger.default_interrogator(x)
117
  return image_description, image_description
118
 
119
 
120
  @torch.inference_mode()
121
- @spaces.GPU(duration=360)
122
  def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
123
  progress=gr.Progress()):
124
- rng = torch.Generator(device=memory_management.gpu).manual_seed(int(seed))
125
 
126
- memory_management.load_models_to_gpu(vae)
127
  fg = resize_and_center_crop(input_fg, image_width, image_height)
128
  concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
129
  concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
130
 
131
- memory_management.load_models_to_gpu(text_encoder)
132
  conds = encode_cropped_prompt_77tokens(prompt)
133
  unconds = encode_cropped_prompt_77tokens(n_prompt)
134
 
135
- memory_management.load_models_to_gpu(unet)
136
  fs = torch.tensor(input_undo_steps).to(device=unet.device, dtype=torch.long)
137
  initial_latents = torch.zeros_like(concat_conds)
138
  concat_conds = concat_conds.to(device=unet.device, dtype=unet.dtype)
@@ -150,7 +140,6 @@ def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed,
150
  progress_tqdm=functools.partial(progress.tqdm, desc='Generating Key Frames')
151
  ).to(vae.dtype) / vae.config.scaling_factor
152
 
153
- memory_management.load_models_to_gpu(vae)
154
  pixels = vae.decode(latents).sample
155
  pixels = pytorch2numpy(pixels)
156
  pixels = [fg] + pixels + [np.zeros_like(fg) + 255]
@@ -177,25 +166,21 @@ def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=
177
  input_frames = numpy2pytorch([image_1, image_2])
178
  input_frames = input_frames.unsqueeze(0).movedim(1, 2)
179
 
180
- memory_management.load_models_to_gpu(video_pipe.text_encoder)
181
  positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
182
  negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
183
 
184
- memory_management.load_models_to_gpu([video_pipe.image_projection, video_pipe.image_encoder])
185
  input_frames = input_frames.to(device=video_pipe.image_encoder.device, dtype=video_pipe.image_encoder.dtype)
186
  positive_image_cond = video_pipe.encode_clip_vision(input_frames)
187
  positive_image_cond = video_pipe.image_projection(positive_image_cond)
188
  negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
189
  negative_image_cond = video_pipe.image_projection(negative_image_cond)
190
 
191
- memory_management.load_models_to_gpu([video_pipe.vae])
192
  input_frames = input_frames.to(device=video_pipe.vae.device, dtype=video_pipe.vae.dtype)
193
  input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
194
  first_frame = input_frame_latents[:, :, 0]
195
  last_frame = input_frame_latents[:, :, 1]
196
  concat_cond = torch.stack([first_frame] + [torch.zeros_like(first_frame)] * (frames - 2) + [last_frame], dim=2)
197
 
198
- memory_management.load_models_to_gpu([video_pipe.unet])
199
  latents = video_pipe(
200
  batch_size=1,
201
  steps=int(steps),
@@ -209,7 +194,6 @@ def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=
209
  progress_tqdm=progress_tqdm
210
  )
211
 
212
- memory_management.load_models_to_gpu([video_pipe.vae])
213
  video = video_pipe.decode_latents(latents, vae_hidden_states)
214
  return video, image_1, image_2
215
 
 
12
  import numpy as np
13
  import torch
14
  import wd14tagger
 
15
  import uuid
16
 
17
  from PIL import Image
 
36
 
37
  model_name = 'lllyasviel/paints_undo_single_frame'
38
  tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
39
+ text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(torch.float16).to("cuda")
40
+ vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16).to("cuda") # bfloat16 vae
41
+ unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16).to("cuda")
42
 
43
  unet.set_attn_processor(AttnProcessor2_0())
44
  vae.set_attn_processor(AttnProcessor2_0())
 
46
  video_pipe = LatentVideoDiffusionPipeline.from_pretrained(
47
  'lllyasviel/paints_undo_multi_frame',
48
  fp16=True
49
+ ).to("cuda")
 
 
 
 
 
50
 
51
  k_sampler = KDiffusionSampler(
52
  unet=unet,
 
70
 
71
  @torch.inference_mode()
72
  def encode_cropped_prompt_77tokens(txt: str):
 
73
  cond_ids = tokenizer(txt,
74
  padding="max_length",
75
  max_length=tokenizer.model_max_length,
 
104
 
105
 
106
  @torch.inference_mode()
107
+ @spaces.GPU()
108
  def interrogator_process(x):
109
  image_description = wd14tagger.default_interrogator(x)
110
  return image_description, image_description
111
 
112
 
113
  @torch.inference_mode()
114
+ @spaces.GPU()
115
  def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
116
  progress=gr.Progress()):
117
+ rng = torch.Generator(device="cuda").manual_seed(int(seed))
118
 
 
119
  fg = resize_and_center_crop(input_fg, image_width, image_height)
120
  concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
121
  concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
122
 
 
123
  conds = encode_cropped_prompt_77tokens(prompt)
124
  unconds = encode_cropped_prompt_77tokens(n_prompt)
125
 
 
126
  fs = torch.tensor(input_undo_steps).to(device=unet.device, dtype=torch.long)
127
  initial_latents = torch.zeros_like(concat_conds)
128
  concat_conds = concat_conds.to(device=unet.device, dtype=unet.dtype)
 
140
  progress_tqdm=functools.partial(progress.tqdm, desc='Generating Key Frames')
141
  ).to(vae.dtype) / vae.config.scaling_factor
142
 
 
143
  pixels = vae.decode(latents).sample
144
  pixels = pytorch2numpy(pixels)
145
  pixels = [fg] + pixels + [np.zeros_like(fg) + 255]
 
166
  input_frames = numpy2pytorch([image_1, image_2])
167
  input_frames = input_frames.unsqueeze(0).movedim(1, 2)
168
 
 
169
  positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
170
  negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
171
 
 
172
  input_frames = input_frames.to(device=video_pipe.image_encoder.device, dtype=video_pipe.image_encoder.dtype)
173
  positive_image_cond = video_pipe.encode_clip_vision(input_frames)
174
  positive_image_cond = video_pipe.image_projection(positive_image_cond)
175
  negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
176
  negative_image_cond = video_pipe.image_projection(negative_image_cond)
177
 
 
178
  input_frames = input_frames.to(device=video_pipe.vae.device, dtype=video_pipe.vae.dtype)
179
  input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
180
  first_frame = input_frame_latents[:, :, 0]
181
  last_frame = input_frame_latents[:, :, 1]
182
  concat_cond = torch.stack([first_frame] + [torch.zeros_like(first_frame)] * (frames - 2) + [last_frame], dim=2)
183
 
 
184
  latents = video_pipe(
185
  batch_size=1,
186
  steps=int(steps),
 
194
  progress_tqdm=progress_tqdm
195
  )
196
 
 
197
  video = video_pipe.decode_latents(latents, vae_hidden_states)
198
  return video, image_1, image_2
199
 
memory_management.py DELETED
@@ -1,67 +0,0 @@
1
- import torch
2
- from contextlib import contextmanager
3
-
4
-
5
- high_vram = False
6
- gpu = torch.device('cuda')
7
- cpu = torch.device('cpu')
8
-
9
- torch.zeros((1, 1)).to(gpu, torch.float32)
10
- torch.cuda.empty_cache()
11
-
12
- models_in_gpu = []
13
-
14
-
15
- @contextmanager
16
- def movable_bnb_model(m):
17
- if hasattr(m, 'quantization_method'):
18
- m.quantization_method_backup = m.quantization_method
19
- del m.quantization_method
20
- try:
21
- yield None
22
- finally:
23
- if hasattr(m, 'quantization_method_backup'):
24
- m.quantization_method = m.quantization_method_backup
25
- del m.quantization_method_backup
26
- return
27
-
28
-
29
- def load_models_to_gpu(models):
30
- global models_in_gpu
31
-
32
- if not isinstance(models, (tuple, list)):
33
- models = [models]
34
-
35
- models_to_remain = [m for m in set(models) if m in models_in_gpu]
36
- models_to_load = [m for m in set(models) if m not in models_in_gpu]
37
- models_to_unload = [m for m in set(models_in_gpu) if m not in models_to_remain]
38
-
39
- if not high_vram:
40
- for m in models_to_unload:
41
- with movable_bnb_model(m):
42
- m.to(cpu)
43
- print('Unload to CPU:', m.__class__.__name__)
44
- models_in_gpu = models_to_remain
45
-
46
- for m in models_to_load:
47
- with movable_bnb_model(m):
48
- m.to(gpu)
49
- print('Load to GPU:', m.__class__.__name__)
50
-
51
- models_in_gpu = list(set(models_in_gpu + models))
52
- torch.cuda.empty_cache()
53
- return
54
-
55
-
56
- def unload_all_models(extra_models=None):
57
- global models_in_gpu
58
-
59
- if extra_models is None:
60
- extra_models = []
61
-
62
- if not isinstance(extra_models, (tuple, list)):
63
- extra_models = [extra_models]
64
-
65
- models_in_gpu = list(set(models_in_gpu + extra_models))
66
-
67
- return load_models_to_gpu([])