ashawkey commited on
Commit
3cb3552
·
1 Parent(s): 8ba0984

update app

Browse files
Files changed (1) hide show
  1. app.py +76 -61
app.py CHANGED
@@ -60,32 +60,40 @@ model = Model(model_config).eval().cuda().bfloat16()
60
  ckpt_dict = torch.load(flow_ckpt_path, weights_only=True)
61
  model.load_state_dict(ckpt_dict, strict=True)
62
 
63
- # process function
64
- @spaces.GPU(duration=120)
65
- def process(input_image, num_steps=30, cfg_scale=7.5, grid_res=384, seed=42, randomize_seed=True, simplify_mesh=False, target_num_faces=100000):
66
-
67
- # seed
68
  if randomize_seed:
69
  seed = np.random.randint(0, MAX_SEED)
70
- kiui.seed_everything(seed)
71
 
72
- # output path
73
- os.makedirs("output", exist_ok=True)
74
- output_glb_path = f"output/partpacker_{datetime.now().strftime('%Y%m%d_%H%M%S')}.glb"
75
-
76
- # input image
77
  input_image = np.array(input_image) # uint8
78
-
79
  # bg removal if there is no alpha channel
80
  if input_image.shape[-1] == 3:
81
  input_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4]
82
  mask = input_image[..., -1] > 0
83
  image = recenter_foreground(input_image, mask, border_ratio=0.1)
84
  image = cv2.resize(image, (518, 518), interpolation=cv2.INTER_LINEAR)
85
- image = image.astype(np.float32) / 255.0
86
- image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) # white background
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  image_tensor = torch.from_numpy(image).permute(2, 0, 1).contiguous().unsqueeze(0).float().cuda()
 
89
  data = {"cond_images": image_tensor}
90
 
91
  with torch.inference_mode():
@@ -126,7 +134,7 @@ def process(input_image, num_steps=30, cfg_scale=7.5, grid_res=384, seed=42, ran
126
  # export the whole mesh
127
  mesh.export(output_glb_path)
128
 
129
- return seed, image, output_glb_path
130
 
131
  # gradio UI
132
 
@@ -145,57 +153,64 @@ _DESCRIPTION = '''
145
  block = gr.Blocks(title=_TITLE).queue()
146
  with block:
147
  with gr.Row():
148
- with gr.Column(scale=1):
149
  gr.Markdown('# ' + _TITLE)
150
  gr.Markdown(_DESCRIPTION)
151
 
152
  with gr.Row():
153
- with gr.Column(scale=4):
154
- # input image
155
- input_image = gr.Image(label="Image", type='pil')
156
- # inference steps
157
- num_steps = gr.Slider(label="Inference steps", minimum=1, maximum=100, step=1, value=30)
158
- # cfg scale
159
- cfg_scale = gr.Slider(label="CFG scale", minimum=2, maximum=10, step=0.1, value=7.0)
160
- # grid resolution
161
- input_grid_res = gr.Slider(label="Grid resolution", minimum=256, maximum=512, step=1, value=384)
162
- # random seed
163
- seed = gr.Slider(label="Random seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
164
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
165
- # simplify mesh
166
- simplify_mesh = gr.Checkbox(label="Simplify mesh", value=False)
167
- target_num_faces = gr.Slider(label="Face number", minimum=10000, maximum=1000000, step=1000, value=100000)
168
- # gen button
169
- button_gen = gr.Button("Generate")
170
-
171
-
172
- with gr.Column(scale=8):
173
- with gr.Tab("3D Model"):
174
- # glb file
175
- output_model = gr.Model3D(label="Geometry", height=512)
176
-
177
- with gr.Tab("Input Image"):
178
- # background removed image
179
- output_image = gr.Image(interactive=False, show_label=False)
180
-
181
 
182
  with gr.Column(scale=1):
183
- gr.Examples(
184
- examples=[
185
- ["examples/rabbit.png"],
186
- ["examples/robot.png"],
187
- ["examples/teapot.png"],
188
- ["examples/barrel.png"],
189
- ["examples/cactus.png"],
190
- ["examples/cyan_car.png"],
191
- ["examples/pickup.png"],
192
- ["examples/swivelchair.png"],
193
- ["examples/warhammer.png"],
194
- ],
195
- inputs=[input_image],
196
- cache_examples=False,
197
- )
198
-
199
- button_gen.click(process, inputs=[input_image, num_steps, cfg_scale, input_grid_res, seed, randomize_seed, simplify_mesh, target_num_faces], outputs=[seed, output_image, output_model])
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  block.launch()
 
60
  ckpt_dict = torch.load(flow_ckpt_path, weights_only=True)
61
  model.load_state_dict(ckpt_dict, strict=True)
62
 
63
+ # get random seed
64
+ def get_random_seed(randomize_seed, seed):
 
 
 
65
  if randomize_seed:
66
  seed = np.random.randint(0, MAX_SEED)
67
+ return seed
68
 
69
+ # process image
70
+ @spaces.GPU(duration=10)
71
+ def process_image(input_image):
 
 
72
  input_image = np.array(input_image) # uint8
 
73
  # bg removal if there is no alpha channel
74
  if input_image.shape[-1] == 3:
75
  input_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4]
76
  mask = input_image[..., -1] > 0
77
  image = recenter_foreground(input_image, mask, border_ratio=0.1)
78
  image = cv2.resize(image, (518, 518), interpolation=cv2.INTER_LINEAR)
79
+ return image
 
80
 
81
+ # process generation
82
+ @spaces.GPU(duration=120)
83
+ def process_3d(input_image, num_steps=30, cfg_scale=7.5, grid_res=384, seed=42, simplify_mesh=False, target_num_faces=100000):
84
+
85
+ # seed
86
+ kiui.seed_everything(seed)
87
+
88
+ # output path
89
+ os.makedirs("output", exist_ok=True)
90
+ output_glb_path = f"output/partpacker_{datetime.now().strftime('%Y%m%d_%H%M%S')}.glb"
91
+
92
+ # input image (assume processed to RGBA uint8)
93
+ image = input_image.astype(np.float32) / 255.0
94
+ image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) # white background
95
  image_tensor = torch.from_numpy(image).permute(2, 0, 1).contiguous().unsqueeze(0).float().cuda()
96
+
97
  data = {"cond_images": image_tensor}
98
 
99
  with torch.inference_mode():
 
134
  # export the whole mesh
135
  mesh.export(output_glb_path)
136
 
137
+ return output_glb_path
138
 
139
  # gradio UI
140
 
 
153
  block = gr.Blocks(title=_TITLE).queue()
154
  with block:
155
  with gr.Row():
156
+ with gr.Column():
157
  gr.Markdown('# ' + _TITLE)
158
  gr.Markdown(_DESCRIPTION)
159
 
160
  with gr.Row():
161
+ with gr.Column(scale=1):
162
+ with gr.Row():
163
+ # input image
164
+ input_image = gr.Image(label="Input Image", type="filepath")
165
+ seg_image = gr.Image(label="Segmentation Result", type="numpy", format="png", interactive=False)
166
+ with gr.Accordion("Settings", open=True):
167
+ # inference steps
168
+ num_steps = gr.Slider(label="Inference steps", minimum=1, maximum=100, step=1, value=30)
169
+ # cfg scale
170
+ cfg_scale = gr.Slider(label="CFG scale", minimum=2, maximum=10, step=0.1, value=7.0)
171
+ # grid resolution
172
+ input_grid_res = gr.Slider(label="Grid resolution", minimum=256, maximum=512, step=1, value=384)
173
+ # random seed
174
+ with gr.Row():
175
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
176
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
177
+ # simplify mesh
178
+ with gr.Row():
179
+ simplify_mesh = gr.Checkbox(label="Simplify mesh", value=False)
180
+ target_num_faces = gr.Slider(label="Face number", minimum=10000, maximum=1000000, step=1000, value=100000)
181
+ # gen button
182
+ button_gen = gr.Button("Generate")
 
 
 
 
 
 
183
 
184
  with gr.Column(scale=1):
185
+ # glb file
186
+ output_model = gr.Model3D(label="Geometry", height=512)
187
+
188
+
189
+ with gr.Row():
190
+ gr.Examples(
191
+ examples=[
192
+ ["examples/rabbit.png"],
193
+ ["examples/robot.png"],
194
+ ["examples/teapot.png"],
195
+ ["examples/barrel.png"],
196
+ ["examples/cactus.png"],
197
+ ["examples/cyan_car.png"],
198
+ ["examples/pickup.png"],
199
+ ["examples/swivelchair.png"],
200
+ ["examples/warhammer.png"],
201
+ ],
202
+ fn=process_image, # still need to click button_gen to get the 3d
203
+ inputs=[input_image],
204
+ outputs=[seg_image],
205
+ cache_examples=False,
206
+ )
207
+
208
+ button_gen.click(
209
+ process_image, inputs=[input_image], outputs=[seg_image]
210
+ ).then(
211
+ get_random_seed, inputs=[randomize_seed, seed], outputs=[seed]
212
+ ).then(
213
+ process_3d, inputs=[seg_image, num_steps, cfg_scale, input_grid_res, seed, simplify_mesh, target_num_faces], outputs=[output_model]
214
+ )
215
 
216
  block.launch()