Anonymous-sub commited on
Commit
9c1dc83
·
1 Parent(s): 251e479

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -15
app.py CHANGED
@@ -30,12 +30,29 @@ from src.img_util import find_flat_region, numpy2tensor
30
  from src.video_util import (frame_to_video, get_fps, get_frame_count,
31
  prepare_frames)
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  inversed_model_dict = dict()
34
  for k, v in model_dict.items():
35
  inversed_model_dict[v] = k
36
 
37
  to_tensor = T.PILToTensor()
38
  blur = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18))
 
39
 
40
 
41
  class ProcessingState(Enum):
@@ -64,7 +81,7 @@ class GlobalState:
64
  attention_type='swin',
65
  ffn_dim_expansion=4,
66
  num_transformer_layers=6,
67
- ).to('cuda')
68
 
69
  checkpoint = torch.load('models/gmflow_sintel-0c07dcb3.pth',
70
  map_location=lambda storage, loc: storage)
@@ -86,25 +103,32 @@ class GlobalState:
86
  model = create_model('./ControlNet/models/cldm_v15.yaml').cpu()
87
  if control_type == 'HED':
88
  model.load_state_dict(
89
- load_state_dict('./models/control_sd15_hed.pth',
90
- location='cuda'))
 
91
  elif control_type == 'canny':
92
  model.load_state_dict(
93
- load_state_dict('./models/control_sd15_canny.pth',
94
- location='cuda'))
95
- model = model.cuda()
 
96
  sd_model_path = model_dict[sd_model]
97
  if len(sd_model_path) > 0:
98
  model_ext = os.path.splitext(sd_model_path)[1]
 
 
99
  if model_ext == '.safetensors':
100
- model.load_state_dict(load_file(sd_model_path), strict=False)
101
- elif model_ext == '.ckpt' or model_ext == '.pth':
102
- model.load_state_dict(torch.load(sd_model_path)['state_dict'],
103
  strict=False)
 
 
 
104
 
105
  try:
106
  model.first_stage_model.load_state_dict(torch.load(
107
- './models/vae-ft-mse-840000-ema-pruned.ckpt')['state_dict'],
 
 
108
  strict=False)
109
  except Exception:
110
  print('Warning: We suggest you download the fine-tuned VAE',
@@ -115,7 +139,8 @@ class GlobalState:
115
  def clear_sd_model(self):
116
  self.sd_model = None
117
  self.ddim_v_sampler = None
118
- torch.cuda.empty_cache()
 
119
 
120
  def update_detector(self, control_type, canny_low=100, canny_high=200):
121
  if self.detector_type == control_type:
@@ -286,14 +311,14 @@ def process1(*args):
286
  img_ = numpy2tensor(img)
287
 
288
  def generate_first_img(img_, strength):
289
- encoder_posterior = model.encode_first_stage(img_.cuda())
290
  x0 = model.get_first_stage_encoding(encoder_posterior).detach()
291
 
292
  detected_map = detector(img)
293
  detected_map = HWC3(detected_map)
294
 
295
  control = torch.from_numpy(
296
- detected_map.copy()).float().cuda() / 255.0
297
  control = torch.stack([control for _ in range(num_samples)], dim=0)
298
  control = einops.rearrange(control, 'b h w c -> b c h w').clone()
299
  cond = {
@@ -411,13 +436,14 @@ def process2(*args):
411
  img_ = apply_color_correction(global_state.color_corrections,
412
  Image.fromarray(img))
413
  img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1
414
- encoder_posterior = model.encode_first_stage(img_.cuda())
415
  x0 = model.get_first_stage_encoding(encoder_posterior).detach()
416
 
417
  detected_map = detector(img)
418
  detected_map = HWC3(detected_map)
419
 
420
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
 
421
  control = torch.stack([control for _ in range(num_samples)], dim=0)
422
  control = einops.rearrange(control, 'b h w c -> b c h w').clone()
423
  cond = {
 
30
  from src.video_util import (frame_to_video, get_fps, get_frame_count,
31
  prepare_frames)
32
 
33
+ import huggingface_hub
34
+
35
+ repo_name = 'Anonymous-sub/Rerender'
36
+
37
+ huggingface_hub.hf_hub_download(repo_name,
38
+ 'pexels-koolshooters-7322716.mp4',
39
+ local_dir='videos')
40
+ huggingface_hub.hf_hub_download(
41
+ repo_name,
42
+ 'pexels-antoni-shkraba-8048492-540x960-25fps.mp4',
43
+ local_dir='videos')
44
+ huggingface_hub.hf_hub_download(
45
+ repo_name,
46
+ 'pexels-cottonbro-studio-6649832-960x506-25fps.mp4',
47
+ local_dir='videos')
48
+
49
  inversed_model_dict = dict()
50
  for k, v in model_dict.items():
51
  inversed_model_dict[v] = k
52
 
53
  to_tensor = T.PILToTensor()
54
  blur = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18))
55
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
56
 
57
 
58
  class ProcessingState(Enum):
 
81
  attention_type='swin',
82
  ffn_dim_expansion=4,
83
  num_transformer_layers=6,
84
+ ).to(device)
85
 
86
  checkpoint = torch.load('models/gmflow_sintel-0c07dcb3.pth',
87
  map_location=lambda storage, loc: storage)
 
103
  model = create_model('./ControlNet/models/cldm_v15.yaml').cpu()
104
  if control_type == 'HED':
105
  model.load_state_dict(
106
+ load_state_dict(huggingface_hub.hf_hub_download(
107
+ 'lllyasviel/ControlNet', './models/control_sd15_hed.pth'),
108
+ location=device))
109
  elif control_type == 'canny':
110
  model.load_state_dict(
111
+ load_state_dict(huggingface_hub.hf_hub_download(
112
+ 'lllyasviel/ControlNet', 'models/control_sd15_canny.pth'),
113
+ location=device))
114
+ model.to(device)
115
  sd_model_path = model_dict[sd_model]
116
  if len(sd_model_path) > 0:
117
  model_ext = os.path.splitext(sd_model_path)[1]
118
+ downloaded_model = huggingface_hub.hf_hub_download(
119
+ repo_name, sd_model_path)
120
  if model_ext == '.safetensors':
121
+ model.load_state_dict(load_file(downloaded_model),
 
 
122
  strict=False)
123
+ elif model_ext == '.ckpt' or model_ext == '.pth':
124
+ model.load_state_dict(
125
+ torch.load(downloaded_model)['state_dict'], strict=False)
126
 
127
  try:
128
  model.first_stage_model.load_state_dict(torch.load(
129
+ huggingface_hub.hf_hub_download(
130
+ 'stabilityai/sd-vae-ft-mse-original',
131
+ 'vae-ft-mse-840000-ema-pruned.ckpt'))['state_dict'],
132
  strict=False)
133
  except Exception:
134
  print('Warning: We suggest you download the fine-tuned VAE',
 
139
  def clear_sd_model(self):
140
  self.sd_model = None
141
  self.ddim_v_sampler = None
142
+ if device == 'cuda':
143
+ torch.cuda.empty_cache()
144
 
145
  def update_detector(self, control_type, canny_low=100, canny_high=200):
146
  if self.detector_type == control_type:
 
311
  img_ = numpy2tensor(img)
312
 
313
  def generate_first_img(img_, strength):
314
+ encoder_posterior = model.encode_first_stage(img_.to(device))
315
  x0 = model.get_first_stage_encoding(encoder_posterior).detach()
316
 
317
  detected_map = detector(img)
318
  detected_map = HWC3(detected_map)
319
 
320
  control = torch.from_numpy(
321
+ detected_map.copy()).float().to(device) / 255.0
322
  control = torch.stack([control for _ in range(num_samples)], dim=0)
323
  control = einops.rearrange(control, 'b h w c -> b c h w').clone()
324
  cond = {
 
436
  img_ = apply_color_correction(global_state.color_corrections,
437
  Image.fromarray(img))
438
  img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1
439
+ encoder_posterior = model.encode_first_stage(img_.to(device))
440
  x0 = model.get_first_stage_encoding(encoder_posterior).detach()
441
 
442
  detected_map = detector(img)
443
  detected_map = HWC3(detected_map)
444
 
445
+ control = torch.from_numpy(
446
+ detected_map.copy()).float().to(device) / 255.0
447
  control = torch.stack([control for _ in range(num_samples)], dim=0)
448
  control = einops.rearrange(control, 'b h w c -> b c h w').clone()
449
  cond = {