nftblackmagic commited on
Commit
939cd7a
·
unverified ·
2 Parent(s): cefa68e 7a233fe

Merge pull request #29 from asutermo/user/asutermo/improvements-to-tryon-cog

Browse files

Cog: Fix Up a couple issues with having hybrid tryon/tryoff. Working on Replicate now :-)

Files changed (1) hide show
  1. predict.py +9 -5
predict.py CHANGED
@@ -19,8 +19,8 @@ class Predictor(BasePredictor):
19
  hf_token: Secret = Input(description="Hugging Face API token. Create a write token at https://huggingface.co/settings/token. You also need to approve the Flux Dev terms."),
20
  image: Path = Input(description="Image file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/person/1.jpg"),
21
  mask: Path = Input(description="Mask file path", default="https://github.com/nftblackmagic/catvton-flux/blob/main/example/person/1_mask.png?raw=true"),
22
- try_on: bool = Input(True, description="Try on or try off"),
23
- garment: Path = Input(description="Garment file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/garment/00035_00.jpg"),
24
  num_steps: int = Input(50, description="Number of steps to run the model for"),
25
  guidance_scale: float = Input(30, description="Guidance scale for the model"),
26
  seed: int = Input(0, description="Seed for the model"),
@@ -30,9 +30,9 @@ class Predictor(BasePredictor):
30
  size = (width, height)
31
  i = load_image(str(image)).convert("RGB").resize(size)
32
  m = load_image(str(mask)).convert("RGB").resize(size)
33
- g = load_image(str(garment)).convert("RGB").resize(size)
34
 
35
  if try_on:
 
36
  self.transformer = self.try_on_transformer
37
  else:
38
  self.transformer = self.try_off_transformer
@@ -57,7 +57,11 @@ class Predictor(BasePredictor):
57
  # Transform images using the new preprocessing
58
  image_tensor = transform(i)
59
  mask_tensor = mask_transform(m)[:1] # Take only first channel
60
- garment_tensor = transform(g)
 
 
 
 
61
 
62
  # Create concatenated images
63
  inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width
@@ -66,7 +70,7 @@ class Predictor(BasePredictor):
66
  if try_on:
67
  extended_mask = torch.cat([garment_mask, mask_tensor], dim=2)
68
  else:
69
- extended_mask = torch.cat([1 - garment_mask, mask_tensor], dim=2)
70
 
71
  prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
72
  f"[IMAGE1] Detailed product shot of a clothing" \
 
19
  hf_token: Secret = Input(description="Hugging Face API token. Create a write token at https://huggingface.co/settings/token. You also need to approve the Flux Dev terms."),
20
  image: Path = Input(description="Image file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/person/1.jpg"),
21
  mask: Path = Input(description="Mask file path", default="https://github.com/nftblackmagic/catvton-flux/blob/main/example/person/1_mask.png?raw=true"),
22
+ try_on: bool = Input(False, description="Try on or try off"),
23
+ garment: Path = Input(description="Garment file path like https://github.com/nftblackmagic/catvton-flux/raw/main/example/garment/00035_00.jpg", default=None),
24
  num_steps: int = Input(50, description="Number of steps to run the model for"),
25
  guidance_scale: float = Input(30, description="Guidance scale for the model"),
26
  seed: int = Input(0, description="Seed for the model"),
 
30
  size = (width, height)
31
  i = load_image(str(image)).convert("RGB").resize(size)
32
  m = load_image(str(mask)).convert("RGB").resize(size)
 
33
 
34
  if try_on:
35
+ g = load_image(str(garment)).convert("RGB").resize(size)
36
  self.transformer = self.try_on_transformer
37
  else:
38
  self.transformer = self.try_off_transformer
 
57
  # Transform images using the new preprocessing
58
  image_tensor = transform(i)
59
  mask_tensor = mask_transform(m)[:1] # Take only first channel
60
+ if try_on:
61
+ garment_tensor = transform(g)
62
+ else:
63
+ garment_tensor = torch.zeros_like(image_tensor)
64
+ image_tensor = image_tensor * mask_tensor
65
 
66
  # Create concatenated images
67
  inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width
 
70
  if try_on:
71
  extended_mask = torch.cat([garment_mask, mask_tensor], dim=2)
72
  else:
73
+ extended_mask = torch.cat([1 - garment_mask, garment_mask], dim=2)
74
 
75
  prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
76
  f"[IMAGE1] Detailed product shot of a clothing" \