ameerazam08 commited on
Commit
50db579
·
verified ·
1 Parent(s): 87f32d6

full change

Browse files
Files changed (1) hide show
  1. main.py +57 -5
main.py CHANGED
@@ -34,25 +34,77 @@ class Upscale_CaseCade:
34
  )
35
  self.models_b.generator.eval().requires_grad_(False)
36
  print("STAGE B READY")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
- def upscale_image(self,image_pil,scale_fator):
40
  batch_size = 1
41
  cnet_override = None
42
  images = resize_image(image_pil).unsqueeze(0).expand(batch_size, -1, -1, -1)
43
 
44
  batch = {'images': images}
45
 
 
46
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
47
  effnet_latents = self.core.encode_latents(batch, self.models, self.extras)
48
  effnet_latents_up = torch.nn.functional.interpolate(effnet_latents, scale_factor=scale_fator, mode="nearest")
49
  cnet = self.models.controlnet(effnet_latents_up)
50
  cnet_uncond = cnet
51
  cnet_input = torch.nn.functional.interpolate(images, scale_factor=scale_fator, mode="nearest")
52
- # cnet, cnet_input = core.get_cnet(batch, models, extras)
53
  # cnet_uncond = cnet
54
- og=show_images(batch['images'],return_images=True)
55
- upsclae=show_images(cnet_input,return_images=True)
56
- return og,upsclae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
 
34
  )
35
  self.models_b.generator.eval().requires_grad_(False)
36
  print("STAGE B READY")
37
+ self.caption = "a photo of image"
38
+ self.cnet_multiplier = 1.0 # 0.8 # 0.3
39
+ # Stage C Parameters
40
+ self.extras.sampling_configs['cfg'] = 1
41
+ self.extras.sampling_configs['shift'] = 2
42
+ self.extras.sampling_configs['timesteps'] = 20
43
+ self.extras.sampling_configs['t_start'] = 1.0
44
+ # Stage B Parameters
45
+ self.extras_b.sampling_configs['cfg'] = 1.1
46
+ self.extras_b.sampling_configs['shift'] = 1
47
+ self.extras_b.sampling_configs['timesteps'] = 10
48
+ self.extras_b.sampling_configs['t_start'] = 1.0
49
+ self.models = ControlNetCore.Models(
50
+ **{**self.models.to_dict(), 'generator': torch.compile(self.models.generator, mode="reduce-overhead", fullgraph=True)}
51
+ )
52
+
53
+ self.models_b = WurstCoreB.Models(
54
+ **{**self.models_b.to_dict(), 'generator': torch.compile(self.models_b.generator, mode="reduce-overhead", fullgraph=True)}
55
+ )
56
+
57
 
58
 
59
+ def upscale_image(self,caption,image_pil,scale_fator):
60
  batch_size = 1
61
  cnet_override = None
62
  images = resize_image(image_pil).unsqueeze(0).expand(batch_size, -1, -1, -1)
63
 
64
  batch = {'images': images}
65
 
66
+
67
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
68
  effnet_latents = self.core.encode_latents(batch, self.models, self.extras)
69
  effnet_latents_up = torch.nn.functional.interpolate(effnet_latents, scale_factor=scale_fator, mode="nearest")
70
  cnet = self.models.controlnet(effnet_latents_up)
71
  cnet_uncond = cnet
72
  cnet_input = torch.nn.functional.interpolate(images, scale_factor=scale_fator, mode="nearest")
73
+ # cnet, cnet_input = self.core.get_cnet(batch, self.models, self.extras)
74
  # cnet_uncond = cnet
75
+ height, width = int(cnet[0].size(-2)*32*4/3), int(cnet[0].size(-1)*32*4/3)
76
+ stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
77
+ # PREPARE CONDITIONS
78
+ batch['captions'] = [caption] * batch_size
79
+ conditions = self.core.get_conditions(batch, self.models, self.extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
80
+ unconditions = self.core.get_conditions(batch, self.models, self.extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
81
+ conditions['cnet'] = [c.clone() * self.cnet_multiplier if c is not None else c for c in cnet]
82
+ unconditions['cnet'] = [c.clone() * self.cnet_multiplier if c is not None else c for c in cnet_uncond]
83
+ conditions_b = self.core_b.get_conditions(batch, self.models_b, self.extras_b, is_eval=True, is_unconditional=False)
84
+ unconditions_b = self.core_b.get_conditions(batch, self.models_b, self.extras_b, is_eval=True, is_unconditional=True)
85
+ # torch.manual_seed(42)
86
+ sampling_c = self.extras.gdf.sample(
87
+ self.models.generator, conditions, stage_c_latent_shape,
88
+ unconditions, device=device, **self.extras.sampling_configs,
89
+ )
90
+ for (sampled_c, _, _) in tqdm(sampling_c, total=self.extras.sampling_configs['timesteps']):
91
+ sampled_c = sampled_c
92
+
93
+ # preview_c = models.previewer(sampled_c).float()
94
+ # show_images(preview_c)
95
+
96
+ conditions_b['effnet'] = sampled_c
97
+ unconditions_b['effnet'] = torch.zeros_like(sampled_c)
98
+
99
+ sampling_b = self.extras_b.gdf.sample(
100
+ self.models_b.generator, conditions_b, stage_b_latent_shape,
101
+ unconditions_b, device=device, **self.extras_b.sampling_configs
102
+ )
103
+ for (sampled_b, _, _) in tqdm(sampling_b, total=self.extras_b.sampling_configs['timesteps']):
104
+ sampled_b = sampled_b
105
+ sampled = self.models_b.stage_a.decode(sampled_b).float()
106
+ # og=show_images(batch['images'],return_images=True)
107
+ upscale=show_images(sampled,return_images=True)
108
+ return upscale
109
 
110