vibs08 commited on
Commit
2bad595
·
verified ·
1 Parent(s): ec2b233

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -6
app.py CHANGED
@@ -151,40 +151,56 @@ def check_input_image(input_image):
151
  raise gr.Error("No image uploaded!")
152
 
153
  def preprocess(input_image, do_remove_background, foreground_ratio):
154
- torch.cuda.synchronize()
155
  def fill_background(image):
 
156
  image = np.array(image).astype(np.float32) / 255.0
157
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
158
  image = Image.fromarray((image * 255.0).astype(np.uint8))
159
  return image
160
 
161
  if do_remove_background:
 
162
  image = input_image.convert("RGB")
163
  image = remove_background(image, rembg_session)
164
  image = resize_foreground(image, foreground_ratio)
165
  image = fill_background(image)
 
 
166
  else:
167
  image = input_image
168
  if image.mode == "RGBA":
169
  image = fill_background(image)
170
- torch.cuda.synchronize()
 
171
  return image
172
 
 
 
173
  # @spaces.GPU
174
  def generate(image, mc_resolution, formats=["obj", "glb"]):
 
175
  scene_codes = model(image, device=device)
 
176
  mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
 
177
  mesh = to_gradio_3d_orientation(mesh)
178
-
 
179
  mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False)
 
180
  mesh.export(mesh_path_glb.name)
181
-
 
182
  mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
183
- mesh.apply_scale([-1, 1, 1]) # Otherwise the visualized .obj will be flipped
 
184
  mesh.export(mesh_path_obj.name)
185
-
 
186
  return mesh_path_obj.name, mesh_path_glb.name
187
 
 
 
188
  def run_example(text_prompt,seed ,do_remove_background, foreground_ratio, mc_resolution):
189
  image_pil = generate_image_from_text(text_prompt, seed)
190
 
 
151
  raise gr.Error("No image uploaded!")
152
 
153
  def preprocess(input_image, do_remove_background, foreground_ratio):
 
154
  def fill_background(image):
155
+ torch.cuda.synchronize() # Ensure previous CUDA operations are complete
156
  image = np.array(image).astype(np.float32) / 255.0
157
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
158
  image = Image.fromarray((image * 255.0).astype(np.uint8))
159
  return image
160
 
161
  if do_remove_background:
162
+ torch.cuda.synchronize()
163
  image = input_image.convert("RGB")
164
  image = remove_background(image, rembg_session)
165
  image = resize_foreground(image, foreground_ratio)
166
  image = fill_background(image)
167
+
168
+ torch.cuda.synchronize()
169
  else:
170
  image = input_image
171
  if image.mode == "RGBA":
172
  image = fill_background(image)
173
+ torch.cuda.synchronize() # Wait for all CUDA operations to complete
174
+ torch.cuda.empty_cache()
175
  return image
176
 
177
+
178
+
179
  # @spaces.GPU
180
  def generate(image, mc_resolution, formats=["obj", "glb"]):
181
+ torch.cuda.synchronize()
182
  scene_codes = model(image, device=device)
183
+ torch.cuda.synchronize()
184
  mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
185
+ torch.cuda.synchronize()
186
  mesh = to_gradio_3d_orientation(mesh)
187
+ torch.cuda.synchronize()
188
+
189
  mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False)
190
+ torch.cuda.synchronize()
191
  mesh.export(mesh_path_glb.name)
192
+ torch.cuda.synchronize()
193
+
194
  mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
195
+ torch.cuda.synchronize()
196
+ mesh.apply_scale([-1, 1, 1])
197
  mesh.export(mesh_path_obj.name)
198
+ torch.cuda.synchronize()
199
+ torch.cuda.empty_cache()
200
  return mesh_path_obj.name, mesh_path_glb.name
201
 
202
+
203
+
204
  def run_example(text_prompt,seed ,do_remove_background, foreground_ratio, mc_resolution):
205
  image_pil = generate_image_from_text(text_prompt, seed)
206