er1t0 commited on
Commit
aab55b1
1 Parent(s): 4450c32

initial commit

Browse files
Files changed (3) hide show
  1. app.py +205 -0
  2. checkpoints/test.txt +0 -0
  3. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from transformers import AutoProcessor, AutoModelForCausalLM
8
+ from sam2.build_sam import build_sam2_video_predictor, build_sam2
9
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
10
+ import cv2
11
+ import traceback
12
+ import matplotlib.pyplot as plt
13
+
14
+ # CUDA optimizations
15
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
16
+ if torch.cuda.get_device_properties(0).major >= 8:
17
+ torch.backends.cuda.matmul.allow_tf32 = True
18
+ torch.backends.cudnn.allow_tf32 = True
19
+
20
+ # Initialize models
21
+ sam2_checkpoint = "../checkpoints/sam2_hiera_large.pt"
22
+ model_cfg = "sam2_hiera_l.yaml"
23
+
24
+ video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
25
+ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
26
+ image_predictor = SAM2ImagePredictor(sam2_model)
27
+
28
+ model_id = 'microsoft/Florence-2-large'
29
+ florence_model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.bfloat16).eval().cuda()
30
+ florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
31
+
32
+ def apply_color_mask(frame, mask, obj_id):
33
+ cmap = plt.get_cmap("tab10")
34
+ color = np.array(cmap(obj_id % 10)[:3]) # Use modulo 10 to cycle through colors
35
+
36
+ # Ensure mask has the correct shape
37
+ if mask.ndim == 4:
38
+ mask = mask.squeeze() # Remove singleton dimensions
39
+ if mask.ndim == 3 and mask.shape[0] == 1:
40
+ mask = mask[0] # Take the first channel if it's a single-channel 3D array
41
+
42
+ # Reshape mask to match frame dimensions
43
+ mask = cv2.resize(mask.astype(np.float32), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_LINEAR)
44
+
45
+ # Expand dimensions of mask and color for broadcasting
46
+ mask = np.expand_dims(mask, axis=2)
47
+ color = color.reshape(1, 1, 3)
48
+
49
+ colored_mask = mask * color
50
+ return frame * (1 - mask) + colored_mask * 255
51
+
52
+ def run_florence(image, text_input):
53
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
54
+ task_prompt = '<OPEN_VOCABULARY_DETECTION>'
55
+ prompt = task_prompt + text_input
56
+ inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.bfloat16)
57
+ generated_ids = florence_model.generate(
58
+ input_ids=inputs["input_ids"].cuda(),
59
+ pixel_values=inputs["pixel_values"].cuda(),
60
+ max_new_tokens=1024,
61
+ early_stopping=False,
62
+ do_sample=False,
63
+ num_beams=3,
64
+ )
65
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
66
+ parsed_answer = florence_processor.post_process_generation(
67
+ generated_text,
68
+ task=task_prompt,
69
+ image_size=(image.width, image.height)
70
+ )
71
+ return parsed_answer[task_prompt]['bboxes'][0]
72
+
73
+ def remove_directory_contents(directory):
74
+ for root, dirs, files in os.walk(directory, topdown=False):
75
+ for name in files:
76
+ os.remove(os.path.join(root, name))
77
+ for name in dirs:
78
+ os.rmdir(os.path.join(root, name))
79
+
80
+ def process_video(video_path, prompt, chunk_size=30):
81
+ try:
82
+ video = cv2.VideoCapture(video_path)
83
+ if not video.isOpened():
84
+ raise ValueError("Unable to open video file")
85
+
86
+ fps = video.get(cv2.CAP_PROP_FPS)
87
+ frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
88
+
89
+ # Process video in chunks
90
+ all_segmented_frames = []
91
+ for chunk_start in range(0, frame_count, chunk_size):
92
+ chunk_end = min(chunk_start + chunk_size, frame_count)
93
+
94
+ frames = []
95
+ video.set(cv2.CAP_PROP_POS_FRAMES, chunk_start)
96
+ for _ in range(chunk_end - chunk_start):
97
+ ret, frame = video.read()
98
+ if not ret:
99
+ break
100
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
101
+
102
+ if not frames:
103
+ print(f"No frames extracted for chunk starting at {chunk_start}")
104
+ continue
105
+
106
+ # Florence detection on first frame of the chunk
107
+ first_frame = Image.fromarray(frames[0])
108
+ mask_box = run_florence(first_frame, prompt)
109
+ print("Original mask box:", mask_box)
110
+
111
+ # Convert mask_box to numpy array and ensure it's in the correct format
112
+ mask_box = np.array(mask_box)
113
+ print("Reshaped mask box:", mask_box)
114
+
115
+ # SAM2 segmentation on first frame
116
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
117
+ image_predictor.set_image(first_frame)
118
+ masks, _, _ = image_predictor.predict(
119
+ point_coords=None,
120
+ point_labels=None,
121
+ box=mask_box[None, :],
122
+ multimask_output=False,
123
+ )
124
+ print("masks.shape",masks.shape)
125
+
126
+ mask = masks.squeeze().astype(bool)
127
+ print("Mask shape:", mask.shape)
128
+ print("Frame shape:", frames[0].shape)
129
+
130
+ # SAM2 video propagation
131
+ temp_dir = f"temp_frames_{chunk_start}"
132
+ os.makedirs(temp_dir, exist_ok=True)
133
+ for i, frame in enumerate(frames):
134
+ cv2.imwrite(os.path.join(temp_dir, f"{i:04d}.jpg"), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
135
+
136
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
137
+ inference_state = video_predictor.init_state(video_path=temp_dir)
138
+ _, _, _ = video_predictor.add_new_mask(
139
+ inference_state=inference_state,
140
+ frame_idx=0,
141
+ obj_id=1,
142
+ mask=mask
143
+ )
144
+
145
+ video_segments = {}
146
+ for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
147
+ video_segments[out_frame_idx] = {
148
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
149
+ for i, out_obj_id in enumerate(out_obj_ids)
150
+ }
151
+
152
+ print('segmenting for main vid done')
153
+
154
+ # Apply segmentation masks to frames
155
+ for i, frame in enumerate(frames):
156
+ if i in video_segments:
157
+ for out_obj_id, mask in video_segments[i].items():
158
+ frame = apply_color_mask(frame, mask, out_obj_id)
159
+ all_segmented_frames.append(frame.astype(np.uint8))
160
+ else:
161
+ all_segmented_frames.append(frame)
162
+
163
+ # Clean up temporary files
164
+ remove_directory_contents(temp_dir)
165
+ os.rmdir(temp_dir)
166
+
167
+ video.release()
168
+
169
+ if not all_segmented_frames:
170
+ raise ValueError("No frames were processed successfully")
171
+
172
+ # Create video from segmented frames
173
+ output_path = "segmented_video.mp4"
174
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps,
175
+ (all_segmented_frames[0].shape[1], all_segmented_frames[0].shape[0]))
176
+ for frame in all_segmented_frames:
177
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
178
+ out.release()
179
+
180
+ return output_path
181
+
182
+ except Exception as e:
183
+ print(f"Error in process_video: {str(e)}")
184
+ print(traceback.format_exc()) # This will print the full stack trace
185
+ return None
186
+
187
+ def segment_video(video_file, prompt, chunk_size):
188
+ if video_file is None:
189
+ return None
190
+ output_video = process_video(video_file, prompt, int(chunk_size))
191
+ return output_video
192
+
193
+ demo = gr.Interface(
194
+ fn=segment_video,
195
+ inputs=[
196
+ gr.Video(label="Upload Video"),
197
+ gr.Textbox(label="Enter prompt (e.g., 'a gymnast')"),
198
+ gr.Slider(minimum=10, maximum=100, step=10, value=30, label="Chunk Size (frames)")
199
+ ],
200
+ outputs=gr.Video(label="Segmented Video"),
201
+ title="Video Object Segmentation with Florence and SAM2",
202
+ description="Upload a video and provide a text prompt to segment a specific object throughout the video."
203
+ )
204
+
205
+ demo.launch()
checkpoints/test.txt ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ samv2
4
+ gradio
5
+ Pillow
6
+ transformers
7
+ opencv-python
8
+ matplotlib
9
+ einops
10
+ timm