Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -149,6 +149,10 @@ models_rbm = core.Models(
|
|
149 |
models_rbm.generator.eval().requires_grad_(False)
|
150 |
|
151 |
def infer(style_description, ref_style_file, caption):
|
|
|
|
|
|
|
|
|
152 |
clear_gpu_cache() # Clear cache before inference
|
153 |
|
154 |
height=1024
|
@@ -181,10 +185,10 @@ def infer(style_description, ref_style_file, caption):
|
|
181 |
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
|
182 |
|
183 |
if low_vram:
|
184 |
-
#
|
185 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
186 |
|
187 |
-
# Stage C reverse process
|
188 |
with torch.cuda.amp.autocast(): # Use mixed precision
|
189 |
sampling_c = extras.gdf.sample(
|
190 |
models_rbm.generator, conditions, stage_c_latent_shape,
|
@@ -202,7 +206,10 @@ def infer(style_description, ref_style_file, caption):
|
|
202 |
|
203 |
clear_gpu_cache() # Clear cache between stages
|
204 |
|
205 |
-
#
|
|
|
|
|
|
|
206 |
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
207 |
conditions_b['effnet'] = sampled_c
|
208 |
unconditions_b['effnet'] = torch.zeros_like(sampled_c)
|
@@ -215,6 +222,7 @@ def infer(style_description, ref_style_file, caption):
|
|
215 |
sampled_b = sampled_b
|
216 |
sampled = models_b.stage_a.decode(sampled_b).float()
|
217 |
|
|
|
218 |
sampled = torch.cat([
|
219 |
torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
|
220 |
sampled.cpu(),
|
@@ -234,6 +242,7 @@ def infer(style_description, ref_style_file, caption):
|
|
234 |
|
235 |
return output_file # Return the path to the saved image
|
236 |
|
|
|
237 |
import gradio as gr
|
238 |
|
239 |
gr.Interface(
|
|
|
149 |
models_rbm.generator.eval().requires_grad_(False)
|
150 |
|
151 |
def infer(style_description, ref_style_file, caption):
|
152 |
+
# Ensure all models are moved back to the correct device
|
153 |
+
core_b.generator.to(device)
|
154 |
+
models_rbm.generator.to(device)
|
155 |
+
|
156 |
clear_gpu_cache() # Clear cache before inference
|
157 |
|
158 |
height=1024
|
|
|
185 |
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
|
186 |
|
187 |
if low_vram:
|
188 |
+
# Offload non-essential models to CPU for memory savings
|
189 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
190 |
|
191 |
+
# Stage C reverse process
|
192 |
with torch.cuda.amp.autocast(): # Use mixed precision
|
193 |
sampling_c = extras.gdf.sample(
|
194 |
models_rbm.generator, conditions, stage_c_latent_shape,
|
|
|
206 |
|
207 |
clear_gpu_cache() # Clear cache between stages
|
208 |
|
209 |
+
# Ensure all models are on the right device again
|
210 |
+
models_b.generator.to(device)
|
211 |
+
|
212 |
+
# Stage B reverse process
|
213 |
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
214 |
conditions_b['effnet'] = sampled_c
|
215 |
unconditions_b['effnet'] = torch.zeros_like(sampled_c)
|
|
|
222 |
sampled_b = sampled_b
|
223 |
sampled = models_b.stage_a.decode(sampled_b).float()
|
224 |
|
225 |
+
# Post-process and save the image
|
226 |
sampled = torch.cat([
|
227 |
torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
|
228 |
sampled.cpu(),
|
|
|
242 |
|
243 |
return output_file # Return the path to the saved image
|
244 |
|
245 |
+
|
246 |
import gradio as gr
|
247 |
|
248 |
gr.Interface(
|