fffiloni commited on
Commit
04793a7
·
verified ·
1 Parent(s): 64bb281

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -89
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import sys
2
  import os
3
  from pathlib import Path
 
4
 
5
  # Add the StableCascade and CSD directories to the Python path
6
  app_dir = Path(__file__).parent
@@ -27,12 +28,29 @@ from gdf.schedulers import CosineSchedule
27
  from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
28
  from gdf.targets import EpsilonTarget
29
 
 
 
 
 
30
  # Device configuration
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
  print(device)
33
 
34
  # Flag for low VRAM usage
35
- low_vram = False
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  # Function definition for low VRAM usage
38
  if low_vram:
@@ -53,84 +71,12 @@ if low_vram:
53
  print(f"Change device of '{attr_name}' to {device}")
54
  attr_value.to(device)
55
 
56
- torch.cuda.empty_cache()
57
-
58
- # Stage C model configuration
59
- config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
60
- with open(config_file, "r", encoding="utf-8") as file:
61
- loaded_config = yaml.safe_load(file)
62
-
63
- core = WurstCoreCRBM(config_dict=loaded_config, device=device, training=False)
64
 
65
- # Stage B model configuration
66
- config_file_b = 'third_party/StableCascade/configs/inference/stage_b_3b.yaml'
67
- with open(config_file_b, "r", encoding="utf-8") as file:
68
- config_file_b = yaml.safe_load(file)
69
-
70
- core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)
71
-
72
- # Setup extras and models for Stage C
73
- extras = core.setup_extras_pre()
74
-
75
- gdf_rbm = RBM(
76
- schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
77
- input_scaler=VPScaler(), target=EpsilonTarget(),
78
- noise_cond=CosineTNoiseCond(),
79
- loss_weight=AdaptiveLossWeight(),
80
- )
81
-
82
- sampling_configs = {
83
- "cfg": 5,
84
- "sampler": DDPMSampler(gdf_rbm),
85
- "shift": 1,
86
- "timesteps": 20
87
- }
88
-
89
- extras = core.Extras(
90
- gdf=gdf_rbm,
91
- sampling_configs=sampling_configs,
92
- transforms=extras.transforms,
93
- effnet_preprocess=extras.effnet_preprocess,
94
- clip_preprocess=extras.clip_preprocess
95
- )
96
-
97
- models = core.setup_models(extras)
98
- models.generator.eval().requires_grad_(False)
99
-
100
- # Setup extras and models for Stage B
101
- extras_b = core_b.setup_extras_pre()
102
- models_b = core_b.setup_models(extras_b, skip_clip=True)
103
- models_b = WurstCoreB.Models(
104
- **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
105
- )
106
- models_b.generator.bfloat16().eval().requires_grad_(False)
107
-
108
- # Off-load old generator (low VRAM mode)
109
- if low_vram:
110
- models.generator.to("cpu")
111
- torch.cuda.empty_cache()
112
-
113
- # Load and configure new generator
114
- generator_rbm = StageCRBM()
115
- for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
116
- set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
117
-
118
- generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
119
- generator_rbm = core.load_model(generator_rbm, 'generator')
120
-
121
- # Create models_rbm instance
122
- models_rbm = core.Models(
123
- effnet=models.effnet,
124
- previewer=models.previewer,
125
- generator=generator_rbm,
126
- generator_ema=models.generator_ema,
127
- tokenizer=models.tokenizer,
128
- text_model=models.text_model,
129
- image_model=models.image_model
130
- )
131
- models_rbm.generator.eval().requires_grad_(False)
132
 
133
  def infer(style_description, ref_style_file, caption):
 
134
 
135
  height=1024
136
  width=1024
@@ -166,19 +112,22 @@ def infer(style_description, ref_style_file, caption):
166
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
167
 
168
  # Stage C reverse process.
169
- sampling_c = extras.gdf.sample(
170
- models_rbm.generator, conditions, stage_c_latent_shape,
171
- unconditions, device=device,
172
- **extras.sampling_configs,
173
- x0_style_forward=x0_style_forward,
174
- apply_pushforward=False, tau_pushforward=8,
175
- num_iter=3, eta=0.1, tau=20, eval_csd=True,
176
- extras=extras, models=models_rbm,
177
- lam_style=1, lam_txt_alignment=1.0,
178
- use_ddim_sampler=True,
179
- )
180
- for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
181
- sampled_c = sampled_c
 
 
 
182
 
183
  # Stage B reverse process.
184
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
@@ -203,6 +152,8 @@ def infer(style_description, ref_style_file, caption):
203
  sampled_image = T.ToPILImage()(sampled.squeeze(0)) # Convert tensor to PIL image
204
  sampled_image.save(output_file) # Save the image
205
 
 
 
206
  return output_file # Return the path to the saved image
207
 
208
  import gradio as gr
 
1
  import sys
2
  import os
3
  from pathlib import Path
4
+ import gc
5
 
6
  # Add the StableCascade and CSD directories to the Python path
7
  app_dir = Path(__file__).parent
 
28
  from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
29
  from gdf.targets import EpsilonTarget
30
 
31
+ # Enable mixed precision
32
+ torch.backends.cuda.matmul.allow_tf32 = True
33
+ torch.backends.cudnn.allow_tf32 = True
34
+
35
  # Device configuration
36
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
  print(device)
38
 
39
  # Flag for low VRAM usage
40
+ low_vram = True # Set to True to enable low VRAM optimizations
41
+
42
+ # Function to clear GPU cache
43
+ def clear_gpu_cache():
44
+ torch.cuda.empty_cache()
45
+ gc.collect()
46
+
47
+ # Function to move model to CPU
48
+ def to_cpu(model):
49
+ return model.cpu()
50
+
51
+ # Function to move model to GPU
52
+ def to_gpu(model):
53
+ return model.cuda()
54
 
55
  # Function definition for low VRAM usage
56
  if low_vram:
 
71
  print(f"Change device of '{attr_name}' to {device}")
72
  attr_value.to(device)
73
 
74
+ clear_gpu_cache()
 
 
 
 
 
 
 
75
 
76
+ # ... (rest of your setup code remains the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  def infer(style_description, ref_style_file, caption):
79
+ clear_gpu_cache() # Clear cache before inference
80
 
81
  height=1024
82
  width=1024
 
112
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
113
 
114
  # Stage C reverse process.
115
+ with torch.cuda.amp.autocast(): # Use mixed precision
116
+ sampling_c = extras.gdf.sample(
117
+ models_rbm.generator, conditions, stage_c_latent_shape,
118
+ unconditions, device=device,
119
+ **extras.sampling_configs,
120
+ x0_style_forward=x0_style_forward,
121
+ apply_pushforward=False, tau_pushforward=8,
122
+ num_iter=3, eta=0.1, tau=20, eval_csd=True,
123
+ extras=extras, models=models_rbm,
124
+ lam_style=1, lam_txt_alignment=1.0,
125
+ use_ddim_sampler=True,
126
+ )
127
+ for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
128
+ sampled_c = sampled_c
129
+
130
+ clear_gpu_cache() # Clear cache between stages
131
 
132
  # Stage B reverse process.
133
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
 
152
  sampled_image = T.ToPILImage()(sampled.squeeze(0)) # Convert tensor to PIL image
153
  sampled_image.save(output_file) # Save the image
154
 
155
+ clear_gpu_cache() # Clear cache after inference
156
+
157
  return output_file # Return the path to the saved image
158
 
159
  import gradio as gr