xiaozaa commited on
Commit
a6d7aa6
·
1 Parent(s): 4ae4b3e

add spaces

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. app.py +158 -96
  3. example/garment/00396_00.jpg +0 -0
  4. requirements.txt +1 -14
README.md CHANGED
@@ -8,7 +8,7 @@ Also inspired by [In-Context LoRA](https://arxiv.org/abs/2410.23775) for prompt
8
  ---
9
  **Latest Achievement**
10
  (2024/11/25):
11
- - Released lora weights.
12
 
13
  (2024/11/24):
14
  - Released FID score and gradio demo
 
8
  ---
9
  **Latest Achievement**
10
  (2024/11/25):
11
+ - Released lora weights. FID: 6.0675811767578125 on VITON-HD dataset. Test configuration: scale 30, step 30.
12
 
13
  (2024/11/24):
14
  - Released FID score and gradio demo
app.py CHANGED
@@ -1,17 +1,63 @@
 
 
1
  import gradio as gr
2
  from tryon_inference import run_inference
3
  import os
4
  import numpy as np
5
  from PIL import Image
6
  import tempfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def gradio_inference(
9
  image_data,
10
  garment,
11
  num_steps=50,
12
  guidance_scale=30.0,
13
  seed=-1,
14
- size=(768,1024)
 
15
  ):
16
  """Wrapper function for Gradio interface"""
17
  # Use temporary directory
@@ -38,116 +84,132 @@ def gradio_inference(
38
  try:
39
  # Run inference
40
  _, tryon_result = run_inference(
 
41
  image_path=temp_image,
42
  mask_path=temp_mask,
43
  garment_path=temp_garment,
44
  num_steps=num_steps,
45
  guidance_scale=guidance_scale,
46
  seed=seed,
47
- size=size
48
  )
49
  return tryon_result
50
  except Exception as e:
51
  raise gr.Error(f"Error during inference: {str(e)}")
52
 
53
- def create_demo():
54
- with gr.Blocks() as demo:
55
- gr.Markdown("""
56
- # CATVTON FLUX Virtual Try-On Demo
57
- Upload a model image, an agnostic mask, and a garment image to generate virtual try-on results.
58
-
59
- [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/xiaozaa/catvton-flux-alpha)
60
- [![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/nftblackmagic/catvton-flux)
61
- """)
62
-
63
- with gr.Column():
64
- with gr.Row():
65
- with gr.Column():
66
- image_input = gr.ImageMask(
67
- label="Model Image (Draw mask where garment should go)",
68
- type="pil",
69
- height=600,
70
- )
71
- gr.Examples(
72
- examples=[
73
- ["./example/person/00008_00.jpg"],
74
- ["./example/person/00055_00.jpg"],
75
- ["./example/person/00057_00.jpg"],
76
- ["./example/person/00067_00.jpg"],
77
- ["./example/person/00069_00.jpg"],
78
- ],
79
- inputs=[image_input],
80
- label="Person Images",
81
- )
82
- with gr.Column():
83
- garment_input = gr.Image(label="Garment Image", type="pil", height=600)
84
- gr.Examples(
85
- examples=[
86
- ["./example/garment/04564_00.jpg"],
87
- ["./example/garment/00055_00.jpg"],
88
- ["./example/garment/00057_00.jpg"],
89
- ["./example/garment/00067_00.jpg"],
90
- ["./example/garment/00069_00.jpg"],
91
- ],
92
- inputs=[garment_input],
93
- label="Garment Images",
94
- )
95
-
96
- with gr.Row():
97
- num_steps = gr.Slider(
98
- minimum=1,
99
- maximum=100,
100
- value=50,
101
- step=1,
102
- label="Number of Steps"
103
- )
104
- guidance_scale = gr.Slider(
105
- minimum=1.0,
106
- maximum=50.0,
107
- value=30.0,
108
- step=0.5,
109
- label="Guidance Scale"
110
  )
111
- seed = gr.Slider(
112
- minimum=-1,
113
- maximum=2147483647,
114
- step=1,
115
- value=-1,
116
- label="Seed (-1 for random)"
 
 
 
 
117
  )
118
-
119
- submit_btn = gr.Button("Generate Try-On", variant="primary")
120
-
121
  with gr.Column():
122
- tryon_output = gr.Image(label="Try-On Result")
123
-
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  with gr.Row():
125
- gr.Markdown("""
126
- ### Notes:
127
- - The model image should be a full-body photo
128
- - The mask should indicate the region where the garment will be placed
129
- - The garment image should be on a clean background
130
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- submit_btn.click(
133
- fn=gradio_inference,
134
- inputs=[
135
- image_input,
136
- garment_input,
137
- num_steps,
138
- guidance_scale,
139
- seed
140
- ],
141
- outputs=[tryon_output],
142
- api_name="try-on"
143
- )
144
 
145
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- if __name__ == "__main__":
148
- demo = create_demo()
149
- demo.queue() # Enable queuing for multiple users
150
- demo.launch(
151
- share=True,
152
- server_name="0.0.0.0" # Makes the server accessible from other machines
153
- )
 
1
+ import spaces
2
+
3
  import gradio as gr
4
  from tryon_inference import run_inference
5
  import os
6
  import numpy as np
7
  from PIL import Image
8
  import tempfile
9
+ import torch
10
+ from diffusers import FluxTransformer2DModel, FluxFillPipeline
11
+
12
+ import shutil
13
+
14
+ def find_cuda():
15
+ # Check if CUDA_HOME or CUDA_PATH environment variables are set
16
+ cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
17
+
18
+ if cuda_home and os.path.exists(cuda_home):
19
+ return cuda_home
20
+
21
+ # Search for the nvcc executable in the system's PATH
22
+ nvcc_path = shutil.which('nvcc')
23
+
24
+ if nvcc_path:
25
+ # Remove the 'bin/nvcc' part to get the CUDA installation path
26
+ cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
27
+ return cuda_path
28
+
29
+ return None
30
+
31
+ cuda_path = find_cuda()
32
+
33
+ if cuda_path:
34
+ print(f"CUDA installation found at: {cuda_path}")
35
+ else:
36
+ print("CUDA installation not found")
37
 
38
+ device = torch.device('cuda')
39
+
40
+ print('Loading diffusion model ...')
41
+ transformer = FluxTransformer2DModel.from_pretrained(
42
+ "xiaozaa/catvton-flux-alpha",
43
+ torch_dtype=torch.bfloat16
44
+ )
45
+ pipe = FluxFillPipeline.from_pretrained(
46
+ "black-forest-labs/FLUX.1-dev",
47
+ transformer=transformer,
48
+ torch_dtype=torch.bfloat16
49
+ ).to(device)
50
+ print('Loading Finished!')
51
+
52
+ @spaces.GPU
53
  def gradio_inference(
54
  image_data,
55
  garment,
56
  num_steps=50,
57
  guidance_scale=30.0,
58
  seed=-1,
59
+ width=768,
60
+ height=1024
61
  ):
62
  """Wrapper function for Gradio interface"""
63
  # Use temporary directory
 
84
  try:
85
  # Run inference
86
  _, tryon_result = run_inference(
87
+ pipe=pipe,
88
  image_path=temp_image,
89
  mask_path=temp_mask,
90
  garment_path=temp_garment,
91
  num_steps=num_steps,
92
  guidance_scale=guidance_scale,
93
  seed=seed,
94
+ size=(width, height)
95
  )
96
  return tryon_result
97
  except Exception as e:
98
  raise gr.Error(f"Error during inference: {str(e)}")
99
 
100
+ with gr.Blocks() as demo:
101
+ gr.Markdown("""
102
+ # CATVTON FLUX Virtual Try-On Demo
103
+ Upload a model image, draw a mask, and a garment image to generate virtual try-on results.
104
+
105
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/xiaozaa/catvton-flux-alpha)
106
+ [![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/nftblackmagic/catvton-flux)
107
+ """)
108
+
109
+ gr.Video("example/github.mp4", label="Demo Video: How to use the tool")
110
+
111
+ with gr.Column():
112
+ with gr.Row():
113
+ with gr.Column():
114
+ image_input = gr.ImageMask(
115
+ label="Model Image (Click 'Edit' and draw mask over the clothing area)",
116
+ type="pil",
117
+ height=600,
118
+ width=300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  )
120
+ gr.Examples(
121
+ examples=[
122
+ ["./example/person/00008_00.jpg"],
123
+ ["./example/person/00055_00.jpg"],
124
+ ["./example/person/00057_00.jpg"],
125
+ ["./example/person/00067_00.jpg"],
126
+ ["./example/person/00069_00.jpg"],
127
+ ],
128
+ inputs=[image_input],
129
+ label="Person Images",
130
  )
 
 
 
131
  with gr.Column():
132
+ garment_input = gr.Image(label="Garment Image", type="pil", height=600, width=300)
133
+ gr.Examples(
134
+ examples=[
135
+ ["./example/garment/04564_00.jpg"],
136
+ ["./example/garment/00055_00.jpg"],
137
+ ["./example/garment/00396_00.jpg"],
138
+ ["./example/garment/00067_00.jpg"],
139
+ ["./example/garment/00069_00.jpg"],
140
+ ],
141
+ inputs=[garment_input],
142
+ label="Garment Images",
143
+ )
144
+ with gr.Column():
145
+ tryon_output = gr.Image(label="Try-On Result", height=600, width=300)
146
+
147
  with gr.Row():
148
+ num_steps = gr.Slider(
149
+ minimum=1,
150
+ maximum=100,
151
+ value=30,
152
+ step=1,
153
+ label="Number of Steps"
154
+ )
155
+ guidance_scale = gr.Slider(
156
+ minimum=1.0,
157
+ maximum=50.0,
158
+ value=30.0,
159
+ step=0.5,
160
+ label="Guidance Scale"
161
+ )
162
+ seed = gr.Slider(
163
+ minimum=-1,
164
+ maximum=2147483647,
165
+ step=1,
166
+ value=-1,
167
+ label="Seed (-1 for random)"
168
+ )
169
+ width = gr.Slider(
170
+ minimum=256,
171
+ maximum=1024,
172
+ step=64,
173
+ value=768,
174
+ label="Width"
175
+ )
176
+ height = gr.Slider(
177
+ minimum=256,
178
+ maximum=1024,
179
+ step=64,
180
+ value=1024,
181
+ label="Height"
182
+ )
183
+
184
+
185
+ submit_btn = gr.Button("Generate Try-On", variant="primary")
186
 
187
+
188
+ with gr.Row():
189
+ gr.Markdown("""
190
+ ### Notes:
191
+ - The model is trained on VITON-HD dataset. It focuses on the woman upper body try-on generation.
192
+ - The mask should indicate the region where the garment will be placed.
193
+ - The garment image should be on a clean background.
194
+ - The model is not perfect. It may generate some artifacts.
195
+ - The model is slow. Please be patient.
196
+ - The model is just for research purpose.
197
+ """)
 
198
 
199
+ submit_btn.click(
200
+ fn=gradio_inference,
201
+ inputs=[
202
+ image_input,
203
+ garment_input,
204
+ num_steps,
205
+ guidance_scale,
206
+ seed,
207
+ width,
208
+ height
209
+ ],
210
+ outputs=[tryon_output],
211
+ api_name="try-on"
212
+ )
213
+
214
 
215
+ demo.launch()
 
 
 
 
 
 
example/garment/00396_00.jpg ADDED
requirements.txt CHANGED
@@ -37,19 +37,6 @@ multiprocess==0.70.16
37
  networkx==3.3
38
  ninja==1.11.1.1
39
  numpy==1.26.4
40
- nvidia-cublas-cu12==12.1.3.1
41
- nvidia-cuda-cupti-cu12==12.1.105
42
- nvidia-cuda-nvrtc-cu12==12.1.105
43
- nvidia-cuda-runtime-cu12==12.1.105
44
- nvidia-cudnn-cu12==9.1.0.70
45
- nvidia-cufft-cu12==11.0.2.54
46
- nvidia-curand-cu12==10.3.2.106
47
- nvidia-cusolver-cu12==11.4.5.107
48
- nvidia-cusparse-cu12==12.1.0.106
49
- nvidia-ml-py==12.555.43
50
- nvidia-nccl-cu12==2.20.5
51
- nvidia-nvjitlink-cu12==12.6.20
52
- nvidia-nvtx-cu12==12.1.105
53
  omegaconf==2.3.0
54
  onnxruntime-gpu==1.18.1
55
  opencv-python==4.10.0.84
@@ -59,7 +46,6 @@ pandas==2.2.2
59
  pillow==10.4.0
60
  platformdirs==4.2.2
61
  protobuf==5.27.3
62
- psutil==6.0.0
63
  py-cpuinfo==9.0.0
64
  pyarrow==17.0.0
65
  pydantic==2.8.2
@@ -97,4 +83,5 @@ gradio==5.6.0
97
  gradio_client==1.4.3
98
  prodigyopt
99
  huggingface-hub
 
100
  git+https://github.com/huggingface/diffusers.git
 
37
  networkx==3.3
38
  ninja==1.11.1.1
39
  numpy==1.26.4
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  omegaconf==2.3.0
41
  onnxruntime-gpu==1.18.1
42
  opencv-python==4.10.0.84
 
46
  pillow==10.4.0
47
  platformdirs==4.2.2
48
  protobuf==5.27.3
 
49
  py-cpuinfo==9.0.0
50
  pyarrow==17.0.0
51
  pydantic==2.8.2
 
83
  gradio_client==1.4.3
84
  prodigyopt
85
  huggingface-hub
86
+ spaces
87
  git+https://github.com/huggingface/diffusers.git