fffiloni commited on
Commit
bc12e9f
·
verified ·
1 Parent(s): 967d0a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -216,21 +216,20 @@ def infer(style_description, ref_style_file, caption):
216
  sampled = models_b.stage_a.decode(sampled_b).float()
217
 
218
  sampled = torch.cat([
219
- torch.nn.functional.interpolate(ref_style.cpu(), size=height),
220
  sampled.cpu(),
221
- ],
222
- dim=0)
223
 
224
- # Remove batch dimension if it exists
225
- if sampled.dim() == 4 and sampled.size(0) == 1:
226
- sampled = sampled.squeeze(0)
227
 
228
  # Ensure the tensor is in [C, H, W] format
229
- if sampled.dim() == 3:
230
  sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
231
  sampled_image.save(output_file) # Save the image as a PNG
232
  else:
233
- raise ValueError(f"Expected tensor of shape [C, H, W] but got {sampled.shape}")
 
234
  clear_gpu_cache() # Clear cache after inference
235
 
236
  return output_file # Return the path to the saved image
 
216
  sampled = models_b.stage_a.decode(sampled_b).float()
217
 
218
  sampled = torch.cat([
219
+ torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
220
  sampled.cpu(),
221
+ ], dim=0)
 
222
 
223
+ # Remove the batch dimension and keep only the generated image
224
+ sampled = sampled[1] # This selects the generated image, discarding the reference style image
 
225
 
226
  # Ensure the tensor is in [C, H, W] format
227
+ if sampled.dim() == 3 and sampled.shape[0] == 3:
228
  sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
229
  sampled_image.save(output_file) # Save the image as a PNG
230
  else:
231
+ raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
232
+
233
  clear_gpu_cache() # Clear cache after inference
234
 
235
  return output_file # Return the path to the saved image