fffiloni commited on
Commit
967d0a0
·
verified ·
1 Parent(s): 024ee6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -221,10 +221,16 @@ def infer(style_description, ref_style_file, caption):
221
  ],
222
  dim=0)
223
 
224
- # Save the sampled image to a file
225
- sampled_image = T.ToPILImage()(sampled.squeeze(0)) # Convert tensor to PIL image
226
- sampled_image.save(output_file) # Save the image
227
-
 
 
 
 
 
 
228
  clear_gpu_cache() # Clear cache after inference
229
 
230
  return output_file # Return the path to the saved image
 
221
  ],
222
  dim=0)
223
 
224
+ # Remove batch dimension if it exists
225
+ if sampled.dim() == 4 and sampled.size(0) == 1:
226
+ sampled = sampled.squeeze(0)
227
+
228
+ # Ensure the tensor is in [C, H, W] format
229
+ if sampled.dim() == 3:
230
+ sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
231
+ sampled_image.save(output_file) # Save the image as a PNG
232
+ else:
233
+ raise ValueError(f"Expected tensor of shape [C, H, W] but got {sampled.shape}")
234
  clear_gpu_cache() # Clear cache after inference
235
 
236
  return output_file # Return the path to the saved image