asutermo commited on
Commit
7a233fe
·
1 Parent(s): b87db4b

try off working as expected

Browse files
Files changed (1) hide show
  1. predict.py +6 -2
predict.py CHANGED
@@ -57,7 +57,11 @@ class Predictor(BasePredictor):
57
  # Transform images using the new preprocessing
58
  image_tensor = transform(i)
59
  mask_tensor = mask_transform(m)[:1] # Take only first channel
60
- garment_tensor = transform(g)
 
 
 
 
61
 
62
  # Create concatenated images
63
  inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width
@@ -66,7 +70,7 @@ class Predictor(BasePredictor):
66
  if try_on:
67
  extended_mask = torch.cat([garment_mask, mask_tensor], dim=2)
68
  else:
69
- extended_mask = torch.cat([1 - garment_mask, mask_tensor], dim=2)
70
 
71
  prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
72
  f"[IMAGE1] Detailed product shot of a clothing" \
 
57
  # Transform images using the new preprocessing
58
  image_tensor = transform(i)
59
  mask_tensor = mask_transform(m)[:1] # Take only first channel
60
+ if try_on:
61
+ garment_tensor = transform(g)
62
+ else:
63
+ garment_tensor = torch.zeros_like(image_tensor)
64
+ image_tensor = image_tensor * mask_tensor
65
 
66
  # Create concatenated images
67
  inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width
 
70
  if try_on:
71
  extended_mask = torch.cat([garment_mask, mask_tensor], dim=2)
72
  else:
73
+ extended_mask = torch.cat([1 - garment_mask, garment_mask], dim=2)
74
 
75
  prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
76
  f"[IMAGE1] Detailed product shot of a clothing" \