chongzhou commited on
Commit
c306abe
·
1 Parent(s): 3da2a0c
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -120,7 +120,11 @@ def reset(
120
  )
121
 
122
 
123
- def clear_points(session_input_points, session_input_labels, request: gr.Request,):
 
 
 
 
124
  session_id = request.session_id
125
  predictor.to("cpu")
126
  session_input_points = []
@@ -237,9 +241,7 @@ def segment_with_points(
237
  print(f"TRACKING INPUT LABEL: {session_input_labels}")
238
 
239
  # Open the image and get its dimensions
240
- transparent_background Image.fromarray(session_first_frame).convert(
241
- "RGBA"
242
- )
243
  w, h = transparent_background.size
244
 
245
  # Define the circle radius as a fraction of the smaller dimension
@@ -277,7 +279,12 @@ def segment_with_points(
277
  first_frame_output = Image.alpha_composite(transparent_background, mask_image)
278
 
279
  torch.cuda.empty_cache()
280
- return selected_point_map, first_frame_output, session_input_points, session_input_labels
 
 
 
 
 
281
 
282
 
283
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
@@ -308,7 +315,7 @@ def propagate_to_all(
308
  torch.backends.cudnn.allow_tf32 = True
309
  with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
310
  if (
311
- len (session_input_points) == 0
312
  or video_in is None
313
  or global_inference_states[session_id] is None
314
  ):
 
120
  )
121
 
122
 
123
+ def clear_points(
124
+ session_input_points,
125
+ session_input_labels,
126
+ request: gr.Request,
127
+ ):
128
  session_id = request.session_id
129
  predictor.to("cpu")
130
  session_input_points = []
 
241
  print(f"TRACKING INPUT LABEL: {session_input_labels}")
242
 
243
  # Open the image and get its dimensions
244
+ transparent_background = Image.fromarray(session_first_frame).convert("RGBA")
 
 
245
  w, h = transparent_background.size
246
 
247
  # Define the circle radius as a fraction of the smaller dimension
 
279
  first_frame_output = Image.alpha_composite(transparent_background, mask_image)
280
 
281
  torch.cuda.empty_cache()
282
+ return (
283
+ selected_point_map,
284
+ first_frame_output,
285
+ session_input_points,
286
+ session_input_labels,
287
+ )
288
 
289
 
290
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
 
315
  torch.backends.cudnn.allow_tf32 = True
316
  with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
317
  if (
318
+ len(session_input_points) == 0
319
  or video_in is None
320
  or global_inference_states[session_id] is None
321
  ):