Spaces:
Runtime error
Runtime error
Yang You
commited on
Commit
·
2cabcdf
1
Parent(s):
b9f276b
upload files
Browse files- README.md +5 -5
- app.py +328 -0
- checkpoints/648Ai4i4i3n4s_1e-5m_c5c_stage3_from_kub_ns_wa_kk_lsh_dyk_46470/model-000600000.pth +3 -0
- data/244754_medium.mp4 +3 -0
- demo_dense_visualize.py +229 -0
- nets/blocks.py +1337 -0
- nets/net34.py +647 -0
- requirements.txt +17 -0
- utils/basic.py +429 -0
- utils/data.py +122 -0
- utils/geom.py +771 -0
- utils/improc.py +1645 -0
- utils/loss.py +220 -0
- utils/metric.py +176 -0
- utils/misc.py +1062 -0
- utils/py.py +755 -0
- utils/samp.py +213 -0
- utils/saveload.py +59 -0
- utils/test.py +194 -0
- utils/visualizer.py +365 -0
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
1 |
---
|
2 |
+
title: DenseTrack
|
3 |
+
emoji: 🏃
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.21.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
app.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
import datetime
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import imageio
|
9 |
+
from pathlib import Path
|
10 |
+
from tqdm import tqdm
|
11 |
+
import gradio as gr
|
12 |
+
|
13 |
+
# Import your custom modules
|
14 |
+
import utils.loss
|
15 |
+
import utils.samp
|
16 |
+
import utils.data
|
17 |
+
import utils.improc
|
18 |
+
import utils.misc
|
19 |
+
import utils.saveload
|
20 |
+
from nets.blocks import InputPadder
|
21 |
+
from nets.net34 import Net
|
22 |
+
from tensorboardX import SummaryWriter
|
23 |
+
import imageio
|
24 |
+
from demo_dense_visualize import Tracker
|
25 |
+
import spaces
|
26 |
+
|
27 |
+
# -------------------- Utility Functions -------------------- #
|
28 |
+
def count_parameters(model):
|
29 |
+
total_params = 0
|
30 |
+
for name, parameter in model.named_parameters():
|
31 |
+
if not parameter.requires_grad:
|
32 |
+
continue
|
33 |
+
total_params += parameter.numel()
|
34 |
+
print('Total params: %.2f M' % (total_params/1e6))
|
35 |
+
return total_params
|
36 |
+
|
37 |
+
def seed_everything(seed: int):
|
38 |
+
random.seed(seed)
|
39 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
40 |
+
np.random.seed(seed)
|
41 |
+
torch.manual_seed(seed)
|
42 |
+
torch.cuda.manual_seed(seed)
|
43 |
+
torch.backends.cudnn.deterministic = True
|
44 |
+
torch.backends.cudnn.benchmark = False
|
45 |
+
|
46 |
+
# -------------------- Step 1: Extract the First Frame -------------------- #
|
47 |
+
def extract_first_frame(video_path, _, tracker):
|
48 |
+
"""
|
49 |
+
Opens the video, extracts the first frame, resizes it (largest dimension 1024),
|
50 |
+
and returns:
|
51 |
+
- the frame for display (to be annotated),
|
52 |
+
- the video file path (to store in state),
|
53 |
+
- a copy of the original first frame (to be used when adding points)
|
54 |
+
"""
|
55 |
+
cap = cv2.VideoCapture(video_path)
|
56 |
+
if not cap.isOpened():
|
57 |
+
return None, None, None
|
58 |
+
ret, frame = cap.read()
|
59 |
+
cap.release()
|
60 |
+
if not ret:
|
61 |
+
return None, video_path, None
|
62 |
+
# Convert BGR to RGB
|
63 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
64 |
+
scale = min(tracker.target_res / frame_rgb.shape[0], tracker.target_res / frame_rgb.shape[1])
|
65 |
+
frame_resized = cv2.resize(frame_rgb, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
|
66 |
+
# Return the displayed frame, the video file path, and a copy of the original frame for point drawing.
|
67 |
+
return frame_resized, video_path, frame_resized.copy(), []
|
68 |
+
|
69 |
+
# -------------------- Callback to Add a Clicked Point -------------------- #
|
70 |
+
def add_point(orig_frame, points, evt: gr.SelectData):
|
71 |
+
"""
|
72 |
+
Called when the user clicks on the displayed first frame.
|
73 |
+
- orig_frame: The original first frame image (numpy array).
|
74 |
+
- points: The current list of point coordinates.
|
75 |
+
- evt: Event data from the image click (expects evt.index as (x, y)).
|
76 |
+
|
77 |
+
Returns the updated image (with circles drawn at all points)
|
78 |
+
and the updated list of points.
|
79 |
+
"""
|
80 |
+
if points is None:
|
81 |
+
points = []
|
82 |
+
# evt.index contains the (x, y) coordinates of the click.
|
83 |
+
x, y = evt.index
|
84 |
+
new_points = points.copy()
|
85 |
+
new_points.append([x, y])
|
86 |
+
# Draw circles on a copy of the original image.
|
87 |
+
updated_frame = orig_frame.copy()
|
88 |
+
for (px, py) in new_points:
|
89 |
+
cv2.circle(updated_frame, (int(round(px)), int(round(py))), radius=5, color=(0,255,0), thickness=-1)
|
90 |
+
|
91 |
+
# print(updated_frame.shape)
|
92 |
+
return updated_frame, new_points
|
93 |
+
|
94 |
+
# -------------------- Step 2: Process Video & Track Points -------------------- #
|
95 |
+
@torch.no_grad()
|
96 |
+
@spaces.GPU
|
97 |
+
def process_video_with_points(video_path, click_points, tracker):
|
98 |
+
"""
|
99 |
+
Runs the dense flow prediction over the entire video, tracking the user-selected points.
|
100 |
+
Args:
|
101 |
+
video_path: Path to the uploaded video.
|
102 |
+
click_points: List of [x, y] coordinates selected on the first frame.
|
103 |
+
(Coordinates are in the same (resized) space as the displayed first frame.)
|
104 |
+
tracker: The tracker instance to use for processing.
|
105 |
+
Returns:
|
106 |
+
A path to the output video with tracked points overlaid.
|
107 |
+
"""
|
108 |
+
if len(click_points) == 0:
|
109 |
+
print("No points selected for tracking.")
|
110 |
+
return "Error: No points selected for tracking."
|
111 |
+
|
112 |
+
# Open the video.
|
113 |
+
cap = cv2.VideoCapture(video_path)
|
114 |
+
if not cap.isOpened():
|
115 |
+
return "Error: Could not open video."
|
116 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
117 |
+
|
118 |
+
# List to store frames with overlaid points.
|
119 |
+
output_frames = []
|
120 |
+
# Initialize the points with those selected on the first frame.
|
121 |
+
|
122 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
123 |
+
pbar = tqdm(total=total_frames, desc="Processing video")
|
124 |
+
|
125 |
+
tracker.reset()
|
126 |
+
|
127 |
+
frame_disps = []
|
128 |
+
try:
|
129 |
+
while True:
|
130 |
+
torch.cuda.empty_cache()
|
131 |
+
ret, frame = cap.read()
|
132 |
+
if not ret:
|
133 |
+
break
|
134 |
+
|
135 |
+
# Convert frame from BGR to RGB and resize as in your original code.
|
136 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
137 |
+
scale = min(tracker.target_res / frame_rgb.shape[0], tracker.target_res / frame_rgb.shape[1])
|
138 |
+
frame_disp = cv2.resize(frame_rgb, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
|
139 |
+
frame_disps.append(frame_disp)
|
140 |
+
|
141 |
+
flows = tracker.track(frame_rgb)
|
142 |
+
|
143 |
+
if flows is not None:
|
144 |
+
flows_np = flows[0].cpu().numpy()
|
145 |
+
|
146 |
+
for i, flow_np in enumerate(flows_np):
|
147 |
+
# --- Update tracked points using the flow ---
|
148 |
+
current_points = []
|
149 |
+
for (x, y) in click_points:
|
150 |
+
xi = int(round(x))
|
151 |
+
yi = int(round(y))
|
152 |
+
# print('xi, yi', xi, yi)
|
153 |
+
if 0 <= yi < flow_np.shape[1] and 0 <= xi < flow_np.shape[2]:
|
154 |
+
dx = flow_np[0, yi, xi]
|
155 |
+
dy = flow_np[1, yi, xi]
|
156 |
+
# print('dx, dy', dx, dy)
|
157 |
+
else:
|
158 |
+
dx, dy = 0.0, 0.0
|
159 |
+
current_points.append([x + dx, y + dy])
|
160 |
+
|
161 |
+
# Draw the updated points on the frame.
|
162 |
+
for (x, y) in current_points:
|
163 |
+
cv2.circle(frame_disps[i], (int(round(x)), int(round(y))), radius=5, color=(0,255,0), thickness=-1)
|
164 |
+
output_frames.append(frame_disps[i])
|
165 |
+
frame_disps = []
|
166 |
+
pbar.update(1)
|
167 |
+
|
168 |
+
except RuntimeError as e:
|
169 |
+
# Check if the error message indicates an OOM error.
|
170 |
+
if "out of memory" in str(e).lower():
|
171 |
+
torch.cuda.empty_cache()
|
172 |
+
pbar.close()
|
173 |
+
cap.release()
|
174 |
+
print("Error: Out of Memory during video processing.")
|
175 |
+
return "Error: Out of Memory during video processing. Please try a smaller video or lower resolution."
|
176 |
+
else:
|
177 |
+
# Re-raise if it's another type of error.
|
178 |
+
raise e
|
179 |
+
pbar.close()
|
180 |
+
cap.release()
|
181 |
+
|
182 |
+
# -------------------- Save the Output Video -------------------- #
|
183 |
+
output_path = "tracked_output.mp4"
|
184 |
+
print(len(output_frames), output_frames[0].shape)
|
185 |
+
imageio.mimwrite(output_path, output_frames, fps=fps)
|
186 |
+
|
187 |
+
return output_path
|
188 |
+
|
189 |
+
# -------------------- Model Initialization -------------------- #
|
190 |
+
def load_model():
|
191 |
+
"""Initialize and load the tracking model."""
|
192 |
+
# Adjust these paths as needed.
|
193 |
+
init_dir = '648Ai4i4i3n4s_1e-5m_c5c_stage3_from_kub_ns_wa_kk_lsh_dyk_46470'
|
194 |
+
ckpt_dir = 'checkpoints'
|
195 |
+
load_dir = os.path.join(ckpt_dir, init_dir)
|
196 |
+
|
197 |
+
# Create the model and load weights.
|
198 |
+
model = Net(16)
|
199 |
+
count_parameters(model)
|
200 |
+
_ = utils.saveload.load(
|
201 |
+
None,
|
202 |
+
load_dir,
|
203 |
+
model,
|
204 |
+
optimizer=None,
|
205 |
+
scheduler=None,
|
206 |
+
ignore_load=None,
|
207 |
+
strict=True,
|
208 |
+
verbose=False,
|
209 |
+
weights_only=False,
|
210 |
+
)
|
211 |
+
model.cuda()
|
212 |
+
for n, p in model.named_parameters():
|
213 |
+
p.requires_grad = False
|
214 |
+
model.eval()
|
215 |
+
|
216 |
+
return model
|
217 |
+
|
218 |
+
def create_tracker(model):
|
219 |
+
"""Create tracker instance with the loaded model."""
|
220 |
+
return Tracker(
|
221 |
+
model=model,
|
222 |
+
mean=torch.tensor([0.485, 0.456, 0.406]).cuda().reshape(1, 3, 1, 1),
|
223 |
+
std=torch.tensor([0.229, 0.224, 0.225]).cuda().reshape(1, 3, 1, 1),
|
224 |
+
S=16,
|
225 |
+
stride=8,
|
226 |
+
inference_iters=4,
|
227 |
+
target_res=1024,
|
228 |
+
)
|
229 |
+
|
230 |
+
# -------------------- Wrappers to Update Tracker Based on UI Settings -------------------- #
|
231 |
+
def extract_with_config(video_path, points, resolution, window_index, tracker):
|
232 |
+
"""
|
233 |
+
Update the tracker configuration using the slider values, then extract the first frame.
|
234 |
+
- resolution: Target resolution from slider (e.g., 512, 768, 1024).
|
235 |
+
- window_index: An index (0–3) to be mapped to sliding window lengths {0:2, 1:4, 2:8, 3:16}.
|
236 |
+
- tracker: The tracker instance to configure.
|
237 |
+
"""
|
238 |
+
tracker.target_res = resolution
|
239 |
+
mapping = {0: 2, 1: 4, 2: 8, 3: 16}
|
240 |
+
tracker.S = mapping.get(int(window_index), 16)
|
241 |
+
return extract_first_frame(video_path, points, tracker)
|
242 |
+
|
243 |
+
@torch.no_grad()
|
244 |
+
@spaces.GPU
|
245 |
+
def process_with_config(video_path, click_points, resolution, window_index, tracker):
|
246 |
+
"""
|
247 |
+
Update the tracker configuration using the slider values, then process the video.
|
248 |
+
"""
|
249 |
+
tracker.target_res = resolution
|
250 |
+
mapping = {0: 2, 1: 4, 2: 8, 3: 16}
|
251 |
+
tracker.S = mapping.get(int(window_index), 16)
|
252 |
+
return process_video_with_points(video_path, click_points, tracker)
|
253 |
+
|
254 |
+
def main():
|
255 |
+
"""Main function that initializes the model and runs the Gradio interface."""
|
256 |
+
# Set torch matmul precision
|
257 |
+
torch.set_float32_matmul_precision('medium')
|
258 |
+
|
259 |
+
# Initialize random seeds
|
260 |
+
seed_everything(42)
|
261 |
+
torch.set_grad_enabled(False)
|
262 |
+
|
263 |
+
# Load model and create tracker
|
264 |
+
print("Loading model...")
|
265 |
+
model = load_model()
|
266 |
+
tracker = create_tracker(model)
|
267 |
+
print("Model loaded successfully!")
|
268 |
+
|
269 |
+
# -------------------- Gradio Interface -------------------- #
|
270 |
+
# The interface is built in two steps:
|
271 |
+
# 1. Upload a video and extract the first frame.
|
272 |
+
# 2. Annotate the first frame with multiple points (using gr.Points),
|
273 |
+
# then run tracking on the video.
|
274 |
+
with gr.Blocks() as demo:
|
275 |
+
gr.Markdown("## Dense Flow Tracking with Clickable Points")
|
276 |
+
|
277 |
+
with gr.Row():
|
278 |
+
with gr.Column():
|
279 |
+
video_input = gr.Video(label="Upload Video", value="data/244754_medium.mp4")
|
280 |
+
extract_btn = gr.Button("Extract First Frame")
|
281 |
+
# Add sliders for resolution and sliding window length.
|
282 |
+
resolution_slider = gr.Slider(minimum=512, maximum=1024, step=256, value=1024, label="Target Resolution")
|
283 |
+
# The slider below outputs an index 0-3; we'll map it to {0:2, 1:4, 2:8, 3:16}
|
284 |
+
window_slider = gr.Slider(minimum=0, maximum=3, step=1, value=3, label="Sliding Window Length (Index: 0->2, 1->4, 2->8, 3->16)")
|
285 |
+
with gr.Column():
|
286 |
+
# This image will display the first frame and be interactive.
|
287 |
+
first_frame_display = gr.Image(label="First Frame (Click to add points)", interactive=True)
|
288 |
+
clear_pts_btn = gr.Button("Clear Points")
|
289 |
+
|
290 |
+
# Hidden states: video file path, original first frame, and accumulated click points.
|
291 |
+
video_state = gr.State(None)
|
292 |
+
orig_frame_state = gr.State(None)
|
293 |
+
points_state = gr.State([])
|
294 |
+
|
295 |
+
track_btn = gr.Button("Track Points")
|
296 |
+
output_video = gr.Video(label="Tracked Video")
|
297 |
+
|
298 |
+
# When "Extract First Frame" is clicked, extract and display the first frame.
|
299 |
+
extract_btn.click(
|
300 |
+
fn=lambda *args: extract_with_config(*args, tracker),
|
301 |
+
inputs=[video_input, points_state, resolution_slider, window_slider],
|
302 |
+
outputs=[first_frame_display, video_state, orig_frame_state, points_state]
|
303 |
+
)
|
304 |
+
|
305 |
+
clear_pts_btn.click(
|
306 |
+
fn=lambda orig_frame, points: (orig_frame, []),
|
307 |
+
inputs=[orig_frame_state, points_state],
|
308 |
+
outputs=[first_frame_display, points_state]
|
309 |
+
)
|
310 |
+
|
311 |
+
# When the user clicks on the image, add a point.
|
312 |
+
first_frame_display.select(
|
313 |
+
fn=add_point,
|
314 |
+
inputs=[orig_frame_state, points_state],
|
315 |
+
outputs=[first_frame_display, points_state]
|
316 |
+
)
|
317 |
+
|
318 |
+
# When "Track Points" is clicked, process the video using the accumulated points.
|
319 |
+
track_btn.click(
|
320 |
+
fn=lambda *args: process_with_config(*args, tracker),
|
321 |
+
inputs=[video_state, points_state, resolution_slider, window_slider],
|
322 |
+
outputs=output_video
|
323 |
+
)
|
324 |
+
|
325 |
+
demo.launch()
|
326 |
+
|
327 |
+
if __name__ == '__main__':
|
328 |
+
main()
|
checkpoints/648Ai4i4i3n4s_1e-5m_c5c_stage3_from_kub_ns_wa_kk_lsh_dyk_46470/model-000600000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:15b029e1396feb453d361988afafab02d94993f4b43191b4dfcc44aac10fffa5
|
3 |
+
size 198027808
|
data/244754_medium.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a26ffc0ea63adfb3e2b44f5f6ea384a6071f74cd8453e7b6d386a829d8c11fcb
|
3 |
+
size 42948491
|
demo_dense_visualize.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
import signal
|
5 |
+
import socket
|
6 |
+
import sys
|
7 |
+
import json
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import numpy as np
|
10 |
+
import argparse
|
11 |
+
from pathlib import Path
|
12 |
+
import torch.optim as optim
|
13 |
+
from torch.cuda.amp import GradScaler
|
14 |
+
from lightning_fabric import Fabric
|
15 |
+
|
16 |
+
import utils.loss
|
17 |
+
import utils.samp
|
18 |
+
import utils.data
|
19 |
+
import utils.improc
|
20 |
+
import utils.misc
|
21 |
+
import utils.saveload
|
22 |
+
from tensorboardX import SummaryWriter
|
23 |
+
import datetime
|
24 |
+
import time
|
25 |
+
import cv2
|
26 |
+
import imageio
|
27 |
+
from nets.blocks import InputPadder
|
28 |
+
from tqdm import tqdm
|
29 |
+
# from pytorch_lightning.callbacks import BaseFinetuning
|
30 |
+
from utils.visualizer import Visualizer
|
31 |
+
from torchvision.transforms.functional import resize
|
32 |
+
|
33 |
+
import torch
|
34 |
+
import requests
|
35 |
+
from PIL import Image, ImageDraw
|
36 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
37 |
+
import numpy as np
|
38 |
+
|
39 |
+
|
40 |
+
torch.set_float32_matmul_precision('medium')
|
41 |
+
|
42 |
+
def run_example(processor, model, task_prompt, image, text_input=None):
|
43 |
+
if text_input is None:
|
44 |
+
prompt = task_prompt
|
45 |
+
else:
|
46 |
+
prompt = task_prompt + text_input
|
47 |
+
inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.float32)
|
48 |
+
generated_ids = model.generate(
|
49 |
+
input_ids=inputs["input_ids"],
|
50 |
+
pixel_values=inputs["pixel_values"],
|
51 |
+
max_new_tokens=1024,
|
52 |
+
num_beams=3
|
53 |
+
)
|
54 |
+
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
55 |
+
|
56 |
+
parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
|
57 |
+
|
58 |
+
return parsed_answer
|
59 |
+
|
60 |
+
|
61 |
+
def polygons_to_mask(image, prediction, fill_value=255):
|
62 |
+
"""
|
63 |
+
Converts polygons into a mask.
|
64 |
+
|
65 |
+
Parameters:
|
66 |
+
- image: A PIL Image instance whose size will be used for the mask.
|
67 |
+
- prediction: Dictionary containing 'polygons' and 'labels'.
|
68 |
+
'polygons' is a list where each element is a list of sub-polygons.
|
69 |
+
- fill_value: The pixel value used to fill the polygon areas (default 255 for a binary mask).
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
- A NumPy array representing the mask (same width and height as the input image).
|
73 |
+
"""
|
74 |
+
# Create a blank grayscale mask image with the same size as the original image.
|
75 |
+
mask = Image.new('L', image.size, 0)
|
76 |
+
draw = ImageDraw.Draw(mask)
|
77 |
+
|
78 |
+
# Iterate over each set of polygons
|
79 |
+
for polygons in prediction['polygons']:
|
80 |
+
# Each element in "polygons" can be a sub-polygon
|
81 |
+
for poly in polygons:
|
82 |
+
# Ensure the polygon is in the right shape and has at least 3 points.
|
83 |
+
poly_arr = np.array(poly).reshape(-1, 2)
|
84 |
+
if poly_arr.shape[0] < 3:
|
85 |
+
print('Skipping invalid polygon:', poly_arr)
|
86 |
+
continue
|
87 |
+
# Convert the polygon vertices into a list for drawing.
|
88 |
+
poly_list = poly_arr.reshape(-1).tolist()
|
89 |
+
# Draw the polygon on the mask with the fill_value.
|
90 |
+
draw.polygon(poly_list, fill=fill_value)
|
91 |
+
|
92 |
+
# Convert the PIL mask image to a NumPy array and return it.
|
93 |
+
return np.array(mask)
|
94 |
+
|
95 |
+
|
96 |
+
class Tracker:
|
97 |
+
def __init__(self, model, mean, std, S, stride, inference_iters, target_res, device='cuda'):
|
98 |
+
"""
|
99 |
+
Initializes the Tracker.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
model: The model used to compute feature maps and forward window flow.
|
103 |
+
mean: Tensor or value used for normalizing the input.
|
104 |
+
std: Tensor or value used for normalizing the input.
|
105 |
+
S: Window size for the tracker.
|
106 |
+
stride: The stride used when updating the window.
|
107 |
+
inference_iters: Number of inference iterations.
|
108 |
+
device: Torch device, defaults to 'cuda'.
|
109 |
+
"""
|
110 |
+
self.model = model
|
111 |
+
self.mean = mean
|
112 |
+
self.std = std
|
113 |
+
self.S = S
|
114 |
+
self.stride = stride
|
115 |
+
self.inference_iters = inference_iters
|
116 |
+
self.device = device
|
117 |
+
self.target_res = target_res
|
118 |
+
|
119 |
+
self.padder = None
|
120 |
+
self.cnt = 0
|
121 |
+
self.fmap_anchor = None
|
122 |
+
self.fmaps2 = None
|
123 |
+
self.flows8 = None
|
124 |
+
self.visconfs8 = None
|
125 |
+
self.flows = [] # List to store computed flows
|
126 |
+
self.visibs = [] # List to store visibility confidences
|
127 |
+
self.rgbs = [] # List to store RGB frames
|
128 |
+
|
129 |
+
def reset(self):
|
130 |
+
"""Reset the tracker state."""
|
131 |
+
self.padder = None
|
132 |
+
self.cnt = 0
|
133 |
+
self.fmap_anchor = None
|
134 |
+
self.fmaps2 = None
|
135 |
+
self.flows8 = None
|
136 |
+
self.visconfs8 = None
|
137 |
+
self.flows = []
|
138 |
+
self.visibs = []
|
139 |
+
self.rgbs = []
|
140 |
+
|
141 |
+
def preprocess(self, rgb_frame):
|
142 |
+
# Resize frame (scale to keep maximum dimension ~1024)
|
143 |
+
scale = min(self.target_res / rgb_frame.shape[0], self.target_res / rgb_frame.shape[1])
|
144 |
+
rgb_resized = cv2.resize(rgb_frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
|
145 |
+
|
146 |
+
# Convert to tensor, normalize and move to device.
|
147 |
+
rgb_tensor = torch.from_numpy(rgb_resized).permute(2, 0, 1).float().unsqueeze(0).to(self.device)
|
148 |
+
rgb_tensor = rgb_tensor / 255.0
|
149 |
+
|
150 |
+
self.rgbs.append(rgb_tensor.cpu())
|
151 |
+
|
152 |
+
# import pdb; pdb.set_trace()
|
153 |
+
rgb_tensor = (rgb_tensor - self.mean) / self.std
|
154 |
+
return rgb_tensor
|
155 |
+
|
156 |
+
@torch.no_grad()
|
157 |
+
def track(self, rgb_frame):
|
158 |
+
"""
|
159 |
+
Process a single RGB frame and return the computed flow when available.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
rgb_frame: A NumPy array containing the RGB frame.
|
163 |
+
(Assumed to be in RGB; if coming from OpenCV, convert it before passing.)
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
flow_predictions: The predicted flow for the current frame (or None if not enough frames have been processed).
|
167 |
+
"""
|
168 |
+
torch.cuda.empty_cache()
|
169 |
+
|
170 |
+
rgb_tensor = self.preprocess(rgb_frame)
|
171 |
+
|
172 |
+
# Initialize padder on the first frame.
|
173 |
+
if self.cnt == 0:
|
174 |
+
self.padder = InputPadder(rgb_tensor.shape)
|
175 |
+
rgb_padded = self.padder.pad(rgb_tensor)[0]
|
176 |
+
_, _, H_pad, W_pad = rgb_padded.shape
|
177 |
+
C = 256 # Feature map channel dimension (could be parameterized if needed)
|
178 |
+
H8, W8 = H_pad // 8, W_pad // 8
|
179 |
+
|
180 |
+
# Accumulate feature maps until the window is full.
|
181 |
+
if self.cnt == 0:
|
182 |
+
self.fmap_anchor = self.model.get_fmaps(rgb_padded, 1, 1, None, False, False).reshape(1, C, H8, W8)
|
183 |
+
self.fmaps2 = self.fmap_anchor[:, None]
|
184 |
+
self.cnt += 1
|
185 |
+
return None
|
186 |
+
|
187 |
+
new_fmap = self.model.get_fmaps(rgb_padded, 1, 1, None, False, False).reshape(1, 1, C, H8, W8)
|
188 |
+
self.fmaps2 = torch.cat([self.fmaps2[:, (1 if self.fmaps2.shape[1] >= self.S else 0):].detach().clone(), new_fmap], dim=1)
|
189 |
+
|
190 |
+
# need to track
|
191 |
+
if self.cnt - self.S + 1 >= 0 and (self.cnt - self.S + 1) % self.stride == 0:
|
192 |
+
# Initialize or update temporary flow buffers.
|
193 |
+
iter_num = self.inference_iters
|
194 |
+
if self.flows8 is None:
|
195 |
+
self.flows8 = torch.zeros((self.S, 2, H_pad // 8, W_pad // 8), device=self.device)
|
196 |
+
self.visconfs8 = torch.zeros((self.S, 2, H_pad // 8, W_pad // 8), device=self.device)
|
197 |
+
# iter_num = self.inference_iters
|
198 |
+
else:
|
199 |
+
self.flows8 = torch.cat([
|
200 |
+
self.flows8[self.stride:self.stride + self.S // 2].detach().clone(),
|
201 |
+
self.flows8[self.stride + self.S // 2 - 1:self.stride + self.S // 2].detach().clone().repeat(self.S // 2, 1, 1, 1)
|
202 |
+
])
|
203 |
+
self.visconfs8 = torch.cat([
|
204 |
+
self.visconfs8[self.stride:self.stride + self.S // 2].detach().clone(),
|
205 |
+
self.visconfs8[self.stride + self.S // 2 - 1:self.stride + self.S // 2].detach().clone().repeat(self.S // 2, 1, 1, 1)
|
206 |
+
])
|
207 |
+
|
208 |
+
# import pdb; pdb.set_trace()
|
209 |
+
# Compute flow predictions using the model's forward window.
|
210 |
+
flow_predictions, visconf_predictions, self.flows8, self.visconfs8, _ = self.model.forward_window(
|
211 |
+
self.fmap_anchor,
|
212 |
+
self.fmaps2,
|
213 |
+
self.visconfs8,
|
214 |
+
iters=iter_num,
|
215 |
+
flowfeat=None,
|
216 |
+
flows8=self.flows8,
|
217 |
+
is_training=False
|
218 |
+
)
|
219 |
+
flow_predictions = self.padder.unpad(flow_predictions[-1][0 if self.cnt == self.S - 1 else -self.stride:])
|
220 |
+
visconf_predictions = self.padder.unpad(torch.sigmoid(visconf_predictions[-1][0 if self.cnt == self.S - 1 else -self.stride:]))
|
221 |
+
|
222 |
+
self.cnt += 1
|
223 |
+
self.flows.append(flow_predictions.cpu())
|
224 |
+
self.visibs.append(visconf_predictions.cpu())
|
225 |
+
|
226 |
+
return flow_predictions, visconf_predictions
|
227 |
+
|
228 |
+
self.cnt += 1
|
229 |
+
return None
|
nets/blocks.py
ADDED
@@ -0,0 +1,1337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn, Tensor
|
5 |
+
from itertools import repeat
|
6 |
+
import collections
|
7 |
+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
|
8 |
+
from functools import partial
|
9 |
+
import einops
|
10 |
+
import math
|
11 |
+
from torchvision.ops.misc import Conv2dNormActivation, Permute
|
12 |
+
from torchvision.ops.stochastic_depth import StochasticDepth
|
13 |
+
|
14 |
+
def _ntuple(n):
|
15 |
+
def parse(x):
|
16 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
17 |
+
return tuple(x)
|
18 |
+
return tuple(repeat(x, n))
|
19 |
+
return parse
|
20 |
+
|
21 |
+
def exists(val):
|
22 |
+
return val is not None
|
23 |
+
|
24 |
+
def default(val, d):
|
25 |
+
return val if exists(val) else d
|
26 |
+
|
27 |
+
to_2tuple = _ntuple(2)
|
28 |
+
|
29 |
+
class InputPadder:
|
30 |
+
""" Pads images such that dimensions are divisible by a certain stride """
|
31 |
+
def __init__(self, dims, mode='sintel'):
|
32 |
+
self.ht, self.wd = dims[-2:]
|
33 |
+
pad_ht = (((self.ht // 64) + 1) * 64 - self.ht) % 64
|
34 |
+
pad_wd = (((self.wd // 64) + 1) * 64 - self.wd) % 64
|
35 |
+
if mode == 'sintel':
|
36 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
37 |
+
else:
|
38 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
39 |
+
|
40 |
+
def pad(self, *inputs):
|
41 |
+
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
42 |
+
|
43 |
+
def unpad(self, x):
|
44 |
+
ht, wd = x.shape[-2:]
|
45 |
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
46 |
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
47 |
+
|
48 |
+
def bilinear_sampler(
|
49 |
+
input, coords,
|
50 |
+
align_corners=True,
|
51 |
+
padding_mode="border",
|
52 |
+
normalize_coords=True):
|
53 |
+
# func from mattie (oct9)
|
54 |
+
if input.ndim not in [4, 5]:
|
55 |
+
raise ValueError("input must be 4D or 5D.")
|
56 |
+
|
57 |
+
if input.ndim == 4 and not coords.ndim == 4:
|
58 |
+
raise ValueError("input is 4D, but coords is not 4D.")
|
59 |
+
|
60 |
+
if input.ndim == 5 and not coords.ndim == 5:
|
61 |
+
raise ValueError("input is 5D, but coords is not 5D.")
|
62 |
+
|
63 |
+
if coords.ndim == 5:
|
64 |
+
coords = coords[..., [1, 2, 0]] # t x y -> x y t to match what grid_sample() expects.
|
65 |
+
|
66 |
+
if normalize_coords:
|
67 |
+
if align_corners:
|
68 |
+
# Normalize coordinates from [0, W/H - 1] to [-1, 1].
|
69 |
+
coords = (
|
70 |
+
coords
|
71 |
+
* torch.tensor([2 / max(size - 1, 1) for size in reversed(input.shape[2:])], device=coords.device)
|
72 |
+
- 1
|
73 |
+
)
|
74 |
+
else:
|
75 |
+
# Normalize coordinates from [0, W/H] to [-1, 1].
|
76 |
+
coords = coords * torch.tensor([2 / size for size in reversed(input.shape[2:])], device=coords.device) - 1
|
77 |
+
|
78 |
+
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
79 |
+
|
80 |
+
|
81 |
+
class CorrBlock:
|
82 |
+
def __init__(self, fmap1, fmap2, corr_levels, corr_radius):
|
83 |
+
self.num_levels = corr_levels
|
84 |
+
self.radius = corr_radius
|
85 |
+
self.corr_pyramid = []
|
86 |
+
# all pairs correlation
|
87 |
+
for i in range(self.num_levels):
|
88 |
+
corr = CorrBlock.corr(fmap1, fmap2, 1)
|
89 |
+
batch, h1, w1, dim, h2, w2 = corr.shape
|
90 |
+
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
91 |
+
fmap2 = F.interpolate(fmap2, scale_factor=0.5, mode='area')
|
92 |
+
# print('corr', corr.shape)
|
93 |
+
self.corr_pyramid.append(corr)
|
94 |
+
|
95 |
+
def __call__(self, coords, dilation=None):
|
96 |
+
r = self.radius
|
97 |
+
coords = coords.permute(0, 2, 3, 1)
|
98 |
+
batch, h1, w1, _ = coords.shape
|
99 |
+
|
100 |
+
if dilation is None:
|
101 |
+
dilation = torch.ones(batch, 1, h1, w1, device=coords.device)
|
102 |
+
|
103 |
+
# print(dilation.max(), dilation.mean(), dilation.min())
|
104 |
+
out_pyramid = []
|
105 |
+
for i in range(self.num_levels):
|
106 |
+
corr = self.corr_pyramid[i]
|
107 |
+
device = coords.device
|
108 |
+
dx = torch.linspace(-r, r, 2*r+1, device=device)
|
109 |
+
dy = torch.linspace(-r, r, 2*r+1, device=device)
|
110 |
+
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
|
111 |
+
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
112 |
+
delta_lvl = delta_lvl * dilation.view(batch * h1 * w1, 1, 1, 1)
|
113 |
+
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
114 |
+
coords_lvl = centroid_lvl + delta_lvl
|
115 |
+
corr = bilinear_sampler(corr, coords_lvl)
|
116 |
+
corr = corr.view(batch, h1, w1, -1)
|
117 |
+
out_pyramid.append(corr)
|
118 |
+
|
119 |
+
out = torch.cat(out_pyramid, dim=-1)
|
120 |
+
out = out.permute(0, 3, 1, 2).contiguous().float()
|
121 |
+
return out
|
122 |
+
|
123 |
+
@staticmethod
|
124 |
+
def corr(fmap1, fmap2, num_head):
|
125 |
+
batch, dim, h1, w1 = fmap1.shape
|
126 |
+
h2, w2 = fmap2.shape[2:]
|
127 |
+
fmap1 = fmap1.view(batch, num_head, dim // num_head, h1*w1)
|
128 |
+
fmap2 = fmap2.view(batch, num_head, dim // num_head, h2*w2)
|
129 |
+
corr = fmap1.transpose(2, 3) @ fmap2
|
130 |
+
corr = corr.reshape(batch, num_head, h1, w1, h2, w2).permute(0, 2, 3, 1, 4, 5)
|
131 |
+
return corr / torch.sqrt(torch.tensor(dim).float())
|
132 |
+
|
133 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
134 |
+
"""1x1 convolution without padding"""
|
135 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0)
|
136 |
+
|
137 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
138 |
+
"""3x3 convolution with padding"""
|
139 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1)
|
140 |
+
|
141 |
+
class LayerNorm2d(nn.LayerNorm):
|
142 |
+
def forward(self, x: Tensor) -> Tensor:
|
143 |
+
x = x.permute(0, 2, 3, 1)
|
144 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
145 |
+
x = x.permute(0, 3, 1, 2)
|
146 |
+
return x
|
147 |
+
|
148 |
+
class CNBlock1d(nn.Module):
|
149 |
+
def __init__(
|
150 |
+
self,
|
151 |
+
dim,
|
152 |
+
output_dim,
|
153 |
+
layer_scale: float = 1e-6,
|
154 |
+
stochastic_depth_prob: float = 0,
|
155 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
156 |
+
dense=True,
|
157 |
+
use_attn=True,
|
158 |
+
use_mixer=False,
|
159 |
+
use_conv=False,
|
160 |
+
use_convb=False,
|
161 |
+
use_layer_scale=True,
|
162 |
+
) -> None:
|
163 |
+
super().__init__()
|
164 |
+
self.dense = dense
|
165 |
+
self.use_attn = use_attn
|
166 |
+
self.use_mixer = use_mixer
|
167 |
+
self.use_conv = use_conv
|
168 |
+
self.use_layer_scale = use_layer_scale
|
169 |
+
|
170 |
+
if use_attn:
|
171 |
+
assert not use_mixer
|
172 |
+
assert not use_conv
|
173 |
+
assert not use_convb
|
174 |
+
|
175 |
+
if norm_layer is None:
|
176 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
177 |
+
|
178 |
+
if use_attn:
|
179 |
+
num_heads = 8
|
180 |
+
self.block = AttnBlock(
|
181 |
+
hidden_size=dim,
|
182 |
+
num_heads=num_heads,
|
183 |
+
mlp_ratio=4,
|
184 |
+
attn_class=Attention,
|
185 |
+
)
|
186 |
+
elif use_mixer:
|
187 |
+
self.block = MLPMixerBlock(
|
188 |
+
S=16,
|
189 |
+
dim=dim,
|
190 |
+
depth=1,
|
191 |
+
expansion_factor=2,
|
192 |
+
)
|
193 |
+
elif use_conv:
|
194 |
+
self.block = nn.Sequential(
|
195 |
+
nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True, padding_mode='zeros'),
|
196 |
+
Permute([0, 2, 1]),
|
197 |
+
norm_layer(dim),
|
198 |
+
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
|
199 |
+
nn.GELU(),
|
200 |
+
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
|
201 |
+
Permute([0, 2, 1]),
|
202 |
+
)
|
203 |
+
elif use_convb:
|
204 |
+
self.block = nn.Sequential(
|
205 |
+
nn.Conv1d(dim, dim, kernel_size=3, padding=1, bias=True, padding_mode='zeros'),
|
206 |
+
Permute([0, 2, 1]),
|
207 |
+
norm_layer(dim),
|
208 |
+
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
|
209 |
+
nn.GELU(),
|
210 |
+
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
|
211 |
+
Permute([0, 2, 1]),
|
212 |
+
)
|
213 |
+
else:
|
214 |
+
assert(False) # choose attn, mixer, or conv please
|
215 |
+
|
216 |
+
if self.use_layer_scale:
|
217 |
+
self.layer_scale = nn.Parameter(torch.ones(dim, 1) * layer_scale)
|
218 |
+
else:
|
219 |
+
self.layer_scale = 1.0
|
220 |
+
|
221 |
+
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
|
222 |
+
|
223 |
+
if output_dim != dim:
|
224 |
+
self.final = nn.Conv1d(dim, output_dim, kernel_size=1, padding=0)
|
225 |
+
else:
|
226 |
+
self.final = nn.Identity()
|
227 |
+
|
228 |
+
def forward(self, input, S=None):
|
229 |
+
if self.dense:
|
230 |
+
assert S is not None
|
231 |
+
BS,C,H,W = input.shape
|
232 |
+
B = BS//S
|
233 |
+
|
234 |
+
# if S<7:
|
235 |
+
# return input
|
236 |
+
|
237 |
+
input = einops.rearrange(input, '(b s) c h w -> (b h w) c s', b=B, s=S, c=C, h=H, w=W)
|
238 |
+
|
239 |
+
if self.use_mixer or self.use_attn:
|
240 |
+
# mixer/transformer blocks want B,S,C
|
241 |
+
result = self.layer_scale * self.block(input.permute(0,2,1)).permute(0,2,1)
|
242 |
+
else:
|
243 |
+
result = self.layer_scale * self.block(input)
|
244 |
+
result = self.stochastic_depth(result)
|
245 |
+
result += input
|
246 |
+
result = self.final(result)
|
247 |
+
|
248 |
+
result = einops.rearrange(result, '(b h w) c s -> (b s) c h w', b=B, s=S, c=C, h=H, w=W)
|
249 |
+
else:
|
250 |
+
B,S,C = input.shape
|
251 |
+
|
252 |
+
if S<7:
|
253 |
+
return input
|
254 |
+
|
255 |
+
input = einops.rearrange(input, 'b s c -> b c s', b=B, s=S, c=C)
|
256 |
+
|
257 |
+
result = self.layer_scale * self.block(input)
|
258 |
+
result = self.stochastic_depth(result)
|
259 |
+
result += input
|
260 |
+
|
261 |
+
result = self.final(result)
|
262 |
+
|
263 |
+
result = einops.rearrange(result, 'b c s -> b s c', b=B, s=S, c=C)
|
264 |
+
|
265 |
+
return result
|
266 |
+
|
267 |
+
class CNBlock2d(nn.Module):
|
268 |
+
def __init__(
|
269 |
+
self,
|
270 |
+
dim,
|
271 |
+
output_dim,
|
272 |
+
layer_scale: float = 1e-6,
|
273 |
+
stochastic_depth_prob: float = 0,
|
274 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
275 |
+
use_layer_scale=True,
|
276 |
+
) -> None:
|
277 |
+
super().__init__()
|
278 |
+
self.use_layer_scale = use_layer_scale
|
279 |
+
if norm_layer is None:
|
280 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
281 |
+
|
282 |
+
self.block = nn.Sequential(
|
283 |
+
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True, padding_mode='zeros'),
|
284 |
+
Permute([0, 2, 3, 1]),
|
285 |
+
norm_layer(dim),
|
286 |
+
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
|
287 |
+
nn.GELU(),
|
288 |
+
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
|
289 |
+
Permute([0, 3, 1, 2]),
|
290 |
+
)
|
291 |
+
if self.use_layer_scale:
|
292 |
+
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
|
293 |
+
else:
|
294 |
+
self.layer_scale = 1.0
|
295 |
+
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
|
296 |
+
|
297 |
+
if output_dim != dim:
|
298 |
+
self.final = nn.Conv2d(dim, output_dim, kernel_size=1, padding=0)
|
299 |
+
else:
|
300 |
+
self.final = nn.Identity()
|
301 |
+
|
302 |
+
def forward(self, input, S=None):
|
303 |
+
result = self.layer_scale * self.block(input)
|
304 |
+
result = self.stochastic_depth(result)
|
305 |
+
result += input
|
306 |
+
result = self.final(result)
|
307 |
+
return result
|
308 |
+
|
309 |
+
class CNBlockConfig:
|
310 |
+
# Stores information listed at Section 3 of the ConvNeXt paper
|
311 |
+
def __init__(
|
312 |
+
self,
|
313 |
+
input_channels: int,
|
314 |
+
out_channels: Optional[int],
|
315 |
+
num_layers: int,
|
316 |
+
downsample: bool,
|
317 |
+
) -> None:
|
318 |
+
self.input_channels = input_channels
|
319 |
+
self.out_channels = out_channels
|
320 |
+
self.num_layers = num_layers
|
321 |
+
self.downsample = downsample
|
322 |
+
|
323 |
+
def __repr__(self) -> str:
|
324 |
+
s = self.__class__.__name__ + "("
|
325 |
+
s += "input_channels={input_channels}"
|
326 |
+
s += ", out_channels={out_channels}"
|
327 |
+
s += ", num_layers={num_layers}"
|
328 |
+
s += ", downsample={downsample}"
|
329 |
+
s += ")"
|
330 |
+
return s.format(**self.__dict__)
|
331 |
+
|
332 |
+
class ConvNeXt(nn.Module):
|
333 |
+
def __init__(
|
334 |
+
self,
|
335 |
+
block_setting: List[CNBlockConfig],
|
336 |
+
stochastic_depth_prob: float = 0.0,
|
337 |
+
layer_scale: float = 1e-6,
|
338 |
+
num_classes: int = 1000,
|
339 |
+
block: Optional[Callable[..., nn.Module]] = None,
|
340 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
341 |
+
init_weights=True):
|
342 |
+
super().__init__()
|
343 |
+
|
344 |
+
self.init_weights = init_weights
|
345 |
+
|
346 |
+
if not block_setting:
|
347 |
+
raise ValueError("The block_setting should not be empty")
|
348 |
+
elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
|
349 |
+
raise TypeError("The block_setting should be List[CNBlockConfig]")
|
350 |
+
|
351 |
+
if block is None:
|
352 |
+
block = CNBlock2d
|
353 |
+
|
354 |
+
if norm_layer is None:
|
355 |
+
norm_layer = partial(LayerNorm2d, eps=1e-6)
|
356 |
+
|
357 |
+
layers: List[nn.Module] = []
|
358 |
+
|
359 |
+
# Stem
|
360 |
+
firstconv_output_channels = block_setting[0].input_channels
|
361 |
+
layers.append(
|
362 |
+
Conv2dNormActivation(
|
363 |
+
3,
|
364 |
+
firstconv_output_channels,
|
365 |
+
kernel_size=4,
|
366 |
+
stride=4,
|
367 |
+
padding=0,
|
368 |
+
norm_layer=norm_layer,
|
369 |
+
activation_layer=None,
|
370 |
+
bias=True,
|
371 |
+
)
|
372 |
+
)
|
373 |
+
|
374 |
+
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
|
375 |
+
stage_block_id = 0
|
376 |
+
for cnf in block_setting:
|
377 |
+
# Bottlenecks
|
378 |
+
stage: List[nn.Module] = []
|
379 |
+
for _ in range(cnf.num_layers):
|
380 |
+
# adjust stochastic depth probability based on the depth of the stage block
|
381 |
+
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
|
382 |
+
stage.append(block(cnf.input_channels, cnf.input_channels, layer_scale, sd_prob))
|
383 |
+
stage_block_id += 1
|
384 |
+
layers.append(nn.Sequential(*stage))
|
385 |
+
if cnf.out_channels is not None:
|
386 |
+
if cnf.downsample:
|
387 |
+
layers.append(
|
388 |
+
nn.Sequential(
|
389 |
+
norm_layer(cnf.input_channels),
|
390 |
+
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
|
391 |
+
)
|
392 |
+
)
|
393 |
+
else:
|
394 |
+
# we convert the 2x2 downsampling layer into a 3x3 with dilation2 and replicate padding.
|
395 |
+
# replicate padding compensates for the fact that this kernel never saw zero-padding.
|
396 |
+
layers.append(
|
397 |
+
nn.Sequential(
|
398 |
+
norm_layer(cnf.input_channels),
|
399 |
+
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=3, stride=1, padding=2, dilation=2, padding_mode='zeros'),
|
400 |
+
)
|
401 |
+
)
|
402 |
+
|
403 |
+
self.features = nn.Sequential(*layers)
|
404 |
+
|
405 |
+
# self.final_conv = conv1x1(block_setting[-1].input_channels, output_dim)
|
406 |
+
|
407 |
+
for m in self.modules():
|
408 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
409 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
410 |
+
if m.bias is not None:
|
411 |
+
nn.init.zeros_(m.bias)
|
412 |
+
|
413 |
+
if self.init_weights:
|
414 |
+
from torchvision.models import convnext_tiny, ConvNeXt_Tiny_Weights
|
415 |
+
pretrained_dict = convnext_tiny(weights=ConvNeXt_Tiny_Weights.DEFAULT).state_dict()
|
416 |
+
# from torchvision.models import convnext_base, ConvNeXt_Base_Weights
|
417 |
+
# pretrained_dict = convnext_base(weights=ConvNeXt_Base_Weights.DEFAULT).state_dict()
|
418 |
+
model_dict = self.state_dict()
|
419 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
420 |
+
# if self.input_dim == 6:
|
421 |
+
# for k, v in pretrained_dict.items():
|
422 |
+
# if k == 'conv1.weight':
|
423 |
+
# pretrained_dict[k] = torch.cat((v, v), dim=1)
|
424 |
+
|
425 |
+
# del pretrained_dict['features.4.1.weight']
|
426 |
+
|
427 |
+
for k, v in pretrained_dict.items():
|
428 |
+
if k == 'features.4.1.weight': # this is the layer normally in charge of 2x2 downsampling
|
429 |
+
# convert to 3x3 filter
|
430 |
+
pretrained_dict[k] = F.interpolate(v, (3, 3), mode='bicubic', align_corners=True) * (4/9.0)
|
431 |
+
|
432 |
+
# print('v0', v[0,0])
|
433 |
+
# print('d0', pretrained_dict[k][0,0])
|
434 |
+
|
435 |
+
model_dict.update(pretrained_dict)
|
436 |
+
self.load_state_dict(model_dict, strict=False)
|
437 |
+
|
438 |
+
|
439 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
440 |
+
# with torch.no_grad():
|
441 |
+
x = self.features(x)
|
442 |
+
# x = self.final_conv(x)
|
443 |
+
return x
|
444 |
+
|
445 |
+
def forward(self, x: Tensor) -> Tensor:
|
446 |
+
return self._forward_impl(x)
|
447 |
+
|
448 |
+
class Mlp(nn.Module):
|
449 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
450 |
+
|
451 |
+
def __init__(
|
452 |
+
self,
|
453 |
+
in_features,
|
454 |
+
hidden_features=None,
|
455 |
+
out_features=None,
|
456 |
+
act_layer=nn.GELU,
|
457 |
+
norm_layer=None,
|
458 |
+
bias=True,
|
459 |
+
drop=0.0,
|
460 |
+
use_conv=False,
|
461 |
+
):
|
462 |
+
super().__init__()
|
463 |
+
out_features = out_features or in_features
|
464 |
+
hidden_features = hidden_features or in_features
|
465 |
+
bias = to_2tuple(bias)
|
466 |
+
drop_probs = to_2tuple(drop)
|
467 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
468 |
+
|
469 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
470 |
+
self.act = act_layer()
|
471 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
472 |
+
self.norm = (
|
473 |
+
norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
474 |
+
)
|
475 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
476 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
477 |
+
|
478 |
+
def forward(self, x):
|
479 |
+
x = self.fc1(x)
|
480 |
+
x = self.act(x)
|
481 |
+
x = self.drop1(x)
|
482 |
+
x = self.fc2(x)
|
483 |
+
x = self.drop2(x)
|
484 |
+
return x
|
485 |
+
|
486 |
+
class Attention(nn.Module):
|
487 |
+
def __init__(
|
488 |
+
self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False
|
489 |
+
):
|
490 |
+
super().__init__()
|
491 |
+
inner_dim = dim_head * num_heads
|
492 |
+
context_dim = default(context_dim, query_dim)
|
493 |
+
self.scale = dim_head**-0.5
|
494 |
+
self.heads = num_heads
|
495 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
|
496 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
|
497 |
+
self.to_out = nn.Linear(inner_dim, query_dim)
|
498 |
+
|
499 |
+
def forward(self, x, context=None, attn_bias=None):
|
500 |
+
B, N1, C = x.shape
|
501 |
+
H = self.heads
|
502 |
+
q = self.to_q(x)
|
503 |
+
context = default(context, x)
|
504 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
505 |
+
q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))
|
506 |
+
x = F.scaled_dot_product_attention(q, k, v) # scale default is already dim^-0.5
|
507 |
+
x = einops.rearrange(x, 'b h n d -> b n (h d)')
|
508 |
+
return self.to_out(x)
|
509 |
+
|
510 |
+
class CrossAttnBlock(nn.Module):
|
511 |
+
def __init__(
|
512 |
+
self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs
|
513 |
+
):
|
514 |
+
super().__init__()
|
515 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
516 |
+
self.norm_context = nn.LayerNorm(hidden_size)
|
517 |
+
self.cross_attn = Attention(
|
518 |
+
hidden_size,
|
519 |
+
context_dim=context_dim,
|
520 |
+
num_heads=num_heads,
|
521 |
+
qkv_bias=True,
|
522 |
+
**block_kwargs
|
523 |
+
)
|
524 |
+
|
525 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
526 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
527 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
528 |
+
self.mlp = Mlp(
|
529 |
+
in_features=hidden_size,
|
530 |
+
hidden_features=mlp_hidden_dim,
|
531 |
+
act_layer=approx_gelu,
|
532 |
+
drop=0,
|
533 |
+
)
|
534 |
+
|
535 |
+
def forward(self, x, context, mask=None):
|
536 |
+
attn_bias = None
|
537 |
+
if mask is not None:
|
538 |
+
if mask.shape[1] == x.shape[1]:
|
539 |
+
mask = mask[:, None, :, None].expand(
|
540 |
+
-1, self.cross_attn.heads, -1, context.shape[1]
|
541 |
+
)
|
542 |
+
else:
|
543 |
+
mask = mask[:, None, None].expand(
|
544 |
+
-1, self.cross_attn.heads, x.shape[1], -1
|
545 |
+
)
|
546 |
+
|
547 |
+
max_neg_value = -torch.finfo(x.dtype).max
|
548 |
+
attn_bias = (~mask) * max_neg_value
|
549 |
+
x = x + self.cross_attn(
|
550 |
+
self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias
|
551 |
+
)
|
552 |
+
x = x + self.mlp(self.norm2(x))
|
553 |
+
return x
|
554 |
+
|
555 |
+
class AttnBlock(nn.Module):
|
556 |
+
def __init__(
|
557 |
+
self,
|
558 |
+
hidden_size,
|
559 |
+
num_heads,
|
560 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
561 |
+
mlp_ratio=4.0,
|
562 |
+
**block_kwargs
|
563 |
+
):
|
564 |
+
super().__init__()
|
565 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
566 |
+
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, dim_head=hidden_size//num_heads)
|
567 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
568 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
569 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
570 |
+
self.mlp = Mlp(
|
571 |
+
in_features=hidden_size,
|
572 |
+
hidden_features=mlp_hidden_dim,
|
573 |
+
act_layer=approx_gelu,
|
574 |
+
drop=0,
|
575 |
+
)
|
576 |
+
|
577 |
+
def forward(self, x, mask=None):
|
578 |
+
attn_bias = mask
|
579 |
+
if mask is not None:
|
580 |
+
mask = (
|
581 |
+
(mask[:, None] * mask[:, :, None])
|
582 |
+
.unsqueeze(1)
|
583 |
+
.expand(-1, self.attn.num_heads, -1, -1)
|
584 |
+
)
|
585 |
+
max_neg_value = -torch.finfo(x.dtype).max
|
586 |
+
attn_bias = (~mask) * max_neg_value
|
587 |
+
|
588 |
+
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
|
589 |
+
x = x + self.mlp(self.norm2(x))
|
590 |
+
return x
|
591 |
+
|
592 |
+
|
593 |
+
class ResidualBlock(nn.Module):
|
594 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
595 |
+
super(ResidualBlock, self).__init__()
|
596 |
+
|
597 |
+
self.conv1 = nn.Conv2d(
|
598 |
+
in_planes,
|
599 |
+
planes,
|
600 |
+
kernel_size=3,
|
601 |
+
padding=1,
|
602 |
+
stride=stride,
|
603 |
+
padding_mode="zeros",
|
604 |
+
)
|
605 |
+
self.conv2 = nn.Conv2d(
|
606 |
+
planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
|
607 |
+
)
|
608 |
+
self.relu = nn.ReLU(inplace=True)
|
609 |
+
|
610 |
+
num_groups = planes // 8
|
611 |
+
|
612 |
+
if norm_fn == "group":
|
613 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
614 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
615 |
+
if not stride == 1:
|
616 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
617 |
+
|
618 |
+
elif norm_fn == "batch":
|
619 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
620 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
621 |
+
if not stride == 1:
|
622 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
623 |
+
|
624 |
+
elif norm_fn == "instance":
|
625 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
626 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
627 |
+
if not stride == 1:
|
628 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
629 |
+
|
630 |
+
elif norm_fn == "none":
|
631 |
+
self.norm1 = nn.Sequential()
|
632 |
+
self.norm2 = nn.Sequential()
|
633 |
+
if not stride == 1:
|
634 |
+
self.norm3 = nn.Sequential()
|
635 |
+
|
636 |
+
if stride == 1:
|
637 |
+
self.downsample = None
|
638 |
+
|
639 |
+
else:
|
640 |
+
self.downsample = nn.Sequential(
|
641 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
642 |
+
)
|
643 |
+
|
644 |
+
def forward(self, x):
|
645 |
+
y = x
|
646 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
647 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
648 |
+
|
649 |
+
if self.downsample is not None:
|
650 |
+
x = self.downsample(x)
|
651 |
+
|
652 |
+
return self.relu(x + y)
|
653 |
+
|
654 |
+
|
655 |
+
class BasicEncoder(nn.Module):
|
656 |
+
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
657 |
+
super(BasicEncoder, self).__init__()
|
658 |
+
self.stride = stride
|
659 |
+
self.norm_fn = "instance"
|
660 |
+
self.in_planes = output_dim // 2
|
661 |
+
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
662 |
+
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
663 |
+
|
664 |
+
self.conv1 = nn.Conv2d(
|
665 |
+
input_dim,
|
666 |
+
self.in_planes,
|
667 |
+
kernel_size=7,
|
668 |
+
stride=2,
|
669 |
+
padding=3,
|
670 |
+
padding_mode="zeros",
|
671 |
+
)
|
672 |
+
self.relu1 = nn.ReLU(inplace=True)
|
673 |
+
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
674 |
+
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
675 |
+
self.layer3 = self._make_layer(output_dim, stride=2)
|
676 |
+
self.layer4 = self._make_layer(output_dim, stride=2)
|
677 |
+
|
678 |
+
self.conv2 = nn.Conv2d(
|
679 |
+
output_dim * 3 + output_dim // 4,
|
680 |
+
output_dim * 2,
|
681 |
+
kernel_size=3,
|
682 |
+
padding=1,
|
683 |
+
padding_mode="zeros",
|
684 |
+
)
|
685 |
+
self.relu2 = nn.ReLU(inplace=True)
|
686 |
+
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
687 |
+
for m in self.modules():
|
688 |
+
if isinstance(m, nn.Conv2d):
|
689 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
690 |
+
elif isinstance(m, (nn.InstanceNorm2d)):
|
691 |
+
if m.weight is not None:
|
692 |
+
nn.init.constant_(m.weight, 1)
|
693 |
+
if m.bias is not None:
|
694 |
+
nn.init.constant_(m.bias, 0)
|
695 |
+
|
696 |
+
def _make_layer(self, dim, stride=1):
|
697 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
698 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
699 |
+
layers = (layer1, layer2)
|
700 |
+
|
701 |
+
self.in_planes = dim
|
702 |
+
return nn.Sequential(*layers)
|
703 |
+
|
704 |
+
def forward(self, x):
|
705 |
+
_, _, H, W = x.shape
|
706 |
+
|
707 |
+
# with torch.no_grad():
|
708 |
+
x = self.conv1(x)
|
709 |
+
x = self.norm1(x)
|
710 |
+
x = self.relu1(x)
|
711 |
+
|
712 |
+
a = self.layer1(x)
|
713 |
+
b = self.layer2(a)
|
714 |
+
c = self.layer3(b)
|
715 |
+
d = self.layer4(c)
|
716 |
+
|
717 |
+
def _bilinear_intepolate(x):
|
718 |
+
return F.interpolate(
|
719 |
+
x,
|
720 |
+
(H // self.stride, W // self.stride),
|
721 |
+
mode="bilinear",
|
722 |
+
align_corners=True,
|
723 |
+
)
|
724 |
+
|
725 |
+
a = _bilinear_intepolate(a)
|
726 |
+
b = _bilinear_intepolate(b)
|
727 |
+
c = _bilinear_intepolate(c)
|
728 |
+
d = _bilinear_intepolate(d)
|
729 |
+
|
730 |
+
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
731 |
+
x = self.norm2(x)
|
732 |
+
x = self.relu2(x)
|
733 |
+
x = self.conv3(x)
|
734 |
+
return x
|
735 |
+
|
736 |
+
class EfficientUpdateFormer(nn.Module):
|
737 |
+
"""
|
738 |
+
Transformer model that updates track estimates.
|
739 |
+
"""
|
740 |
+
|
741 |
+
def __init__(
|
742 |
+
self,
|
743 |
+
space_depth=6,
|
744 |
+
time_depth=6,
|
745 |
+
input_dim=320,
|
746 |
+
hidden_size=384,
|
747 |
+
num_heads=8,
|
748 |
+
output_dim=130,
|
749 |
+
mlp_ratio=4.0,
|
750 |
+
num_virtual_tracks=64,
|
751 |
+
add_space_attn=True,
|
752 |
+
linear_layer_for_vis_conf=False,
|
753 |
+
use_time_conv=False,
|
754 |
+
use_time_mixer=False,
|
755 |
+
):
|
756 |
+
super().__init__()
|
757 |
+
self.out_channels = 2
|
758 |
+
self.num_heads = num_heads
|
759 |
+
self.hidden_size = hidden_size
|
760 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
761 |
+
if linear_layer_for_vis_conf:
|
762 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True)
|
763 |
+
self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
|
764 |
+
else:
|
765 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
766 |
+
self.num_virtual_tracks = num_virtual_tracks
|
767 |
+
self.virual_tracks = nn.Parameter(
|
768 |
+
torch.randn(1, num_virtual_tracks, 1, hidden_size)
|
769 |
+
)
|
770 |
+
self.add_space_attn = add_space_attn
|
771 |
+
self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
|
772 |
+
|
773 |
+
if use_time_conv:
|
774 |
+
self.time_blocks = nn.ModuleList(
|
775 |
+
[
|
776 |
+
CNBlock1d(hidden_size, hidden_size, dense=False)
|
777 |
+
for _ in range(time_depth)
|
778 |
+
]
|
779 |
+
)
|
780 |
+
elif use_time_mixer:
|
781 |
+
self.time_blocks = nn.ModuleList(
|
782 |
+
[
|
783 |
+
MLPMixerBlock(
|
784 |
+
S=16,
|
785 |
+
dim=hidden_size,
|
786 |
+
depth=1,
|
787 |
+
)
|
788 |
+
for _ in range(time_depth)
|
789 |
+
]
|
790 |
+
)
|
791 |
+
else:
|
792 |
+
self.time_blocks = nn.ModuleList(
|
793 |
+
[
|
794 |
+
AttnBlock(
|
795 |
+
hidden_size,
|
796 |
+
num_heads,
|
797 |
+
mlp_ratio=mlp_ratio,
|
798 |
+
attn_class=Attention,
|
799 |
+
)
|
800 |
+
for _ in range(time_depth)
|
801 |
+
]
|
802 |
+
)
|
803 |
+
|
804 |
+
if add_space_attn:
|
805 |
+
self.space_virtual_blocks = nn.ModuleList(
|
806 |
+
[
|
807 |
+
AttnBlock(
|
808 |
+
hidden_size,
|
809 |
+
num_heads,
|
810 |
+
mlp_ratio=mlp_ratio,
|
811 |
+
attn_class=Attention,
|
812 |
+
)
|
813 |
+
for _ in range(space_depth)
|
814 |
+
]
|
815 |
+
)
|
816 |
+
self.space_point2virtual_blocks = nn.ModuleList(
|
817 |
+
[
|
818 |
+
CrossAttnBlock(
|
819 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
820 |
+
)
|
821 |
+
for _ in range(space_depth)
|
822 |
+
]
|
823 |
+
)
|
824 |
+
self.space_virtual2point_blocks = nn.ModuleList(
|
825 |
+
[
|
826 |
+
CrossAttnBlock(
|
827 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
828 |
+
)
|
829 |
+
for _ in range(space_depth)
|
830 |
+
]
|
831 |
+
)
|
832 |
+
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
|
833 |
+
self.initialize_weights()
|
834 |
+
|
835 |
+
def initialize_weights(self):
|
836 |
+
def _basic_init(module):
|
837 |
+
if isinstance(module, nn.Linear):
|
838 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
839 |
+
if module.bias is not None:
|
840 |
+
nn.init.constant_(module.bias, 0)
|
841 |
+
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
|
842 |
+
if self.linear_layer_for_vis_conf:
|
843 |
+
torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001)
|
844 |
+
|
845 |
+
def _trunc_init(module):
|
846 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
847 |
+
if isinstance(module, nn.Linear):
|
848 |
+
torch.nn.init.trunc_normal_(module.weight, std=0.02)
|
849 |
+
if module.bias is not None:
|
850 |
+
nn.init.zeros_(module.bias)
|
851 |
+
|
852 |
+
self.apply(_basic_init)
|
853 |
+
|
854 |
+
def forward(self, input_tensor, mask=None, add_space_attn=True):
|
855 |
+
tokens = self.input_transform(input_tensor)
|
856 |
+
|
857 |
+
B, _, T, _ = tokens.shape
|
858 |
+
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
|
859 |
+
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
860 |
+
|
861 |
+
_, N, _, _ = tokens.shape
|
862 |
+
j = 0
|
863 |
+
layers = []
|
864 |
+
for i in range(len(self.time_blocks)):
|
865 |
+
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
866 |
+
time_tokens = self.time_blocks[i](time_tokens)
|
867 |
+
|
868 |
+
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
869 |
+
if (
|
870 |
+
add_space_attn
|
871 |
+
and hasattr(self, "space_virtual_blocks")
|
872 |
+
and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0)
|
873 |
+
):
|
874 |
+
space_tokens = (
|
875 |
+
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
|
876 |
+
) # B N T C -> (B T) N C
|
877 |
+
|
878 |
+
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
|
879 |
+
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
|
880 |
+
|
881 |
+
virtual_tokens = self.space_virtual2point_blocks[j](
|
882 |
+
virtual_tokens, point_tokens, mask=mask
|
883 |
+
)
|
884 |
+
|
885 |
+
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
|
886 |
+
point_tokens = self.space_point2virtual_blocks[j](
|
887 |
+
point_tokens, virtual_tokens, mask=mask
|
888 |
+
)
|
889 |
+
|
890 |
+
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
891 |
+
tokens = space_tokens.view(B, T, N, -1).permute(
|
892 |
+
0, 2, 1, 3
|
893 |
+
) # (B T) N C -> B N T C
|
894 |
+
j += 1
|
895 |
+
tokens = tokens[:, : N - self.num_virtual_tracks]
|
896 |
+
|
897 |
+
flow = self.flow_head(tokens)
|
898 |
+
if self.linear_layer_for_vis_conf:
|
899 |
+
vis_conf = self.vis_conf_head(tokens)
|
900 |
+
flow = torch.cat([flow, vis_conf], dim=-1)
|
901 |
+
|
902 |
+
return flow
|
903 |
+
|
904 |
+
|
905 |
+
class MMPreNormResidual(nn.Module):
|
906 |
+
def __init__(self, dim, fn):
|
907 |
+
super().__init__()
|
908 |
+
self.fn = fn
|
909 |
+
self.norm = nn.LayerNorm(dim)
|
910 |
+
|
911 |
+
def forward(self, x):
|
912 |
+
return self.fn(self.norm(x)) + x
|
913 |
+
|
914 |
+
def MMFeedForward(dim, expansion_factor=4, dropout=0., dense=nn.Linear):
|
915 |
+
return nn.Sequential(
|
916 |
+
dense(dim, dim * expansion_factor),
|
917 |
+
nn.GELU(),
|
918 |
+
nn.Dropout(dropout),
|
919 |
+
dense(dim * expansion_factor, dim),
|
920 |
+
nn.Dropout(dropout)
|
921 |
+
)
|
922 |
+
|
923 |
+
def MLPMixer(S, input_dim, dim, output_dim, depth=6, expansion_factor=4, dropout=0., do_reduce=False):
|
924 |
+
# input is coming in as B,S,C, as standard for mlp and transformer
|
925 |
+
# chan_first treats S as the channel dim, and transforms it to a new S
|
926 |
+
# chan_last treats C as the channel dim, and transforms it to a new C
|
927 |
+
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
|
928 |
+
if do_reduce:
|
929 |
+
return nn.Sequential(
|
930 |
+
nn.Linear(input_dim, dim),
|
931 |
+
*[nn.Sequential(
|
932 |
+
MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
|
933 |
+
MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
|
934 |
+
) for _ in range(depth)],
|
935 |
+
nn.LayerNorm(dim),
|
936 |
+
Reduce('b n c -> b c', 'mean'),
|
937 |
+
nn.Linear(dim, output_dim)
|
938 |
+
)
|
939 |
+
else:
|
940 |
+
return nn.Sequential(
|
941 |
+
nn.Linear(input_dim, dim),
|
942 |
+
*[nn.Sequential(
|
943 |
+
MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
|
944 |
+
MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
|
945 |
+
) for _ in range(depth)],
|
946 |
+
)
|
947 |
+
|
948 |
+
def MLPMixerBlock(S, dim, depth=1, expansion_factor=4, dropout=0., do_reduce=False):
|
949 |
+
# input is coming in as B,S,C, as standard for mlp and transformer
|
950 |
+
# chan_first treats S as the channel dim, and transforms it to a new S
|
951 |
+
# chan_last treats C as the channel dim, and transforms it to a new C
|
952 |
+
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
|
953 |
+
return nn.Sequential(
|
954 |
+
*[nn.Sequential(
|
955 |
+
MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
|
956 |
+
MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
|
957 |
+
) for _ in range(depth)],
|
958 |
+
)
|
959 |
+
|
960 |
+
|
961 |
+
class MlpUpdateFormer(nn.Module):
|
962 |
+
"""
|
963 |
+
Transformer model that updates track estimates.
|
964 |
+
"""
|
965 |
+
|
966 |
+
def __init__(
|
967 |
+
self,
|
968 |
+
space_depth=6,
|
969 |
+
time_depth=6,
|
970 |
+
input_dim=320,
|
971 |
+
hidden_size=384,
|
972 |
+
num_heads=8,
|
973 |
+
output_dim=130,
|
974 |
+
mlp_ratio=4.0,
|
975 |
+
num_virtual_tracks=64,
|
976 |
+
add_space_attn=True,
|
977 |
+
linear_layer_for_vis_conf=False,
|
978 |
+
):
|
979 |
+
super().__init__()
|
980 |
+
self.out_channels = 2
|
981 |
+
self.num_heads = num_heads
|
982 |
+
self.hidden_size = hidden_size
|
983 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
984 |
+
if linear_layer_for_vis_conf:
|
985 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True)
|
986 |
+
self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
|
987 |
+
else:
|
988 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
989 |
+
self.num_virtual_tracks = num_virtual_tracks
|
990 |
+
self.virual_tracks = nn.Parameter(
|
991 |
+
torch.randn(1, num_virtual_tracks, 1, hidden_size)
|
992 |
+
)
|
993 |
+
self.add_space_attn = add_space_attn
|
994 |
+
self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
|
995 |
+
self.time_blocks = nn.ModuleList(
|
996 |
+
[
|
997 |
+
MLPMixer(
|
998 |
+
S=16,
|
999 |
+
input_dim=hidden_size,
|
1000 |
+
dim=hidden_size,
|
1001 |
+
output_dim=hidden_size,
|
1002 |
+
depth=1,
|
1003 |
+
)
|
1004 |
+
for _ in range(time_depth)
|
1005 |
+
]
|
1006 |
+
)
|
1007 |
+
|
1008 |
+
if add_space_attn:
|
1009 |
+
self.space_virtual_blocks = nn.ModuleList(
|
1010 |
+
[
|
1011 |
+
AttnBlock(
|
1012 |
+
hidden_size,
|
1013 |
+
num_heads,
|
1014 |
+
mlp_ratio=mlp_ratio,
|
1015 |
+
attn_class=Attention,
|
1016 |
+
)
|
1017 |
+
for _ in range(space_depth)
|
1018 |
+
]
|
1019 |
+
)
|
1020 |
+
self.space_point2virtual_blocks = nn.ModuleList(
|
1021 |
+
[
|
1022 |
+
CrossAttnBlock(
|
1023 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
1024 |
+
)
|
1025 |
+
for _ in range(space_depth)
|
1026 |
+
]
|
1027 |
+
)
|
1028 |
+
self.space_virtual2point_blocks = nn.ModuleList(
|
1029 |
+
[
|
1030 |
+
CrossAttnBlock(
|
1031 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
1032 |
+
)
|
1033 |
+
for _ in range(space_depth)
|
1034 |
+
]
|
1035 |
+
)
|
1036 |
+
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
|
1037 |
+
self.initialize_weights()
|
1038 |
+
|
1039 |
+
def initialize_weights(self):
|
1040 |
+
def _basic_init(module):
|
1041 |
+
if isinstance(module, nn.Linear):
|
1042 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
1043 |
+
if module.bias is not None:
|
1044 |
+
nn.init.constant_(module.bias, 0)
|
1045 |
+
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
|
1046 |
+
if self.linear_layer_for_vis_conf:
|
1047 |
+
torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001)
|
1048 |
+
|
1049 |
+
def _trunc_init(module):
|
1050 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
1051 |
+
if isinstance(module, nn.Linear):
|
1052 |
+
torch.nn.init.trunc_normal_(module.weight, std=0.02)
|
1053 |
+
if module.bias is not None:
|
1054 |
+
nn.init.zeros_(module.bias)
|
1055 |
+
|
1056 |
+
self.apply(_basic_init)
|
1057 |
+
|
1058 |
+
def forward(self, input_tensor, mask=None, add_space_attn=True):
|
1059 |
+
tokens = self.input_transform(input_tensor)
|
1060 |
+
|
1061 |
+
B, _, T, _ = tokens.shape
|
1062 |
+
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
|
1063 |
+
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
1064 |
+
|
1065 |
+
_, N, _, _ = tokens.shape
|
1066 |
+
j = 0
|
1067 |
+
layers = []
|
1068 |
+
for i in range(len(self.time_blocks)):
|
1069 |
+
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
1070 |
+
time_tokens = self.time_blocks[i](time_tokens)
|
1071 |
+
|
1072 |
+
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
1073 |
+
if (
|
1074 |
+
add_space_attn
|
1075 |
+
and hasattr(self, "space_virtual_blocks")
|
1076 |
+
and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0)
|
1077 |
+
):
|
1078 |
+
space_tokens = (
|
1079 |
+
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
|
1080 |
+
) # B N T C -> (B T) N C
|
1081 |
+
|
1082 |
+
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
|
1083 |
+
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
|
1084 |
+
|
1085 |
+
virtual_tokens = self.space_virtual2point_blocks[j](
|
1086 |
+
virtual_tokens, point_tokens, mask=mask
|
1087 |
+
)
|
1088 |
+
|
1089 |
+
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
|
1090 |
+
point_tokens = self.space_point2virtual_blocks[j](
|
1091 |
+
point_tokens, virtual_tokens, mask=mask
|
1092 |
+
)
|
1093 |
+
|
1094 |
+
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
1095 |
+
tokens = space_tokens.view(B, T, N, -1).permute(
|
1096 |
+
0, 2, 1, 3
|
1097 |
+
) # (B T) N C -> B N T C
|
1098 |
+
j += 1
|
1099 |
+
tokens = tokens[:, : N - self.num_virtual_tracks]
|
1100 |
+
|
1101 |
+
flow = self.flow_head(tokens)
|
1102 |
+
if self.linear_layer_for_vis_conf:
|
1103 |
+
vis_conf = self.vis_conf_head(tokens)
|
1104 |
+
flow = torch.cat([flow, vis_conf], dim=-1)
|
1105 |
+
|
1106 |
+
return flow
|
1107 |
+
|
1108 |
+
class BasicMotionEncoder(nn.Module):
|
1109 |
+
def __init__(self, corr_channel, dim=128, pdim=2):
|
1110 |
+
super(BasicMotionEncoder, self).__init__()
|
1111 |
+
self.pdim = pdim
|
1112 |
+
self.convc1 = nn.Conv2d(corr_channel, dim*4, 1, padding=0)
|
1113 |
+
self.convc2 = nn.Conv2d(dim*4, dim+dim//2, 3, padding=1)
|
1114 |
+
if pdim==2 or pdim==4:
|
1115 |
+
self.convf1 = nn.Conv2d(pdim, dim*2, 5, padding=2)
|
1116 |
+
self.convf2 = nn.Conv2d(dim*2, dim//2, 3, padding=1)
|
1117 |
+
self.conv = nn.Conv2d(dim*2, dim-pdim, 3, padding=1)
|
1118 |
+
else:
|
1119 |
+
self.conv = nn.Conv2d(dim+dim//2+pdim, dim, 3, padding=1)
|
1120 |
+
|
1121 |
+
def forward(self, flow, corr):
|
1122 |
+
cor = F.relu(self.convc1(corr))
|
1123 |
+
cor = F.relu(self.convc2(cor))
|
1124 |
+
if self.pdim==2 or self.pdim==4:
|
1125 |
+
flo = F.relu(self.convf1(flow))
|
1126 |
+
flo = F.relu(self.convf2(flo))
|
1127 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
1128 |
+
out = F.relu(self.conv(cor_flo))
|
1129 |
+
return torch.cat([out, flow], dim=1)
|
1130 |
+
else:
|
1131 |
+
# the flow is already encoded to something nice
|
1132 |
+
cor_flo = torch.cat([cor, flow], dim=1)
|
1133 |
+
return F.relu(self.conv(cor_flo))
|
1134 |
+
# return torch.cat([out, flow], dim=1)
|
1135 |
+
|
1136 |
+
def conv133_encoder(input_dim, dim, expansion_factor=4):
|
1137 |
+
return nn.Sequential(
|
1138 |
+
nn.Conv2d(input_dim, dim*expansion_factor, kernel_size=1),
|
1139 |
+
nn.GELU(),
|
1140 |
+
nn.Conv2d(dim*expansion_factor, dim*expansion_factor, kernel_size=3, padding=1),
|
1141 |
+
nn.GELU(),
|
1142 |
+
nn.Conv2d(dim*expansion_factor, dim, kernel_size=3, padding=1),
|
1143 |
+
)
|
1144 |
+
|
1145 |
+
class BasicUpdateBlock(nn.Module):
|
1146 |
+
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128):
|
1147 |
+
# flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
|
1148 |
+
super(BasicUpdateBlock, self).__init__()
|
1149 |
+
self.encoder = BasicMotionEncoder(corr_channel, dim=cdim)
|
1150 |
+
self.compressor = conv1x1(2*cdim+hdim, hdim)
|
1151 |
+
|
1152 |
+
self.refine = []
|
1153 |
+
for i in range(num_blocks):
|
1154 |
+
self.refine.append(CNBlock1d(hdim, hdim))
|
1155 |
+
self.refine.append(CNBlock2d(hdim, hdim))
|
1156 |
+
self.refine = nn.ModuleList(self.refine)
|
1157 |
+
|
1158 |
+
def forward(self, flowfeat, ctxfeat, corr, flow, S, upsample=True):
|
1159 |
+
BS,C,H,W = flowfeat.shape
|
1160 |
+
B = BS//S
|
1161 |
+
|
1162 |
+
# with torch.no_grad():
|
1163 |
+
motion_features = self.encoder(flow, corr)
|
1164 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features], dim=1))
|
1165 |
+
|
1166 |
+
for blk in self.refine:
|
1167 |
+
flowfeat = blk(flowfeat, S)
|
1168 |
+
return flowfeat
|
1169 |
+
|
1170 |
+
class FullUpdateBlock(nn.Module):
|
1171 |
+
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128, pdim=2, use_attn=False):
|
1172 |
+
# flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
|
1173 |
+
super(FullUpdateBlock, self).__init__()
|
1174 |
+
self.encoder = BasicMotionEncoder(corr_channel, dim=cdim, pdim=pdim)
|
1175 |
+
|
1176 |
+
# note we have hdim==cdim
|
1177 |
+
# compressor chans:
|
1178 |
+
# dim for flowfeat
|
1179 |
+
# dim for ctxfeat
|
1180 |
+
# dim for motion_features
|
1181 |
+
# pdim for flow (if p 2, like if we give sincos(relflow))
|
1182 |
+
# 2 for visconf
|
1183 |
+
|
1184 |
+
if pdim==2:
|
1185 |
+
# hdim==cdim
|
1186 |
+
# dim for flowfeat
|
1187 |
+
# dim for ctxfeat
|
1188 |
+
# dim for motion_features
|
1189 |
+
# 2 for visconf
|
1190 |
+
self.compressor = conv1x1(2*cdim+hdim+2, hdim)
|
1191 |
+
else:
|
1192 |
+
# we concatenate the flow info again, to not lose it (e.g., from the relu)
|
1193 |
+
self.compressor = conv1x1(2*cdim+hdim+2+pdim, hdim)
|
1194 |
+
|
1195 |
+
self.refine = []
|
1196 |
+
for i in range(num_blocks):
|
1197 |
+
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn))
|
1198 |
+
self.refine.append(CNBlock2d(hdim, hdim))
|
1199 |
+
self.refine = nn.ModuleList(self.refine)
|
1200 |
+
|
1201 |
+
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
|
1202 |
+
BS,C,H,W = flowfeat.shape
|
1203 |
+
B = BS//S
|
1204 |
+
|
1205 |
+
# print('flowfeat', flowfeat.shape)
|
1206 |
+
# print('ctxfeat', ctxfeat.shape)
|
1207 |
+
# print('visconf', visconf.shape)
|
1208 |
+
# print('corr', corr.shape)
|
1209 |
+
# print('flow', flow.shape)
|
1210 |
+
|
1211 |
+
motion_features = self.encoder(flow, corr)
|
1212 |
+
|
1213 |
+
# print('cat', torch.cat([flowfeat, ctxfeat, motion_features, visconf, flow], dim=1).shape)
|
1214 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features, visconf], dim=1))
|
1215 |
+
for blk in self.refine:
|
1216 |
+
flowfeat = blk(flowfeat, S)
|
1217 |
+
return flowfeat
|
1218 |
+
|
1219 |
+
class MixerUpdateBlock(nn.Module):
|
1220 |
+
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128):
|
1221 |
+
# flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
|
1222 |
+
super(MixerUpdateBlock, self).__init__()
|
1223 |
+
self.encoder = BasicMotionEncoder(corr_channel, dim=cdim)
|
1224 |
+
self.compressor = conv1x1(2*cdim+hdim, hdim)
|
1225 |
+
|
1226 |
+
self.refine = []
|
1227 |
+
for i in range(num_blocks):
|
1228 |
+
self.refine.append(CNBlock1d(hdim, hdim, use_mixer=True))
|
1229 |
+
self.refine.append(CNBlock2d(hdim, hdim))
|
1230 |
+
self.refine = nn.ModuleList(self.refine)
|
1231 |
+
|
1232 |
+
def forward(self, flowfeat, ctxfeat, corr, flow, S, upsample=True):
|
1233 |
+
BS,C,H,W = flowfeat.shape
|
1234 |
+
B = BS//S
|
1235 |
+
|
1236 |
+
# with torch.no_grad():
|
1237 |
+
motion_features = self.encoder(flow, corr)
|
1238 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features], dim=1))
|
1239 |
+
|
1240 |
+
for ii, blk in enumerate(self.refine):
|
1241 |
+
flowfeat = blk(flowfeat, S)
|
1242 |
+
return flowfeat
|
1243 |
+
|
1244 |
+
class FacUpdateBlock(nn.Module):
|
1245 |
+
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128, pdim=84, use_attn=False):
|
1246 |
+
super(FacUpdateBlock, self).__init__()
|
1247 |
+
self.corr_encoder = conv133_encoder(corr_channel, cdim)
|
1248 |
+
# note we have hdim==cdim
|
1249 |
+
# compressor chans:
|
1250 |
+
# dim for flowfeat
|
1251 |
+
# dim for ctxfeat
|
1252 |
+
# dim for corr
|
1253 |
+
# pdim for flow
|
1254 |
+
# 2 for visconf
|
1255 |
+
self.compressor = conv1x1(2*cdim+hdim+2+pdim, hdim)
|
1256 |
+
self.refine = []
|
1257 |
+
for i in range(num_blocks):
|
1258 |
+
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn))
|
1259 |
+
self.refine.append(CNBlock2d(hdim, hdim))
|
1260 |
+
self.refine = nn.ModuleList(self.refine)
|
1261 |
+
|
1262 |
+
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
|
1263 |
+
BS,C,H,W = flowfeat.shape
|
1264 |
+
B = BS//S
|
1265 |
+
|
1266 |
+
# print('flowfeat', flowfeat.shape)
|
1267 |
+
# print('ctxfeat', ctxfeat.shape)
|
1268 |
+
# print('visconf', visconf.shape)
|
1269 |
+
# print('corr', corr.shape)
|
1270 |
+
# print('flow', flow.shape)
|
1271 |
+
corr = self.corr_encoder(corr)
|
1272 |
+
# print('cat', torch.cat([flowfeat, ctxfeat, motion_features, visconf, flow], dim=1).shape)
|
1273 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, corr, visconf, flow], dim=1))
|
1274 |
+
for blk in self.refine:
|
1275 |
+
flowfeat = blk(flowfeat, S)
|
1276 |
+
return flowfeat
|
1277 |
+
|
1278 |
+
class CleanUpdateBlock(nn.Module):
|
1279 |
+
def __init__(self, corr_channel, num_blocks, cdim=128, hdim=256, pdim=84, use_attn=False, use_layer_scale=True):
|
1280 |
+
super(CleanUpdateBlock, self).__init__()
|
1281 |
+
self.corr_encoder = conv133_encoder(corr_channel, cdim)
|
1282 |
+
# compressor chans:
|
1283 |
+
# cdim for flowfeat
|
1284 |
+
# cdim for ctxfeat
|
1285 |
+
# cdim for corrfeat
|
1286 |
+
# pdim for flow
|
1287 |
+
# 2 for visconf
|
1288 |
+
self.compressor = conv1x1(3*cdim+pdim+2, hdim)
|
1289 |
+
self.refine = []
|
1290 |
+
for i in range(num_blocks):
|
1291 |
+
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn, use_layer_scale=use_layer_scale))
|
1292 |
+
self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale))
|
1293 |
+
self.refine = nn.ModuleList(self.refine)
|
1294 |
+
self.final_conv = conv1x1(hdim, cdim)
|
1295 |
+
|
1296 |
+
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
|
1297 |
+
BS,C,H,W = flowfeat.shape
|
1298 |
+
B = BS//S
|
1299 |
+
corrfeat = self.corr_encoder(corr)
|
1300 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, corrfeat, flow, visconf], dim=1))
|
1301 |
+
for blk in self.refine:
|
1302 |
+
flowfeat = blk(flowfeat, S)
|
1303 |
+
flowfeat = self.final_conv(flowfeat)
|
1304 |
+
return flowfeat
|
1305 |
+
|
1306 |
+
class RelUpdateBlock(nn.Module):
|
1307 |
+
def __init__(self, corr_channel, num_blocks, cdim=128, hdim=128, pdim=4, use_attn=True, use_mixer=False, use_conv=False, use_convb=False, use_layer_scale=True, no_time=False, no_space=False, final_space=False, no_ctx=False):
|
1308 |
+
super(RelUpdateBlock, self).__init__()
|
1309 |
+
self.motion_encoder = BasicMotionEncoder(corr_channel, dim=hdim, pdim=pdim) # B,hdim,H,W
|
1310 |
+
self.no_ctx = no_ctx
|
1311 |
+
if no_ctx:
|
1312 |
+
self.compressor = conv1x1(cdim+hdim+2, hdim)
|
1313 |
+
else:
|
1314 |
+
self.compressor = conv1x1(2*cdim+hdim+2, hdim)
|
1315 |
+
self.refine = []
|
1316 |
+
for i in range(num_blocks):
|
1317 |
+
if not no_time:
|
1318 |
+
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn, use_mixer=use_mixer, use_conv=use_conv, use_convb=use_convb, use_layer_scale=use_layer_scale))
|
1319 |
+
if not no_space:
|
1320 |
+
self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale))
|
1321 |
+
if final_space:
|
1322 |
+
self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale))
|
1323 |
+
self.refine = nn.ModuleList(self.refine)
|
1324 |
+
self.final_conv = conv1x1(hdim, cdim)
|
1325 |
+
|
1326 |
+
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
|
1327 |
+
BS,C,H,W = flowfeat.shape
|
1328 |
+
B = BS//S
|
1329 |
+
motion_features = self.motion_encoder(flow, corr)
|
1330 |
+
if self.no_ctx:
|
1331 |
+
flowfeat = self.compressor(torch.cat([flowfeat, motion_features, visconf], dim=1))
|
1332 |
+
else:
|
1333 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features, visconf], dim=1))
|
1334 |
+
for blk in self.refine:
|
1335 |
+
flowfeat = blk(flowfeat, S)
|
1336 |
+
flowfeat = self.final_conv(flowfeat)
|
1337 |
+
return flowfeat
|
nets/net34.py
ADDED
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import utils.samp
|
5 |
+
import utils.misc
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from nets.blocks import CNBlockConfig, ConvNeXt, conv1x1, RelUpdateBlock, InputPadder, CorrBlock, BasicEncoder
|
9 |
+
|
10 |
+
# init mostly from raft82 in alltrack
|
11 |
+
|
12 |
+
class Net(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
seqlen,
|
16 |
+
noise_level=0,
|
17 |
+
use_attn=True,
|
18 |
+
use_mixer=False,
|
19 |
+
use_conv=False,
|
20 |
+
use_convb=False,
|
21 |
+
use_basicencoder=False,
|
22 |
+
use_sinmotion=False,
|
23 |
+
use_relmotion=False,
|
24 |
+
use_sinrelmotion=False,
|
25 |
+
use_feats8=False,
|
26 |
+
no_time=False,
|
27 |
+
no_space=False,
|
28 |
+
final_space=False,
|
29 |
+
no_split=False,
|
30 |
+
no_ctx=False,
|
31 |
+
full_split=False,
|
32 |
+
half_corr=False,
|
33 |
+
corr_levels=5,
|
34 |
+
corr_radius=4,
|
35 |
+
num_blocks=3,
|
36 |
+
dim=128,
|
37 |
+
hdim=128,
|
38 |
+
init_weights=True,
|
39 |
+
):
|
40 |
+
super(Net, self).__init__()
|
41 |
+
|
42 |
+
self.dim = dim
|
43 |
+
self.hdim = hdim
|
44 |
+
|
45 |
+
self.noise_level = noise_level
|
46 |
+
self.no_time = no_time
|
47 |
+
self.no_space = no_space
|
48 |
+
self.final_space = final_space
|
49 |
+
self.seqlen = seqlen
|
50 |
+
self.corr_levels = corr_levels
|
51 |
+
self.corr_radius = corr_radius
|
52 |
+
self.corr_channel = self.corr_levels * (self.corr_radius * 2 + 1) ** 2
|
53 |
+
self.num_blocks = num_blocks
|
54 |
+
|
55 |
+
self.use_feats8 = use_feats8
|
56 |
+
self.use_basicencoder = use_basicencoder
|
57 |
+
self.use_sinmotion = use_sinmotion
|
58 |
+
self.use_relmotion = use_relmotion
|
59 |
+
self.use_sinrelmotion = use_sinrelmotion
|
60 |
+
self.no_split = no_split
|
61 |
+
self.no_ctx = no_ctx
|
62 |
+
self.full_split = full_split
|
63 |
+
self.half_corr = half_corr
|
64 |
+
|
65 |
+
if use_basicencoder:
|
66 |
+
if self.full_split:
|
67 |
+
self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8)
|
68 |
+
self.cnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8)
|
69 |
+
else:
|
70 |
+
if self.no_split:
|
71 |
+
self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8)
|
72 |
+
else:
|
73 |
+
self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim*2, stride=8)
|
74 |
+
else:
|
75 |
+
block_setting = [
|
76 |
+
CNBlockConfig(96, 192, 3, True), # 4x
|
77 |
+
CNBlockConfig(192, 384, 3, False), # 8x
|
78 |
+
CNBlockConfig(384, None, 9, False), # 8x
|
79 |
+
]
|
80 |
+
self.cnn = ConvNeXt(block_setting, stochastic_depth_prob=0.0, init_weights=init_weights)
|
81 |
+
if self.no_split:
|
82 |
+
self.dot_conv = conv1x1(384, dim)
|
83 |
+
else:
|
84 |
+
self.dot_conv = conv1x1(384, dim*2)
|
85 |
+
|
86 |
+
# # conv for iter 0 results
|
87 |
+
# self.init_conv = conv3x3(2 * dim, 2 * dim)
|
88 |
+
self.upsample_weight = nn.Sequential(
|
89 |
+
# convex combination of 3x3 patches
|
90 |
+
nn.Conv2d(dim, dim * 2, 3, padding=1),
|
91 |
+
nn.ReLU(inplace=True),
|
92 |
+
nn.Conv2d(dim * 2, 64 * 9, 1, padding=0)
|
93 |
+
)
|
94 |
+
self.flow_head = nn.Sequential(
|
95 |
+
nn.Conv2d(dim, 2*dim, kernel_size=3, padding=1),
|
96 |
+
nn.ReLU(inplace=True),
|
97 |
+
nn.Conv2d(2*dim, 2, kernel_size=3, padding=1)
|
98 |
+
)
|
99 |
+
self.visconf_head = nn.Sequential(
|
100 |
+
nn.Conv2d(dim, 2*dim, kernel_size=3, padding=1),
|
101 |
+
nn.ReLU(inplace=True),
|
102 |
+
nn.Conv2d(2*dim, 2, kernel_size=3, padding=1)
|
103 |
+
)
|
104 |
+
|
105 |
+
if self.use_sinrelmotion:
|
106 |
+
self.pdim = 84 # 32*2
|
107 |
+
elif self.use_relmotion:
|
108 |
+
self.pdim = 4
|
109 |
+
elif self.use_sinmotion:
|
110 |
+
self.pdim = 42
|
111 |
+
else:
|
112 |
+
self.pdim = 2
|
113 |
+
|
114 |
+
self.update_block = RelUpdateBlock(self.corr_channel, self.num_blocks, cdim=dim, hdim=hdim, pdim=self.pdim,
|
115 |
+
use_attn=use_attn, use_mixer=use_mixer, use_conv=use_conv, use_convb=use_convb,
|
116 |
+
use_layer_scale=True, no_time=no_time, no_space=no_space, final_space=final_space,
|
117 |
+
no_ctx=no_ctx)
|
118 |
+
|
119 |
+
time_line = torch.linspace(0, seqlen-1, seqlen).reshape(1, seqlen, 1)
|
120 |
+
self.register_buffer("time_emb", utils.misc.get_1d_sincos_pos_embed_from_grid(self.dim, time_line[0])) # 1,S,C
|
121 |
+
|
122 |
+
def fetch_time_embed(self, t, dtype, is_training=False):
|
123 |
+
S = self.time_emb.shape[1]
|
124 |
+
# print('fetching time_embed for t', t, '(we have %d)' % (S))
|
125 |
+
# print('self.time_emb', self.time_emb.shape)
|
126 |
+
if t == S:
|
127 |
+
return self.time_emb.to(dtype)
|
128 |
+
elif t==1:
|
129 |
+
if is_training:
|
130 |
+
ind = np.random.choice(S)
|
131 |
+
return self.time_emb[:,ind:ind+1].to(dtype)
|
132 |
+
else:
|
133 |
+
return self.time_emb[:,1:2].to(dtype)
|
134 |
+
else:
|
135 |
+
time_emb = self.time_emb.float()
|
136 |
+
time_emb = F.interpolate(time_emb.permute(0, 2, 1), size=t, mode="linear").permute(0, 2, 1)
|
137 |
+
return time_emb.to(dtype)
|
138 |
+
|
139 |
+
def coords_grid(self, batch, ht, wd, device, dtype):
|
140 |
+
coords = torch.meshgrid(torch.arange(ht, device=device, dtype=dtype), torch.arange(wd, device=device, dtype=dtype))
|
141 |
+
coords = torch.stack(coords[::-1], dim=0)
|
142 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
143 |
+
|
144 |
+
def initialize_flow(self, img):
|
145 |
+
""" Flow is represented as difference between two coordinate grids flow = coords2 - coords1"""
|
146 |
+
N, C, H, W = img.shape
|
147 |
+
coords1 = self.coords_grid(N, H//8, W//8, device=img.device)
|
148 |
+
coords2 = self.coords_grid(N, H//8, W//8, device=img.device)
|
149 |
+
return coords1, coords2
|
150 |
+
|
151 |
+
def upsample_data(self, flow, mask):
|
152 |
+
""" Upsample [H/8, W/8, C] -> [H, W, C] using convex combination """
|
153 |
+
N, C, H, W = flow.shape
|
154 |
+
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
155 |
+
mask = torch.softmax(mask, dim=2)
|
156 |
+
|
157 |
+
up_flow = F.unfold(8 * flow, [3,3], padding=1)
|
158 |
+
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
159 |
+
|
160 |
+
up_flow = torch.sum(mask * up_flow, dim=2)
|
161 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
162 |
+
|
163 |
+
return up_flow.reshape(N, 2, 8*H, 8*W).to(flow.dtype)
|
164 |
+
|
165 |
+
def get_T_padded_images(self, images, T, S, is_training, stride=None, pad=True):
|
166 |
+
B,T,C,H,W = images.shape
|
167 |
+
indices = None
|
168 |
+
if T > 2:
|
169 |
+
step = S // 2 if stride is None else stride
|
170 |
+
# starts = list(range(step,max(T,S),step))
|
171 |
+
# indices = [mid-step for mid in mids]
|
172 |
+
indices = []
|
173 |
+
start = 0
|
174 |
+
while start + S < T:
|
175 |
+
indices.append(start)
|
176 |
+
start += step
|
177 |
+
indices.append(start)
|
178 |
+
Tpad = indices[-1]+S-T
|
179 |
+
# print(indices, Tpad)
|
180 |
+
# import pdb; pdb.set_trace()
|
181 |
+
if pad:
|
182 |
+
if is_training:
|
183 |
+
assert Tpad == 0
|
184 |
+
else:
|
185 |
+
images = images.reshape(B,1,T,C*H*W)
|
186 |
+
if Tpad > 0:
|
187 |
+
padding_tensor = images[:,:,-1:,:].expand(B,1,Tpad,C*H*W)
|
188 |
+
images = torch.cat([images, padding_tensor], dim=2)
|
189 |
+
images = images.reshape(B,T+Tpad,C,H,W)
|
190 |
+
T = T+Tpad
|
191 |
+
else:
|
192 |
+
assert T == 2
|
193 |
+
return images, T, indices
|
194 |
+
|
195 |
+
def get_fmaps(self, images_, B, T, sw, is_training, nograd_backbone):
|
196 |
+
_, _, H_pad, W_pad = images_.shape # revised HW
|
197 |
+
|
198 |
+
C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
|
199 |
+
if self.no_split:
|
200 |
+
C = self.dim
|
201 |
+
|
202 |
+
fmaps_chunk_size = 64
|
203 |
+
if (not is_training) and (T > fmaps_chunk_size):
|
204 |
+
images = images_.reshape(B,T,3,H_pad,W_pad)
|
205 |
+
fmaps = []
|
206 |
+
for t in range(0, T, fmaps_chunk_size):
|
207 |
+
images_chunk = images[:, t : t + fmaps_chunk_size]
|
208 |
+
images_chunk = images_chunk.cuda()
|
209 |
+
if self.use_basicencoder:
|
210 |
+
if self.full_split:
|
211 |
+
fmaps_chunk1 = self.fnet(images_chunk.reshape(-1, 3, H_pad, W_pad))
|
212 |
+
fmaps_chunk2 = self.cnet(images_chunk.reshape(-1, 3, H_pad, W_pad))
|
213 |
+
fmaps_chunk = torch.cat([fmaps_chunk1, fmaps_chunk2], axis=1)
|
214 |
+
else:
|
215 |
+
fmaps_chunk = self.fnet(images_chunk.reshape(-1, 3, H_pad, W_pad))
|
216 |
+
else:
|
217 |
+
fmaps_chunk = self.cnn(images_chunk.reshape(-1, 3, H_pad, W_pad))
|
218 |
+
if t==0 and sw is not None and sw.save_this:
|
219 |
+
sw.summ_feat('1_model/fmap_raw', fmaps_chunk[0:1])
|
220 |
+
fmaps_chunk = self.dot_conv(fmaps_chunk) # B*T,C,H8,W8
|
221 |
+
T_chunk = images_chunk.shape[1]
|
222 |
+
fmaps.append(fmaps_chunk.reshape(B, -1, C, H8, W8))
|
223 |
+
fmaps_ = torch.cat(fmaps, dim=1).reshape(-1, C, H8, W8)
|
224 |
+
else:
|
225 |
+
if not is_training:
|
226 |
+
# sometimes we need to move things to cuda here
|
227 |
+
images_ = images_.cuda()
|
228 |
+
if self.use_basicencoder:
|
229 |
+
if self.full_split:
|
230 |
+
# if self.half_corr:
|
231 |
+
fmaps1_ = self.fnet(images_)
|
232 |
+
fmaps2_ = self.cnet(images_)
|
233 |
+
fmaps_ = torch.cat([fmaps1_, fmaps2_], axis=1)
|
234 |
+
else:
|
235 |
+
fmaps_ = self.fnet(images_)
|
236 |
+
else:
|
237 |
+
if nograd_backbone:
|
238 |
+
with torch.no_grad():
|
239 |
+
fmaps_ = self.cnn(images_)
|
240 |
+
else:
|
241 |
+
fmaps_ = self.cnn(images_)
|
242 |
+
if sw is not None and sw.save_this:
|
243 |
+
sw.summ_feat('1_model/fmap_raw', fmaps_[0:1])
|
244 |
+
fmaps_ = self.dot_conv(fmaps_) # B*T,C,H8,W8
|
245 |
+
return fmaps_
|
246 |
+
|
247 |
+
def forward(self, images, iters=6, sw=None, nograd_backbone=False, is_training=False, stride=None):
|
248 |
+
B,T,C,H,W = images.shape
|
249 |
+
S = self.seqlen
|
250 |
+
device = images.device
|
251 |
+
dtype = images.dtype
|
252 |
+
# print('images', images.shape, 'device', device)
|
253 |
+
|
254 |
+
T_bak = T
|
255 |
+
if stride is not None:
|
256 |
+
pad = False
|
257 |
+
else:
|
258 |
+
pad = True
|
259 |
+
images, T, indices = self.get_T_padded_images(images, T, S, is_training, stride=stride, pad=pad)
|
260 |
+
# print('indices', indices)
|
261 |
+
|
262 |
+
images = images.contiguous()
|
263 |
+
images_ = images.reshape(B*T,3,H,W)
|
264 |
+
padder = InputPadder(images_.shape)
|
265 |
+
images_ = padder.pad(images_)[0]
|
266 |
+
|
267 |
+
_, _, H_pad, W_pad = images_.shape # revised HW
|
268 |
+
C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
|
269 |
+
C2 = C//2
|
270 |
+
if self.no_split:
|
271 |
+
C = self.dim
|
272 |
+
C2 = C
|
273 |
+
|
274 |
+
fmaps = self.get_fmaps(images_, B, T, sw, is_training, nograd_backbone).reshape(B,T,C,H8,W8)
|
275 |
+
# print('fmaps_', fmaps_.shape)
|
276 |
+
device = fmaps.device
|
277 |
+
|
278 |
+
# if sw is not None and sw.save_this:
|
279 |
+
# sw.summ_feat('1_model/fmap_dc', fmaps_[0:1])
|
280 |
+
|
281 |
+
# fmaps = fmaps_.reshape(B,T,C,H8,W8)
|
282 |
+
# del fmaps_
|
283 |
+
fmap_anchor = fmaps[:,0]
|
284 |
+
|
285 |
+
# if not is_training:
|
286 |
+
# del images
|
287 |
+
# del images_
|
288 |
+
|
289 |
+
if T<=2 or is_training:
|
290 |
+
# note: collecting preds can get expensive on a long video
|
291 |
+
all_flow_preds = []
|
292 |
+
all_visconf_preds = []
|
293 |
+
else:
|
294 |
+
all_flow_preds = None
|
295 |
+
all_visconf_preds = None
|
296 |
+
|
297 |
+
if T > 2: # multiframe tracking
|
298 |
+
|
299 |
+
# we will store our final outputs in these tensors
|
300 |
+
full_flows = torch.zeros((B,T,2,H,W), dtype=dtype, device=device)
|
301 |
+
full_visconfs = torch.zeros((B,T,2,H,W), dtype=dtype, device=device)
|
302 |
+
# 1/8 resolution
|
303 |
+
full_flows8 = torch.zeros((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
304 |
+
full_visconfs8 = torch.zeros((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
305 |
+
|
306 |
+
if is_training and self.noise_level and np.random.rand() < 0.1:
|
307 |
+
# full_flows8 = 2.0*torch.randn((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
308 |
+
# print('flows8 += randn4, const on time')
|
309 |
+
# full_flows8 = float(self.noise_level)*torch.randn((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
310 |
+
# print('noise const on time')
|
311 |
+
full_flows8 += np.random.rand()*float(self.noise_level)*torch.randn((B,1,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
312 |
+
|
313 |
+
if self.use_feats8:
|
314 |
+
full_feats8 = torch.zeros((B,T,C2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
315 |
+
visits = np.zeros((T))
|
316 |
+
|
317 |
+
for ii, ind in enumerate(indices):
|
318 |
+
ara = np.arange(ind,ind+S)
|
319 |
+
if ii < len(indices)-1:
|
320 |
+
next_ind = indices[ii+1]
|
321 |
+
next_ara = np.arange(next_ind,next_ind+S)
|
322 |
+
|
323 |
+
# print("torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024), 'ara', ara)
|
324 |
+
# print('ara', ara)
|
325 |
+
fmaps2 = fmaps[:,ara]
|
326 |
+
flows8 = full_flows8[:,ara].reshape(B*(S),2,H_pad//8,W_pad//8).detach()
|
327 |
+
visconfs8 = full_visconfs8[:,ara].reshape(B*(S),2,H_pad//8,W_pad//8).detach()
|
328 |
+
|
329 |
+
if self.use_feats8:
|
330 |
+
if ind==0:
|
331 |
+
feats8 = None
|
332 |
+
else:
|
333 |
+
feats8 = full_feats8[:,ara].reshape(B*(S),C2,H_pad//8,W_pad//8).detach()
|
334 |
+
else:
|
335 |
+
feats8 = None
|
336 |
+
|
337 |
+
flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window(
|
338 |
+
fmap_anchor, fmaps2, visconfs8, iters=iters, flowfeat=feats8, flows8=flows8,
|
339 |
+
is_training=is_training)
|
340 |
+
|
341 |
+
unpad_flow_predictions = []
|
342 |
+
unpad_visconf_predictions = []
|
343 |
+
for i in range(len(flow_predictions)):
|
344 |
+
flow_predictions[i] = padder.unpad(flow_predictions[i])
|
345 |
+
unpad_flow_predictions.append(flow_predictions[i].reshape(B,S,2,H,W))
|
346 |
+
visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
|
347 |
+
# print('visconf_predictions[%d]' % i, visconf_predictions[i].shape)
|
348 |
+
unpad_visconf_predictions.append(visconf_predictions[i].reshape(B,S,2,H,W))
|
349 |
+
|
350 |
+
full_flows[:,ara] = unpad_flow_predictions[-1].reshape(B,S,2,H,W)
|
351 |
+
full_flows8[:,ara] = flows8.reshape(B,S,2,H_pad//8,W_pad//8)
|
352 |
+
full_visconfs[:,ara] = unpad_visconf_predictions[-1].reshape(B,S,2,H,W)
|
353 |
+
full_visconfs8[:,ara] = visconfs8.reshape(B,S,2,H_pad//8,W_pad//8)
|
354 |
+
if self.use_feats8:
|
355 |
+
full_feats8[:,ara] = feats8.reshape(B,S,C2,H_pad//8,W_pad//8)
|
356 |
+
visits[ara] += 1
|
357 |
+
|
358 |
+
if is_training:
|
359 |
+
all_flow_preds.append(unpad_flow_predictions)
|
360 |
+
all_visconf_preds.append(unpad_visconf_predictions)
|
361 |
+
else:
|
362 |
+
del unpad_flow_predictions
|
363 |
+
del unpad_visconf_predictions
|
364 |
+
|
365 |
+
# for the next iter, replace empty data with nearest available preds
|
366 |
+
invalid_idx = np.where(visits==0)[0]
|
367 |
+
valid_idx = np.where(visits>0)[0]
|
368 |
+
for idx in invalid_idx:
|
369 |
+
nearest = valid_idx[np.argmin(np.abs(valid_idx - idx))]
|
370 |
+
# print('replacing %d with %d' % (idx, nearest))
|
371 |
+
full_flows8[:,idx] = full_flows8[:,nearest]
|
372 |
+
full_visconfs8[:,idx] = full_visconfs8[:,nearest]
|
373 |
+
if self.use_feats8:
|
374 |
+
full_feats8[:,idx] = full_feats8[:,nearest]
|
375 |
+
else: # flow
|
376 |
+
|
377 |
+
flows8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
378 |
+
visconfs8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
379 |
+
|
380 |
+
if is_training and self.noise_level and np.random.rand() < 0.1:
|
381 |
+
# full_flows8 = 2.0*torch.randn((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
382 |
+
# print('flows8 += randn4, const on time')
|
383 |
+
# full_flows8 = float(self.noise_level)*torch.randn((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
384 |
+
# print('noise on flow too')
|
385 |
+
flows8 += np.random.rand()*float(self.noise_level)*torch.randn((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
386 |
+
|
387 |
+
flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window(
|
388 |
+
fmap_anchor, fmaps[:,1:2], visconfs8, iters=iters, flowfeat=None, flows8=flows8,
|
389 |
+
is_training=is_training)
|
390 |
+
unpad_flow_predictions = []
|
391 |
+
unpad_visconf_predictions = []
|
392 |
+
for i in range(len(flow_predictions)):
|
393 |
+
flow_predictions[i] = padder.unpad(flow_predictions[i])
|
394 |
+
all_flow_preds.append(flow_predictions[i].reshape(B,2,H,W))
|
395 |
+
visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
|
396 |
+
all_visconf_preds.append(visconf_predictions[i].reshape(B,2,H,W))
|
397 |
+
full_flows = all_flow_preds[-1].reshape(B,2,H,W)
|
398 |
+
full_visconfs = all_visconf_preds[-1].reshape(B,2,H,W)
|
399 |
+
|
400 |
+
if (not is_training) and (T > 2):
|
401 |
+
# print('full_flows', full_flows.shape)
|
402 |
+
# print('T_bak', T_bak)
|
403 |
+
full_flows = full_flows[:,:T_bak]
|
404 |
+
full_visconfs = full_visconfs[:,:T_bak]
|
405 |
+
# print('full_flows trim', full_flows.shape)
|
406 |
+
|
407 |
+
return full_flows, full_visconfs, all_flow_preds, all_visconf_preds#, bak_flows8
|
408 |
+
|
409 |
+
def forward_sliding(self, images, iters=6, sw=None, nograd_backbone=False, is_training=False, window_len=None, stride=None):
|
410 |
+
B,T,C,H,W = images.shape
|
411 |
+
S = self.seqlen if window_len is None else window_len
|
412 |
+
device = images.device
|
413 |
+
dtype = images.dtype
|
414 |
+
# print('images', images.shape, 'device', device)
|
415 |
+
stride = S // 2 if stride is None else stride
|
416 |
+
|
417 |
+
T_bak = T
|
418 |
+
images, T, indices = self.get_T_padded_images(images, T, S, is_training, stride)
|
419 |
+
# print('indices', indices)
|
420 |
+
assert stride <= S // 2
|
421 |
+
|
422 |
+
images = images.contiguous()
|
423 |
+
images_ = images.reshape(B*T,3,H,W)
|
424 |
+
padder = InputPadder(images_.shape)
|
425 |
+
images_ = padder.pad(images_)[0]
|
426 |
+
|
427 |
+
_, _, H_pad, W_pad = images_.shape # revised HW
|
428 |
+
C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
|
429 |
+
C2 = C//2
|
430 |
+
if self.no_split:
|
431 |
+
C = self.dim
|
432 |
+
C2 = C
|
433 |
+
|
434 |
+
all_flow_preds = None
|
435 |
+
all_visconf_preds = None
|
436 |
+
|
437 |
+
if T<=2:
|
438 |
+
# note: collecting preds can get expensive on a long video
|
439 |
+
all_flow_preds = []
|
440 |
+
all_visconf_preds = []
|
441 |
+
|
442 |
+
fmaps = self.get_fmaps(images_, B, T, sw, is_training, nograd_backbone).reshape(B,T,C,H8,W8)
|
443 |
+
device = fmaps.device
|
444 |
+
|
445 |
+
flows8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
446 |
+
visconfs8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
447 |
+
|
448 |
+
if is_training and self.noise_level and np.random.rand() < 0.1:
|
449 |
+
# full_flows8 = 2.0*torch.randn((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
450 |
+
# print('flows8 += randn4, const on time')
|
451 |
+
# full_flows8 = float(self.noise_level)*torch.randn((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
452 |
+
# print('noise on flow too')
|
453 |
+
flows8 += np.random.rand()*float(self.noise_level)*torch.randn((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
454 |
+
|
455 |
+
fmap_anchor = fmaps[:,0]
|
456 |
+
|
457 |
+
flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window(
|
458 |
+
fmap_anchor, fmaps[:,1:2], visconfs8, iters=iters, flowfeat=None, flows8=flows8,
|
459 |
+
is_training=is_training)
|
460 |
+
unpad_flow_predictions = []
|
461 |
+
unpad_visconf_predictions = []
|
462 |
+
for i in range(len(flow_predictions)):
|
463 |
+
flow_predictions[i] = padder.unpad(flow_predictions[i])
|
464 |
+
all_flow_preds.append(flow_predictions[i].reshape(B,2,H,W))
|
465 |
+
visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
|
466 |
+
all_visconf_preds.append(visconf_predictions[i].reshape(B,2,H,W))
|
467 |
+
full_flows = all_flow_preds[-1].reshape(B,2,H,W).detach().cpu()
|
468 |
+
full_visconfs = all_visconf_preds[-1].reshape(B,2,H,W).detach().cpu()
|
469 |
+
|
470 |
+
return full_flows, full_visconfs, all_flow_preds, all_visconf_preds#, bak_flows8
|
471 |
+
|
472 |
+
assert T > 2 # multiframe tracking
|
473 |
+
|
474 |
+
if is_training:
|
475 |
+
all_flow_preds = []
|
476 |
+
all_visconf_preds = []
|
477 |
+
|
478 |
+
# # we will store our final outputs in these tensors cpu
|
479 |
+
full_flows = torch.zeros((B,T,2,H,W), dtype=dtype, device='cpu')
|
480 |
+
full_visconfs = torch.zeros((B,T,2,H,W), dtype=dtype, device='cpu')
|
481 |
+
|
482 |
+
images_ = images_.reshape(B,T,3,H_pad,W_pad)
|
483 |
+
fmap_anchor = self.get_fmaps(images_[:,:1].reshape(-1,3,H_pad,W_pad), B, 1, sw, is_training, nograd_backbone).reshape(B,C,H8,W8)
|
484 |
+
device = fmap_anchor.device
|
485 |
+
full_visited = torch.zeros((T,), dtype=torch.bool, device=device)
|
486 |
+
|
487 |
+
for ii, ind in enumerate(indices):
|
488 |
+
ara = np.arange(ind,ind+S)
|
489 |
+
if ii == 0:
|
490 |
+
flows8 = torch.zeros((B,S,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
491 |
+
visconfs8 = torch.zeros((B,S,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
492 |
+
fmaps2 = self.get_fmaps(images_[:,ara].reshape(-1,3,H_pad,W_pad), B, S, sw, is_training, nograd_backbone).reshape(B,S,C,H8,W8)
|
493 |
+
if is_training and self.noise_level and np.random.rand() < 0.1:
|
494 |
+
flows8 += np.random.rand()*float(self.noise_level)*torch.randn((B,1,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
495 |
+
else:
|
496 |
+
# import pdb; pdb.set_trace()
|
497 |
+
flows8 = torch.cat([flows8[:,stride:stride+S//2], flows8[:,stride+S//2-1:stride+S//2].repeat(1,S//2,1,1,1)], dim=1)
|
498 |
+
visconfs8 = torch.cat([visconfs8[:,stride:stride+S//2], visconfs8[:,stride+S//2-1:stride+S//2].repeat(1,S//2,1,1,1)], dim=1)
|
499 |
+
fmaps2 = torch.cat([fmaps2[:,stride:stride+S//2],
|
500 |
+
self.get_fmaps(images_[:,np.arange(ind+S//2,ind+S)].reshape(-1,3,H_pad,W_pad), B, S//2, sw, is_training, nograd_backbone).reshape(B,S//2,C,H8,W8)], dim=1)
|
501 |
+
|
502 |
+
flows8 = flows8.reshape(B*S,2,H_pad//8,W_pad//8).detach()
|
503 |
+
visconfs8 = visconfs8.reshape(B*S,2,H_pad//8,W_pad//8).detach()
|
504 |
+
|
505 |
+
# print("torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024), 'ara', ara)
|
506 |
+
# print('ara', ara)
|
507 |
+
flow_predictions, visconf_predictions, flows8, visconfs8, _ = self.forward_window(
|
508 |
+
fmap_anchor, fmaps2, visconfs8, iters=iters, flowfeat=None, flows8=flows8,
|
509 |
+
is_training=is_training)
|
510 |
+
|
511 |
+
unpad_flow_predictions = []
|
512 |
+
unpad_visconf_predictions = []
|
513 |
+
for i in range(len(flow_predictions)):
|
514 |
+
flow_predictions[i] = padder.unpad(flow_predictions[i])
|
515 |
+
unpad_flow_predictions.append(flow_predictions[i].reshape(B,S,2,H,W))
|
516 |
+
visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
|
517 |
+
# print('visconf_predictions[%d]' % i, visconf_predictions[i].shape)
|
518 |
+
unpad_visconf_predictions.append(visconf_predictions[i].reshape(B,S,2,H,W))
|
519 |
+
|
520 |
+
current_visiting = torch.zeros((T,), dtype=torch.bool, device=device)
|
521 |
+
current_visiting[ara] = True
|
522 |
+
|
523 |
+
to_fill = current_visiting & (~full_visited)
|
524 |
+
to_fill_sum = to_fill.sum().item()
|
525 |
+
full_flows[:,to_fill] = unpad_flow_predictions[-1].reshape(B,S,2,H,W)[:,-to_fill_sum:].detach().cpu()
|
526 |
+
full_visconfs[:,to_fill] = unpad_visconf_predictions[-1].reshape(B,S,2,H,W)[:,-to_fill_sum:].detach().cpu()
|
527 |
+
full_visited |= current_visiting
|
528 |
+
|
529 |
+
if is_training:
|
530 |
+
all_flow_preds.append(unpad_flow_predictions)
|
531 |
+
all_visconf_preds.append(unpad_visconf_predictions)
|
532 |
+
else:
|
533 |
+
del unpad_flow_predictions
|
534 |
+
del unpad_visconf_predictions
|
535 |
+
|
536 |
+
flows8 = flows8.reshape(B,S,2,H_pad//8,W_pad//8)
|
537 |
+
visconfs8 = visconfs8.reshape(B,S,2,H_pad//8,W_pad//8)
|
538 |
+
|
539 |
+
if not is_training:
|
540 |
+
# print('full_flows', full_flows.shape)
|
541 |
+
# print('T_bak', T_bak)
|
542 |
+
full_flows = full_flows[:,:T_bak]
|
543 |
+
full_visconfs = full_visconfs[:,:T_bak]
|
544 |
+
# print('full_flows trim', full_flows.shape)
|
545 |
+
|
546 |
+
return full_flows, full_visconfs, all_flow_preds, all_visconf_preds#, bak_flows8
|
547 |
+
|
548 |
+
def forward_window(self, fmap1_single, fmaps2, visconfs8, iters=None, flowfeat=None, flows8=None, sw=None, is_training=False):
|
549 |
+
B,S,C,H8,W8 = fmaps2.shape
|
550 |
+
device = fmaps2.device
|
551 |
+
dtype = fmaps2.dtype
|
552 |
+
|
553 |
+
flow_predictions = []
|
554 |
+
visconf_predictions = []
|
555 |
+
|
556 |
+
# print('fmap1_single', fmap1_single.shape)
|
557 |
+
# print('fmaps2', fmaps2.shape)
|
558 |
+
|
559 |
+
fmap1 = fmap1_single.unsqueeze(1).repeat(1,S,1,1,1) # B,S,C,H,W
|
560 |
+
# print('fmap1', fmap1.shape)
|
561 |
+
fmap1 = fmap1.reshape(B*(S),C,H8,W8).contiguous()
|
562 |
+
# print('fmap1', fmap1.shape)
|
563 |
+
|
564 |
+
fmap2 = fmaps2.reshape(B*(S),C,H8,W8).contiguous()
|
565 |
+
|
566 |
+
visconfs8 = visconfs8.reshape(B*(S),2,H8,W8).contiguous()
|
567 |
+
|
568 |
+
if not self.half_corr:
|
569 |
+
corr_fn = CorrBlock(fmap1, fmap2, self.corr_levels, self.corr_radius)
|
570 |
+
|
571 |
+
coords1 = self.coords_grid(B*(S), H8, W8, device=fmap1.device, dtype=dtype)
|
572 |
+
|
573 |
+
if self.no_split:
|
574 |
+
flowfeat, ctxfeat = fmap1.clone(), fmap1.clone()
|
575 |
+
else:
|
576 |
+
if flowfeat is not None:
|
577 |
+
_, ctxfeat = torch.split(fmap1, [self.dim, self.dim], dim=1)
|
578 |
+
else:
|
579 |
+
flowfeat, ctxfeat = torch.split(fmap1, [self.dim, self.dim], dim=1)
|
580 |
+
|
581 |
+
if self.half_corr:
|
582 |
+
# we discard the cnet output for fmap2
|
583 |
+
# in general maybe we shouldn't compute it, if half_corr pays off
|
584 |
+
flowfeat2, _ = torch.split(fmap2, [self.dim, self.dim], dim=1)
|
585 |
+
corr_fn = CorrBlock(flowfeat, flowfeat2, self.corr_levels, self.corr_radius)
|
586 |
+
|
587 |
+
# add pos emb to ctxfeat (and not flowfeat), since ctxfeat is untouched across iters
|
588 |
+
time_emb = self.fetch_time_embed(S, ctxfeat.dtype, is_training).reshape(1,S,self.dim,1,1).repeat(B,1,1,1,1)
|
589 |
+
ctxfeat = ctxfeat + time_emb.reshape(B*S,self.dim,1,1)
|
590 |
+
|
591 |
+
if self.no_ctx:
|
592 |
+
flowfeat = flowfeat + time_emb.reshape(B*S,self.dim,1,1)
|
593 |
+
|
594 |
+
for itr in range(iters):
|
595 |
+
_, _, H8, W8 = flows8.shape
|
596 |
+
flows8 = flows8.detach()
|
597 |
+
coords2 = (coords1 + flows8).detach() # B*S,2,H,W
|
598 |
+
corr = corr_fn(coords2).to(dtype)
|
599 |
+
|
600 |
+
if self.use_relmotion or self.use_sinrelmotion:
|
601 |
+
coords_ = coords2.reshape(B,S,2,H8*W8).permute(0,1,3,2) # B,S,H8*W8,2
|
602 |
+
rel_coords_forward = coords_[:, :-1] - coords_[:, 1:]
|
603 |
+
rel_coords_backward = coords_[:, 1:] - coords_[:, :-1]
|
604 |
+
rel_coords_forward = torch.nn.functional.pad(
|
605 |
+
rel_coords_forward, (0, 0, 0, 0, 0, 1) # pad the 3rd-last dim (S) by (0,1)
|
606 |
+
)
|
607 |
+
rel_coords_backward = torch.nn.functional.pad(
|
608 |
+
rel_coords_backward, (0, 0, 0, 0, 1, 0) # pad the 3rd-last dim (S) by (1,0)
|
609 |
+
)
|
610 |
+
rel_coords = torch.cat([rel_coords_forward, rel_coords_backward], dim=-1) # B,S,H8*W8,4
|
611 |
+
|
612 |
+
if self.use_sinrelmotion:
|
613 |
+
rel_pos_emb_input = utils.misc.posenc(
|
614 |
+
rel_coords,
|
615 |
+
min_deg=0,
|
616 |
+
max_deg=10,
|
617 |
+
) # B,S,H*W,pdim
|
618 |
+
motion = rel_pos_emb_input.reshape(B*S,H8,W8,self.pdim).permute(0,3,1,2).to(dtype) # B*S,pdim,H8,W8
|
619 |
+
else:
|
620 |
+
motion = rel_coords.reshape(B*S,H8,W8,4).permute(0,3,1,2).to(dtype) # B*S,4,H8,W8
|
621 |
+
|
622 |
+
else:
|
623 |
+
if self.use_sinmotion:
|
624 |
+
pos_emb_input = utils.misc.posenc(
|
625 |
+
flows8.reshape(B,S,H8*W8,2),
|
626 |
+
min_deg=0,
|
627 |
+
max_deg=10,
|
628 |
+
) # B,S,H*W,pdim
|
629 |
+
motion = pos_emb_input.reshape(B*S,H8,W8,self.pdim).permute(0,3,1,2).to(dtype) # B*S,pdim,H8,W8
|
630 |
+
else:
|
631 |
+
motion = flows8
|
632 |
+
|
633 |
+
flowfeat = self.update_block(flowfeat, ctxfeat, visconfs8, corr, motion, S)
|
634 |
+
flow_update = self.flow_head(flowfeat)
|
635 |
+
visconf_update = self.visconf_head(flowfeat)
|
636 |
+
weight_update = .25 * self.upsample_weight(flowfeat)
|
637 |
+
flows8 = flows8 + flow_update
|
638 |
+
visconfs8 = visconfs8 + visconf_update
|
639 |
+
flow_up = self.upsample_data(flows8, weight_update)
|
640 |
+
flow_predictions.append(flow_up)
|
641 |
+
visconf_up = self.upsample_data(visconfs8, weight_update)
|
642 |
+
visconf_predictions.append(visconf_up)
|
643 |
+
|
644 |
+
return flow_predictions, visconf_predictions, flows8, visconfs8, flowfeat#, bak_flows8
|
645 |
+
|
646 |
+
|
647 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.26.4
|
2 |
+
imageio==2.19.3
|
3 |
+
imageio-ffmpeg==0.4.7
|
4 |
+
gradio
|
5 |
+
spaces
|
6 |
+
matplotlib
|
7 |
+
pillow
|
8 |
+
torch==2.2.0
|
9 |
+
torchvision==0.17.0
|
10 |
+
albumentations
|
11 |
+
pytorch-lightning==2.2.5
|
12 |
+
opencv-python
|
13 |
+
scikit-learn
|
14 |
+
scikit-image
|
15 |
+
einops
|
16 |
+
tensorboardX
|
17 |
+
transformers
|
utils/basic.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from os.path import isfile
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
EPS = 1e-6
|
7 |
+
import copy
|
8 |
+
|
9 |
+
def sub2ind(height, width, y, x):
|
10 |
+
return y*width + x
|
11 |
+
|
12 |
+
def ind2sub(height, width, ind):
|
13 |
+
y = ind // width
|
14 |
+
x = ind % width
|
15 |
+
return y, x
|
16 |
+
|
17 |
+
def get_lr_str(lr):
|
18 |
+
lrn = "%.1e" % lr # e.g., 5.0e-04
|
19 |
+
lrn = lrn[0] + lrn[3:5] + lrn[-1] # e.g., 5e-4
|
20 |
+
return lrn
|
21 |
+
|
22 |
+
def strnum(x):
|
23 |
+
s = '%g' % x
|
24 |
+
if '.' in s:
|
25 |
+
if x < 1.0:
|
26 |
+
s = s[s.index('.'):]
|
27 |
+
s = s[:min(len(s),4)]
|
28 |
+
return s
|
29 |
+
|
30 |
+
def assert_same_shape(t1, t2):
|
31 |
+
for (x, y) in zip(list(t1.shape), list(t2.shape)):
|
32 |
+
assert(x==y)
|
33 |
+
|
34 |
+
def print_stats(name, tensor):
|
35 |
+
shape = tensor.shape
|
36 |
+
tensor = tensor.detach().cpu().numpy()
|
37 |
+
print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape)
|
38 |
+
|
39 |
+
def print_stats_py(name, tensor):
|
40 |
+
shape = tensor.shape
|
41 |
+
print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape)
|
42 |
+
|
43 |
+
def print_(name, tensor):
|
44 |
+
tensor = tensor.detach().cpu().numpy()
|
45 |
+
print(name, tensor, tensor.shape)
|
46 |
+
|
47 |
+
def mkdir(path):
|
48 |
+
if not os.path.exists(path):
|
49 |
+
os.makedirs(path)
|
50 |
+
|
51 |
+
def normalize_single(d):
|
52 |
+
# d is a whatever shape torch tensor
|
53 |
+
dmin = torch.min(d)
|
54 |
+
dmax = torch.max(d)
|
55 |
+
d = (d-dmin)/(EPS+(dmax-dmin))
|
56 |
+
return d
|
57 |
+
|
58 |
+
def normalize(d):
|
59 |
+
# d is B x whatever. normalize within each element of the batch
|
60 |
+
out = torch.zeros(d.size(), dtype=d.dtype, device=d.device)
|
61 |
+
B = list(d.size())[0]
|
62 |
+
for b in list(range(B)):
|
63 |
+
out[b] = normalize_single(d[b])
|
64 |
+
return out
|
65 |
+
|
66 |
+
def hard_argmax2d(tensor):
|
67 |
+
B, C, Y, X = list(tensor.shape)
|
68 |
+
assert(C==1)
|
69 |
+
|
70 |
+
# flatten the Tensor along the height and width axes
|
71 |
+
flat_tensor = tensor.reshape(B, -1)
|
72 |
+
# argmax of the flat tensor
|
73 |
+
argmax = torch.argmax(flat_tensor, dim=1)
|
74 |
+
|
75 |
+
# convert the indices into 2d coordinates
|
76 |
+
argmax_y = torch.floor(argmax / X) # row
|
77 |
+
argmax_x = argmax % X # col
|
78 |
+
|
79 |
+
argmax_y = argmax_y.reshape(B)
|
80 |
+
argmax_x = argmax_x.reshape(B)
|
81 |
+
return argmax_y, argmax_x
|
82 |
+
|
83 |
+
def argmax2d(heat, hard=True):
|
84 |
+
B, C, Y, X = list(heat.shape)
|
85 |
+
assert(C==1)
|
86 |
+
|
87 |
+
if hard:
|
88 |
+
# hard argmax
|
89 |
+
loc_y, loc_x = hard_argmax2d(heat)
|
90 |
+
loc_y = loc_y.float()
|
91 |
+
loc_x = loc_x.float()
|
92 |
+
else:
|
93 |
+
heat = heat.reshape(B, Y*X)
|
94 |
+
prob = torch.nn.functional.softmax(heat, dim=1)
|
95 |
+
|
96 |
+
grid_y, grid_x = meshgrid2d(B, Y, X)
|
97 |
+
|
98 |
+
grid_y = grid_y.reshape(B, -1)
|
99 |
+
grid_x = grid_x.reshape(B, -1)
|
100 |
+
|
101 |
+
loc_y = torch.sum(grid_y*prob, dim=1)
|
102 |
+
loc_x = torch.sum(grid_x*prob, dim=1)
|
103 |
+
# these are B
|
104 |
+
|
105 |
+
return loc_y, loc_x
|
106 |
+
|
107 |
+
def reduce_masked_mean(x, mask, dim=None, keepdim=False, broadcast=False):
|
108 |
+
# x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
|
109 |
+
# returns shape-1
|
110 |
+
# axis can be a list of axes
|
111 |
+
if not broadcast:
|
112 |
+
for (a,b) in zip(x.size(), mask.size()):
|
113 |
+
if not a==b:
|
114 |
+
print('some shape mismatch:', x.shape, mask.shape)
|
115 |
+
assert(a==b) # some shape mismatch!
|
116 |
+
# assert(x.size() == mask.size())
|
117 |
+
prod = x*mask
|
118 |
+
if dim is None:
|
119 |
+
numer = torch.sum(prod)
|
120 |
+
denom = EPS+torch.sum(mask)
|
121 |
+
else:
|
122 |
+
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
123 |
+
denom = EPS+torch.sum(mask, dim=dim, keepdim=keepdim)
|
124 |
+
mean = numer/denom
|
125 |
+
return mean
|
126 |
+
|
127 |
+
def reduce_masked_median(x, mask, keep_batch=False):
|
128 |
+
# x and mask are the same shape
|
129 |
+
assert(x.size() == mask.size())
|
130 |
+
device = x.device
|
131 |
+
|
132 |
+
B = list(x.shape)[0]
|
133 |
+
x = x.detach().cpu().numpy()
|
134 |
+
mask = mask.detach().cpu().numpy()
|
135 |
+
|
136 |
+
if keep_batch:
|
137 |
+
x = np.reshape(x, [B, -1])
|
138 |
+
mask = np.reshape(mask, [B, -1])
|
139 |
+
meds = np.zeros([B], np.float32)
|
140 |
+
for b in list(range(B)):
|
141 |
+
xb = x[b]
|
142 |
+
mb = mask[b]
|
143 |
+
if np.sum(mb) > 0:
|
144 |
+
xb = xb[mb > 0]
|
145 |
+
meds[b] = np.median(xb)
|
146 |
+
else:
|
147 |
+
meds[b] = np.nan
|
148 |
+
meds = torch.from_numpy(meds).to(device)
|
149 |
+
return meds.float()
|
150 |
+
else:
|
151 |
+
x = np.reshape(x, [-1])
|
152 |
+
mask = np.reshape(mask, [-1])
|
153 |
+
if np.sum(mask) > 0:
|
154 |
+
x = x[mask > 0]
|
155 |
+
med = np.median(x)
|
156 |
+
else:
|
157 |
+
med = np.nan
|
158 |
+
med = np.array([med], np.float32)
|
159 |
+
med = torch.from_numpy(med).to(device)
|
160 |
+
return med.float()
|
161 |
+
|
162 |
+
def pack_seqdim(tensor, B):
|
163 |
+
shapelist = list(tensor.shape)
|
164 |
+
B_, S = shapelist[:2]
|
165 |
+
assert(B==B_)
|
166 |
+
otherdims = shapelist[2:]
|
167 |
+
tensor = torch.reshape(tensor, [B*S]+otherdims)
|
168 |
+
return tensor
|
169 |
+
|
170 |
+
def unpack_seqdim(tensor, B):
|
171 |
+
shapelist = list(tensor.shape)
|
172 |
+
BS = shapelist[0]
|
173 |
+
assert(BS%B==0)
|
174 |
+
otherdims = shapelist[1:]
|
175 |
+
S = int(BS/B)
|
176 |
+
tensor = torch.reshape(tensor, [B,S]+otherdims)
|
177 |
+
return tensor
|
178 |
+
|
179 |
+
def meshgrid2d(B, Y, X, stack=False, norm=False, device='cuda', on_chans=False):
|
180 |
+
# returns a meshgrid sized B x Y x X
|
181 |
+
|
182 |
+
grid_y = torch.linspace(0.0, Y-1, Y, device=torch.device(device))
|
183 |
+
grid_y = torch.reshape(grid_y, [1, Y, 1])
|
184 |
+
grid_y = grid_y.repeat(B, 1, X)
|
185 |
+
|
186 |
+
grid_x = torch.linspace(0.0, X-1, X, device=torch.device(device))
|
187 |
+
grid_x = torch.reshape(grid_x, [1, 1, X])
|
188 |
+
grid_x = grid_x.repeat(B, Y, 1)
|
189 |
+
|
190 |
+
if norm:
|
191 |
+
grid_y, grid_x = normalize_grid2d(
|
192 |
+
grid_y, grid_x, Y, X)
|
193 |
+
|
194 |
+
if stack:
|
195 |
+
# note we stack in xy order
|
196 |
+
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
|
197 |
+
if on_chans:
|
198 |
+
grid = torch.stack([grid_x, grid_y], dim=1)
|
199 |
+
else:
|
200 |
+
grid = torch.stack([grid_x, grid_y], dim=-1)
|
201 |
+
return grid
|
202 |
+
else:
|
203 |
+
return grid_y, grid_x
|
204 |
+
|
205 |
+
def meshgrid2d_py(B, Y, X, stack=False, on_chans=False):
|
206 |
+
grid_y = np.linspace(0.0, Y-1, Y)
|
207 |
+
grid_y = np.reshape(grid_y, [1, Y, 1])
|
208 |
+
grid_y = np.tile(grid_y, [B, 1, X])
|
209 |
+
|
210 |
+
grid_x = np.linspace(0.0, X-1, X)
|
211 |
+
grid_x = np.reshape(grid_x, [1, 1, X])
|
212 |
+
grid_x = np.tile(grid_x, [B, Y, 1])
|
213 |
+
|
214 |
+
if stack:
|
215 |
+
# note we stack in xy order
|
216 |
+
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
|
217 |
+
if on_chans:
|
218 |
+
grid = np.stack([grid_x, grid_y], axis=1)
|
219 |
+
else:
|
220 |
+
grid = np.stack([grid_x, grid_y], axis=-1)
|
221 |
+
return grid
|
222 |
+
else:
|
223 |
+
# outputs are Y x X
|
224 |
+
return grid_y, grid_x
|
225 |
+
|
226 |
+
|
227 |
+
def meshgrid3d(B, Z, Y, X, stack=False, norm=False, device='cuda'):
|
228 |
+
# returns a meshgrid sized B x Z x Y x X
|
229 |
+
|
230 |
+
grid_z = torch.linspace(0.0, Z-1, Z, device=device)
|
231 |
+
grid_z = torch.reshape(grid_z, [1, Z, 1, 1])
|
232 |
+
grid_z = grid_z.repeat(B, 1, Y, X)
|
233 |
+
|
234 |
+
grid_y = torch.linspace(0.0, Y-1, Y, device=device)
|
235 |
+
grid_y = torch.reshape(grid_y, [1, 1, Y, 1])
|
236 |
+
grid_y = grid_y.repeat(B, Z, 1, X)
|
237 |
+
|
238 |
+
grid_x = torch.linspace(0.0, X-1, X, device=device)
|
239 |
+
grid_x = torch.reshape(grid_x, [1, 1, 1, X])
|
240 |
+
grid_x = grid_x.repeat(B, Z, Y, 1)
|
241 |
+
|
242 |
+
if norm:
|
243 |
+
grid_z, grid_y, grid_x = normalize_grid3d(
|
244 |
+
grid_z, grid_y, grid_x, Z, Y, X)
|
245 |
+
|
246 |
+
if stack:
|
247 |
+
# note we stack in xyz order
|
248 |
+
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
|
249 |
+
grid = torch.stack([grid_x, grid_y, grid_z], dim=-1)
|
250 |
+
return grid
|
251 |
+
else:
|
252 |
+
return grid_z, grid_y, grid_x
|
253 |
+
|
254 |
+
def normalize_grid2d(grid_y, grid_x, Y, X, clamp_extreme=True):
|
255 |
+
# make things in [-1,1]
|
256 |
+
grid_y = 2.0*(grid_y / float(Y-1)) - 1.0
|
257 |
+
grid_x = 2.0*(grid_x / float(X-1)) - 1.0
|
258 |
+
|
259 |
+
if clamp_extreme:
|
260 |
+
grid_y = torch.clamp(grid_y, min=-2.0, max=2.0)
|
261 |
+
grid_x = torch.clamp(grid_x, min=-2.0, max=2.0)
|
262 |
+
|
263 |
+
return grid_y, grid_x
|
264 |
+
|
265 |
+
def normalize_grid3d(grid_z, grid_y, grid_x, Z, Y, X, clamp_extreme=True):
|
266 |
+
# make things in [-1,1]
|
267 |
+
grid_z = 2.0*(grid_z / float(Z-1)) - 1.0
|
268 |
+
grid_y = 2.0*(grid_y / float(Y-1)) - 1.0
|
269 |
+
grid_x = 2.0*(grid_x / float(X-1)) - 1.0
|
270 |
+
|
271 |
+
if clamp_extreme:
|
272 |
+
grid_z = torch.clamp(grid_z, min=-2.0, max=2.0)
|
273 |
+
grid_y = torch.clamp(grid_y, min=-2.0, max=2.0)
|
274 |
+
grid_x = torch.clamp(grid_x, min=-2.0, max=2.0)
|
275 |
+
|
276 |
+
return grid_z, grid_y, grid_x
|
277 |
+
|
278 |
+
def gridcloud2d(B, Y, X, norm=False, device='cuda'):
|
279 |
+
# we want to sample for each location in the grid
|
280 |
+
grid_y, grid_x = meshgrid2d(B, Y, X, norm=norm, device=device)
|
281 |
+
x = torch.reshape(grid_x, [B, -1])
|
282 |
+
y = torch.reshape(grid_y, [B, -1])
|
283 |
+
# these are B x N
|
284 |
+
xy = torch.stack([x, y], dim=2)
|
285 |
+
# this is B x N x 2
|
286 |
+
return xy
|
287 |
+
|
288 |
+
def gridcloud3d(B, Z, Y, X, norm=False, device='cuda'):
|
289 |
+
# we want to sample for each location in the grid
|
290 |
+
grid_z, grid_y, grid_x = meshgrid3d(B, Z, Y, X, norm=norm, device=device)
|
291 |
+
x = torch.reshape(grid_x, [B, -1])
|
292 |
+
y = torch.reshape(grid_y, [B, -1])
|
293 |
+
z = torch.reshape(grid_z, [B, -1])
|
294 |
+
# these are B x N
|
295 |
+
xyz = torch.stack([x, y, z], dim=2)
|
296 |
+
# this is B x N x 3
|
297 |
+
return xyz
|
298 |
+
|
299 |
+
# import re
|
300 |
+
# def readPFM(file):
|
301 |
+
# file = open(file, 'rb')
|
302 |
+
|
303 |
+
# color = None
|
304 |
+
# width = None
|
305 |
+
# height = None
|
306 |
+
# scale = None
|
307 |
+
# endian = None
|
308 |
+
|
309 |
+
# header = file.readline().rstrip()
|
310 |
+
# if header == b'PF':
|
311 |
+
# color = True
|
312 |
+
# elif header == b'Pf':
|
313 |
+
# color = False
|
314 |
+
# else:
|
315 |
+
# raise Exception('Not a PFM file.')
|
316 |
+
|
317 |
+
# dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
|
318 |
+
# if dim_match:
|
319 |
+
# width, height = map(int, dim_match.groups())
|
320 |
+
# else:
|
321 |
+
# raise Exception('Malformed PFM header.')
|
322 |
+
|
323 |
+
# scale = float(file.readline().rstrip())
|
324 |
+
# if scale < 0: # little-endian
|
325 |
+
# endian = '<'
|
326 |
+
# scale = -scale
|
327 |
+
# else:
|
328 |
+
# endian = '>' # big-endian
|
329 |
+
|
330 |
+
# data = np.fromfile(file, endian + 'f')
|
331 |
+
# shape = (height, width, 3) if color else (height, width)
|
332 |
+
|
333 |
+
# data = np.reshape(data, shape)
|
334 |
+
# data = np.flipud(data)
|
335 |
+
# return data
|
336 |
+
|
337 |
+
def normalize_boxlist2d(boxlist2d, H, W):
|
338 |
+
boxlist2d = boxlist2d.clone()
|
339 |
+
ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
|
340 |
+
ymin = ymin / float(H)
|
341 |
+
ymax = ymax / float(H)
|
342 |
+
xmin = xmin / float(W)
|
343 |
+
xmax = xmax / float(W)
|
344 |
+
boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
|
345 |
+
return boxlist2d
|
346 |
+
|
347 |
+
def unnormalize_boxlist2d(boxlist2d, H, W):
|
348 |
+
boxlist2d = boxlist2d.clone()
|
349 |
+
ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
|
350 |
+
ymin = ymin * float(H)
|
351 |
+
ymax = ymax * float(H)
|
352 |
+
xmin = xmin * float(W)
|
353 |
+
xmax = xmax * float(W)
|
354 |
+
boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
|
355 |
+
return boxlist2d
|
356 |
+
|
357 |
+
def unnormalize_box2d(box2d, H, W):
|
358 |
+
return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
|
359 |
+
|
360 |
+
def normalize_box2d(box2d, H, W):
|
361 |
+
return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
|
362 |
+
|
363 |
+
def get_gaussian_kernel_2d(channels, kernel_size=3, sigma=2.0, mid_one=False, device='cuda'):
|
364 |
+
C = channels
|
365 |
+
xy_grid = gridcloud2d(C, kernel_size, kernel_size, device=device) # C x N x 2
|
366 |
+
|
367 |
+
mean = (kernel_size - 1)/2.0
|
368 |
+
variance = sigma**2.0
|
369 |
+
|
370 |
+
gaussian_kernel = (1.0/(2.0*np.pi*variance)**1.5) * torch.exp(-torch.sum((xy_grid - mean)**2.0, dim=-1) / (2.0*variance)) # C X N
|
371 |
+
gaussian_kernel = gaussian_kernel.view(C, 1, kernel_size, kernel_size) # C x 1 x 3 x 3
|
372 |
+
kernel_sum = torch.sum(gaussian_kernel, dim=(2,3), keepdim=True)
|
373 |
+
|
374 |
+
gaussian_kernel = gaussian_kernel / kernel_sum # normalize
|
375 |
+
|
376 |
+
if mid_one:
|
377 |
+
# normalize so that the middle element is 1
|
378 |
+
maxval = gaussian_kernel[:,:,(kernel_size//2),(kernel_size//2)].reshape(C, 1, 1, 1)
|
379 |
+
gaussian_kernel = gaussian_kernel / maxval
|
380 |
+
|
381 |
+
return gaussian_kernel
|
382 |
+
|
383 |
+
def gaussian_blur_2d(input, kernel_size=3, sigma=2.0, reflect_pad=False, mid_one=False):
|
384 |
+
B, C, Z, X = input.shape
|
385 |
+
kernel = get_gaussian_kernel_2d(C, kernel_size, sigma, mid_one=mid_one, device=input.device)
|
386 |
+
if reflect_pad:
|
387 |
+
pad = (kernel_size - 1)//2
|
388 |
+
out = F.pad(input, (pad, pad, pad, pad), mode='reflect')
|
389 |
+
out = F.conv2d(out, kernel, padding=0, groups=C)
|
390 |
+
else:
|
391 |
+
out = F.conv2d(input, kernel, padding=(kernel_size - 1)//2, groups=C)
|
392 |
+
return out
|
393 |
+
|
394 |
+
def gradient2d(x, absolute=False, square=False, return_sum=False):
|
395 |
+
# x should be B x C x H x W
|
396 |
+
dh = x[:, :, 1:, :] - x[:, :, :-1, :]
|
397 |
+
dw = x[:, :, :, 1:] - x[:, :, :, :-1]
|
398 |
+
|
399 |
+
zeros = torch.zeros_like(x)
|
400 |
+
zero_h = zeros[:, :, 0:1, :]
|
401 |
+
zero_w = zeros[:, :, :, 0:1]
|
402 |
+
dh = torch.cat([dh, zero_h], axis=2)
|
403 |
+
dw = torch.cat([dw, zero_w], axis=3)
|
404 |
+
if absolute:
|
405 |
+
dh = torch.abs(dh)
|
406 |
+
dw = torch.abs(dw)
|
407 |
+
if square:
|
408 |
+
dh = dh ** 2
|
409 |
+
dw = dw ** 2
|
410 |
+
if return_sum:
|
411 |
+
return dh+dw
|
412 |
+
else:
|
413 |
+
return dh, dw
|
414 |
+
|
415 |
+
def gradient1d(x, absolute=False, square=False):
|
416 |
+
# x should be B,S,C
|
417 |
+
dx = x[:, 1:] - x[:, :-1]
|
418 |
+
zero = torch.zeros_like(x[:,0:1])
|
419 |
+
dx = torch.cat([dx, zero], axis=1)
|
420 |
+
if absolute:
|
421 |
+
dx = torch.abs(dx)
|
422 |
+
if square:
|
423 |
+
dx = dx ** 2
|
424 |
+
return dx
|
425 |
+
|
426 |
+
def smart_cat(tensor1, tensor2, dim):
|
427 |
+
if tensor1 is None:
|
428 |
+
return tensor2
|
429 |
+
return torch.cat([tensor1, tensor2], dim=dim)
|
utils/data.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import dataclasses
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Any, Optional, Dict
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass(eq=False)
|
9 |
+
class VideoData:
|
10 |
+
"""
|
11 |
+
Dataclass for storing video tracks data.
|
12 |
+
"""
|
13 |
+
|
14 |
+
video: torch.Tensor # B, S, C, H, W
|
15 |
+
trajs: torch.Tensor # B, S, N, 2
|
16 |
+
visibs: torch.Tensor # B, S, N
|
17 |
+
# optional data
|
18 |
+
valids: Optional[torch.Tensor] = None # B, S, N
|
19 |
+
hards: Optional[torch.Tensor] = None # B, S, N
|
20 |
+
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
|
21 |
+
seq_name: Optional[str] = None
|
22 |
+
dname: Optional[str] = None
|
23 |
+
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
|
24 |
+
transforms: Optional[Dict[str, Any]] = None
|
25 |
+
aug_video: Optional[torch.Tensor] = None
|
26 |
+
|
27 |
+
|
28 |
+
def collate_fn(batch):
|
29 |
+
"""
|
30 |
+
Collate function for video tracks data.
|
31 |
+
"""
|
32 |
+
video = torch.stack([b.video for b in batch], dim=0)
|
33 |
+
trajs = torch.stack([b.trajs for b in batch], dim=0)
|
34 |
+
visibs = torch.stack([b.visibs for b in batch], dim=0)
|
35 |
+
query_points = segmentation = None
|
36 |
+
if batch[0].query_points is not None:
|
37 |
+
query_points = torch.stack([b.query_points for b in batch], dim=0)
|
38 |
+
if batch[0].segmentation is not None:
|
39 |
+
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
|
40 |
+
seq_name = [b.seq_name for b in batch]
|
41 |
+
dname = [b.dname for b in batch]
|
42 |
+
|
43 |
+
return VideoData(
|
44 |
+
video=video,
|
45 |
+
trajs=trajs,
|
46 |
+
visibs=visibs,
|
47 |
+
segmentation=segmentation,
|
48 |
+
seq_name=seq_name,
|
49 |
+
dname=dname,
|
50 |
+
query_points=query_points,
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def collate_fn_train(batch):
|
55 |
+
"""
|
56 |
+
Collate function for video tracks data during training.
|
57 |
+
"""
|
58 |
+
gotit = [gotit for _, gotit in batch]
|
59 |
+
video = torch.stack([b.video for b, _ in batch], dim=0)
|
60 |
+
trajs = torch.stack([b.trajs for b, _ in batch], dim=0)
|
61 |
+
visibs = torch.stack([b.visibs for b, _ in batch], dim=0)
|
62 |
+
valids = torch.stack([b.valids for b, _ in batch], dim=0)
|
63 |
+
seq_name = [b.seq_name for b, _ in batch]
|
64 |
+
dname = [b.dname for b, _ in batch]
|
65 |
+
query_points = transforms = aug_video = hards = None
|
66 |
+
if batch[0][0].query_points is not None:
|
67 |
+
query_points = torch.stack([b.query_points for b, _ in batch], dim=0)
|
68 |
+
if batch[0][0].hards is not None:
|
69 |
+
hards = torch.stack([b.hards for b, _ in batch], dim=0)
|
70 |
+
|
71 |
+
if batch[0][0].transforms is not None:
|
72 |
+
transforms = [b.transforms for b, _ in batch]
|
73 |
+
|
74 |
+
if batch[0][0].aug_video is not None:
|
75 |
+
aug_video = torch.stack([b.aug_video for b, _ in batch], dim=0)
|
76 |
+
return (
|
77 |
+
VideoData(
|
78 |
+
video=video,
|
79 |
+
trajs=trajs,
|
80 |
+
visibs=visibs,
|
81 |
+
valids=valids,
|
82 |
+
hards=hards,
|
83 |
+
seq_name=seq_name,
|
84 |
+
dname=dname,
|
85 |
+
query_points=query_points,
|
86 |
+
aug_video=aug_video,
|
87 |
+
transforms=transforms,
|
88 |
+
),
|
89 |
+
gotit,
|
90 |
+
)
|
91 |
+
|
92 |
+
|
93 |
+
def try_to_cuda(t: Any) -> Any:
|
94 |
+
"""
|
95 |
+
Try to move the input variable `t` to a cuda device.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
t: Input.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
t_cuda: `t` moved to a cuda device, if supported.
|
102 |
+
"""
|
103 |
+
try:
|
104 |
+
t = t.float().cuda()
|
105 |
+
except AttributeError:
|
106 |
+
pass
|
107 |
+
return t
|
108 |
+
|
109 |
+
|
110 |
+
def dataclass_to_cuda_(obj):
|
111 |
+
"""
|
112 |
+
Move all contents of a dataclass to cuda inplace if supported.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
batch: Input dataclass.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
batch_cuda: `batch` moved to a cuda device, if supported.
|
119 |
+
"""
|
120 |
+
for f in dataclasses.fields(obj):
|
121 |
+
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
|
122 |
+
return obj
|
utils/geom.py
ADDED
@@ -0,0 +1,771 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import utils.basic
|
3 |
+
# import utils.box
|
4 |
+
# import utils.vox
|
5 |
+
import numpy as np
|
6 |
+
import torchvision.ops as ops
|
7 |
+
from utils.basic import print_
|
8 |
+
|
9 |
+
def split_intrinsics(K):
|
10 |
+
# K is B x 3 x 3 or B x 4 x 4
|
11 |
+
fx = K[:,0,0]
|
12 |
+
fy = K[:,1,1]
|
13 |
+
x0 = K[:,0,2]
|
14 |
+
y0 = K[:,1,2]
|
15 |
+
return fx, fy, x0, y0
|
16 |
+
|
17 |
+
# def apply_pix_T_cam_py(pix_T_cam, xyz):
|
18 |
+
|
19 |
+
# fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
20 |
+
|
21 |
+
# # xyz is shaped B x H*W x 3
|
22 |
+
# # returns xy, shaped B x H*W x 2
|
23 |
+
|
24 |
+
# B, N, C = list(xyz.shape)
|
25 |
+
# assert(C==3)
|
26 |
+
|
27 |
+
# x, y, z = xyz[:,:,0], xyz[:,:,1], xyz[:,:,2]
|
28 |
+
|
29 |
+
# fx = np.reshape(fx, [B, 1])
|
30 |
+
# fy = np.reshape(fy, [B, 1])
|
31 |
+
# x0 = np.reshape(x0, [B, 1])
|
32 |
+
# y0 = np.reshape(y0, [B, 1])
|
33 |
+
|
34 |
+
# EPS = 1e-4
|
35 |
+
# z = np.clip(z, EPS, None)
|
36 |
+
# x = (x*fx)/(z)+x0
|
37 |
+
# y = (y*fy)/(z)+y0
|
38 |
+
# xy = np.stack([x, y], axis=-1)
|
39 |
+
# return xy
|
40 |
+
|
41 |
+
def apply_pix_T_cam(pix_T_cam, xyz):
|
42 |
+
|
43 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
44 |
+
|
45 |
+
# xyz is shaped B x H*W x 3
|
46 |
+
# returns xy, shaped B x H*W x 2
|
47 |
+
|
48 |
+
B, N, C = list(xyz.shape)
|
49 |
+
assert(C==3)
|
50 |
+
|
51 |
+
x, y, z = torch.unbind(xyz, axis=-1)
|
52 |
+
|
53 |
+
fx = torch.reshape(fx, [B, 1])
|
54 |
+
fy = torch.reshape(fy, [B, 1])
|
55 |
+
x0 = torch.reshape(x0, [B, 1])
|
56 |
+
y0 = torch.reshape(y0, [B, 1])
|
57 |
+
|
58 |
+
EPS = 1e-4
|
59 |
+
z = torch.clamp(z, min=EPS)
|
60 |
+
x = (x*fx)/(z)+x0
|
61 |
+
y = (y*fy)/(z)+y0
|
62 |
+
xy = torch.stack([x, y], axis=-1)
|
63 |
+
return xy
|
64 |
+
|
65 |
+
# def apply_4x4_py(RT, xyz):
|
66 |
+
# # print('RT', RT.shape)
|
67 |
+
# B, N, _ = list(xyz.shape)
|
68 |
+
# ones = np.ones_like(xyz[:,:,0:1])
|
69 |
+
# xyz1 = np.concatenate([xyz, ones], 2)
|
70 |
+
# # print('xyz1', xyz1.shape)
|
71 |
+
# xyz1_t = xyz1.transpose(0,2,1)
|
72 |
+
# # print('xyz1_t', xyz1_t.shape)
|
73 |
+
# # this is B x 4 x N
|
74 |
+
# xyz2_t = np.matmul(RT, xyz1_t)
|
75 |
+
# # print('xyz2_t', xyz2_t.shape)
|
76 |
+
# xyz2 = xyz2_t.transpose(0,2,1)
|
77 |
+
# # print('xyz2', xyz2.shape)
|
78 |
+
# xyz2 = xyz2[:,:,:3]
|
79 |
+
# return xyz2
|
80 |
+
|
81 |
+
def apply_4x4(RT, xyz):
|
82 |
+
B, N, _ = list(xyz.shape)
|
83 |
+
ones = torch.ones_like(xyz[:,:,0:1])
|
84 |
+
xyz1 = torch.cat([xyz, ones], 2)
|
85 |
+
xyz1_t = torch.transpose(xyz1, 1, 2)
|
86 |
+
# this is B x 4 x N
|
87 |
+
xyz2_t = torch.matmul(RT, xyz1_t)
|
88 |
+
xyz2 = torch.transpose(xyz2_t, 1, 2)
|
89 |
+
xyz2 = xyz2[:,:,:3]
|
90 |
+
return xyz2
|
91 |
+
|
92 |
+
def apply_3x3(RT, xy):
|
93 |
+
B, N, _ = list(xy.shape)
|
94 |
+
ones = torch.ones_like(xy[:,:,0:1])
|
95 |
+
xy1 = torch.cat([xy, ones], 2)
|
96 |
+
xy1_t = torch.transpose(xy1, 1, 2)
|
97 |
+
# this is B x 4 x N
|
98 |
+
xy2_t = torch.matmul(RT, xy1_t)
|
99 |
+
xy2 = torch.transpose(xy2_t, 1, 2)
|
100 |
+
xy2 = xy2[:,:,:2]
|
101 |
+
return xy2
|
102 |
+
|
103 |
+
def generate_polygon(ctr_x, ctr_y, avg_r, irregularity, spikiness, num_verts):
|
104 |
+
'''
|
105 |
+
Start with the center of the polygon at ctr_x, ctr_y,
|
106 |
+
Then creates the polygon by sampling points on a circle around the center.
|
107 |
+
Random noise is added by varying the angular spacing between sequential points,
|
108 |
+
and by varying the radial distance of each point from the centre.
|
109 |
+
|
110 |
+
Params:
|
111 |
+
ctr_x, ctr_y - coordinates of the "centre" of the polygon
|
112 |
+
avg_r - in px, the average radius of this polygon, this roughly controls how large the polygon is, really only useful for order of magnitude.
|
113 |
+
irregularity - [0,1] indicating how much variance there is in the angular spacing of vertices. [0,1] will map to [0, 2pi/numberOfVerts]
|
114 |
+
spikiness - [0,1] indicating how much variance there is in each vertex from the circle of radius avg_r. [0,1] will map to [0, avg_r]
|
115 |
+
num_verts
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
np.array [num_verts, 2] - CCW order.
|
119 |
+
'''
|
120 |
+
# spikiness
|
121 |
+
spikiness = np.clip(spikiness, 0, 1) * avg_r
|
122 |
+
|
123 |
+
# generate n angle steps
|
124 |
+
irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / num_verts
|
125 |
+
lower = (2*np.pi / num_verts) - irregularity
|
126 |
+
upper = (2*np.pi / num_verts) + irregularity
|
127 |
+
|
128 |
+
# angle steps
|
129 |
+
angle_steps = np.random.uniform(lower, upper, num_verts)
|
130 |
+
sc = (2 * np.pi) / angle_steps.sum()
|
131 |
+
angle_steps *= sc
|
132 |
+
|
133 |
+
# get all radii
|
134 |
+
angle = np.random.uniform(0, 2*np.pi)
|
135 |
+
radii = np.clip(np.random.normal(avg_r, spikiness, num_verts), 0, 2 * avg_r)
|
136 |
+
|
137 |
+
# compute all points
|
138 |
+
points = []
|
139 |
+
for i in range(num_verts):
|
140 |
+
x = ctr_x + radii[i] * np.cos(angle)
|
141 |
+
y = ctr_y + radii[i] * np.sin(angle)
|
142 |
+
points.append([x, y])
|
143 |
+
angle += angle_steps[i]
|
144 |
+
|
145 |
+
return np.array(points).astype(int)
|
146 |
+
|
147 |
+
|
148 |
+
def get_random_affine_2d(B, rot_min=-5.0, rot_max=5.0, tx_min=-0.1, tx_max=0.1, ty_min=-0.1, ty_max=0.1, sx_min=-0.05, sx_max=0.05, sy_min=-0.05, sy_max=0.05, shx_min=-0.05, shx_max=0.05, shy_min=-0.05, shy_max=0.05):
|
149 |
+
'''
|
150 |
+
Params:
|
151 |
+
rot_min: rotation amount min
|
152 |
+
rot_max: rotation amount max
|
153 |
+
|
154 |
+
tx_min: translation x min
|
155 |
+
tx_max: translation x max
|
156 |
+
|
157 |
+
ty_min: translation y min
|
158 |
+
ty_max: translation y max
|
159 |
+
|
160 |
+
sx_min: scaling x min
|
161 |
+
sx_max: scaling x max
|
162 |
+
|
163 |
+
sy_min: scaling y min
|
164 |
+
sy_max: scaling y max
|
165 |
+
|
166 |
+
shx_min: shear x min
|
167 |
+
shx_max: shear x max
|
168 |
+
|
169 |
+
shy_min: shear y min
|
170 |
+
shy_max: shear y max
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
transformation matrix: (B, 3, 3)
|
174 |
+
'''
|
175 |
+
# rotation
|
176 |
+
if rot_max - rot_min != 0:
|
177 |
+
rot_amount = np.random.uniform(low=rot_min, high=rot_max, size=B)
|
178 |
+
rot_amount = np.pi/180.0*rot_amount
|
179 |
+
else:
|
180 |
+
rot_amount = rot_min
|
181 |
+
rotation = np.zeros((B, 3, 3)) # B, 3, 3
|
182 |
+
rotation[:, 2, 2] = 1
|
183 |
+
rotation[:, 0, 0] = np.cos(rot_amount)
|
184 |
+
rotation[:, 0, 1] = -np.sin(rot_amount)
|
185 |
+
rotation[:, 1, 0] = np.sin(rot_amount)
|
186 |
+
rotation[:, 1, 1] = np.cos(rot_amount)
|
187 |
+
|
188 |
+
# translation
|
189 |
+
translation = np.zeros((B, 3, 3)) # B, 3, 3
|
190 |
+
translation[:, [0,1,2], [0,1,2]] = 1
|
191 |
+
if (tx_max - tx_min) > 0:
|
192 |
+
trans_x = np.random.uniform(low=tx_min, high=tx_max, size=B)
|
193 |
+
translation[:, 0, 2] = trans_x
|
194 |
+
# else:
|
195 |
+
# translation[:, 0, 2] = tx_max
|
196 |
+
if ty_max - ty_min != 0:
|
197 |
+
trans_y = np.random.uniform(low=ty_min, high=ty_max, size=B)
|
198 |
+
translation[:, 1, 2] = trans_y
|
199 |
+
# else:
|
200 |
+
# translation[:, 1, 2] = ty_max
|
201 |
+
|
202 |
+
# scaling
|
203 |
+
scaling = np.zeros((B, 3, 3)) # B, 3, 3
|
204 |
+
scaling[:, [0,1,2], [0,1,2]] = 1
|
205 |
+
if (sx_max - sx_min) > 0:
|
206 |
+
scale_x = 1 + np.random.uniform(low=sx_min, high=sx_max, size=B)
|
207 |
+
scaling[:, 0, 0] = scale_x
|
208 |
+
# else:
|
209 |
+
# scaling[:, 0, 0] = sx_max
|
210 |
+
if (sy_max - sy_min) > 0:
|
211 |
+
scale_y = 1 + np.random.uniform(low=sy_min, high=sy_max, size=B)
|
212 |
+
scaling[:, 1, 1] = scale_y
|
213 |
+
# else:
|
214 |
+
# scaling[:, 1, 1] = sy_max
|
215 |
+
|
216 |
+
# shear
|
217 |
+
shear = np.zeros((B, 3, 3)) # B, 3, 3
|
218 |
+
shear[:, [0,1,2], [0,1,2]] = 1
|
219 |
+
if (shx_max - shx_min) > 0:
|
220 |
+
shear_x = np.random.uniform(low=shx_min, high=shx_max, size=B)
|
221 |
+
shear[:, 0, 1] = shear_x
|
222 |
+
# else:
|
223 |
+
# shear[:, 0, 1] = shx_max
|
224 |
+
if (shy_max - shy_min) > 0:
|
225 |
+
shear_y = np.random.uniform(low=shy_min, high=shy_max, size=B)
|
226 |
+
shear[:, 1, 0] = shear_y
|
227 |
+
# else:
|
228 |
+
# shear[:, 1, 0] = shy_max
|
229 |
+
|
230 |
+
# compose all those
|
231 |
+
rt = np.einsum("ijk,ikl->ijl", rotation, translation)
|
232 |
+
ss = np.einsum("ijk,ikl->ijl", scaling, shear)
|
233 |
+
trans = np.einsum("ijk,ikl->ijl", rt, ss)
|
234 |
+
|
235 |
+
return trans
|
236 |
+
|
237 |
+
def get_centroid_from_box2d(box2d):
|
238 |
+
ymin = box2d[:,0]
|
239 |
+
xmin = box2d[:,1]
|
240 |
+
ymax = box2d[:,2]
|
241 |
+
xmax = box2d[:,3]
|
242 |
+
x = (xmin+xmax)/2.0
|
243 |
+
y = (ymin+ymax)/2.0
|
244 |
+
return y, x
|
245 |
+
|
246 |
+
def normalize_boxlist2d(boxlist2d, H, W):
|
247 |
+
boxlist2d = boxlist2d.clone()
|
248 |
+
ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
|
249 |
+
ymin = ymin / float(H)
|
250 |
+
ymax = ymax / float(H)
|
251 |
+
xmin = xmin / float(W)
|
252 |
+
xmax = xmax / float(W)
|
253 |
+
boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
|
254 |
+
return boxlist2d
|
255 |
+
|
256 |
+
def unnormalize_boxlist2d(boxlist2d, H, W):
|
257 |
+
boxlist2d = boxlist2d.clone()
|
258 |
+
ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
|
259 |
+
ymin = ymin * float(H)
|
260 |
+
ymax = ymax * float(H)
|
261 |
+
xmin = xmin * float(W)
|
262 |
+
xmax = xmax * float(W)
|
263 |
+
boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
|
264 |
+
return boxlist2d
|
265 |
+
|
266 |
+
def unnormalize_box2d(box2d, H, W):
|
267 |
+
return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
|
268 |
+
|
269 |
+
def normalize_box2d(box2d, H, W):
|
270 |
+
return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
|
271 |
+
|
272 |
+
def get_size_from_box2d(box2d):
|
273 |
+
ymin = box2d[:,0]
|
274 |
+
xmin = box2d[:,1]
|
275 |
+
ymax = box2d[:,2]
|
276 |
+
xmax = box2d[:,3]
|
277 |
+
height = ymax-ymin
|
278 |
+
width = xmax-xmin
|
279 |
+
return height, width
|
280 |
+
|
281 |
+
def crop_and_resize(im, boxlist, PH, PW, boxlist_is_normalized=False):
|
282 |
+
B, C, H, W = im.shape
|
283 |
+
B2, N, D = boxlist.shape
|
284 |
+
assert(B==B2)
|
285 |
+
assert(D==4)
|
286 |
+
# PH, PW is the size to resize to
|
287 |
+
|
288 |
+
# print('im', im.shape)
|
289 |
+
|
290 |
+
# output is B,N,C,PH,PW
|
291 |
+
|
292 |
+
# pt wants xy xy, unnormalized
|
293 |
+
if boxlist_is_normalized:
|
294 |
+
boxlist_unnorm = unnormalize_boxlist2d(boxlist, H, W)
|
295 |
+
else:
|
296 |
+
boxlist_unnorm = boxlist
|
297 |
+
|
298 |
+
ymin, xmin, ymax, xmax = boxlist_unnorm.unbind(2)
|
299 |
+
# boxlist_pt = torch.stack([boxlist_unnorm[:,1], boxlist_unnorm[:,0], boxlist_unnorm[:,3], boxlist_unnorm[:,2]], dim=1)
|
300 |
+
boxlist_pt = torch.stack([xmin, ymin, xmax, ymax], dim=2)
|
301 |
+
# we want a B-len list of K x 4 arrays
|
302 |
+
|
303 |
+
# print('im', im.shape)
|
304 |
+
# print('boxlist', boxlist.shape)
|
305 |
+
# print('boxlist_pt', boxlist_pt.shape)
|
306 |
+
|
307 |
+
# boxlist_pt = list(boxlist_pt.unbind(0))
|
308 |
+
|
309 |
+
# crops = []
|
310 |
+
# for b in range(B):
|
311 |
+
# crops_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW))
|
312 |
+
# crops.append(crops_b)
|
313 |
+
# # # crops = im
|
314 |
+
|
315 |
+
# inds = torch.arange(B).reshape(B,1)
|
316 |
+
# # boxlist_pt = torch.cat([inds, boxlist_
|
317 |
+
# boxlist_pt = list(boxlist_pt.unbind(0)
|
318 |
+
|
319 |
+
# crops = ops.roi_align(im, [boxlist_pt[b]], output_size=(PH, PW))
|
320 |
+
crops = ops.roi_align(im, list(boxlist_pt.unbind(0)), output_size=(PH, PW))
|
321 |
+
|
322 |
+
|
323 |
+
# print('crops', crops.shape)
|
324 |
+
# crops = crops.reshape(B,N,C,PH,PW)
|
325 |
+
|
326 |
+
|
327 |
+
# crops = []
|
328 |
+
# for b in range(B):
|
329 |
+
# crop_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW))
|
330 |
+
# print('crop_b', crop_b.shape)
|
331 |
+
# crops.append(crop_b)
|
332 |
+
# crops = torch.stack(crops, dim=0)
|
333 |
+
|
334 |
+
# print('crops', crops.shape)
|
335 |
+
# boxlist_list = boxlist_pt.unbind(0)
|
336 |
+
# print('rgb_crop', rgb_crop.shape)
|
337 |
+
|
338 |
+
return crops
|
339 |
+
|
340 |
+
|
341 |
+
# def get_boxlist_from_centroid_and_size(cy, cx, h, w, clip=True):
|
342 |
+
# # cy,cx are both B,N
|
343 |
+
# ymin = cy - h/2
|
344 |
+
# ymax = cy + h/2
|
345 |
+
# xmin = cx - w/2
|
346 |
+
# xmax = cx + w/2
|
347 |
+
|
348 |
+
# box = torch.stack([ymin, xmin, ymax, xmax], dim=-1)
|
349 |
+
# if clip:
|
350 |
+
# box = torch.clamp(box, 0, 1)
|
351 |
+
# return box
|
352 |
+
|
353 |
+
|
354 |
+
def get_boxlist_from_centroid_and_size(cy, cx, h, w):#, clip=False):
|
355 |
+
# cy,cx are the same shape
|
356 |
+
ymin = cy - h/2
|
357 |
+
ymax = cy + h/2
|
358 |
+
xmin = cx - w/2
|
359 |
+
xmax = cx + w/2
|
360 |
+
|
361 |
+
# if clip:
|
362 |
+
# ymin = torch.clamp(ymin, 0, H-1)
|
363 |
+
# ymax = torch.clamp(ymax, 0, H-1)
|
364 |
+
# xmin = torch.clamp(xmin, 0, W-1)
|
365 |
+
# xmax = torch.clamp(xmax, 0, W-1)
|
366 |
+
|
367 |
+
box = torch.stack([ymin, xmin, ymax, xmax], dim=-1)
|
368 |
+
return box
|
369 |
+
|
370 |
+
|
371 |
+
def get_box2d_from_mask(mask, normalize=False):
|
372 |
+
# mask is B, 1, H, W
|
373 |
+
|
374 |
+
B, C, H, W = mask.shape
|
375 |
+
assert(C==1)
|
376 |
+
xy = utils.basic.gridcloud2d(B, H, W, norm=False, device=mask.device) # B, H*W, 2
|
377 |
+
|
378 |
+
box = torch.zeros((B, 4), dtype=torch.float32, device=mask.device)
|
379 |
+
for b in range(B):
|
380 |
+
xy_b = xy[b] # H*W, 2
|
381 |
+
mask_b = mask[b].reshape(H*W)
|
382 |
+
xy_ = xy_b[mask_b > 0]
|
383 |
+
x_ = xy_[:,0]
|
384 |
+
y_ = xy_[:,1]
|
385 |
+
ymin = torch.min(y_)
|
386 |
+
ymax = torch.max(y_)
|
387 |
+
xmin = torch.min(x_)
|
388 |
+
xmax = torch.max(x_)
|
389 |
+
box[b] = torch.stack([ymin, xmin, ymax, xmax], dim=0)
|
390 |
+
if normalize:
|
391 |
+
box = normalize_boxlist2d(box.unsqueeze(1), H, W).squeeze(1)
|
392 |
+
return box
|
393 |
+
|
394 |
+
def convert_box2d_to_intrinsics(box2d, pix_T_cam, H, W, use_image_aspect_ratio=True, mult_padding=1.0):
|
395 |
+
# box2d is B x 4, with ymin, xmin, ymax, xmax in normalized coords
|
396 |
+
# ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1)
|
397 |
+
# H, W is the original size of the image
|
398 |
+
# mult_padding is relative to object size in pixels
|
399 |
+
|
400 |
+
# i assume we're rendering an image the same size as the original (H, W)
|
401 |
+
|
402 |
+
if not mult_padding==1.0:
|
403 |
+
y, x = get_centroid_from_box2d(box2d)
|
404 |
+
h, w = get_size_from_box2d(box2d)
|
405 |
+
box2d = get_box2d_from_centroid_and_size(
|
406 |
+
y, x, h*mult_padding, w*mult_padding, clip=False)
|
407 |
+
|
408 |
+
if use_image_aspect_ratio:
|
409 |
+
h, w = get_size_from_box2d(box2d)
|
410 |
+
y, x = get_centroid_from_box2d(box2d)
|
411 |
+
|
412 |
+
# note h,w are relative right now
|
413 |
+
# we need to undo this, to see the real ratio
|
414 |
+
|
415 |
+
h = h*float(H)
|
416 |
+
w = w*float(W)
|
417 |
+
box_ratio = h/w
|
418 |
+
im_ratio = H/float(W)
|
419 |
+
|
420 |
+
# print('box_ratio:', box_ratio)
|
421 |
+
# print('im_ratio:', im_ratio)
|
422 |
+
|
423 |
+
if box_ratio >= im_ratio:
|
424 |
+
w = h/im_ratio
|
425 |
+
# print('setting w:', h/im_ratio)
|
426 |
+
else:
|
427 |
+
h = w*im_ratio
|
428 |
+
# print('setting h:', w*im_ratio)
|
429 |
+
|
430 |
+
box2d = get_box2d_from_centroid_and_size(
|
431 |
+
y, x, h/float(H), w/float(W), clip=False)
|
432 |
+
|
433 |
+
assert(h > 1e-4)
|
434 |
+
assert(w > 1e-4)
|
435 |
+
|
436 |
+
ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1)
|
437 |
+
|
438 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
439 |
+
|
440 |
+
# the topleft of the new image will now have a different offset from the center of projection
|
441 |
+
|
442 |
+
new_x0 = x0 - xmin*W
|
443 |
+
new_y0 = y0 - ymin*H
|
444 |
+
|
445 |
+
pix_T_cam = pack_intrinsics(fx, fy, new_x0, new_y0)
|
446 |
+
# this alone will give me an image in original resolution,
|
447 |
+
# with its topleft at the box corner
|
448 |
+
|
449 |
+
box_h, box_w = get_size_from_box2d(box2d)
|
450 |
+
# these are normalized, and shaped B. (e.g., [0.4], [0.3])
|
451 |
+
|
452 |
+
# we are going to scale the image by the inverse of this,
|
453 |
+
# since we are zooming into this area
|
454 |
+
|
455 |
+
sy = 1./box_h
|
456 |
+
sx = 1./box_w
|
457 |
+
|
458 |
+
pix_T_cam = scale_intrinsics(pix_T_cam, sx, sy)
|
459 |
+
return pix_T_cam, box2d
|
460 |
+
|
461 |
+
def generatePolygon(ctr_x, ctr_y, avg_r, irregularity, spikiness, num_verts):
|
462 |
+
'''
|
463 |
+
Start with the center of the polygon at ctr_x, ctr_y,
|
464 |
+
Then creates the polygon by sampling points on a circle around the center.
|
465 |
+
Random noise is added by varying the angular spacing between sequential points,
|
466 |
+
and by varying the radial distance of each point from the centre.
|
467 |
+
|
468 |
+
Params:
|
469 |
+
ctr_x, ctr_y - coordinates of the "centre" of the polygon
|
470 |
+
avg_r - in px, the average radius of this polygon, this roughly controls how large the polygon is, really only useful for order of magnitude.
|
471 |
+
irregularity - [0,1] indicating how much variance there is in the angular spacing of vertices. [0,1] will map to [0, 2pi/numberOfVerts]
|
472 |
+
spikiness - [0,1] indicating how much variance there is in each vertex from the circle of radius avg_r. [0,1] will map to [0, avg_r]
|
473 |
+
num_verts
|
474 |
+
|
475 |
+
Returns:
|
476 |
+
np.array [num_verts, 2] - CCW order.
|
477 |
+
'''
|
478 |
+
# spikiness
|
479 |
+
spikiness = np.clip(spikiness, 0, 1) * avg_r
|
480 |
+
|
481 |
+
# generate n angle steps
|
482 |
+
irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / num_verts
|
483 |
+
lower = (2*np.pi / num_verts) - irregularity
|
484 |
+
upper = (2*np.pi / num_verts) + irregularity
|
485 |
+
|
486 |
+
# angle steps
|
487 |
+
angle_steps = np.random.uniform(lower, upper, num_verts)
|
488 |
+
sc = (2 * np.pi) / angle_steps.sum()
|
489 |
+
angle_steps *= sc
|
490 |
+
|
491 |
+
# get all radii
|
492 |
+
angle = np.random.uniform(0, 2*np.pi)
|
493 |
+
radii = np.clip(np.random.normal(avg_r, spikiness, num_verts), 0, 2 * avg_r)
|
494 |
+
|
495 |
+
# compute all points
|
496 |
+
points = []
|
497 |
+
for i in range(num_verts):
|
498 |
+
x = ctr_x + radii[i] * np.cos(angle)
|
499 |
+
y = ctr_y + radii[i] * np.sin(angle)
|
500 |
+
points.append([x, y])
|
501 |
+
angle += angle_steps[i]
|
502 |
+
|
503 |
+
return np.array(points).astype(int)
|
504 |
+
|
505 |
+
def get_random_affine_2d(B, rot_min=-5.0, rot_max=5.0, tx_min=-0.1, tx_max=0.1, ty_min=-0.1, ty_max=0.1, sx_min=-0.05, sx_max=0.05, sy_min=-0.05, sy_max=0.05, shx_min=-0.05, shx_max=0.05, shy_min=-0.05, shy_max=0.05):
|
506 |
+
'''
|
507 |
+
Params:
|
508 |
+
rot_min: rotation amount min
|
509 |
+
rot_max: rotation amount max
|
510 |
+
|
511 |
+
tx_min: translation x min
|
512 |
+
tx_max: translation x max
|
513 |
+
|
514 |
+
ty_min: translation y min
|
515 |
+
ty_max: translation y max
|
516 |
+
|
517 |
+
sx_min: scaling x min
|
518 |
+
sx_max: scaling x max
|
519 |
+
|
520 |
+
sy_min: scaling y min
|
521 |
+
sy_max: scaling y max
|
522 |
+
|
523 |
+
shx_min: shear x min
|
524 |
+
shx_max: shear x max
|
525 |
+
|
526 |
+
shy_min: shear y min
|
527 |
+
shy_max: shear y max
|
528 |
+
|
529 |
+
Returns:
|
530 |
+
transformation matrix: (B, 3, 3)
|
531 |
+
'''
|
532 |
+
# rotation
|
533 |
+
if rot_max - rot_min != 0:
|
534 |
+
rot_amount = np.random.uniform(low=rot_min, high=rot_max, size=B)
|
535 |
+
rot_amount = np.pi/180.0*rot_amount
|
536 |
+
else:
|
537 |
+
rot_amount = rot_min
|
538 |
+
rotation = np.zeros((B, 3, 3)) # B, 3, 3
|
539 |
+
rotation[:, 2, 2] = 1
|
540 |
+
rotation[:, 0, 0] = np.cos(rot_amount)
|
541 |
+
rotation[:, 0, 1] = -np.sin(rot_amount)
|
542 |
+
rotation[:, 1, 0] = np.sin(rot_amount)
|
543 |
+
rotation[:, 1, 1] = np.cos(rot_amount)
|
544 |
+
|
545 |
+
# translation
|
546 |
+
translation = np.zeros((B, 3, 3)) # B, 3, 3
|
547 |
+
translation[:, [0,1,2], [0,1,2]] = 1
|
548 |
+
if tx_max - tx_min != 0:
|
549 |
+
trans_x = np.random.uniform(low=tx_min, high=tx_max, size=B)
|
550 |
+
translation[:, 0, 2] = trans_x
|
551 |
+
else:
|
552 |
+
translation[:, 0, 2] = tx_max
|
553 |
+
if ty_max - ty_min != 0:
|
554 |
+
trans_y = np.random.uniform(low=ty_min, high=ty_max, size=B)
|
555 |
+
translation[:, 1, 2] = trans_y
|
556 |
+
else:
|
557 |
+
translation[:, 1, 2] = ty_max
|
558 |
+
|
559 |
+
# scaling
|
560 |
+
scaling = np.zeros((B, 3, 3)) # B, 3, 3
|
561 |
+
scaling[:, [0,1,2], [0,1,2]] = 1
|
562 |
+
if sx_max - sx_min != 0:
|
563 |
+
scale_x = 1 + np.random.uniform(low=sx_min, high=sx_max, size=B)
|
564 |
+
scaling[:, 0, 0] = scale_x
|
565 |
+
else:
|
566 |
+
scaling[:, 0, 0] = sx_max
|
567 |
+
if sy_max - sy_min != 0:
|
568 |
+
scale_y = 1 + np.random.uniform(low=sy_min, high=sy_max, size=B)
|
569 |
+
scaling[:, 1, 1] = scale_y
|
570 |
+
else:
|
571 |
+
scaling[:, 1, 1] = sy_max
|
572 |
+
|
573 |
+
# shear
|
574 |
+
shear = np.zeros((B, 3, 3)) # B, 3, 3
|
575 |
+
shear[:, [0,1,2], [0,1,2]] = 1
|
576 |
+
if shx_max - shx_min != 0:
|
577 |
+
shear_x = np.random.uniform(low=shx_min, high=shx_max, size=B)
|
578 |
+
shear[:, 0, 1] = shear_x
|
579 |
+
else:
|
580 |
+
shear[:, 0, 1] = shx_max
|
581 |
+
if shy_max - shy_min != 0:
|
582 |
+
shear_y = np.random.uniform(low=shy_min, high=shy_max, size=B)
|
583 |
+
shear[:, 1, 0] = shear_y
|
584 |
+
else:
|
585 |
+
shear[:, 1, 0] = shy_max
|
586 |
+
|
587 |
+
# compose all those
|
588 |
+
rt = np.einsum("ijk,ikl->ijl", rotation, translation)
|
589 |
+
ss = np.einsum("ijk,ikl->ijl", scaling, shear)
|
590 |
+
trans = np.einsum("ijk,ikl->ijl", rt, ss)
|
591 |
+
|
592 |
+
return trans
|
593 |
+
|
594 |
+
def pixels2camera(x,y,z,fx,fy,x0,y0):
|
595 |
+
# x and y are locations in pixel coordinates, z is a depth in meters
|
596 |
+
# they can be images or pointclouds
|
597 |
+
# fx, fy, x0, y0 are camera intrinsics
|
598 |
+
# returns xyz, sized B x N x 3
|
599 |
+
|
600 |
+
B = x.shape[0]
|
601 |
+
|
602 |
+
fx = torch.reshape(fx, [B,1])
|
603 |
+
fy = torch.reshape(fy, [B,1])
|
604 |
+
x0 = torch.reshape(x0, [B,1])
|
605 |
+
y0 = torch.reshape(y0, [B,1])
|
606 |
+
|
607 |
+
x = torch.reshape(x, [B,-1])
|
608 |
+
y = torch.reshape(y, [B,-1])
|
609 |
+
z = torch.reshape(z, [B,-1])
|
610 |
+
|
611 |
+
# unproject
|
612 |
+
x = (z/fx)*(x-x0)
|
613 |
+
y = (z/fy)*(y-y0)
|
614 |
+
|
615 |
+
xyz = torch.stack([x,y,z], dim=2)
|
616 |
+
# B x N x 3
|
617 |
+
return xyz
|
618 |
+
|
619 |
+
def camera2pixels(xyz, pix_T_cam):
|
620 |
+
# xyz is shaped B x H*W x 3
|
621 |
+
# returns xy, shaped B x H*W x 2
|
622 |
+
|
623 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
624 |
+
x, y, z = torch.unbind(xyz, dim=-1)
|
625 |
+
B = list(z.shape)[0]
|
626 |
+
|
627 |
+
fx = torch.reshape(fx, [B,1])
|
628 |
+
fy = torch.reshape(fy, [B,1])
|
629 |
+
x0 = torch.reshape(x0, [B,1])
|
630 |
+
y0 = torch.reshape(y0, [B,1])
|
631 |
+
x = torch.reshape(x, [B,-1])
|
632 |
+
y = torch.reshape(y, [B,-1])
|
633 |
+
z = torch.reshape(z, [B,-1])
|
634 |
+
|
635 |
+
EPS = 1e-4
|
636 |
+
z = torch.clamp(z, min=EPS)
|
637 |
+
x = (x*fx)/z + x0
|
638 |
+
y = (y*fy)/z + y0
|
639 |
+
xy = torch.stack([x, y], dim=-1)
|
640 |
+
return xy
|
641 |
+
|
642 |
+
def depth2pointcloud(z, pix_T_cam):
|
643 |
+
B, C, H, W = list(z.shape)
|
644 |
+
device = z.device
|
645 |
+
y, x = utils.basic.meshgrid2d(B, H, W, device=device)
|
646 |
+
z = torch.reshape(z, [B, H, W])
|
647 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
648 |
+
xyz = pixels2camera(x, y, z, fx, fy, x0, y0)
|
649 |
+
return xyz
|
650 |
+
|
651 |
+
|
652 |
+
|
653 |
+
def create_depth_image_single(xy, z, H, W, force_positive=True, max_val=100.0, serial=False, slices=20):
|
654 |
+
# turn the xy coordinates into image inds
|
655 |
+
xy = torch.round(xy).long()
|
656 |
+
depth = torch.zeros(H*W, dtype=torch.float32, device=xy.device)
|
657 |
+
depth[:] = max_val
|
658 |
+
|
659 |
+
# lidar reports a sphere of measurements
|
660 |
+
# only use the inds that are within the image bounds
|
661 |
+
# also, only use forward-pointing depths (z > 0)
|
662 |
+
valid_inds = (xy[:,0] <= W-1) & (xy[:,1] <= H-1) & (xy[:,0] >= 0) & (xy[:,1] >= 0) & (z[:] > 0)
|
663 |
+
|
664 |
+
# gather these up
|
665 |
+
xy = xy[valid_inds]
|
666 |
+
z = z[valid_inds]
|
667 |
+
|
668 |
+
inds = utils.basic.sub2ind(H, W, xy[:,1], xy[:,0]).long()
|
669 |
+
if not serial:
|
670 |
+
depth[inds] = z
|
671 |
+
else:
|
672 |
+
if False:
|
673 |
+
for (index, replacement) in zip(inds, z):
|
674 |
+
if depth[index] > replacement:
|
675 |
+
depth[index] = replacement
|
676 |
+
# ok my other idea is:
|
677 |
+
# sort the depths by distance
|
678 |
+
# create N depth maps
|
679 |
+
# merge them back-to-front
|
680 |
+
|
681 |
+
# actually maybe you don't even need the separate maps
|
682 |
+
|
683 |
+
sort_inds = torch.argsort(z, descending=True)
|
684 |
+
xy = xy[sort_inds]
|
685 |
+
z = z[sort_inds]
|
686 |
+
N = len(sort_inds)
|
687 |
+
def split(a, n):
|
688 |
+
k, m = divmod(len(a), n)
|
689 |
+
return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
|
690 |
+
|
691 |
+
slice_inds = split(range(N), slices)
|
692 |
+
for si in slice_inds:
|
693 |
+
mini_z = z[si]
|
694 |
+
mini_xy = xy[si]
|
695 |
+
inds = utils.basic.sub2ind(H, W, mini_xy[:,1], mini_xy[:,0]).long()
|
696 |
+
depth[inds] = mini_z
|
697 |
+
# cool; this is rougly as fast as the parallel, and as accurate as the serial
|
698 |
+
|
699 |
+
if False:
|
700 |
+
print('inds', inds.shape)
|
701 |
+
unique, inverse, counts = torch.unique(inds, return_inverse=True, return_counts=True)
|
702 |
+
print('unique', unique.shape)
|
703 |
+
|
704 |
+
perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
|
705 |
+
inverse, perm = inverse.flip([0]), perm.flip([0])
|
706 |
+
perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm)
|
707 |
+
|
708 |
+
# new_inds = inds[inverse_inds]
|
709 |
+
# depth[new_inds] = z[unique_inds]
|
710 |
+
|
711 |
+
depth[unique] = z[perm]
|
712 |
+
|
713 |
+
# now for the duplicates...
|
714 |
+
|
715 |
+
dup = counts > 1
|
716 |
+
dup_unique = unique[dup]
|
717 |
+
print('dup_unique', dup_unique.shape)
|
718 |
+
depth[dup_unique] = 0.5
|
719 |
+
|
720 |
+
if force_positive:
|
721 |
+
# valid = (depth > 0.0).float()
|
722 |
+
depth[torch.where(depth == 0.0)] = max_val
|
723 |
+
# else:
|
724 |
+
# valid = torch.ones_like(depth)
|
725 |
+
|
726 |
+
valid = (depth > 0.0).float() * (depth < max_val).float()
|
727 |
+
|
728 |
+
depth = torch.reshape(depth, [1, H, W])
|
729 |
+
valid = torch.reshape(valid, [1, H, W])
|
730 |
+
return depth, valid
|
731 |
+
|
732 |
+
def create_depth_image(pix_T_cam, xyz_cam, H, W, offset_amount=0, max_val=100.0, serial=False, slices=20):
|
733 |
+
B, N, D = list(xyz_cam.shape)
|
734 |
+
assert(D==3)
|
735 |
+
B2, E, F = list(pix_T_cam.shape)
|
736 |
+
assert(B==B2)
|
737 |
+
assert(E==4)
|
738 |
+
assert(F==4)
|
739 |
+
xy = apply_pix_T_cam(pix_T_cam, xyz_cam)
|
740 |
+
z = xyz_cam[:,:,2]
|
741 |
+
|
742 |
+
depth = torch.zeros(B, 1, H, W, dtype=torch.float32, device=xyz_cam.device)
|
743 |
+
valid = torch.zeros(B, 1, H, W, dtype=torch.float32, device=xyz_cam.device)
|
744 |
+
for b in list(range(B)):
|
745 |
+
xy_b, z_b = xy[b], z[b]
|
746 |
+
ind = z_b > 0
|
747 |
+
xy_b = xy_b[ind]
|
748 |
+
z_b = z_b[ind]
|
749 |
+
depth_b, valid_b = create_depth_image_single(xy_b, z_b, H, W, max_val=max_val, serial=serial, slices=slices)
|
750 |
+
if offset_amount:
|
751 |
+
depth_b = depth_b.reshape(-1)
|
752 |
+
valid_b = valid_b.reshape(-1)
|
753 |
+
|
754 |
+
for off_x in range(offset_amount):
|
755 |
+
for off_y in range(offset_amount):
|
756 |
+
for sign in [-1,1]:
|
757 |
+
offset = np.array([sign*off_x,sign*off_y]).astype(np.float32)
|
758 |
+
offset = torch.from_numpy(offset).reshape(1, 2).to(xyz_cam.device)
|
759 |
+
# offsets.append(offset)
|
760 |
+
depth_, valid_ = create_depth_image_single(xy_b + offset, z_b, H, W, max_val=max_val)
|
761 |
+
depth_ = depth_.reshape(-1)
|
762 |
+
valid_ = valid_.reshape(-1)
|
763 |
+
# at invalid locations, use this new value
|
764 |
+
depth_b[valid_b==0] = depth_[valid_b==0]
|
765 |
+
valid_b[valid_b==0] = valid_[valid_b==0]
|
766 |
+
|
767 |
+
depth_b = depth_b.reshape(1, H, W)
|
768 |
+
valid_b = valid_b.reshape(1, H, W)
|
769 |
+
depth[b] = depth_b
|
770 |
+
valid[b] = valid_b
|
771 |
+
return depth, valid
|
utils/improc.py
ADDED
@@ -0,0 +1,1645 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import utils.basic
|
4 |
+
import utils.py
|
5 |
+
from sklearn.decomposition import PCA
|
6 |
+
from matplotlib import cm
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import cv2
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torchvision
|
11 |
+
EPS = 1e-6
|
12 |
+
|
13 |
+
from skimage.color import (
|
14 |
+
rgb2lab, rgb2yuv, rgb2ycbcr, lab2rgb, yuv2rgb, ycbcr2rgb,
|
15 |
+
rgb2hsv, hsv2rgb, rgb2xyz, xyz2rgb, rgb2hed, hed2rgb)
|
16 |
+
|
17 |
+
def _convert(input_, type_):
|
18 |
+
return {
|
19 |
+
'float': input_.float(),
|
20 |
+
'double': input_.double(),
|
21 |
+
}.get(type_, input_)
|
22 |
+
|
23 |
+
def _generic_transform_sk_3d(transform, in_type='', out_type=''):
|
24 |
+
def apply_transform_individual(input_):
|
25 |
+
device = input_.device
|
26 |
+
input_ = input_.cpu()
|
27 |
+
input_ = _convert(input_, in_type)
|
28 |
+
|
29 |
+
input_ = input_.permute(1, 2, 0).detach().numpy()
|
30 |
+
transformed = transform(input_)
|
31 |
+
output = torch.from_numpy(transformed).float().permute(2, 0, 1)
|
32 |
+
output = _convert(output, out_type)
|
33 |
+
return output.to(device)
|
34 |
+
|
35 |
+
def apply_transform(input_):
|
36 |
+
to_stack = []
|
37 |
+
for image in input_:
|
38 |
+
to_stack.append(apply_transform_individual(image))
|
39 |
+
return torch.stack(to_stack)
|
40 |
+
return apply_transform
|
41 |
+
|
42 |
+
hsv_to_rgb = _generic_transform_sk_3d(hsv2rgb)
|
43 |
+
|
44 |
+
def preprocess_color_tf(x):
|
45 |
+
import tensorflow as tf
|
46 |
+
return tf.cast(x,tf.float32) * 1./255 - 0.5
|
47 |
+
|
48 |
+
def preprocess_color(x):
|
49 |
+
if isinstance(x, np.ndarray):
|
50 |
+
return x.astype(np.float32) * 1./255 - 0.5
|
51 |
+
else:
|
52 |
+
return x.float() * 1./255 - 0.5
|
53 |
+
|
54 |
+
def pca_embed(emb, keep, valid=None):
|
55 |
+
## emb -- [S,H/2,W/2,C]
|
56 |
+
## keep is the number of principal components to keep
|
57 |
+
## Helper function for reduce_emb.
|
58 |
+
emb = emb + EPS
|
59 |
+
#emb is B x C x H x W
|
60 |
+
emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C
|
61 |
+
|
62 |
+
if valid:
|
63 |
+
valid = valid.cpu().detach().numpy().reshape((H*W))
|
64 |
+
|
65 |
+
emb_reduced = list()
|
66 |
+
|
67 |
+
B, H, W, C = np.shape(emb)
|
68 |
+
for img in emb:
|
69 |
+
if np.isnan(img).any():
|
70 |
+
emb_reduced.append(np.zeros([H, W, keep]))
|
71 |
+
continue
|
72 |
+
|
73 |
+
pixels_kd = np.reshape(img, (H*W, C))
|
74 |
+
|
75 |
+
if valid:
|
76 |
+
pixels_kd_pca = pixels_kd[valid]
|
77 |
+
else:
|
78 |
+
pixels_kd_pca = pixels_kd
|
79 |
+
|
80 |
+
P = PCA(keep)
|
81 |
+
P.fit(pixels_kd_pca)
|
82 |
+
|
83 |
+
if valid:
|
84 |
+
pixels3d = P.transform(pixels_kd)*valid
|
85 |
+
else:
|
86 |
+
pixels3d = P.transform(pixels_kd)
|
87 |
+
|
88 |
+
out_img = np.reshape(pixels3d, [H,W,keep]).astype(np.float32)
|
89 |
+
if np.isnan(out_img).any():
|
90 |
+
emb_reduced.append(np.zeros([H, W, keep]))
|
91 |
+
continue
|
92 |
+
|
93 |
+
emb_reduced.append(out_img)
|
94 |
+
|
95 |
+
emb_reduced = np.stack(emb_reduced, axis=0).astype(np.float32)
|
96 |
+
|
97 |
+
return torch.from_numpy(emb_reduced).permute(0, 3, 1, 2)
|
98 |
+
|
99 |
+
def pca_embed_together(emb, keep):
|
100 |
+
## emb -- [S,H/2,W/2,C]
|
101 |
+
## keep is the number of principal components to keep
|
102 |
+
## Helper function for reduce_emb.
|
103 |
+
emb = emb + EPS
|
104 |
+
#emb is B x C x H x W
|
105 |
+
emb = emb.permute(0, 2, 3, 1).cpu().detach().float().numpy() #this is B x H x W x C
|
106 |
+
|
107 |
+
B, H, W, C = np.shape(emb)
|
108 |
+
if np.isnan(emb).any():
|
109 |
+
return torch.zeros(B, keep, H, W)
|
110 |
+
|
111 |
+
pixelskd = np.reshape(emb, (B*H*W, C))
|
112 |
+
P = PCA(keep)
|
113 |
+
P.fit(pixelskd)
|
114 |
+
pixels3d = P.transform(pixelskd)
|
115 |
+
out_img = np.reshape(pixels3d, [B,H,W,keep]).astype(np.float32)
|
116 |
+
|
117 |
+
if np.isnan(out_img).any():
|
118 |
+
return torch.zeros(B, keep, H, W)
|
119 |
+
|
120 |
+
return torch.from_numpy(out_img).permute(0, 3, 1, 2)
|
121 |
+
|
122 |
+
def reduce_emb(emb, valid=None, inbound=None, together=False):
|
123 |
+
## emb -- [S,C,H/2,W/2], inbound -- [S,1,H/2,W/2]
|
124 |
+
## Reduce number of chans to 3 with PCA. For vis.
|
125 |
+
# S,H,W,C = emb.shape.as_list()
|
126 |
+
S, C, H, W = list(emb.size())
|
127 |
+
keep = 4
|
128 |
+
|
129 |
+
if together:
|
130 |
+
reduced_emb = pca_embed_together(emb, keep)
|
131 |
+
else:
|
132 |
+
reduced_emb = pca_embed(emb, keep, valid) #not im
|
133 |
+
|
134 |
+
reduced_emb = reduced_emb[:,1:]
|
135 |
+
reduced_emb = utils.basic.normalize(reduced_emb) - 0.5
|
136 |
+
if inbound is not None:
|
137 |
+
emb_inbound = emb*inbound
|
138 |
+
else:
|
139 |
+
emb_inbound = None
|
140 |
+
|
141 |
+
return reduced_emb, emb_inbound
|
142 |
+
|
143 |
+
def get_feat_pca(feat, valid=None):
|
144 |
+
B, C, D, W = list(feat.size())
|
145 |
+
# feat is B x C x D x W. If 3D input, average it through Height dimension before passing into this function.
|
146 |
+
|
147 |
+
pca, _ = reduce_emb(feat, valid=valid,inbound=None, together=True)
|
148 |
+
# pca is B x 3 x W x D
|
149 |
+
return pca
|
150 |
+
|
151 |
+
def gif_and_tile(ims, just_gif=False):
|
152 |
+
S = len(ims)
|
153 |
+
# each im is B x H x W x C
|
154 |
+
# i want a gif in the left, and the tiled frames on the right
|
155 |
+
# for the gif tool, this means making a B x S x H x W tensor
|
156 |
+
# where the leftmost part is sequential and the rest is tiled
|
157 |
+
gif = torch.stack(ims, dim=1)
|
158 |
+
if just_gif:
|
159 |
+
return gif
|
160 |
+
til = torch.cat(ims, dim=2)
|
161 |
+
til = til.unsqueeze(dim=1).repeat(1, S, 1, 1, 1)
|
162 |
+
im = torch.cat([gif, til], dim=3)
|
163 |
+
return im
|
164 |
+
|
165 |
+
def back2color(i, blacken_zeros=False):
|
166 |
+
if blacken_zeros:
|
167 |
+
const = torch.tensor([-0.5])
|
168 |
+
i = torch.where(i==0.0, const.cuda() if i.is_cuda else const, i)
|
169 |
+
return back2color(i)
|
170 |
+
else:
|
171 |
+
return ((i+0.5)*255).type(torch.ByteTensor)
|
172 |
+
|
173 |
+
def xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=False):
|
174 |
+
# xy is B x N x 2, containing float x and y coordinates of N things
|
175 |
+
# grid_xs and grid_ys are B x N x Y x X
|
176 |
+
|
177 |
+
B, N, Y, X = list(grid_xs.shape)
|
178 |
+
|
179 |
+
mu_x = xy[:,:,0].clone()
|
180 |
+
mu_y = xy[:,:,1].clone()
|
181 |
+
|
182 |
+
x_valid = (mu_x>-0.5) & (mu_x<float(X+0.5))
|
183 |
+
y_valid = (mu_y>-0.5) & (mu_y<float(Y+0.5))
|
184 |
+
not_valid = ~(x_valid & y_valid)
|
185 |
+
|
186 |
+
mu_x[not_valid] = -10000
|
187 |
+
mu_y[not_valid] = -10000
|
188 |
+
|
189 |
+
mu_x = mu_x.reshape(B, N, 1, 1).repeat(1, 1, Y, X)
|
190 |
+
mu_y = mu_y.reshape(B, N, 1, 1).repeat(1, 1, Y, X)
|
191 |
+
|
192 |
+
sigma_sq = sigma*sigma
|
193 |
+
# sigma_sq = (sigma*sigma).reshape(B, N, 1, 1)
|
194 |
+
sq_diff_x = (grid_xs - mu_x)**2
|
195 |
+
sq_diff_y = (grid_ys - mu_y)**2
|
196 |
+
|
197 |
+
term1 = 1./2.*np.pi*sigma_sq
|
198 |
+
term2 = torch.exp(-(sq_diff_x+sq_diff_y)/(2.*sigma_sq))
|
199 |
+
gauss = term1*term2
|
200 |
+
|
201 |
+
if norm:
|
202 |
+
# normalize so each gaussian peaks at 1
|
203 |
+
gauss_ = gauss.reshape(B*N, Y, X)
|
204 |
+
gauss_ = utils.basic.normalize(gauss_)
|
205 |
+
gauss = gauss_.reshape(B, N, Y, X)
|
206 |
+
|
207 |
+
return gauss
|
208 |
+
|
209 |
+
def xy2heatmaps(xy, Y, X, sigma=30.0, norm=True):
|
210 |
+
# xy is B x N x 2
|
211 |
+
|
212 |
+
B, N, D = list(xy.shape)
|
213 |
+
assert(D==2)
|
214 |
+
|
215 |
+
device = xy.device
|
216 |
+
|
217 |
+
grid_y, grid_x = utils.basic.meshgrid2d(B, Y, X, device=device)
|
218 |
+
# grid_x and grid_y are B x Y x X
|
219 |
+
grid_xs = grid_x.unsqueeze(1).repeat(1, N, 1, 1)
|
220 |
+
grid_ys = grid_y.unsqueeze(1).repeat(1, N, 1, 1)
|
221 |
+
heat = xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=norm)
|
222 |
+
return heat
|
223 |
+
|
224 |
+
def draw_circles_at_xy(xy, Y, X, sigma=12.5, round=False):
|
225 |
+
B, N, D = list(xy.shape)
|
226 |
+
assert(D==2)
|
227 |
+
prior = xy2heatmaps(xy, Y, X, sigma=sigma)
|
228 |
+
# prior is B x N x Y x X
|
229 |
+
if round:
|
230 |
+
prior = (prior > 0.5).float()
|
231 |
+
return prior
|
232 |
+
|
233 |
+
def seq2color(im, norm=True, colormap='coolwarm'):
|
234 |
+
B, S, H, W = list(im.shape)
|
235 |
+
# S is sequential
|
236 |
+
|
237 |
+
# prep a mask of the valid pixels, so we can blacken the invalids later
|
238 |
+
mask = torch.max(im, dim=1, keepdim=True)[0]
|
239 |
+
|
240 |
+
# turn the S dim into an explicit sequence
|
241 |
+
coeffs = np.linspace(1.0, float(S), S).astype(np.float32)/float(S)
|
242 |
+
|
243 |
+
# # increase the spacing from the center
|
244 |
+
# coeffs[:int(S/2)] -= 2.0
|
245 |
+
# coeffs[int(S/2)+1:] += 2.0
|
246 |
+
|
247 |
+
coeffs = torch.from_numpy(coeffs).float().cuda()
|
248 |
+
coeffs = coeffs.reshape(1, S, 1, 1).repeat(B, 1, H, W)
|
249 |
+
# scale each channel by the right coeff
|
250 |
+
im = im * coeffs
|
251 |
+
# now im is in [1/S, 1], except for the invalid parts which are 0
|
252 |
+
# keep the highest valid coeff at each pixel
|
253 |
+
im = torch.max(im, dim=1, keepdim=True)[0]
|
254 |
+
|
255 |
+
out = []
|
256 |
+
for b in range(B):
|
257 |
+
im_ = im[b]
|
258 |
+
# move channels out to last dim_
|
259 |
+
im_ = im_.detach().cpu().numpy()
|
260 |
+
im_ = np.squeeze(im_)
|
261 |
+
# im_ is H x W
|
262 |
+
if colormap=='coolwarm':
|
263 |
+
im_ = cm.coolwarm(im_)[:, :, :3]
|
264 |
+
elif colormap=='PiYG':
|
265 |
+
im_ = cm.PiYG(im_)[:, :, :3]
|
266 |
+
elif colormap=='winter':
|
267 |
+
im_ = cm.winter(im_)[:, :, :3]
|
268 |
+
elif colormap=='spring':
|
269 |
+
im_ = cm.spring(im_)[:, :, :3]
|
270 |
+
elif colormap=='onediff':
|
271 |
+
im_ = np.reshape(im_, (-1))
|
272 |
+
im0_ = cm.spring(im_)[:, :3]
|
273 |
+
im1_ = cm.winter(im_)[:, :3]
|
274 |
+
im1_[im_==1/float(S)] = im0_[im_==1/float(S)]
|
275 |
+
im_ = np.reshape(im1_, (H, W, 3))
|
276 |
+
else:
|
277 |
+
assert(False) # invalid colormap
|
278 |
+
# move channels into dim 0
|
279 |
+
im_ = np.transpose(im_, [2, 0, 1])
|
280 |
+
im_ = torch.from_numpy(im_).float().cuda()
|
281 |
+
out.append(im_)
|
282 |
+
out = torch.stack(out, dim=0)
|
283 |
+
|
284 |
+
# blacken the invalid pixels, instead of using the 0-color
|
285 |
+
out = out*mask
|
286 |
+
# out = out*255.0
|
287 |
+
|
288 |
+
# put it in [-0.5, 0.5]
|
289 |
+
out = out - 0.5
|
290 |
+
|
291 |
+
return out
|
292 |
+
|
293 |
+
def colorize(d):
|
294 |
+
# this is actually just grayscale right now
|
295 |
+
|
296 |
+
if d.ndim==2:
|
297 |
+
d = d.unsqueeze(dim=0)
|
298 |
+
else:
|
299 |
+
assert(d.ndim==3)
|
300 |
+
|
301 |
+
# color_map = cm.get_cmap('plasma')
|
302 |
+
color_map = cm.get_cmap('inferno')
|
303 |
+
# S1, D = traj.shape
|
304 |
+
|
305 |
+
# print('d1', d.shape)
|
306 |
+
C,H,W = d.shape
|
307 |
+
assert(C==1)
|
308 |
+
d = d.reshape(-1)
|
309 |
+
d = d.detach().cpu().numpy()
|
310 |
+
# print('d2', d.shape)
|
311 |
+
color = np.array(color_map(d)) * 255 # rgba
|
312 |
+
# print('color1', color.shape)
|
313 |
+
color = np.reshape(color[:,:3], [H*W, 3])
|
314 |
+
# print('color2', color.shape)
|
315 |
+
color = torch.from_numpy(color).permute(1,0).reshape(3,H,W)
|
316 |
+
# # gather
|
317 |
+
# cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray')
|
318 |
+
# if cmap=='RdBu' or cmap=='RdYlGn':
|
319 |
+
# colors = cm(np.arange(256))[:, :3]
|
320 |
+
# else:
|
321 |
+
# colors = cm.colors
|
322 |
+
# colors = np.array(colors).astype(np.float32)
|
323 |
+
# colors = np.reshape(colors, [-1, 3])
|
324 |
+
# colors = tf.constant(colors, dtype=tf.float32)
|
325 |
+
|
326 |
+
# value = tf.gather(colors, indices)
|
327 |
+
# colorize(value, normalize=True, vmin=None, vmax=None, cmap=None, vals=255)
|
328 |
+
|
329 |
+
# copy to the three chans
|
330 |
+
# d = d.repeat(3, 1, 1)
|
331 |
+
return color
|
332 |
+
|
333 |
+
|
334 |
+
def oned2inferno(d, norm=True, do_colorize=False):
|
335 |
+
# convert a 1chan input to a 3chan image output
|
336 |
+
|
337 |
+
# if it's just B x H x W, add a C dim
|
338 |
+
if d.ndim==3:
|
339 |
+
d = d.unsqueeze(dim=1)
|
340 |
+
# d should be B x C x H x W, where C=1
|
341 |
+
B, C, H, W = list(d.shape)
|
342 |
+
assert(C==1)
|
343 |
+
|
344 |
+
if norm:
|
345 |
+
d = utils.basic.normalize(d)
|
346 |
+
|
347 |
+
if do_colorize:
|
348 |
+
rgb = torch.zeros(B, 3, H, W)
|
349 |
+
for b in list(range(B)):
|
350 |
+
rgb[b] = colorize(d[b])
|
351 |
+
else:
|
352 |
+
rgb = d.repeat(1, 3, 1, 1)*255.0
|
353 |
+
# rgb = (255.0*rgb).type(torch.ByteTensor)
|
354 |
+
rgb = rgb.type(torch.ByteTensor)
|
355 |
+
|
356 |
+
# rgb = tf.cast(255.0*rgb, tf.uint8)
|
357 |
+
# rgb = tf.reshape(rgb, [-1, hyp.H, hyp.W, 3])
|
358 |
+
# rgb = tf.expand_dims(rgb, axis=0)
|
359 |
+
return rgb
|
360 |
+
|
361 |
+
def oned2gray(d, norm=True):
|
362 |
+
# convert a 1chan input to a 3chan image output
|
363 |
+
|
364 |
+
# if it's just B x H x W, add a C dim
|
365 |
+
if d.ndim==3:
|
366 |
+
d = d.unsqueeze(dim=1)
|
367 |
+
# d should be B x C x H x W, where C=1
|
368 |
+
B, C, H, W = list(d.shape)
|
369 |
+
assert(C==1)
|
370 |
+
|
371 |
+
if norm:
|
372 |
+
d = utils.basic.normalize(d)
|
373 |
+
|
374 |
+
rgb = d.repeat(1,3,1,1)
|
375 |
+
rgb = (255.0*rgb).type(torch.ByteTensor)
|
376 |
+
return rgb
|
377 |
+
|
378 |
+
|
379 |
+
def draw_frame_id_on_vis(vis, frame_id, scale=0.5, left=5, top=20, shadow=True):
|
380 |
+
|
381 |
+
rgb = vis.detach().cpu().numpy()[0]
|
382 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
383 |
+
rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
|
384 |
+
color = (255, 255, 255)
|
385 |
+
# print('putting frame id', frame_id)
|
386 |
+
|
387 |
+
frame_str = utils.basic.strnum(frame_id)
|
388 |
+
|
389 |
+
text_color_bg = (0,0,0)
|
390 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
391 |
+
text_size, _ = cv2.getTextSize(frame_str, font, scale, 1)
|
392 |
+
text_w, text_h = text_size
|
393 |
+
if shadow:
|
394 |
+
cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1)
|
395 |
+
|
396 |
+
cv2.putText(
|
397 |
+
rgb,
|
398 |
+
frame_str,
|
399 |
+
(left, top), # from left, from top
|
400 |
+
font,
|
401 |
+
scale, # font scale (float)
|
402 |
+
color,
|
403 |
+
1) # font thickness (int)
|
404 |
+
rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
405 |
+
vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
406 |
+
return vis
|
407 |
+
|
408 |
+
def draw_frame_str_on_vis(vis, frame_str, scale=0.5, left=5, top=40, shadow=True):
|
409 |
+
|
410 |
+
rgb = vis.detach().cpu().numpy()[0]
|
411 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
412 |
+
rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
|
413 |
+
color = (255, 255, 255)
|
414 |
+
|
415 |
+
text_color_bg = (0,0,0)
|
416 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
417 |
+
text_size, _ = cv2.getTextSize(frame_str, font, scale, 1)
|
418 |
+
text_w, text_h = text_size
|
419 |
+
if shadow:
|
420 |
+
cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1)
|
421 |
+
|
422 |
+
cv2.putText(
|
423 |
+
rgb,
|
424 |
+
frame_str,
|
425 |
+
(left, top), # from left, from top
|
426 |
+
font,
|
427 |
+
scale, # font scale (float)
|
428 |
+
color,
|
429 |
+
1) # font thickness (int)
|
430 |
+
rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
431 |
+
vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
432 |
+
return vis
|
433 |
+
|
434 |
+
COLORMAP_FILE = "./utils/bremm.png"
|
435 |
+
class ColorMap2d:
|
436 |
+
def __init__(self, filename=None):
|
437 |
+
self._colormap_file = filename or COLORMAP_FILE
|
438 |
+
self._img = plt.imread(self._colormap_file)
|
439 |
+
|
440 |
+
self._height = self._img.shape[0]
|
441 |
+
self._width = self._img.shape[1]
|
442 |
+
|
443 |
+
def __call__(self, X):
|
444 |
+
assert len(X.shape) == 2
|
445 |
+
output = np.zeros((X.shape[0], 3))
|
446 |
+
for i in range(X.shape[0]):
|
447 |
+
x, y = X[i, :]
|
448 |
+
xp = int((self._width-1) * x)
|
449 |
+
yp = int((self._height-1) * y)
|
450 |
+
xp = np.clip(xp, 0, self._width-1)
|
451 |
+
yp = np.clip(yp, 0, self._height-1)
|
452 |
+
output[i, :] = self._img[yp, xp]
|
453 |
+
return output
|
454 |
+
|
455 |
+
def get_n_colors(N, sequential=False):
|
456 |
+
label_colors = []
|
457 |
+
for ii in range(N):
|
458 |
+
if sequential:
|
459 |
+
rgb = cm.winter(ii/(N-1))
|
460 |
+
rgb = (np.array(rgb) * 255).astype(np.uint8)[:3]
|
461 |
+
else:
|
462 |
+
# rgb = np.zeros(3)
|
463 |
+
# while np.sum(rgb) < 128: # ensure min brightness
|
464 |
+
# rgb = np.random.randint(0,256,3)
|
465 |
+
rgb = cm.gist_rainbow(ii/(N-1))
|
466 |
+
rgb = (np.array(rgb) * 255).astype(np.uint8)[:3]
|
467 |
+
|
468 |
+
label_colors.append(rgb)
|
469 |
+
return label_colors
|
470 |
+
|
471 |
+
class Summ_writer(object):
|
472 |
+
def __init__(self, writer, global_step, log_freq=10, fps=8, scalar_freq=100, just_gif=False):
|
473 |
+
self.writer = writer
|
474 |
+
self.global_step = global_step
|
475 |
+
self.log_freq = log_freq
|
476 |
+
self.scalar_freq = scalar_freq
|
477 |
+
self.fps = fps
|
478 |
+
self.just_gif = just_gif
|
479 |
+
self.maxwidth = 10000
|
480 |
+
self.save_this = (self.global_step % self.log_freq == 0)
|
481 |
+
self.scalar_freq = max(scalar_freq,1)
|
482 |
+
self.save_scalar = (self.global_step % self.scalar_freq == 0)
|
483 |
+
if self.save_this:
|
484 |
+
self.save_scalar = True
|
485 |
+
|
486 |
+
def summ_gif(self, name, tensor, blacken_zeros=False):
|
487 |
+
# tensor should be in B x S x C x H x W
|
488 |
+
|
489 |
+
assert tensor.dtype in {torch.uint8,torch.float32}
|
490 |
+
shape = list(tensor.shape)
|
491 |
+
|
492 |
+
if tensor.dtype == torch.float32:
|
493 |
+
tensor = back2color(tensor, blacken_zeros=blacken_zeros)
|
494 |
+
|
495 |
+
video_to_write = tensor[0:1]
|
496 |
+
|
497 |
+
S = video_to_write.shape[1]
|
498 |
+
if S==1:
|
499 |
+
# video_to_write is 1 x 1 x C x H x W
|
500 |
+
self.writer.add_image(name, video_to_write[0,0], global_step=self.global_step)
|
501 |
+
else:
|
502 |
+
self.writer.add_video(name, video_to_write, fps=self.fps, global_step=self.global_step)
|
503 |
+
|
504 |
+
return video_to_write
|
505 |
+
|
506 |
+
def draw_boxlist2d_on_image(self, rgb, boxlist, scores=None, tids=None, linewidth=1):
|
507 |
+
B, C, H, W = list(rgb.shape)
|
508 |
+
assert(C==3)
|
509 |
+
B2, N, D = list(boxlist.shape)
|
510 |
+
assert(B2==B)
|
511 |
+
assert(D==4) # ymin, xmin, ymax, xmax
|
512 |
+
|
513 |
+
rgb = back2color(rgb)
|
514 |
+
if scores is None:
|
515 |
+
scores = torch.ones(B2, N).float()
|
516 |
+
if tids is None:
|
517 |
+
# tids = torch.arange(N).reshape(1,N).repeat(B2,1).long()
|
518 |
+
tids = torch.zeros(B2, N).long()
|
519 |
+
out = self.draw_boxlist2d_on_image_py(
|
520 |
+
rgb[0].cpu().detach().numpy(),
|
521 |
+
boxlist[0].cpu().detach().numpy(),
|
522 |
+
scores[0].cpu().detach().numpy(),
|
523 |
+
tids[0].cpu().detach().numpy(),
|
524 |
+
linewidth=linewidth)
|
525 |
+
out = torch.from_numpy(out).type(torch.ByteTensor).permute(2, 0, 1)
|
526 |
+
out = torch.unsqueeze(out, dim=0)
|
527 |
+
out = preprocess_color(out)
|
528 |
+
out = torch.reshape(out, [1, C, H, W])
|
529 |
+
return out
|
530 |
+
|
531 |
+
def draw_boxlist2d_on_image_py(self, rgb, boxlist, scorelist, tidlist, linewidth=1):
|
532 |
+
# all inputs are numpy tensors
|
533 |
+
# rgb is H x W x 3
|
534 |
+
# boxlist is N x 4
|
535 |
+
# scorelist is N
|
536 |
+
# tidlist is N
|
537 |
+
|
538 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
539 |
+
# rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
|
540 |
+
rgb = rgb.astype(np.uint8).copy()
|
541 |
+
|
542 |
+
H, W, C = rgb.shape
|
543 |
+
assert(C==3)
|
544 |
+
N, D = boxlist.shape
|
545 |
+
assert(D==4)
|
546 |
+
M = scorelist.shape[0]
|
547 |
+
assert(M==N)
|
548 |
+
O = tidlist.shape[0]
|
549 |
+
assert(M==O)
|
550 |
+
|
551 |
+
# color_map = cm.get_cmap('Accent')
|
552 |
+
# color_map = cm.get_cmap('Set3')
|
553 |
+
color_map = cm.get_cmap('tab20')
|
554 |
+
color_map = color_map.colors
|
555 |
+
# print('color_map', color_map)
|
556 |
+
|
557 |
+
# draw
|
558 |
+
for (box, score, tid) in zip(boxlist, scorelist, tidlist):
|
559 |
+
# box is 4
|
560 |
+
if not np.isclose(score, 0.0, atol=1e-3):
|
561 |
+
# ymin, xmin, ymax, xmax = box
|
562 |
+
xmin, ymin, xmax, ymax = box
|
563 |
+
|
564 |
+
color = color_map[tid]
|
565 |
+
color = np.array(color)*255.0
|
566 |
+
color = color.round()
|
567 |
+
|
568 |
+
if not np.isclose(score, 1.0, atol=1e-3):
|
569 |
+
cv2.putText(rgb,
|
570 |
+
# '%d (%.2f)' % (tidlist[ind], scorelist[ind]),
|
571 |
+
'%.2f' % (score),
|
572 |
+
(int(xmin), int(ymin)),
|
573 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
574 |
+
0.5, # font size
|
575 |
+
color),
|
576 |
+
|
577 |
+
xmin = int(np.clip(xmin,0,W-1))
|
578 |
+
ymin = int(np.clip(ymin,0,H-1))
|
579 |
+
xmax = int(np.clip(xmax,0,W-1))
|
580 |
+
ymax = int(np.clip(ymax,0,H-1))
|
581 |
+
|
582 |
+
# print('xmin, xmax, ymin, ymax', xmin, xmax, ymin, ymax)
|
583 |
+
|
584 |
+
cv2.line(rgb, (xmin, ymin), (xmin, ymax), color, linewidth, cv2.LINE_4)
|
585 |
+
cv2.line(rgb, (xmin, ymin), (xmax, ymin), color, linewidth, cv2.LINE_4)
|
586 |
+
cv2.line(rgb, (xmax, ymin), (xmax, ymax), color, linewidth, cv2.LINE_4)
|
587 |
+
cv2.line(rgb, (xmax, ymax), (xmin, ymax), color, linewidth, cv2.LINE_4)
|
588 |
+
|
589 |
+
# rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
590 |
+
return rgb
|
591 |
+
|
592 |
+
def summ_boxlist2d(self, name, rgb, boxlist, scores=None, tids=None, frame_id=None, frame_str=None, only_return=False, linewidth=1):
|
593 |
+
B, C, H, W = list(rgb.shape)
|
594 |
+
boxlist_vis = self.draw_boxlist2d_on_image(rgb, boxlist, scores=scores, tids=tids, linewidth=linewidth)
|
595 |
+
return self.summ_rgb(name, boxlist_vis, frame_id=frame_id, frame_str=frame_str, only_return=only_return)
|
596 |
+
|
597 |
+
def summ_rgbs(self, name, ims, frame_ids=None, frame_strs=None, blacken_zeros=False, only_return=False):
|
598 |
+
if self.save_this:
|
599 |
+
|
600 |
+
ims = gif_and_tile(ims, just_gif=self.just_gif)
|
601 |
+
vis = ims
|
602 |
+
|
603 |
+
assert vis.dtype in {torch.uint8,torch.float32}
|
604 |
+
|
605 |
+
if vis.dtype == torch.float32:
|
606 |
+
vis = back2color(vis, blacken_zeros)
|
607 |
+
|
608 |
+
B, S, C, H, W = list(vis.shape)
|
609 |
+
|
610 |
+
if frame_ids is not None:
|
611 |
+
assert(len(frame_ids)==S)
|
612 |
+
for s in range(S):
|
613 |
+
vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
|
614 |
+
|
615 |
+
if frame_strs is not None:
|
616 |
+
assert(len(frame_strs)==S)
|
617 |
+
for s in range(S):
|
618 |
+
vis[:,s] = draw_frame_str_on_vis(vis[:,s], frame_strs[s])
|
619 |
+
|
620 |
+
if int(W) > self.maxwidth:
|
621 |
+
vis = vis[:,:,:,:self.maxwidth]
|
622 |
+
|
623 |
+
if only_return:
|
624 |
+
return vis
|
625 |
+
else:
|
626 |
+
return self.summ_gif(name, vis, blacken_zeros)
|
627 |
+
|
628 |
+
def summ_rgb(self, name, ims, blacken_zeros=False, frame_id=None, frame_str=None, only_return=False, halfres=False, shadow=True):
|
629 |
+
if self.save_this:
|
630 |
+
assert ims.dtype in {torch.uint8,torch.float32}
|
631 |
+
|
632 |
+
if ims.dtype == torch.float32:
|
633 |
+
ims = back2color(ims, blacken_zeros)
|
634 |
+
|
635 |
+
#ims is B x C x H x W
|
636 |
+
vis = ims[0:1] # just the first one
|
637 |
+
B, C, H, W = list(vis.shape)
|
638 |
+
|
639 |
+
if halfres:
|
640 |
+
vis = F.interpolate(vis, scale_factor=0.5)
|
641 |
+
|
642 |
+
if frame_id is not None:
|
643 |
+
vis = draw_frame_id_on_vis(vis, frame_id, shadow=shadow)
|
644 |
+
|
645 |
+
if frame_str is not None:
|
646 |
+
vis = draw_frame_str_on_vis(vis, frame_str, shadow=shadow)
|
647 |
+
|
648 |
+
if int(W) > self.maxwidth:
|
649 |
+
vis = vis[:,:,:,:self.maxwidth]
|
650 |
+
|
651 |
+
if only_return:
|
652 |
+
return vis
|
653 |
+
else:
|
654 |
+
return self.summ_gif(name, vis.unsqueeze(1), blacken_zeros)
|
655 |
+
|
656 |
+
def flow2color(self, flow, clip=0.0):
|
657 |
+
B, C, H, W = list(flow.size())
|
658 |
+
assert(C==2)
|
659 |
+
flow = flow[0:1].detach()
|
660 |
+
|
661 |
+
if False:
|
662 |
+
flow = flow[0].detach().cpu().permute(1,2,0).numpy() # H,W,2
|
663 |
+
if clip > 0:
|
664 |
+
clip_flow = clip
|
665 |
+
else:
|
666 |
+
clip_flow = None
|
667 |
+
im = utils.py.flow_to_image(flow, clip_flow=clip_flow, convert_to_bgr=True)
|
668 |
+
# im = utils.py.flow_to_image(flow, convert_to_bgr=True)
|
669 |
+
im = torch.from_numpy(im).permute(2,0,1).unsqueeze(0).byte() # 1,3,H,W
|
670 |
+
im = torch.flip(im, dims=[1]).clone() # BGR
|
671 |
+
|
672 |
+
# # i prefer black bkg
|
673 |
+
# white_pixels = (im == 255).all(dim=1, keepdim=True)
|
674 |
+
# im[white_pixels.expand(-1, 3, -1, -1)] = 0
|
675 |
+
|
676 |
+
return im
|
677 |
+
|
678 |
+
# flow_abs = torch.abs(flow)
|
679 |
+
# flow_mean = flow_abs.mean(dim=[1,2,3])
|
680 |
+
# flow_std = flow_abs.std(dim=[1,2,3])
|
681 |
+
if clip==0:
|
682 |
+
clip = torch.max(torch.abs(flow)).item()
|
683 |
+
|
684 |
+
# if clip:
|
685 |
+
flow = torch.clamp(flow, -clip, clip)/clip
|
686 |
+
# else:
|
687 |
+
# # # Apply some kind of normalization. Divide by the perceived maximum (mean + std*2)
|
688 |
+
# # flow_max = flow_mean + flow_std*2 + 1e-10
|
689 |
+
# # for b in range(B):
|
690 |
+
# # flow[b] = flow[b].clamp(-flow_max[b].item(), flow_max[b].item()) / flow_max[b].clamp(min=1)
|
691 |
+
|
692 |
+
# flow_max = torch.max(flow_abs[b])
|
693 |
+
# for b in range(B):
|
694 |
+
# flow[b] = flow[b].clamp(-flow_max.item(), flow_max.item()) / flow_max[b].clamp(min=1)
|
695 |
+
|
696 |
+
|
697 |
+
radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) #B x 1 x H x W
|
698 |
+
radius_clipped = torch.clamp(radius, 0.0, 1.0)
|
699 |
+
|
700 |
+
angle = torch.atan2(-flow[:, 1:2], -flow[:, 0:1]) / np.pi # B x 1 x H x W
|
701 |
+
|
702 |
+
hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0)
|
703 |
+
# hue = torch.mod(angle / (2 * np.pi) + 1.0, 1.0)
|
704 |
+
|
705 |
+
saturation = torch.ones_like(hue) * 0.75
|
706 |
+
value = radius_clipped
|
707 |
+
hsv = torch.cat([hue, saturation, value], dim=1) #B x 3 x H x W
|
708 |
+
|
709 |
+
#flow = tf.image.hsv_to_rgb(hsv)
|
710 |
+
flow = hsv_to_rgb(hsv)
|
711 |
+
flow = (flow*255.0).type(torch.ByteTensor)
|
712 |
+
# flow = torch.flip(flow, dims=[1]).clone() # BGR
|
713 |
+
return flow
|
714 |
+
|
715 |
+
def summ_flow(self, name, im, clip=0.0, only_return=False, frame_id=None, frame_str=None, shadow=True):
|
716 |
+
# flow is B x C x D x W
|
717 |
+
if self.save_this:
|
718 |
+
return self.summ_rgb(name, self.flow2color(im, clip=clip), only_return=only_return, frame_id=frame_id, frame_str=frame_str, shadow=shadow)
|
719 |
+
else:
|
720 |
+
return None
|
721 |
+
|
722 |
+
def summ_oneds(self, name, ims, frame_ids=None, frame_strs=None, bev=False, fro=False, logvis=False, reduce_max=False, max_val=0.0, norm=True, only_return=False, do_colorize=False):
|
723 |
+
if self.save_this:
|
724 |
+
if bev:
|
725 |
+
B, C, H, _, W = list(ims[0].shape)
|
726 |
+
if reduce_max:
|
727 |
+
ims = [torch.max(im, dim=3)[0] for im in ims]
|
728 |
+
else:
|
729 |
+
ims = [torch.mean(im, dim=3) for im in ims]
|
730 |
+
elif fro:
|
731 |
+
B, C, _, H, W = list(ims[0].shape)
|
732 |
+
if reduce_max:
|
733 |
+
ims = [torch.max(im, dim=2)[0] for im in ims]
|
734 |
+
else:
|
735 |
+
ims = [torch.mean(im, dim=2) for im in ims]
|
736 |
+
|
737 |
+
|
738 |
+
if len(ims) != 1: # sequence
|
739 |
+
im = gif_and_tile(ims, just_gif=self.just_gif)
|
740 |
+
else:
|
741 |
+
im = torch.stack(ims, dim=1) # single frame
|
742 |
+
|
743 |
+
B, S, C, H, W = list(im.shape)
|
744 |
+
|
745 |
+
if logvis and max_val:
|
746 |
+
max_val = np.log(max_val)
|
747 |
+
im = torch.log(torch.clamp(im, 0)+1.0)
|
748 |
+
im = torch.clamp(im, 0, max_val)
|
749 |
+
im = im/max_val
|
750 |
+
norm = False
|
751 |
+
elif max_val:
|
752 |
+
im = torch.clamp(im, 0, max_val)
|
753 |
+
im = im/max_val
|
754 |
+
norm = False
|
755 |
+
|
756 |
+
if norm:
|
757 |
+
# normalize before oned2inferno,
|
758 |
+
# so that the ranges are similar within B across S
|
759 |
+
im = utils.basic.normalize(im)
|
760 |
+
|
761 |
+
im = im.view(B*S, C, H, W)
|
762 |
+
vis = oned2inferno(im, norm=norm, do_colorize=do_colorize)
|
763 |
+
vis = vis.view(B, S, 3, H, W)
|
764 |
+
|
765 |
+
if frame_ids is not None:
|
766 |
+
assert(len(frame_ids)==S)
|
767 |
+
for s in range(S):
|
768 |
+
vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
|
769 |
+
|
770 |
+
if frame_strs is not None:
|
771 |
+
assert(len(frame_strs)==S)
|
772 |
+
for s in range(S):
|
773 |
+
vis[:,s] = draw_frame_str_on_vis(vis[:,s], frame_strs[s])
|
774 |
+
|
775 |
+
if W > self.maxwidth:
|
776 |
+
vis = vis[...,:self.maxwidth]
|
777 |
+
|
778 |
+
if only_return:
|
779 |
+
return vis
|
780 |
+
else:
|
781 |
+
self.summ_gif(name, vis)
|
782 |
+
|
783 |
+
def summ_oned(self, name, im, bev=False, fro=False, logvis=False, max_val=0, max_along_y=False, norm=True, frame_id=None, frame_str=None, only_return=False, shadow=True):
|
784 |
+
if self.save_this:
|
785 |
+
|
786 |
+
if bev:
|
787 |
+
B, C, H, _, W = list(im.shape)
|
788 |
+
if max_along_y:
|
789 |
+
im = torch.max(im, dim=3)[0]
|
790 |
+
else:
|
791 |
+
im = torch.mean(im, dim=3)
|
792 |
+
elif fro:
|
793 |
+
B, C, _, H, W = list(im.shape)
|
794 |
+
if max_along_y:
|
795 |
+
im = torch.max(im, dim=2)[0]
|
796 |
+
else:
|
797 |
+
im = torch.mean(im, dim=2)
|
798 |
+
else:
|
799 |
+
B, C, H, W = list(im.shape)
|
800 |
+
|
801 |
+
im = im[0:1] # just the first one
|
802 |
+
assert(C==1)
|
803 |
+
|
804 |
+
if logvis and max_val:
|
805 |
+
max_val = np.log(max_val)
|
806 |
+
im = torch.log(im)
|
807 |
+
im = torch.clamp(im, 0, max_val)
|
808 |
+
im = im/max_val
|
809 |
+
norm = False
|
810 |
+
elif max_val:
|
811 |
+
im = torch.clamp(im, 0, max_val)/max_val
|
812 |
+
norm = False
|
813 |
+
|
814 |
+
vis = oned2inferno(im, norm=norm)
|
815 |
+
if W > self.maxwidth:
|
816 |
+
vis = vis[...,:self.maxwidth]
|
817 |
+
return self.summ_rgb(name, vis, blacken_zeros=False, frame_id=frame_id, frame_str=frame_str, only_return=only_return, shadow=shadow)
|
818 |
+
|
819 |
+
def summ_4chan(self, name, im, norm=True, frame_id=None, frame_str=None, only_return=False):
|
820 |
+
if self.save_this:
|
821 |
+
|
822 |
+
B, C, H, W = list(im.shape)
|
823 |
+
|
824 |
+
im = im[0:1] # just the first one
|
825 |
+
assert(C==4)
|
826 |
+
|
827 |
+
# d = utils.basic.normalize(d)
|
828 |
+
im0 = im[:,0:1]
|
829 |
+
im1 = im[:,1:2]
|
830 |
+
im2 = im[:,2:3]
|
831 |
+
im3 = im[:,3:4]
|
832 |
+
|
833 |
+
im0 = utils.basic.normalize(im0).round()
|
834 |
+
im1 = utils.basic.normalize(im1).round()
|
835 |
+
im2 = utils.basic.normalize(im2).round()
|
836 |
+
im3 = utils.basic.normalize(im3).round()
|
837 |
+
|
838 |
+
# kp_vis = sw.summ_rgbs('tff/2_kp_s%d' % s, kp.unbind(1), only_return=True)
|
839 |
+
# kp_any = (torch.max(kp_vis, dim=2, keepdims=True)[0]).repeat(1, 1, 3, 1, 1)
|
840 |
+
# kp_vis[kp_any==0] = fcp_vis[kp_any==0]
|
841 |
+
|
842 |
+
# vis0 = oned2inferno(im0, norm=False)
|
843 |
+
# vis1 = oned2inferno(im1, norm=False)
|
844 |
+
# vis2 = oned2inferno(im2, norm=False)
|
845 |
+
# vis3 = oned2inferno(im3, norm=False)
|
846 |
+
|
847 |
+
# vis0 = self.summ_seg('', im0[:,0:1]*1, only_return=True, frame_id=frame_id, frame_str=frame_str, colormap='tab20')
|
848 |
+
# vis1 = self.summ_seg('', im1[:,0:1]*2, only_return=True, frame_id=frame_id, frame_str=frame_str, colormap='tab20')
|
849 |
+
# vis2 = self.summ_seg('', im2[:,0:1]*3, only_return=True, frame_id=frame_id, frame_str=frame_str, colormap='tab20')
|
850 |
+
# vis3 = self.summ_seg('', im3[:,0:1]*4, only_return=True, frame_id=frame_id, frame_str=frame_str, colormap='tab20')
|
851 |
+
|
852 |
+
vis0 = self.summ_seg('', im0[:,0]*1, only_return=True, colormap='tab20')
|
853 |
+
vis1 = self.summ_seg('', im1[:,0]*2, only_return=True, colormap='tab20')
|
854 |
+
vis2 = self.summ_seg('', im2[:,0]*3, only_return=True, colormap='tab20')
|
855 |
+
vis3 = self.summ_seg('', im3[:,0]*4, only_return=True, colormap='tab20')
|
856 |
+
|
857 |
+
# vis_any = (torch.max(vis2, dim=2, keepdims=True)[0]).repeat(1, 1, 3, 1, 1)
|
858 |
+
# vis3[vis_any==0] = fcp_vis[kp_any==0]
|
859 |
+
|
860 |
+
vis0_any = (torch.max(vis0, dim=1, keepdims=True)[0]).repeat(1, 3, 1, 1)
|
861 |
+
vis1_any = (torch.max(vis1, dim=1, keepdims=True)[0]).repeat(1, 3, 1, 1)
|
862 |
+
vis2_any = (torch.max(vis2, dim=1, keepdims=True)[0]).repeat(1, 3, 1, 1)
|
863 |
+
vis3_any = (torch.max(vis3, dim=1, keepdims=True)[0]).repeat(1, 3, 1, 1)
|
864 |
+
vis0[vis0_any==0] = vis1[vis0_any==0]
|
865 |
+
vis0[vis1_any==0] = vis2[vis1_any==0]
|
866 |
+
vis0[vis2_any==0] = vis3[vis2_any==0]
|
867 |
+
|
868 |
+
print('vis0', vis0.shape, vis0.device)
|
869 |
+
|
870 |
+
vis0 = vis0.cpu()
|
871 |
+
|
872 |
+
# vis = oned2inferno(im, norm=norm)
|
873 |
+
# if W > self.maxwidth:
|
874 |
+
# vis = vis[...,:self.maxwidth]
|
875 |
+
return self.summ_rgb(name, vis0, blacken_zeros=False, frame_id=frame_id, frame_str=frame_str, only_return=only_return)
|
876 |
+
|
877 |
+
|
878 |
+
def summ_feats(self, name, feats, valids=None, pca=True, fro=False, only_return=False, frame_ids=None, frame_strs=None):
|
879 |
+
if self.save_this:
|
880 |
+
if valids is not None:
|
881 |
+
valids = torch.stack(valids, dim=1)
|
882 |
+
|
883 |
+
feats = torch.stack(feats, dim=1)
|
884 |
+
# feats leads with B x S x C
|
885 |
+
|
886 |
+
if feats.ndim==6:
|
887 |
+
|
888 |
+
# feats is B x S x C x D x H x W
|
889 |
+
if fro:
|
890 |
+
reduce_dim = 3
|
891 |
+
else:
|
892 |
+
reduce_dim = 4
|
893 |
+
|
894 |
+
if valids is None:
|
895 |
+
feats = torch.mean(feats, dim=reduce_dim)
|
896 |
+
else:
|
897 |
+
valids = valids.repeat(1, 1, feats.size()[2], 1, 1, 1)
|
898 |
+
feats = utils.basic.reduce_masked_mean(feats, valids, dim=reduce_dim)
|
899 |
+
|
900 |
+
B, S, C, D, W = list(feats.size())
|
901 |
+
|
902 |
+
if not pca:
|
903 |
+
# feats leads with B x S x C
|
904 |
+
feats = torch.mean(torch.abs(feats), dim=2, keepdims=True)
|
905 |
+
# feats leads with B x S x 1
|
906 |
+
feats = torch.unbind(feats, dim=1)
|
907 |
+
return self.summ_oneds(name=name, ims=feats, norm=True, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
|
908 |
+
|
909 |
+
else:
|
910 |
+
__p = lambda x: utils.basic.pack_seqdim(x, B)
|
911 |
+
__u = lambda x: utils.basic.unpack_seqdim(x, B)
|
912 |
+
|
913 |
+
feats_ = __p(feats)
|
914 |
+
|
915 |
+
if valids is None:
|
916 |
+
feats_pca_ = get_feat_pca(feats_)
|
917 |
+
else:
|
918 |
+
valids_ = __p(valids)
|
919 |
+
feats_pca_ = get_feat_pca(feats_, valids)
|
920 |
+
|
921 |
+
feats_pca = __u(feats_pca_)
|
922 |
+
|
923 |
+
return self.summ_rgbs(name=name, ims=torch.unbind(feats_pca, dim=1), only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
|
924 |
+
|
925 |
+
def summ_feat(self, name, feat, valid=None, pca=True, only_return=False, bev=False, fro=False, frame_id=None, frame_str=None):
|
926 |
+
if self.save_this:
|
927 |
+
if feat.ndim==5: # B x C x D x H x W
|
928 |
+
|
929 |
+
if bev:
|
930 |
+
reduce_axis = 3
|
931 |
+
elif fro:
|
932 |
+
reduce_axis = 2
|
933 |
+
else:
|
934 |
+
# default to bev
|
935 |
+
reduce_axis = 3
|
936 |
+
|
937 |
+
if valid is None:
|
938 |
+
feat = torch.mean(feat, dim=reduce_axis)
|
939 |
+
else:
|
940 |
+
valid = valid.repeat(1, feat.size()[1], 1, 1, 1)
|
941 |
+
feat = utils.basic.reduce_masked_mean(feat, valid, dim=reduce_axis)
|
942 |
+
|
943 |
+
B, C, D, W = list(feat.shape)
|
944 |
+
|
945 |
+
if not pca:
|
946 |
+
feat = torch.mean(torch.abs(feat), dim=1, keepdims=True)
|
947 |
+
# feat is B x 1 x D x W
|
948 |
+
return self.summ_oned(name=name, im=feat, norm=True, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
|
949 |
+
else:
|
950 |
+
feat_pca = get_feat_pca(feat, valid)
|
951 |
+
return self.summ_rgb(name, feat_pca, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
|
952 |
+
|
953 |
+
def summ_scalar(self, name, value):
|
954 |
+
if (not (isinstance(value, int) or isinstance(value, float) or isinstance(value, np.float32))) and ('torch' in value.type()):
|
955 |
+
value = value.detach().cpu().numpy()
|
956 |
+
if not np.isnan(value):
|
957 |
+
if (self.log_freq == 1):
|
958 |
+
self.writer.add_scalar(name, value, global_step=self.global_step)
|
959 |
+
elif self.save_this or self.save_scalar:
|
960 |
+
self.writer.add_scalar(name, value, global_step=self.global_step)
|
961 |
+
|
962 |
+
def summ_seg(self, name, seg, only_return=False, frame_id=None, frame_str=None, colormap='tab20', label_colors=None):
|
963 |
+
if not self.save_this:
|
964 |
+
return
|
965 |
+
|
966 |
+
B,H,W = seg.shape
|
967 |
+
|
968 |
+
if label_colors is None:
|
969 |
+
custom_label_colors = False
|
970 |
+
# label_colors = get_n_colors(int(torch.max(seg).item()), sequential=True)
|
971 |
+
label_colors = cm.get_cmap(colormap).colors
|
972 |
+
label_colors = [[int(i*255) for i in l] for l in label_colors]
|
973 |
+
else:
|
974 |
+
custom_label_colors = True
|
975 |
+
# label_colors = matplotlib.cm.get_cmap(colormap).colors
|
976 |
+
# label_colors = [[int(i*255) for i in l] for l in label_colors]
|
977 |
+
# print('label_colors', label_colors)
|
978 |
+
|
979 |
+
# label_colors = [
|
980 |
+
# (0, 0, 0), # None
|
981 |
+
# (70, 70, 70), # Buildings
|
982 |
+
# (190, 153, 153), # Fences
|
983 |
+
# (72, 0, 90), # Other
|
984 |
+
# (220, 20, 60), # Pedestrians
|
985 |
+
# (153, 153, 153), # Poles
|
986 |
+
# (157, 234, 50), # RoadLines
|
987 |
+
# (128, 64, 128), # Roads
|
988 |
+
# (244, 35, 232), # Sidewalks
|
989 |
+
# (107, 142, 35), # Vegetation
|
990 |
+
# (0, 0, 255), # Vehicles
|
991 |
+
# (102, 102, 156), # Walls
|
992 |
+
# (220, 220, 0) # TrafficSigns
|
993 |
+
# ]
|
994 |
+
|
995 |
+
r = torch.zeros_like(seg,dtype=torch.uint8)
|
996 |
+
g = torch.zeros_like(seg,dtype=torch.uint8)
|
997 |
+
b = torch.zeros_like(seg,dtype=torch.uint8)
|
998 |
+
|
999 |
+
for label in range(0,len(label_colors)):
|
1000 |
+
if (not custom_label_colors):# and (N > 20):
|
1001 |
+
label_ = label % 20
|
1002 |
+
else:
|
1003 |
+
label_ = label
|
1004 |
+
|
1005 |
+
idx = (seg == label)
|
1006 |
+
r[idx] = label_colors[label_][0]
|
1007 |
+
g[idx] = label_colors[label_][1]
|
1008 |
+
b[idx] = label_colors[label_][2]
|
1009 |
+
|
1010 |
+
rgb = torch.stack([r,g,b],axis=1)
|
1011 |
+
return self.summ_rgb(name,rgb,only_return=only_return, frame_id=frame_id, frame_str=frame_str)
|
1012 |
+
|
1013 |
+
def summ_traj2ds_on_rgbs(self, name, trajs, rgbs, visibs=None, valids=None, frame_ids=None, frame_strs=None, only_return=False, show_dots=True, cmap='coolwarm', vals=None, linewidth=1, max_show=1024):
|
1014 |
+
# trajs is B, S, N, 2
|
1015 |
+
# rgbs is B, S, C, H, W
|
1016 |
+
B, S, C, H, W = rgbs.shape
|
1017 |
+
B, S2, N, D = trajs.shape
|
1018 |
+
assert(S==S2)
|
1019 |
+
|
1020 |
+
|
1021 |
+
rgbs = rgbs[0] # S, C, H, W
|
1022 |
+
trajs = trajs[0] # S, N, 2
|
1023 |
+
if valids is None:
|
1024 |
+
valids = torch.ones_like(trajs[:,:,0]) # S, N
|
1025 |
+
else:
|
1026 |
+
valids = valids[0]
|
1027 |
+
|
1028 |
+
if visibs is None:
|
1029 |
+
visibs = torch.ones_like(trajs[:,:,0]) # S, N
|
1030 |
+
else:
|
1031 |
+
visibs = visibs[0]
|
1032 |
+
|
1033 |
+
if vals is not None:
|
1034 |
+
vals = vals[0] # N
|
1035 |
+
# print('vals', vals.shape)
|
1036 |
+
|
1037 |
+
if N > max_show:
|
1038 |
+
inds = np.random.choice(N, max_show)
|
1039 |
+
trajs = trajs[:,inds]
|
1040 |
+
valids = valids[:,inds]
|
1041 |
+
visibs = visibs[:,inds]
|
1042 |
+
if vals is not None:
|
1043 |
+
vals = vals[inds]
|
1044 |
+
N = trajs.shape[1]
|
1045 |
+
|
1046 |
+
trajs = trajs.clamp(-16, W+16)
|
1047 |
+
|
1048 |
+
rgbs_color = []
|
1049 |
+
for rgb in rgbs:
|
1050 |
+
rgb = back2color(rgb).detach().cpu().numpy()
|
1051 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
1052 |
+
rgbs_color.append(rgb) # each element 3 x H x W
|
1053 |
+
|
1054 |
+
for i in range(min(N, max_show)):
|
1055 |
+
if cmap=='onediff' and i==0:
|
1056 |
+
cmap_ = 'spring'
|
1057 |
+
elif cmap=='onediff':
|
1058 |
+
cmap_ = 'winter'
|
1059 |
+
else:
|
1060 |
+
cmap_ = cmap
|
1061 |
+
traj = trajs[:,i].long().detach().cpu().numpy() # S, 2
|
1062 |
+
valid = valids[:,i].long().detach().cpu().numpy() # S
|
1063 |
+
|
1064 |
+
# print('traj', traj.shape)
|
1065 |
+
# print('valid', valid.shape)
|
1066 |
+
|
1067 |
+
if vals is not None:
|
1068 |
+
# val = vals[:,i].float().detach().cpu().numpy() # []
|
1069 |
+
val = vals[i].float().detach().cpu().numpy() # []
|
1070 |
+
# print('val', val.shape)
|
1071 |
+
else:
|
1072 |
+
val = None
|
1073 |
+
|
1074 |
+
for t in range(S):
|
1075 |
+
if valid[t]:
|
1076 |
+
rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj[:t+1], S=S, show_dots=show_dots, cmap=cmap_, val=val, linewidth=linewidth)
|
1077 |
+
|
1078 |
+
for i in range(min(N, max_show)):
|
1079 |
+
if cmap=='onediff' and i==0:
|
1080 |
+
cmap_ = 'spring'
|
1081 |
+
elif cmap=='onediff':
|
1082 |
+
cmap_ = 'winter'
|
1083 |
+
else:
|
1084 |
+
cmap_ = cmap
|
1085 |
+
traj = trajs[:,i] # S,2
|
1086 |
+
vis = visibs[:,i].round() # S
|
1087 |
+
valid = valids[:,i] # S
|
1088 |
+
rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
|
1089 |
+
|
1090 |
+
rgbs = []
|
1091 |
+
for rgb in rgbs_color:
|
1092 |
+
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
1093 |
+
rgbs.append(preprocess_color(rgb))
|
1094 |
+
|
1095 |
+
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
|
1096 |
+
|
1097 |
+
def summ_traj2ds_on_rgbs2(self, name, trajs, visibles, rgbs, valids=None, frame_ids=None, frame_strs=None, only_return=False, show_dots=True, cmap=None, linewidth=1, max_show=1024):
|
1098 |
+
# trajs is B, S, N, 2
|
1099 |
+
# rgbs is B, S, C, H, W
|
1100 |
+
B, S, C, H, W = rgbs.shape
|
1101 |
+
B, S2, N, D = trajs.shape
|
1102 |
+
assert(S==S2)
|
1103 |
+
|
1104 |
+
rgbs = rgbs[0] # S, C, H, W
|
1105 |
+
trajs = trajs[0] # S, N, 2
|
1106 |
+
visibles = visibles[0] # S, N
|
1107 |
+
if valids is None:
|
1108 |
+
valids = torch.ones_like(trajs[:,:,0]) # S, N
|
1109 |
+
else:
|
1110 |
+
valids = valids[0]
|
1111 |
+
|
1112 |
+
rgbs_color = []
|
1113 |
+
for rgb in rgbs:
|
1114 |
+
rgb = back2color(rgb).detach().cpu().numpy()
|
1115 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
1116 |
+
rgbs_color.append(rgb) # each element 3 x H x W
|
1117 |
+
|
1118 |
+
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
|
1119 |
+
visibles = visibles.float().detach().cpu().numpy() # S, N
|
1120 |
+
valids = valids.long().detach().cpu().numpy() # S, N
|
1121 |
+
|
1122 |
+
for i in range(min(N, max_show)):
|
1123 |
+
if cmap=='onediff' and i==0:
|
1124 |
+
cmap_ = 'spring'
|
1125 |
+
elif cmap=='onediff':
|
1126 |
+
cmap_ = 'winter'
|
1127 |
+
else:
|
1128 |
+
cmap_ = cmap
|
1129 |
+
traj = trajs[:,i] # S,2
|
1130 |
+
vis = visibles[:,i] # S
|
1131 |
+
valid = valids[:,i] # S
|
1132 |
+
rgbs_color = self.draw_traj_on_images_py(rgbs_color, traj, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
|
1133 |
+
|
1134 |
+
for i in range(min(N, max_show)):
|
1135 |
+
if cmap=='onediff' and i==0:
|
1136 |
+
cmap_ = 'spring'
|
1137 |
+
elif cmap=='onediff':
|
1138 |
+
cmap_ = 'winter'
|
1139 |
+
else:
|
1140 |
+
cmap_ = cmap
|
1141 |
+
traj = trajs[:,i] # S,2
|
1142 |
+
vis = visibles[:,i] # S
|
1143 |
+
valid = valids[:,i] # S
|
1144 |
+
rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=None, linewidth=linewidth)
|
1145 |
+
|
1146 |
+
rgbs = []
|
1147 |
+
for rgb in rgbs_color:
|
1148 |
+
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
1149 |
+
rgbs.append(preprocess_color(rgb))
|
1150 |
+
|
1151 |
+
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
|
1152 |
+
|
1153 |
+
def summ_traj2ds_on_rgb(self, name, trajs, rgb, valids=None, show_dots=True, show_lines=True, frame_id=None, frame_str=None, only_return=False, cmap='coolwarm', linewidth=1, max_show=1024):
|
1154 |
+
# trajs is B, S, N, 2
|
1155 |
+
# rgb is B, C, H, W
|
1156 |
+
B, C, H, W = rgb.shape
|
1157 |
+
B, S, N, D = trajs.shape
|
1158 |
+
|
1159 |
+
rgb = rgb[0] # S, C, H, W
|
1160 |
+
trajs = trajs[0] # S, N, 2
|
1161 |
+
|
1162 |
+
if valids is None:
|
1163 |
+
valids = torch.ones_like(trajs[:,:,0])
|
1164 |
+
else:
|
1165 |
+
valids = valids[0]
|
1166 |
+
|
1167 |
+
rgb_color = back2color(rgb).detach().cpu().numpy()
|
1168 |
+
rgb_color = np.transpose(rgb_color, [1, 2, 0]) # put channels last
|
1169 |
+
|
1170 |
+
# using maxdist will dampen the colors for short motions
|
1171 |
+
# norms = torch.sqrt(1e-4 + torch.sum((trajs[-1] - trajs[0])**2, dim=1)) # N
|
1172 |
+
# maxdist = torch.quantile(norms, 0.95).detach().cpu().numpy()
|
1173 |
+
maxdist = None
|
1174 |
+
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
|
1175 |
+
valids = valids.long().detach().cpu().numpy() # S, N
|
1176 |
+
|
1177 |
+
if N > max_show:
|
1178 |
+
inds = np.random.choice(N, max_show)
|
1179 |
+
trajs = trajs[:,inds]
|
1180 |
+
valids = valids[:,inds]
|
1181 |
+
N = trajs.shape[1]
|
1182 |
+
|
1183 |
+
for i in range(min(N, max_show)):
|
1184 |
+
if cmap=='onediff' and i==0:
|
1185 |
+
cmap_ = 'spring'
|
1186 |
+
elif cmap=='onediff':
|
1187 |
+
cmap_ = 'winter'
|
1188 |
+
else:
|
1189 |
+
cmap_ = cmap
|
1190 |
+
traj = trajs[:,i] # S, 2
|
1191 |
+
valid = valids[:,i] # S
|
1192 |
+
if valid[0]==1:
|
1193 |
+
traj = traj[valid>0]
|
1194 |
+
rgb_color = self.draw_traj_on_image_py(
|
1195 |
+
rgb_color, traj, S=S, show_dots=show_dots, show_lines=show_lines, cmap=cmap_, maxdist=maxdist, linewidth=linewidth)
|
1196 |
+
|
1197 |
+
rgb_color = torch.from_numpy(rgb_color).permute(2, 0, 1).unsqueeze(0)
|
1198 |
+
rgb = preprocess_color(rgb_color)
|
1199 |
+
return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
|
1200 |
+
|
1201 |
+
def draw_traj_on_image_py(self, rgb, traj, S=50, linewidth=1, show_dots=False, show_lines=True, cmap='coolwarm', val=None, maxdist=None):
|
1202 |
+
# all inputs are numpy tensors
|
1203 |
+
# rgb is 3 x H x W
|
1204 |
+
# traj is S x 2
|
1205 |
+
|
1206 |
+
H, W, C = rgb.shape
|
1207 |
+
assert(C==3)
|
1208 |
+
|
1209 |
+
rgb = rgb.astype(np.uint8).copy()
|
1210 |
+
|
1211 |
+
S1, D = traj.shape
|
1212 |
+
assert(D==2)
|
1213 |
+
|
1214 |
+
color_map = cm.get_cmap(cmap)
|
1215 |
+
S1, D = traj.shape
|
1216 |
+
|
1217 |
+
for s in range(S1):
|
1218 |
+
if val is not None:
|
1219 |
+
color = np.array(color_map(val)[:3]) * 255 # rgb
|
1220 |
+
else:
|
1221 |
+
if maxdist is not None:
|
1222 |
+
val = (np.sqrt(np.sum((traj[s]-traj[0])**2))/maxdist).clip(0,1)
|
1223 |
+
color = np.array(color_map(val)[:3]) * 255 # rgb
|
1224 |
+
else:
|
1225 |
+
color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb
|
1226 |
+
|
1227 |
+
if show_lines and s<(S1-1):
|
1228 |
+
cv2.line(rgb,
|
1229 |
+
(int(traj[s,0]), int(traj[s,1])),
|
1230 |
+
(int(traj[s+1,0]), int(traj[s+1,1])),
|
1231 |
+
color,
|
1232 |
+
linewidth,
|
1233 |
+
cv2.LINE_AA)
|
1234 |
+
if show_dots:
|
1235 |
+
cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, -1)
|
1236 |
+
|
1237 |
+
# if maxdist is not None:
|
1238 |
+
# val = (np.sqrt(np.sum((traj[-1]-traj[0])**2))/maxdist).clip(0,1)
|
1239 |
+
# color = np.array(color_map(val)[:3]) * 255 # rgb
|
1240 |
+
# else:
|
1241 |
+
# # draw the endpoint of traj, using the next color (which may be the last color)
|
1242 |
+
# color = np.array(color_map((S1-1)/max(1,float(S-2)))[:3]) * 255 # rgb
|
1243 |
+
|
1244 |
+
# # emphasize endpoint
|
1245 |
+
# cv2.circle(rgb, (traj[-1,0], traj[-1,1]), linewidth*2, color, -1)
|
1246 |
+
|
1247 |
+
return rgb
|
1248 |
+
|
1249 |
+
|
1250 |
+
def draw_traj_on_images_py(self, rgbs, traj, S=50, linewidth=1, show_dots=False, cmap='coolwarm', maxdist=None):
|
1251 |
+
# all inputs are numpy tensors
|
1252 |
+
# rgbs is a list of H,W,3
|
1253 |
+
# traj is S,2
|
1254 |
+
H, W, C = rgbs[0].shape
|
1255 |
+
assert(C==3)
|
1256 |
+
|
1257 |
+
rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
|
1258 |
+
|
1259 |
+
S1, D = traj.shape
|
1260 |
+
assert(D==2)
|
1261 |
+
|
1262 |
+
x = int(np.clip(traj[0,0], 0, W-1))
|
1263 |
+
y = int(np.clip(traj[0,1], 0, H-1))
|
1264 |
+
color = rgbs[0][y,x]
|
1265 |
+
color = (int(color[0]),int(color[1]),int(color[2]))
|
1266 |
+
for s in range(S):
|
1267 |
+
# bak_color = np.array(color_map(1.0)[:3]) * 255 # rgb
|
1268 |
+
# cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth*4, bak_color, -1)
|
1269 |
+
cv2.polylines(rgbs[s],
|
1270 |
+
[traj[:s+1]],
|
1271 |
+
False,
|
1272 |
+
color,
|
1273 |
+
linewidth,
|
1274 |
+
cv2.LINE_AA)
|
1275 |
+
return rgbs
|
1276 |
+
|
1277 |
+
def draw_circs_on_image_py(self, rgb, xy, colors=None, linewidth=10, radius=3, show_dots=False, maxdist=None):
|
1278 |
+
# all inputs are numpy tensors
|
1279 |
+
# rgbs is a list of 3,H,W
|
1280 |
+
# xy is N,2
|
1281 |
+
H, W, C = rgb.shape
|
1282 |
+
assert(C==3)
|
1283 |
+
|
1284 |
+
rgb = rgb.astype(np.uint8).copy()
|
1285 |
+
|
1286 |
+
N, D = xy.shape
|
1287 |
+
assert(D==2)
|
1288 |
+
|
1289 |
+
|
1290 |
+
xy = xy.astype(np.float32)
|
1291 |
+
xy[:,0] = np.clip(xy[:,0], 0, W-1)
|
1292 |
+
xy[:,1] = np.clip(xy[:,1], 0, H-1)
|
1293 |
+
xy = xy.astype(np.int32)
|
1294 |
+
|
1295 |
+
|
1296 |
+
|
1297 |
+
if colors is None:
|
1298 |
+
colors = get_n_colors(N)
|
1299 |
+
|
1300 |
+
for n in range(N):
|
1301 |
+
color = colors[n]
|
1302 |
+
# print('color', color)
|
1303 |
+
# color = (color[0]*255).astype(np.uint8)
|
1304 |
+
color = (int(color[0]),int(color[1]),int(color[2]))
|
1305 |
+
|
1306 |
+
# x = int(np.clip(xy[0,0], 0, W-1))
|
1307 |
+
# y = int(np.clip(xy[0,1], 0, H-1))
|
1308 |
+
# color_ = rgbs[0][y,x]
|
1309 |
+
# color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
|
1310 |
+
# color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
|
1311 |
+
|
1312 |
+
cv2.circle(rgb, (int(xy[n,0]), int(xy[n,1])), linewidth, color, 3)
|
1313 |
+
# vis_color = int(np.squeeze(vis[s])*255)
|
1314 |
+
# vis_color = (vis_color,vis_color,vis_color)
|
1315 |
+
# cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth+1, vis_color, -1)
|
1316 |
+
return rgb
|
1317 |
+
|
1318 |
+
def draw_circ_on_images_py(self, rgbs, traj, vis, S=50, linewidth=1, show_dots=False, cmap=None, maxdist=None):
|
1319 |
+
# all inputs are numpy tensors
|
1320 |
+
# rgbs is a list of 3,H,W
|
1321 |
+
# traj is S,2
|
1322 |
+
H, W, C = rgbs[0].shape
|
1323 |
+
assert(C==3)
|
1324 |
+
|
1325 |
+
rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
|
1326 |
+
|
1327 |
+
S1, D = traj.shape
|
1328 |
+
assert(D==2)
|
1329 |
+
|
1330 |
+
if cmap is None:
|
1331 |
+
bremm = ColorMap2d()
|
1332 |
+
traj_ = traj[0:1].astype(np.float32)
|
1333 |
+
traj_[:,0] /= float(W)
|
1334 |
+
traj_[:,1] /= float(H)
|
1335 |
+
color = bremm(traj_)
|
1336 |
+
# print('color', color)
|
1337 |
+
color = (color[0]*255).astype(np.uint8)
|
1338 |
+
color = (int(color[0]),int(color[1]),int(color[2]))
|
1339 |
+
|
1340 |
+
for s in range(S):
|
1341 |
+
if cmap is not None:
|
1342 |
+
color_map = cm.get_cmap(cmap)
|
1343 |
+
# color = np.array(color_map(s/(S-1))[:3]) * 255 # rgb
|
1344 |
+
color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb
|
1345 |
+
# color = color.astype(np.uint8)
|
1346 |
+
# color = (color[0], color[1], color[2])
|
1347 |
+
# print('color', color)
|
1348 |
+
# import ipdb; ipdb.set_trace()
|
1349 |
+
|
1350 |
+
cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+2, color, -1)
|
1351 |
+
vis_color = int(np.squeeze(vis[s])*255)
|
1352 |
+
vis_color = (vis_color,vis_color,vis_color)
|
1353 |
+
cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, vis_color, -1)
|
1354 |
+
|
1355 |
+
return rgbs
|
1356 |
+
|
1357 |
+
def summ_traj_as_crops(self, name, trajs_e, rgbs, frame_id=None, frame_str=None, only_return=False, show_circ=True, trajs_g=None, valids_g=None, is_g=False, anchor_ind=None, ara=None, anchors=None, frame_ids=None, frame_strs=None):
|
1358 |
+
B, S, N, D = trajs_e.shape
|
1359 |
+
assert(N==1)
|
1360 |
+
assert(D==2)
|
1361 |
+
|
1362 |
+
rgbs = back2color(rgbs).detach().cpu().byte().numpy()
|
1363 |
+
|
1364 |
+
rgbs_vis = []
|
1365 |
+
n = 0
|
1366 |
+
pad_amount = 128
|
1367 |
+
trajs_e_py = trajs_e[0].detach().cpu().numpy()
|
1368 |
+
# trajs_e_py = np.clip(trajs_e_py, min=pad_amount/2, max=pad_amoun
|
1369 |
+
trajs_e_py = trajs_e_py + pad_amount
|
1370 |
+
|
1371 |
+
if trajs_g is not None:
|
1372 |
+
trajs_g_py = trajs_g[0].detach().cpu().numpy()
|
1373 |
+
trajs_g_py = trajs_g_py + pad_amount
|
1374 |
+
|
1375 |
+
if valids_g is not None:
|
1376 |
+
valids_g_py = valids_g[0].detach().cpu().numpy()
|
1377 |
+
else:
|
1378 |
+
valids_g_py = np.ones_like(trajs_g_py[:,:,:,0])
|
1379 |
+
|
1380 |
+
|
1381 |
+
for s in range(S):
|
1382 |
+
rgb = rgbs[0,s]
|
1383 |
+
# print('orig rgb', rgb.shape)
|
1384 |
+
rgb = np.transpose(rgb,(1,2,0)) # H, W, 3
|
1385 |
+
|
1386 |
+
rgb = np.pad(rgb, ((pad_amount,pad_amount),(pad_amount,pad_amount),(0,0)))
|
1387 |
+
# print('pad rgb', rgb.shape)
|
1388 |
+
H, W, C = rgb.shape
|
1389 |
+
|
1390 |
+
if trajs_g is not None:
|
1391 |
+
xy_g = trajs_g_py[s,n]
|
1392 |
+
xy_g[0] = np.clip(xy_g[0], pad_amount, W-pad_amount)
|
1393 |
+
xy_g[1] = np.clip(xy_g[1], pad_amount, H-pad_amount)
|
1394 |
+
if valids_g_py[s,n] > 0:
|
1395 |
+
rgb = self.draw_circs_on_image_py(rgb, xy_g.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3)
|
1396 |
+
|
1397 |
+
xy_e = trajs_e_py[s,n]
|
1398 |
+
xy_e[0] = np.clip(xy_e[0], pad_amount, W-pad_amount)
|
1399 |
+
xy_e[1] = np.clip(xy_e[1], pad_amount, H-pad_amount)
|
1400 |
+
|
1401 |
+
if show_circ:
|
1402 |
+
|
1403 |
+
# if (anchors is not None) and (s==anchor_ind):
|
1404 |
+
# rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(0,0,0)], linewidth=8, radius=12)
|
1405 |
+
|
1406 |
+
if (anchors is not None) and (s in anchors):
|
1407 |
+
rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(255,255,255)], linewidth=4, radius=8)
|
1408 |
+
|
1409 |
+
if is_g:
|
1410 |
+
rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3)
|
1411 |
+
else:
|
1412 |
+
rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(255,0,255)], linewidth=2, radius=3)
|
1413 |
+
|
1414 |
+
|
1415 |
+
|
1416 |
+
xmin = int(xy_e[0])-pad_amount//2
|
1417 |
+
xmax = int(xy_e[0])+pad_amount//2
|
1418 |
+
ymin = int(xy_e[1])-pad_amount//2
|
1419 |
+
ymax = int(xy_e[1])+pad_amount//2
|
1420 |
+
|
1421 |
+
rgb_ = rgb[ymin:ymax, xmin:xmax]
|
1422 |
+
|
1423 |
+
H_, W_ = rgb_.shape[:2]
|
1424 |
+
# if np.any(rgb_.shape==0):
|
1425 |
+
# input()
|
1426 |
+
if H_==0 or W_==0:
|
1427 |
+
import ipdb; ipdb.set_trace()
|
1428 |
+
|
1429 |
+
if (ara is not None) and (s in ara):
|
1430 |
+
# green border
|
1431 |
+
rgb_[0,:,0] = 0
|
1432 |
+
rgb_[0,:,1] = 255
|
1433 |
+
rgb_[0,:,2] = 0
|
1434 |
+
|
1435 |
+
rgb_[-1,:,0] = 0
|
1436 |
+
rgb_[-1,:,1] = 255
|
1437 |
+
rgb_[-1,:,2] = 0
|
1438 |
+
|
1439 |
+
rgb_[:,0,0] = 0
|
1440 |
+
rgb_[:,0,1] = 255
|
1441 |
+
rgb_[:,0,2] = 0
|
1442 |
+
|
1443 |
+
rgb_[:,-1,0] = 0
|
1444 |
+
rgb_[:,-1,1] = 255
|
1445 |
+
rgb_[:,-1,2] = 0
|
1446 |
+
|
1447 |
+
if (anchor_ind is not None) and (s==anchor_ind):
|
1448 |
+
# inner green border
|
1449 |
+
pad = 4
|
1450 |
+
|
1451 |
+
rgb_[:pad,:,0] = 0
|
1452 |
+
rgb_[:pad,:,1] = 255
|
1453 |
+
rgb_[:pad,:,2] = 0
|
1454 |
+
|
1455 |
+
rgb_[-pad:,:,0] = 0
|
1456 |
+
rgb_[-pad:,:,1] = 255
|
1457 |
+
rgb_[-pad:,:,2] = 0
|
1458 |
+
|
1459 |
+
rgb_[:,:pad,0] = 0
|
1460 |
+
rgb_[:,:pad,1] = 255
|
1461 |
+
rgb_[:,:pad,2] = 0
|
1462 |
+
|
1463 |
+
rgb_[:,-pad:,0] = 0
|
1464 |
+
rgb_[:,-pad:,1] = 255
|
1465 |
+
rgb_[:,-pad:,2] = 0
|
1466 |
+
|
1467 |
+
rgb_ = rgb_.transpose(2,0,1)
|
1468 |
+
rgb_ = torch.from_numpy(rgb_)
|
1469 |
+
|
1470 |
+
if frame_ids is not None:
|
1471 |
+
# if s==anchor_ind:
|
1472 |
+
# frame_ids[s] = frame_ids[s] + '
|
1473 |
+
rgb_ = draw_frame_id_on_vis(rgb_.unsqueeze(0), frame_ids[s]).squeeze(0)
|
1474 |
+
|
1475 |
+
if s==anchor_ind:
|
1476 |
+
rgb_ = draw_frame_str_on_vis(rgb_.unsqueeze(0), '(A)').squeeze(0)
|
1477 |
+
|
1478 |
+
if frame_strs is not None:
|
1479 |
+
rgb_ = draw_frame_str_on_vis(rgb_.unsqueeze(0), frame_strs[s]).squeeze(0)
|
1480 |
+
|
1481 |
+
rgbs_vis.append(rgb_)
|
1482 |
+
|
1483 |
+
# nrow = int(np.sqrt(S)*(16.0/9)/2.0)
|
1484 |
+
nrow = int(np.sqrt(S)*1.5)
|
1485 |
+
grid_img = torchvision.utils.make_grid(torch.stack(rgbs_vis, dim=0), nrow=nrow).unsqueeze(0)
|
1486 |
+
# print('grid_img', grid_img.shape)
|
1487 |
+
return self.summ_rgb(name, grid_img.byte(), frame_id=frame_id, frame_str=frame_str, only_return=only_return)
|
1488 |
+
|
1489 |
+
def summ_pts_on_rgb(self, name, trajs, rgb, visibs=None, valids=None, frame_id=None, frame_str=None, only_return=False, show_dots=True, colors=None, cmap='coolwarm', linewidth=1, max_show=1024, already_sorted=False):
|
1490 |
+
# trajs is B, S, N, 2
|
1491 |
+
# rgbs is B, S, C, H, W
|
1492 |
+
B, C, H, W = rgb.shape
|
1493 |
+
B, S, N, D = trajs.shape
|
1494 |
+
|
1495 |
+
rgb = rgb[0] # C, H, W
|
1496 |
+
trajs = trajs[0] # S, N, 2
|
1497 |
+
if valids is None:
|
1498 |
+
valids = torch.ones_like(trajs[:,:,0]) # S, N
|
1499 |
+
else:
|
1500 |
+
valids = valids[0]
|
1501 |
+
if visibs is None:
|
1502 |
+
visibs = torch.ones_like(trajs[:,:,0]) # S, N
|
1503 |
+
else:
|
1504 |
+
visibs = visibs[0]
|
1505 |
+
|
1506 |
+
trajs = trajs.clamp(-16, W+16)
|
1507 |
+
|
1508 |
+
if N > max_show:
|
1509 |
+
inds = np.random.choice(N, max_show)
|
1510 |
+
trajs = trajs[:,inds]
|
1511 |
+
valids = valids[:,inds]
|
1512 |
+
visibs = visibs[:,inds]
|
1513 |
+
N = trajs.shape[1]
|
1514 |
+
|
1515 |
+
if not already_sorted:
|
1516 |
+
inds = torch.argsort(torch.mean(trajs[:,:,1], dim=0))
|
1517 |
+
trajs = trajs[:,inds]
|
1518 |
+
valids = valids[:,inds]
|
1519 |
+
visibs = visibs[:,inds]
|
1520 |
+
|
1521 |
+
rgb = back2color(rgb).detach().cpu().numpy()
|
1522 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
1523 |
+
|
1524 |
+
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
|
1525 |
+
valids = valids.long().detach().cpu().numpy() # S, N
|
1526 |
+
visibs = visibs.long().detach().cpu().numpy() # S, N
|
1527 |
+
|
1528 |
+
rgb = rgb.astype(np.uint8).copy()
|
1529 |
+
|
1530 |
+
for i in range(min(N, max_show)):
|
1531 |
+
if cmap=='onediff' and i==0:
|
1532 |
+
cmap_ = 'spring'
|
1533 |
+
elif cmap=='onediff':
|
1534 |
+
cmap_ = 'winter'
|
1535 |
+
else:
|
1536 |
+
cmap_ = cmap
|
1537 |
+
traj = trajs[:,i] # S,2
|
1538 |
+
valid = valids[:,i] # S
|
1539 |
+
visib = visibs[:,i] # S
|
1540 |
+
|
1541 |
+
if colors is None:
|
1542 |
+
ii = i/(1e-4+N-1.0)
|
1543 |
+
color_map = cm.get_cmap(cmap)
|
1544 |
+
color = np.array(color_map(ii)[:3]) * 255 # rgb
|
1545 |
+
else:
|
1546 |
+
color = np.array(colors[i]).astype(np.int64)
|
1547 |
+
color = (int(color[0]),int(color[1]),int(color[2]))
|
1548 |
+
|
1549 |
+
for s in range(S):
|
1550 |
+
if valid[s]:
|
1551 |
+
if visib[s]:
|
1552 |
+
thickness = -1
|
1553 |
+
else:
|
1554 |
+
thickness = 2
|
1555 |
+
cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, thickness)
|
1556 |
+
rgb = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0)
|
1557 |
+
rgb = preprocess_color(rgb)
|
1558 |
+
return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
|
1559 |
+
|
1560 |
+
def summ_pts_on_rgbs(self, name, trajs, rgbs, visibs=None, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap='coolwarm', colors=None, linewidth=1, max_show=1024, frame_strs=None):
|
1561 |
+
# trajs is B, S, N, 2
|
1562 |
+
# rgbs is B, S, C, H, W
|
1563 |
+
B, S, C, H, W = rgbs.shape
|
1564 |
+
B, S2, N, D = trajs.shape
|
1565 |
+
assert(S==S2)
|
1566 |
+
|
1567 |
+
rgbs = rgbs[0] # S, C, H, W
|
1568 |
+
trajs = trajs[0] # S, N, 2
|
1569 |
+
if valids is None:
|
1570 |
+
valids = torch.ones_like(trajs[:,:,0]) # S, N
|
1571 |
+
else:
|
1572 |
+
valids = valids[0]
|
1573 |
+
if visibs is None:
|
1574 |
+
visibs = torch.ones_like(trajs[:,:,0]) # S, N
|
1575 |
+
else:
|
1576 |
+
visibs = visibs[0]
|
1577 |
+
|
1578 |
+
if N > max_show:
|
1579 |
+
inds = np.random.choice(N, max_show)
|
1580 |
+
trajs = trajs[:,inds]
|
1581 |
+
valids = valids[:,inds]
|
1582 |
+
visibs = visibs[:,inds]
|
1583 |
+
N = trajs.shape[1]
|
1584 |
+
inds = torch.argsort(torch.mean(trajs[:,:,1], dim=0))
|
1585 |
+
trajs = trajs[:,inds]
|
1586 |
+
valids = valids[:,inds]
|
1587 |
+
visibs = visibs[:,inds]
|
1588 |
+
|
1589 |
+
rgbs_color = []
|
1590 |
+
for rgb in rgbs:
|
1591 |
+
rgb = back2color(rgb).detach().cpu().numpy()
|
1592 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
1593 |
+
rgbs_color.append(rgb) # each element 3 x H x W
|
1594 |
+
|
1595 |
+
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
|
1596 |
+
valids = valids.long().detach().cpu().numpy() # S, N
|
1597 |
+
visibs = visibs.long().detach().cpu().numpy() # S, N
|
1598 |
+
|
1599 |
+
rgbs_color = [rgb.astype(np.uint8).copy() for rgb in rgbs_color]
|
1600 |
+
|
1601 |
+
for i in range(min(N, max_show)):
|
1602 |
+
traj = trajs[:,i] # S,2
|
1603 |
+
valid = valids[:,i] # S
|
1604 |
+
visib = visibs[:,i] # S
|
1605 |
+
|
1606 |
+
if colors is None:
|
1607 |
+
ii = i/(1e-4+N-1.0)
|
1608 |
+
color_map = cm.get_cmap(cmap)
|
1609 |
+
color = np.array(color_map(ii)[:3]) * 255 # rgb
|
1610 |
+
else:
|
1611 |
+
color = np.array(colors[i]).astype(np.int64)
|
1612 |
+
color = (int(color[0]),int(color[1]),int(color[2]))
|
1613 |
+
|
1614 |
+
for s in range(S):
|
1615 |
+
if valid[s]:
|
1616 |
+
if visib[s]:
|
1617 |
+
thickness = -1
|
1618 |
+
else:
|
1619 |
+
thickness = 2
|
1620 |
+
cv2.circle(rgbs_color[s], (int(traj[s,0]), int(traj[s,1])), int(linewidth), color, thickness)
|
1621 |
+
rgbs = []
|
1622 |
+
for rgb in rgbs_color:
|
1623 |
+
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
1624 |
+
rgbs.append(preprocess_color(rgb))
|
1625 |
+
|
1626 |
+
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
|
1627 |
+
|
1628 |
+
|
1629 |
+
def erode2d(im, times=1):
|
1630 |
+
device = im.device
|
1631 |
+
weights2d = torch.ones(1, 1, 3, 3, dtype=im.dtype, device=device)
|
1632 |
+
for time in range(times):
|
1633 |
+
im = 1.0 - F.conv2d(1.0 - im, weights2d, padding=1).clamp(0, 1)
|
1634 |
+
return im
|
1635 |
+
|
1636 |
+
def dilate2d(im, times=1):
|
1637 |
+
device = im.device
|
1638 |
+
assert(times>0)
|
1639 |
+
dilation_kernel_size = times*2 + 1
|
1640 |
+
padding_size = dilation_kernel_size // 2
|
1641 |
+
dilation_kernel = torch.ones((1, 1, dilation_kernel_size, dilation_kernel_size), device=device)
|
1642 |
+
im = F.conv2d(im, dilation_kernel, padding=padding_size, groups=im.shape[1]).clamp(0,1)
|
1643 |
+
return im
|
1644 |
+
|
1645 |
+
|
utils/loss.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
from typing import List
|
5 |
+
import utils.basic
|
6 |
+
|
7 |
+
|
8 |
+
def sequence_loss(
|
9 |
+
flow_preds,
|
10 |
+
flow_gt,
|
11 |
+
valids,
|
12 |
+
vis=None,
|
13 |
+
gamma=0.8,
|
14 |
+
use_huber_loss=False,
|
15 |
+
loss_only_for_visible=False,
|
16 |
+
):
|
17 |
+
"""Loss function defined over sequence of flow predictions"""
|
18 |
+
total_flow_loss = 0.0
|
19 |
+
for j in range(len(flow_gt)):
|
20 |
+
B, S, N, D = flow_gt[j].shape
|
21 |
+
B, S2, N = valids[j].shape
|
22 |
+
assert S == S2
|
23 |
+
n_predictions = len(flow_preds[j])
|
24 |
+
flow_loss = 0.0
|
25 |
+
for i in range(n_predictions):
|
26 |
+
i_weight = gamma ** (n_predictions - i - 1)
|
27 |
+
flow_pred = flow_preds[j][i]
|
28 |
+
if use_huber_loss:
|
29 |
+
i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0)
|
30 |
+
else:
|
31 |
+
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
|
32 |
+
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
33 |
+
valid_ = valids[j].clone()
|
34 |
+
if loss_only_for_visible:
|
35 |
+
valid_ = valid_ * vis[j]
|
36 |
+
flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss, valid_)
|
37 |
+
flow_loss = flow_loss / n_predictions
|
38 |
+
total_flow_loss += flow_loss
|
39 |
+
return total_flow_loss / len(flow_gt)
|
40 |
+
|
41 |
+
def sequence_loss_dense(
|
42 |
+
flow_preds,
|
43 |
+
flow_gt,
|
44 |
+
valids,
|
45 |
+
vis=None,
|
46 |
+
gamma=0.8,
|
47 |
+
use_huber_loss=False,
|
48 |
+
loss_only_for_visible=False,
|
49 |
+
):
|
50 |
+
"""Loss function defined over sequence of flow predictions"""
|
51 |
+
total_flow_loss = 0.0
|
52 |
+
for j in range(len(flow_gt)):
|
53 |
+
# print('flow_gt[j]', flow_gt[j].shape)
|
54 |
+
B, S, D, H, W = flow_gt[j].shape
|
55 |
+
B, S2, _, H, W = valids[j].shape
|
56 |
+
assert S == S2
|
57 |
+
n_predictions = len(flow_preds[j])
|
58 |
+
flow_loss = 0.0
|
59 |
+
# import ipdb; ipdb.set_trace()
|
60 |
+
for i in range(n_predictions):
|
61 |
+
# print('flow_e[j][i]', flow_preds[j][i].shape)
|
62 |
+
i_weight = gamma ** (n_predictions - i - 1)
|
63 |
+
flow_pred = flow_preds[j][i] # B,S,2,H,W
|
64 |
+
if use_huber_loss:
|
65 |
+
i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0) # B,S,2,H,W
|
66 |
+
else:
|
67 |
+
i_loss = (flow_pred - flow_gt[j]).abs() # B,S,2,H,W
|
68 |
+
i_loss_ = torch.mean(i_loss, dim=2) # B,S,H,W
|
69 |
+
valid_ = valids[j].reshape(B,S,H,W)
|
70 |
+
# print(' (%d,%d) i_loss_' % (i,j), i_loss_.shape)
|
71 |
+
# print(' (%d,%d) valid_' % (i,j), valid_.shape)
|
72 |
+
if loss_only_for_visible:
|
73 |
+
valid_ = valid_ * vis[j].reshape(B,-1,H,W) # usually B,S,H,W, but maybe B,1,H,W
|
74 |
+
flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss_, valid_, broadcast=True)
|
75 |
+
# import ipdb; ipdb.set_trace()
|
76 |
+
flow_loss = flow_loss / n_predictions
|
77 |
+
total_flow_loss += flow_loss
|
78 |
+
return total_flow_loss / len(flow_gt)
|
79 |
+
|
80 |
+
|
81 |
+
def huber_loss(x, y, delta=1.0):
|
82 |
+
"""Calculate element-wise Huber loss between x and y"""
|
83 |
+
diff = x - y
|
84 |
+
abs_diff = diff.abs()
|
85 |
+
flag = (abs_diff <= delta).float()
|
86 |
+
return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta)
|
87 |
+
|
88 |
+
|
89 |
+
def sequence_BCE_loss(vis_preds, vis_gts, valids=None, use_logits=False):
|
90 |
+
total_bce_loss = 0.0
|
91 |
+
# all_vis_preds = [torch.stack(vp) for vp in vis_preds]
|
92 |
+
# all_vis_preds = torch.stack(all_vis_preds)
|
93 |
+
# utils.basic.print_stats('all_vis_preds', all_vis_preds)
|
94 |
+
for j in range(len(vis_preds)):
|
95 |
+
n_predictions = len(vis_preds[j])
|
96 |
+
bce_loss = 0.0
|
97 |
+
for i in range(n_predictions):
|
98 |
+
# utils.basic.print_stats('vis_preds[%d][%d]' % (j,i), vis_preds[j][i])
|
99 |
+
# utils.basic.print_stats('vis_gts[%d]' % (i), vis_gts[i])
|
100 |
+
if use_logits:
|
101 |
+
loss = F.binary_cross_entropy_with_logits(vis_preds[j][i], vis_gts[j], reduction='none')
|
102 |
+
else:
|
103 |
+
loss = F.binary_cross_entropy(vis_preds[j][i], vis_gts[j], reduction='none')
|
104 |
+
if valids is None:
|
105 |
+
bce_loss += loss.mean()
|
106 |
+
else:
|
107 |
+
bce_loss += (loss * valids[j]).mean()
|
108 |
+
bce_loss = bce_loss / n_predictions
|
109 |
+
total_bce_loss += bce_loss
|
110 |
+
return total_bce_loss / len(vis_preds)
|
111 |
+
|
112 |
+
|
113 |
+
# def sequence_BCE_loss_dense(vis_preds, vis_gts):
|
114 |
+
# total_bce_loss = 0.0
|
115 |
+
# for j in range(len(vis_preds)):
|
116 |
+
# n_predictions = len(vis_preds[j])
|
117 |
+
# bce_loss = 0.0
|
118 |
+
# for i in range(n_predictions):
|
119 |
+
# vis_e = vis_preds[j][i]
|
120 |
+
# vis_g = vis_gts[j]
|
121 |
+
# print('vis_e', vis_e.shape, 'vis_g', vis_g.shape)
|
122 |
+
# vis_loss = F.binary_cross_entropy(vis_e, vis_g)
|
123 |
+
# bce_loss += vis_loss
|
124 |
+
# bce_loss = bce_loss / n_predictions
|
125 |
+
# total_bce_loss += bce_loss
|
126 |
+
# return total_bce_loss / len(vis_preds)
|
127 |
+
|
128 |
+
|
129 |
+
def sequence_prob_loss(
|
130 |
+
tracks: torch.Tensor,
|
131 |
+
confidence: torch.Tensor,
|
132 |
+
target_points: torch.Tensor,
|
133 |
+
visibility: torch.Tensor,
|
134 |
+
expected_dist_thresh: float = 12.0,
|
135 |
+
use_logits=False,
|
136 |
+
):
|
137 |
+
"""Loss for classifying if a point is within pixel threshold of its target."""
|
138 |
+
# Points with an error larger than 12 pixels are likely to be useless; marking
|
139 |
+
# them as occluded will actually improve Jaccard metrics and give
|
140 |
+
# qualitatively better results.
|
141 |
+
total_logprob_loss = 0.0
|
142 |
+
for j in range(len(tracks)):
|
143 |
+
n_predictions = len(tracks[j])
|
144 |
+
logprob_loss = 0.0
|
145 |
+
for i in range(n_predictions):
|
146 |
+
err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=-1)
|
147 |
+
valid = (err <= expected_dist_thresh**2).float()
|
148 |
+
if use_logits:
|
149 |
+
loss = F.binary_cross_entropy_with_logits(confidence[j][i], valid, reduction="none")
|
150 |
+
else:
|
151 |
+
loss = F.binary_cross_entropy(confidence[j][i], valid, reduction="none")
|
152 |
+
loss *= visibility[j]
|
153 |
+
loss = torch.mean(loss, dim=[1, 2])
|
154 |
+
logprob_loss += loss
|
155 |
+
logprob_loss = logprob_loss / n_predictions
|
156 |
+
total_logprob_loss += logprob_loss
|
157 |
+
return total_logprob_loss / len(tracks)
|
158 |
+
|
159 |
+
def sequence_prob_loss_dense(
|
160 |
+
tracks: torch.Tensor,
|
161 |
+
confidence: torch.Tensor,
|
162 |
+
target_points: torch.Tensor,
|
163 |
+
visibility: torch.Tensor,
|
164 |
+
expected_dist_thresh: float = 12.0,
|
165 |
+
use_logits=False,
|
166 |
+
):
|
167 |
+
"""Loss for classifying if a point is within pixel threshold of its target."""
|
168 |
+
# Points with an error larger than 12 pixels are likely to be useless; marking
|
169 |
+
# them as occluded will actually improve Jaccard metrics and give
|
170 |
+
# qualitatively better results.
|
171 |
+
|
172 |
+
# all_confidence = [torch.stack(vp) for vp in confidence]
|
173 |
+
# all_confidence = torch.stack(all_confidence)
|
174 |
+
# utils.basic.print_stats('all_confidence', all_confidence)
|
175 |
+
|
176 |
+
total_logprob_loss = 0.0
|
177 |
+
for j in range(len(tracks)):
|
178 |
+
n_predictions = len(tracks[j])
|
179 |
+
logprob_loss = 0.0
|
180 |
+
for i in range(n_predictions):
|
181 |
+
# print('trajs_e', tracks[j][i].shape)
|
182 |
+
# print('trajs_g', target_points[j].shape)
|
183 |
+
err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=2)
|
184 |
+
positive = (err <= expected_dist_thresh**2).float()
|
185 |
+
# print('conf', confidence[j][i].shape, 'positive', positive.shape)
|
186 |
+
if use_logits:
|
187 |
+
loss = F.binary_cross_entropy_with_logits(confidence[j][i].squeeze(2), positive, reduction="none")
|
188 |
+
else:
|
189 |
+
loss = F.binary_cross_entropy(confidence[j][i].squeeze(2), positive, reduction="none")
|
190 |
+
loss *= visibility[j].squeeze(2) # B,S,H,W
|
191 |
+
loss = torch.mean(loss, dim=[1,2,3])
|
192 |
+
logprob_loss += loss
|
193 |
+
logprob_loss = logprob_loss / n_predictions
|
194 |
+
total_logprob_loss += logprob_loss
|
195 |
+
return total_logprob_loss / len(tracks)
|
196 |
+
|
197 |
+
|
198 |
+
def masked_mean(data, mask, dim):
|
199 |
+
if mask is None:
|
200 |
+
return data.mean(dim=dim, keepdim=True)
|
201 |
+
mask = mask.float()
|
202 |
+
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
|
203 |
+
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
|
204 |
+
mask_sum, min=1.0
|
205 |
+
)
|
206 |
+
return mask_mean
|
207 |
+
|
208 |
+
|
209 |
+
def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
|
210 |
+
if mask is None:
|
211 |
+
return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
|
212 |
+
mask = mask.float()
|
213 |
+
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
|
214 |
+
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
|
215 |
+
mask_sum, min=1.0
|
216 |
+
)
|
217 |
+
mask_var = torch.sum(
|
218 |
+
mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
|
219 |
+
) / torch.clamp(mask_sum, min=1.0)
|
220 |
+
return mask_mean.squeeze(dim), mask_var.squeeze(dim)
|
utils/metric.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
def _seg2bmap(seg, width=None, height=None):
|
6 |
+
"""
|
7 |
+
From a segmentation, compute a binary boundary map with 1 pixel wide
|
8 |
+
boundaries. The boundary pixels are offset by 1/2 pixel towards the
|
9 |
+
origin from the actual segment boundary.
|
10 |
+
Arguments:
|
11 |
+
seg : Segments labeled from 1..k.
|
12 |
+
width : Width of desired bmap <= seg.shape[1]
|
13 |
+
height : Height of desired bmap <= seg.shape[0]
|
14 |
+
Returns:
|
15 |
+
bmap (ndarray): Binary boundary map.
|
16 |
+
David Martin <[email protected]>
|
17 |
+
January 2003
|
18 |
+
"""
|
19 |
+
|
20 |
+
seg = seg.astype(bool)
|
21 |
+
seg[seg > 0] = 1
|
22 |
+
|
23 |
+
assert np.atleast_3d(seg).shape[2] == 1
|
24 |
+
|
25 |
+
width = seg.shape[1] if width is None else width
|
26 |
+
height = seg.shape[0] if height is None else height
|
27 |
+
|
28 |
+
h, w = seg.shape[:2]
|
29 |
+
|
30 |
+
ar1 = float(width) / float(height)
|
31 |
+
ar2 = float(w) / float(h)
|
32 |
+
|
33 |
+
assert not (width > w | height > h | abs(ar1 - ar2) >
|
34 |
+
0.01), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
|
35 |
+
|
36 |
+
e = np.zeros_like(seg)
|
37 |
+
s = np.zeros_like(seg)
|
38 |
+
se = np.zeros_like(seg)
|
39 |
+
|
40 |
+
e[:, :-1] = seg[:, 1:]
|
41 |
+
s[:-1, :] = seg[1:, :]
|
42 |
+
se[:-1, :-1] = seg[1:, 1:]
|
43 |
+
|
44 |
+
b = seg ^ e | seg ^ s | seg ^ se
|
45 |
+
b[-1, :] = seg[-1, :] ^ e[-1, :]
|
46 |
+
b[:, -1] = seg[:, -1] ^ s[:, -1]
|
47 |
+
b[-1, -1] = 0
|
48 |
+
|
49 |
+
if w == width and h == height:
|
50 |
+
bmap = b
|
51 |
+
else:
|
52 |
+
bmap = np.zeros((height, width))
|
53 |
+
for x in range(w):
|
54 |
+
for y in range(h):
|
55 |
+
if b[y, x]:
|
56 |
+
j = 1 + math.floor((y - 1) + height / h)
|
57 |
+
i = 1 + math.floor((x - 1) + width / h)
|
58 |
+
bmap[j, i] = 1
|
59 |
+
|
60 |
+
return bmap
|
61 |
+
|
62 |
+
# https://github.com/davisvideochallenge/davis2017-evaluation/blob/master/davis2017/metrics.py#L6
|
63 |
+
def db_eval_iou(annotation, segmentation, void_pixels=None):
|
64 |
+
""" Compute region similarity as the Jaccard Index.
|
65 |
+
Arguments:
|
66 |
+
annotation (ndarray): binary annotation map.
|
67 |
+
segmentation (ndarray): binary segmentation map.
|
68 |
+
void_pixels (ndarray): optional mask with void pixels
|
69 |
+
|
70 |
+
Return:
|
71 |
+
jaccard (float): region similarity
|
72 |
+
"""
|
73 |
+
assert annotation.shape == segmentation.shape, \
|
74 |
+
f'Annotation({annotation.shape}) and segmentation:{segmentation.shape} dimensions do not match.'
|
75 |
+
annotation = annotation.astype(bool)
|
76 |
+
segmentation = segmentation.astype(bool)
|
77 |
+
|
78 |
+
if void_pixels is not None:
|
79 |
+
assert annotation.shape == void_pixels.shape, \
|
80 |
+
f'Annotation({annotation.shape}) and void pixels:{void_pixels.shape} dimensions do not match.'
|
81 |
+
void_pixels = void_pixels.astype(bool)
|
82 |
+
else:
|
83 |
+
void_pixels = np.zeros_like(segmentation)
|
84 |
+
|
85 |
+
# Intersection between all sets
|
86 |
+
inters = np.sum((segmentation & annotation) & np.logical_not(void_pixels), axis=(-2, -1))
|
87 |
+
union = np.sum((segmentation | annotation) & np.logical_not(void_pixels), axis=(-2, -1))
|
88 |
+
|
89 |
+
j = inters / union
|
90 |
+
if j.ndim == 0:
|
91 |
+
j = 0 if np.isclose(union, 0) else j
|
92 |
+
else:
|
93 |
+
j[np.isclose(union, 0)] = 0
|
94 |
+
return j
|
95 |
+
|
96 |
+
|
97 |
+
def db_eval_boundary(annotation, segmentation, void_pixels=None, bound_th=0.008):
|
98 |
+
assert annotation.shape == segmentation.shape
|
99 |
+
if void_pixels is not None:
|
100 |
+
assert annotation.shape == void_pixels.shape
|
101 |
+
if annotation.ndim == 3:
|
102 |
+
n_frames = annotation.shape[0]
|
103 |
+
f_res = np.zeros(n_frames)
|
104 |
+
for frame_id in range(n_frames):
|
105 |
+
void_pixels_frame = None if void_pixels is None else void_pixels[frame_id, :, :, ]
|
106 |
+
f_res[frame_id] = f_measure(segmentation[frame_id, :, :, ], annotation[frame_id, :, :], void_pixels_frame, bound_th=bound_th)
|
107 |
+
elif annotation.ndim == 2:
|
108 |
+
f_res = f_measure(segmentation, annotation, void_pixels, bound_th=bound_th)
|
109 |
+
else:
|
110 |
+
raise ValueError(f'db_eval_boundary does not support tensors with {annotation.ndim} dimensions')
|
111 |
+
return f_res
|
112 |
+
|
113 |
+
|
114 |
+
def f_measure(foreground_mask, gt_mask, void_pixels=None, bound_th=0.008):
|
115 |
+
"""
|
116 |
+
Compute mean,recall and decay from per-frame evaluation.
|
117 |
+
Calculates precision/recall for boundaries between foreground_mask and
|
118 |
+
gt_mask using morphological operators to speed it up.
|
119 |
+
|
120 |
+
Arguments:
|
121 |
+
foreground_mask (ndarray): binary segmentation image.
|
122 |
+
gt_mask (ndarray): binary annotated image.
|
123 |
+
void_pixels (ndarray): optional mask with void pixels
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
F (float): boundaries F-measure
|
127 |
+
"""
|
128 |
+
assert np.atleast_3d(foreground_mask).shape[2] == 1
|
129 |
+
if void_pixels is not None:
|
130 |
+
void_pixels = void_pixels.astype(bool)
|
131 |
+
else:
|
132 |
+
void_pixels = np.zeros_like(foreground_mask).astype(bool)
|
133 |
+
|
134 |
+
bound_pix = bound_th if bound_th >= 1 else \
|
135 |
+
np.ceil(bound_th * np.linalg.norm(foreground_mask.shape))
|
136 |
+
|
137 |
+
# Get the pixel boundaries of both masks
|
138 |
+
fg_boundary = _seg2bmap(foreground_mask * np.logical_not(void_pixels))
|
139 |
+
gt_boundary = _seg2bmap(gt_mask * np.logical_not(void_pixels))
|
140 |
+
|
141 |
+
from skimage.morphology import disk
|
142 |
+
|
143 |
+
# fg_dil = binary_dilation(fg_boundary, disk(bound_pix))
|
144 |
+
fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))
|
145 |
+
# gt_dil = binary_dilation(gt_boundary, disk(bound_pix))
|
146 |
+
gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))
|
147 |
+
|
148 |
+
# Get the intersection
|
149 |
+
gt_match = gt_boundary * fg_dil
|
150 |
+
fg_match = fg_boundary * gt_dil
|
151 |
+
|
152 |
+
# Area of the intersection
|
153 |
+
n_fg = np.sum(fg_boundary)
|
154 |
+
n_gt = np.sum(gt_boundary)
|
155 |
+
|
156 |
+
# % Compute precision and recall
|
157 |
+
if n_fg == 0 and n_gt > 0:
|
158 |
+
precision = 1
|
159 |
+
recall = 0
|
160 |
+
elif n_fg > 0 and n_gt == 0:
|
161 |
+
precision = 0
|
162 |
+
recall = 1
|
163 |
+
elif n_fg == 0 and n_gt == 0:
|
164 |
+
precision = 1
|
165 |
+
recall = 1
|
166 |
+
else:
|
167 |
+
precision = np.sum(fg_match) / float(n_fg)
|
168 |
+
recall = np.sum(gt_match) / float(n_gt)
|
169 |
+
|
170 |
+
# Compute F measure
|
171 |
+
if precision + recall == 0:
|
172 |
+
F = 0
|
173 |
+
else:
|
174 |
+
F = 2 * precision * recall / (precision + recall)
|
175 |
+
|
176 |
+
return F
|
utils/misc.py
ADDED
@@ -0,0 +1,1062 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import math
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import utils.basic
|
6 |
+
from typing import Tuple, Union
|
7 |
+
|
8 |
+
def standardize_test_data(rgbs, trajs, visibs, valids, S_cap=600, only_first=False, seq_len=None):
|
9 |
+
trajs = trajs.astype(np.float32) # S,N,2
|
10 |
+
visibs = visibs.astype(np.float32) # S,N
|
11 |
+
valids = valids.astype(np.float32) # S,N
|
12 |
+
|
13 |
+
visval_ok = np.sum(valids*visibs, axis=0) > 1
|
14 |
+
trajs = trajs[:,visval_ok]
|
15 |
+
visibs = visibs[:,visval_ok]
|
16 |
+
valids = valids[:,visval_ok]
|
17 |
+
|
18 |
+
# fill in missing data
|
19 |
+
N = trajs.shape[1]
|
20 |
+
for ni in range(N):
|
21 |
+
trajs[:,ni] = utils.misc.data_replace_with_nearest(trajs[:,ni], valids[:,ni])
|
22 |
+
|
23 |
+
# use cap or seq_len
|
24 |
+
if seq_len is not None:
|
25 |
+
S = min(len(rgbs), seq_len)
|
26 |
+
else:
|
27 |
+
S = len(rgbs)
|
28 |
+
S = min(S, S_cap)
|
29 |
+
|
30 |
+
if only_first:
|
31 |
+
# we'll find the best frame to start on
|
32 |
+
best_count = 0
|
33 |
+
best_ind = 0
|
34 |
+
|
35 |
+
for si in range(0,len(rgbs)-64):
|
36 |
+
# try this slice
|
37 |
+
visibs_ = visibs[si:min(si+S,len(rgbs)+1)] # S,N
|
38 |
+
valids_ = valids[si:min(si+S,len(rgbs)+1)] # S,N
|
39 |
+
visval_ok0 = (visibs_[0]*valids_[0]) > 0 # N
|
40 |
+
visval_okA = np.sum(visibs_*valids_, axis=0) > 1 # N
|
41 |
+
all_ok = visval_ok0 & visval_okA
|
42 |
+
# print('- slicing %d to %d; sum(ok) %d' % (si, min(si+S,len(rgbs)+1), np.sum(all_ok)))
|
43 |
+
if np.sum(all_ok) > best_count:
|
44 |
+
best_count = np.sum(all_ok)
|
45 |
+
best_ind = si
|
46 |
+
si = best_ind
|
47 |
+
rgbs = rgbs[si:si+S]
|
48 |
+
trajs = trajs[si:si+S]
|
49 |
+
visibs = visibs[si:si+S]
|
50 |
+
valids = valids[si:si+S]
|
51 |
+
vis_ok0 = visibs[0] > 0 # N
|
52 |
+
trajs = trajs[:,vis_ok0]
|
53 |
+
visibs = visibs[:,vis_ok0]
|
54 |
+
valids = valids[:,vis_ok0]
|
55 |
+
# print('- best_count', best_count, 'best_ind', best_ind)
|
56 |
+
|
57 |
+
if seq_len is not None:
|
58 |
+
rgbs = rgbs[:seq_len]
|
59 |
+
trajs = trajs[:seq_len]
|
60 |
+
valids = valids[:seq_len]
|
61 |
+
|
62 |
+
# req two timesteps valid (after seqlen trim)
|
63 |
+
visval_ok = np.sum(visibs*valids, axis=0) > 1
|
64 |
+
trajs = trajs[:,visval_ok]
|
65 |
+
valids = valids[:,visval_ok]
|
66 |
+
visibs = visibs[:,visval_ok]
|
67 |
+
|
68 |
+
return rgbs, trajs, visibs, valids
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
def get_2d_sincos_pos_embed(
|
73 |
+
embed_dim: int, grid_size: Union[int, Tuple[int, int]]
|
74 |
+
) -> torch.Tensor:
|
75 |
+
"""
|
76 |
+
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
|
77 |
+
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
|
78 |
+
Args:
|
79 |
+
- embed_dim: The embedding dimension.
|
80 |
+
- grid_size: The grid size.
|
81 |
+
Returns:
|
82 |
+
- pos_embed: The generated 2D positional embedding.
|
83 |
+
"""
|
84 |
+
if isinstance(grid_size, tuple):
|
85 |
+
grid_size_h, grid_size_w = grid_size
|
86 |
+
else:
|
87 |
+
grid_size_h = grid_size_w = grid_size
|
88 |
+
grid_h = torch.arange(grid_size_h, dtype=torch.float)
|
89 |
+
grid_w = torch.arange(grid_size_w, dtype=torch.float)
|
90 |
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
91 |
+
grid = torch.stack(grid, dim=0)
|
92 |
+
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
93 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
94 |
+
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
|
95 |
+
|
96 |
+
|
97 |
+
def get_2d_sincos_pos_embed_from_grid(
|
98 |
+
embed_dim: int, grid: torch.Tensor
|
99 |
+
) -> torch.Tensor:
|
100 |
+
"""
|
101 |
+
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
- embed_dim: The embedding dimension.
|
105 |
+
- grid: The grid to generate the embedding from.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
- emb: The generated 2D positional embedding.
|
109 |
+
"""
|
110 |
+
assert embed_dim % 2 == 0
|
111 |
+
|
112 |
+
# use half of dimensions to encode grid_h
|
113 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
114 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
115 |
+
|
116 |
+
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
|
117 |
+
return emb
|
118 |
+
|
119 |
+
|
120 |
+
def get_1d_sincos_pos_embed_from_grid(
|
121 |
+
embed_dim: int, pos: torch.Tensor
|
122 |
+
) -> torch.Tensor:
|
123 |
+
"""
|
124 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
- embed_dim: The embedding dimension.
|
128 |
+
- pos: The position to generate the embedding from.
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
- emb: The generated 1D positional embedding.
|
132 |
+
"""
|
133 |
+
assert embed_dim % 2 == 0
|
134 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
135 |
+
omega /= embed_dim / 2.0
|
136 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
137 |
+
|
138 |
+
pos = pos.reshape(-1) # (M,)
|
139 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
140 |
+
|
141 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
142 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
143 |
+
|
144 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
145 |
+
return emb[None].float()
|
146 |
+
|
147 |
+
|
148 |
+
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
|
149 |
+
"""
|
150 |
+
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
- xy: The coordinates to generate the embedding from.
|
154 |
+
- C: The size of the embedding.
|
155 |
+
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
- pe: The generated 2D positional embedding.
|
159 |
+
"""
|
160 |
+
B, N, D = xy.shape
|
161 |
+
assert D == 2
|
162 |
+
|
163 |
+
x = xy[:, :, 0:1]
|
164 |
+
y = xy[:, :, 1:2]
|
165 |
+
div_term = (
|
166 |
+
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
|
167 |
+
).reshape(1, 1, int(C / 2))
|
168 |
+
|
169 |
+
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
170 |
+
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
171 |
+
|
172 |
+
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
173 |
+
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
174 |
+
|
175 |
+
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
176 |
+
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
177 |
+
|
178 |
+
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
|
179 |
+
if cat_coords:
|
180 |
+
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
|
181 |
+
return pe
|
182 |
+
|
183 |
+
# from datasets.dataset import mask2bbox
|
184 |
+
|
185 |
+
# from pips2
|
186 |
+
def posemb_sincos_2d_xy(xy, C, temperature=10000, dtype=torch.float32, cat_coords=False):
|
187 |
+
device = xy.device
|
188 |
+
dtype = xy.dtype
|
189 |
+
B, S, D = xy.shape
|
190 |
+
assert(D==2)
|
191 |
+
x = xy[:,:,0]
|
192 |
+
y = xy[:,:,1]
|
193 |
+
assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
|
194 |
+
omega = torch.arange(C // 4, device=device) / (C // 4 - 1)
|
195 |
+
omega = 1. / (temperature ** omega)
|
196 |
+
|
197 |
+
y = y.flatten()[:, None] * omega[None, :]
|
198 |
+
x = x.flatten()[:, None] * omega[None, :]
|
199 |
+
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
200 |
+
pe = pe.reshape(B,S,C).type(dtype)
|
201 |
+
if cat_coords:
|
202 |
+
pe = torch.cat([pe, xy], dim=2) # B,N,C+2
|
203 |
+
return pe
|
204 |
+
|
205 |
+
|
206 |
+
# # prevent circular imports
|
207 |
+
# def mask2bbox(mask):
|
208 |
+
# if mask.ndim == 3:
|
209 |
+
# mask = mask[..., 0]
|
210 |
+
# ys, xs = np.where(mask > 0.4)
|
211 |
+
# if ys.size == 0 or xs.size==0:
|
212 |
+
# return np.array((0, 0, 0, 0), dtype=int)
|
213 |
+
# lt = np.array([np.min(xs), np.min(ys)])
|
214 |
+
# rb = np.array([np.max(xs), np.max(ys)]) + 1
|
215 |
+
# return np.concatenate([lt, rb])
|
216 |
+
|
217 |
+
# def get_stark_2d_embedding(H, W, C=64, device='cuda:0', temperature=10000, normalize=True):
|
218 |
+
# scale = 2*math.pi
|
219 |
+
# mask = torch.ones((1,H,W), dtype=torch.float32, device=device)
|
220 |
+
# y_embed = mask.cumsum(1, dtype=torch.float32) # cumulative sum along axis 1 (h axis) --> (b, h, w)
|
221 |
+
# x_embed = mask.cumsum(2, dtype=torch.float32) # cumulative sum along axis 2 (w axis) --> (b, h, w)
|
222 |
+
# if normalize:
|
223 |
+
# eps = 1e-6
|
224 |
+
# y_embed = y_embed / (y_embed[:, -1:, :] + eps) * scale # 2pi * (y / sigma(y))
|
225 |
+
# x_embed = x_embed / (x_embed[:, :, -1:] + eps) * scale # 2pi * (x / sigma(x))
|
226 |
+
|
227 |
+
# dim_t = torch.arange(C, dtype=torch.float32, device=device) # (0,1,2,...,d/2)
|
228 |
+
# dim_t = temperature ** (2 * (dim_t // 2) / C)
|
229 |
+
|
230 |
+
# pos_x = x_embed[:, :, :, None] / dim_t # (b,h,w,d/2)
|
231 |
+
# pos_y = y_embed[:, :, :, None] / dim_t # (b,h,w,d/2)
|
232 |
+
# pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) # (b,h,w,d/2)
|
233 |
+
# pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) # (b,h,w,d/2)
|
234 |
+
# pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) # (b,h,w,d)
|
235 |
+
# return pos
|
236 |
+
|
237 |
+
# def get_1d_embedding(x, C, cat_coords=False):
|
238 |
+
# B, N, D = x.shape
|
239 |
+
# assert(D==1)
|
240 |
+
|
241 |
+
# div_term = (torch.arange(0, C, 2, device=x.device, dtype=torch.float32) * (10000.0 / C)).reshape(1, 1, int(C/2))
|
242 |
+
|
243 |
+
# pe_x = torch.zeros(B, N, C, device=x.device, dtype=torch.float32)
|
244 |
+
|
245 |
+
# pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
246 |
+
# pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
247 |
+
|
248 |
+
# if cat_coords:
|
249 |
+
# pe_x = torch.cat([pe, x], dim=2) # B,N,C*2+2
|
250 |
+
# return pe_x
|
251 |
+
|
252 |
+
# def posemb_sincos_2d(h, w, dim, temperature=10000, dtype=torch.float32, device='cuda:0'):
|
253 |
+
|
254 |
+
# y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
|
255 |
+
# assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
|
256 |
+
# omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
|
257 |
+
# omega = 1. / (temperature ** omega)
|
258 |
+
|
259 |
+
# y = y.flatten()[:, None] * omega[None, :]
|
260 |
+
# x = x.flatten()[:, None] * omega[None, :]
|
261 |
+
# pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) # B,C,H,W
|
262 |
+
# return pe.type(dtype)
|
263 |
+
|
264 |
+
def iou(bbox1, bbox2):
|
265 |
+
# bbox1, bbox2: [x1, y1, x2, y2]
|
266 |
+
x1, y1, x2, y2 = bbox1
|
267 |
+
x1_, y1_, x2_, y2_ = bbox2
|
268 |
+
inter_x1 = max(x1, x1_)
|
269 |
+
inter_y1 = max(y1, y1_)
|
270 |
+
inter_x2 = min(x2, x2_)
|
271 |
+
inter_y2 = min(y2, y2_)
|
272 |
+
inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
|
273 |
+
area1 = (x2 - x1) * (y2 - y1)
|
274 |
+
area2 = (x2_ - x1_) * (y2_ - y1_)
|
275 |
+
iou = inter_area / (area1 + area2 - inter_area)
|
276 |
+
return iou
|
277 |
+
|
278 |
+
# def get_2d_embedding(xy, C, cat_coords=False):
|
279 |
+
# B, N, D = xy.shape
|
280 |
+
# assert(D==2)
|
281 |
+
|
282 |
+
# x = xy[:,:,0:1]
|
283 |
+
# y = xy[:,:,1:2]
|
284 |
+
# div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (10000.0 / C)).reshape(1, 1, int(C/2))
|
285 |
+
|
286 |
+
# pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
287 |
+
# pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
288 |
+
|
289 |
+
# pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
290 |
+
# pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
291 |
+
|
292 |
+
# pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
293 |
+
# pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
294 |
+
|
295 |
+
# pe = torch.cat([pe_x, pe_y], dim=2) # B,N,C*2
|
296 |
+
# if cat_coords:
|
297 |
+
# pe = torch.cat([pe, xy], dim=2) # B,N,C*2+2
|
298 |
+
# return pe
|
299 |
+
|
300 |
+
# def get_3d_embedding(xyz, C, cat_coords=False):
|
301 |
+
# B, N, D = xyz.shape
|
302 |
+
# assert(D==3)
|
303 |
+
|
304 |
+
# x = xyz[:,:,0:1]
|
305 |
+
# y = xyz[:,:,1:2]
|
306 |
+
# z = xyz[:,:,2:3]
|
307 |
+
# div_term = (torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (10000.0 / C)).reshape(1, 1, int(C/2))
|
308 |
+
|
309 |
+
# pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
310 |
+
# pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
311 |
+
# pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
312 |
+
|
313 |
+
# pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
314 |
+
# pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
315 |
+
|
316 |
+
# pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
317 |
+
# pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
318 |
+
|
319 |
+
# pe_z[:, :, 0::2] = torch.sin(z * div_term)
|
320 |
+
# pe_z[:, :, 1::2] = torch.cos(z * div_term)
|
321 |
+
|
322 |
+
# pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3
|
323 |
+
# if cat_coords:
|
324 |
+
# pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3
|
325 |
+
# return pe
|
326 |
+
|
327 |
+
class SimplePool():
|
328 |
+
def __init__(self, pool_size, version='pt', min_size=1):
|
329 |
+
self.pool_size = pool_size
|
330 |
+
self.version = version
|
331 |
+
self.items = []
|
332 |
+
self.min_size = min_size
|
333 |
+
|
334 |
+
if not (version=='pt' or version=='np'):
|
335 |
+
print('version = %s; please choose pt or np')
|
336 |
+
assert(False) # please choose pt or np
|
337 |
+
|
338 |
+
def __len__(self):
|
339 |
+
return len(self.items)
|
340 |
+
|
341 |
+
def mean(self, min_size=None):
|
342 |
+
if min_size is None:
|
343 |
+
pool_size_thresh = self.min_size
|
344 |
+
elif min_size=='half':
|
345 |
+
pool_size_thresh = self.pool_size/2
|
346 |
+
else:
|
347 |
+
pool_size_thresh = min_size
|
348 |
+
|
349 |
+
if self.version=='np':
|
350 |
+
if len(self.items) >= pool_size_thresh:
|
351 |
+
return np.sum(self.items)/float(len(self.items))
|
352 |
+
else:
|
353 |
+
return np.nan
|
354 |
+
if self.version=='pt':
|
355 |
+
if len(self.items) >= pool_size_thresh:
|
356 |
+
return torch.sum(self.items)/float(len(self.items))
|
357 |
+
else:
|
358 |
+
return torch.from_numpy(np.nan)
|
359 |
+
|
360 |
+
def sample(self, with_replacement=True):
|
361 |
+
idx = np.random.randint(len(self.items))
|
362 |
+
if with_replacement:
|
363 |
+
return self.items[idx]
|
364 |
+
else:
|
365 |
+
return self.items.pop(idx)
|
366 |
+
|
367 |
+
def fetch(self, num=None):
|
368 |
+
if self.version=='pt':
|
369 |
+
item_array = torch.stack(self.items)
|
370 |
+
elif self.version=='np':
|
371 |
+
item_array = np.stack(self.items)
|
372 |
+
if num is not None:
|
373 |
+
# there better be some items
|
374 |
+
assert(len(self.items) >= num)
|
375 |
+
|
376 |
+
# if there are not that many elements just return however many there are
|
377 |
+
if len(self.items) < num:
|
378 |
+
return item_array
|
379 |
+
else:
|
380 |
+
idxs = np.random.randint(len(self.items), size=num)
|
381 |
+
return item_array[idxs]
|
382 |
+
else:
|
383 |
+
return item_array
|
384 |
+
|
385 |
+
def is_full(self):
|
386 |
+
full = len(self.items)==self.pool_size
|
387 |
+
return full
|
388 |
+
|
389 |
+
def empty(self):
|
390 |
+
self.items = []
|
391 |
+
|
392 |
+
def have_min_size(self):
|
393 |
+
return len(self.items) >= self.min_size
|
394 |
+
|
395 |
+
|
396 |
+
def update(self, items):
|
397 |
+
for item in items:
|
398 |
+
if len(self.items) < self.pool_size:
|
399 |
+
# the pool is not full, so let's add this in
|
400 |
+
self.items.append(item)
|
401 |
+
else:
|
402 |
+
# the pool is full
|
403 |
+
# pop from the front
|
404 |
+
self.items.pop(0)
|
405 |
+
# add to the back
|
406 |
+
self.items.append(item)
|
407 |
+
return self.items
|
408 |
+
|
409 |
+
|
410 |
+
class SimpleHeap():
|
411 |
+
def __init__(self, pool_size, version='pt'):
|
412 |
+
self.pool_size = pool_size
|
413 |
+
self.version = version
|
414 |
+
self.items = []
|
415 |
+
self.vals = []
|
416 |
+
|
417 |
+
if not (version=='pt' or version=='np'):
|
418 |
+
print('version = %s; please choose pt or np')
|
419 |
+
assert(False) # please choose pt or np
|
420 |
+
|
421 |
+
def __len__(self):
|
422 |
+
return len(self.items)
|
423 |
+
|
424 |
+
def sample(self, random=True, with_replacement=True, semirandom=False):
|
425 |
+
vals_arr = np.stack(self.vals)
|
426 |
+
if random:
|
427 |
+
ind = np.random.randint(len(self.items))
|
428 |
+
else:
|
429 |
+
if semirandom and len(vals_arr)>1:
|
430 |
+
# choose from the harder half
|
431 |
+
inds = np.argsort(vals_arr) # ascending
|
432 |
+
inds = inds[len(vals_arr)//2:]
|
433 |
+
ind = np.random.choice(inds)
|
434 |
+
else:
|
435 |
+
# find the most valuable element
|
436 |
+
ind = np.argmax(vals_arr)
|
437 |
+
|
438 |
+
|
439 |
+
if with_replacement:
|
440 |
+
return self.items[ind]
|
441 |
+
else:
|
442 |
+
item = self.items.pop(ind)
|
443 |
+
val = self.vals.pop(ind)
|
444 |
+
return item
|
445 |
+
|
446 |
+
def fetch(self, num=None):
|
447 |
+
if self.version=='pt':
|
448 |
+
item_array = torch.stack(self.items)
|
449 |
+
elif self.version=='np':
|
450 |
+
item_array = np.stack(self.items)
|
451 |
+
if num is not None:
|
452 |
+
# there better be some items
|
453 |
+
assert(len(self.items) >= num)
|
454 |
+
|
455 |
+
# if there are not that many elements just return however many there are
|
456 |
+
if len(self.items) < num:
|
457 |
+
return item_array
|
458 |
+
else:
|
459 |
+
idxs = np.random.randint(len(self.items), size=num)
|
460 |
+
return item_array[idxs]
|
461 |
+
else:
|
462 |
+
return item_array
|
463 |
+
|
464 |
+
def is_full(self):
|
465 |
+
full = len(self.items)==self.pool_size
|
466 |
+
return full
|
467 |
+
|
468 |
+
def empty(self):
|
469 |
+
self.items = []
|
470 |
+
|
471 |
+
def update(self, vals, items):
|
472 |
+
for val,item in zip(vals, items):
|
473 |
+
if len(self.items) < self.pool_size:
|
474 |
+
# the pool is not full, so let's add this in
|
475 |
+
self.items.append(item)
|
476 |
+
self.vals.append(val)
|
477 |
+
else:
|
478 |
+
# the pool is full
|
479 |
+
# find our least-valuable element
|
480 |
+
# and see if we should replace it
|
481 |
+
vals_arr = np.stack(self.vals)
|
482 |
+
ind = np.argmin(vals_arr)
|
483 |
+
if vals_arr[ind] < val:
|
484 |
+
# pop the min
|
485 |
+
self.items.pop(ind)
|
486 |
+
self.vals.pop(ind)
|
487 |
+
|
488 |
+
# add to the back
|
489 |
+
self.items.append(item)
|
490 |
+
self.vals.append(val)
|
491 |
+
return self.items
|
492 |
+
|
493 |
+
|
494 |
+
def farthest_point_sample(xyz, npoint, include_ends=False, deterministic=False):
|
495 |
+
"""
|
496 |
+
Input:
|
497 |
+
xyz: pointcloud data, [B, N, C], where C is probably 3
|
498 |
+
npoint: number of samples
|
499 |
+
Return:
|
500 |
+
inds: sampled pointcloud index, [B, npoint]
|
501 |
+
"""
|
502 |
+
device = xyz.device
|
503 |
+
B, N, C = xyz.shape
|
504 |
+
xyz = xyz.float()
|
505 |
+
inds = torch.zeros(B, npoint, dtype=torch.long, device=device)
|
506 |
+
distance = torch.ones((B, N), dtype=torch.float32, device=device) * 1e10
|
507 |
+
if deterministic:
|
508 |
+
farthest = torch.randint(0, 1, (B,), dtype=torch.long, device=device)
|
509 |
+
else:
|
510 |
+
farthest = torch.randint(0, N, (B,), dtype=torch.long, device=device)
|
511 |
+
batch_indices = torch.arange(B, dtype=torch.long, device=device)
|
512 |
+
for i in range(npoint):
|
513 |
+
if include_ends:
|
514 |
+
if i==0:
|
515 |
+
farthest = 0
|
516 |
+
elif i==1:
|
517 |
+
farthest = N-1
|
518 |
+
inds[:, i] = farthest
|
519 |
+
centroid = xyz[batch_indices, farthest, :].view(B, 1, C)
|
520 |
+
dist = torch.sum((xyz - centroid) ** 2, -1)
|
521 |
+
mask = dist < distance
|
522 |
+
distance[mask] = dist[mask]
|
523 |
+
farthest = torch.max(distance, -1)[1]
|
524 |
+
|
525 |
+
if npoint > N:
|
526 |
+
# if we need more samples, make them random
|
527 |
+
distance += torch.randn_like(distance)
|
528 |
+
return inds
|
529 |
+
|
530 |
+
def farthest_point_sample_py(xyz, npoint, deterministic=False):
|
531 |
+
N,C = xyz.shape
|
532 |
+
inds = np.zeros(npoint, dtype=np.int32)
|
533 |
+
distance = np.ones(N) * 1e10
|
534 |
+
if deterministic:
|
535 |
+
farthest = 0
|
536 |
+
else:
|
537 |
+
farthest = np.random.randint(0, N, dtype=np.int32)
|
538 |
+
for i in range(npoint):
|
539 |
+
inds[i] = farthest
|
540 |
+
centroid = xyz[farthest, :].reshape(1,C)
|
541 |
+
dist = np.sum((xyz - centroid) ** 2, -1)
|
542 |
+
mask = dist < distance
|
543 |
+
distance[mask] = dist[mask]
|
544 |
+
farthest = np.argmax(distance, -1)
|
545 |
+
if npoint > N:
|
546 |
+
# if we need more samples, make them random
|
547 |
+
distance += np.random.randn(*distance.shape)
|
548 |
+
return inds
|
549 |
+
|
550 |
+
def balanced_ce_loss(pred, gt, pos_weight=0.5, valid=None, dim=None, return_both=False, use_halfmask=False, H=64, W=64):
|
551 |
+
# # pred and gt are the same shape
|
552 |
+
# for (a,b) in zip(pred.size(), gt.size()):
|
553 |
+
# if not a==b:
|
554 |
+
# print('mismatch: pred, gt', pred.shape, gt.shape)
|
555 |
+
# assert(a==b) # some shape mismatch!
|
556 |
+
|
557 |
+
pred = pred.reshape(-1)
|
558 |
+
gt = gt.reshape(-1)
|
559 |
+
device = pred.device
|
560 |
+
|
561 |
+
if valid is not None:
|
562 |
+
valid = valid.reshape(-1)
|
563 |
+
for (a,b) in zip(pred.size(), valid.size()):
|
564 |
+
assert(a==b) # some shape mismatch!
|
565 |
+
else:
|
566 |
+
valid = torch.ones_like(gt)
|
567 |
+
|
568 |
+
pos = (gt > 0.95).float()
|
569 |
+
if use_halfmask:
|
570 |
+
pos_wide = (gt >= 0.5).float()
|
571 |
+
halfmask = (gt == 0.5).float()
|
572 |
+
else:
|
573 |
+
pos_wide = pos
|
574 |
+
|
575 |
+
neg = (gt < 0.05).float()
|
576 |
+
|
577 |
+
label = pos_wide*2.0 - 1.0
|
578 |
+
a = -label * pred
|
579 |
+
b = F.relu(a)
|
580 |
+
loss = b + torch.log(torch.exp(-b)+torch.exp(a-b))
|
581 |
+
|
582 |
+
if torch.sum(pos*valid)>0:
|
583 |
+
pos_loss = loss[(pos*valid) > 0].mean()
|
584 |
+
else:
|
585 |
+
pos_loss = torch.tensor(0.0, requires_grad=True, device=device)
|
586 |
+
|
587 |
+
if torch.sum(neg*valid)>0:
|
588 |
+
neg_loss = loss[(neg*valid) > 0].mean()
|
589 |
+
else:
|
590 |
+
neg_loss = torch.tensor(0.0, requires_grad=True, device=device)
|
591 |
+
balanced_loss = pos_weight*pos_loss + (1-pos_weight)*neg_loss
|
592 |
+
return balanced_loss
|
593 |
+
|
594 |
+
# pos_loss = utils.basic.reduce_masked_mean(loss, pos*valid, dim=dim)
|
595 |
+
# neg_loss = utils.basic.reduce_masked_mean(loss, neg*valid, dim=dim)
|
596 |
+
|
597 |
+
if use_halfmask:
|
598 |
+
# here we will find the pixels which are already leaning positive,
|
599 |
+
# and encourage them to be more positive
|
600 |
+
B = loss.shape[0]
|
601 |
+
loss_ = loss.reshape(B,-1)
|
602 |
+
mask_ = halfmask.reshape(B,-1) * valid.reshape(B,-1)
|
603 |
+
|
604 |
+
# to avoid the issue where spikes become spikier,
|
605 |
+
# we will only apply this loss on batch els where we predicted zero positives
|
606 |
+
pred_sig_ = torch.sigmoid(pred).reshape(B,-1)
|
607 |
+
no_pred_ = torch.max(pred_sig_.round(), axis=1)[0] < 1 # B
|
608 |
+
# and only on batch els where we have negatives available
|
609 |
+
have_neg_ = torch.sum(neg, dim=1)>0 # B
|
610 |
+
|
611 |
+
loss_ = loss_[no_pred_ & have_neg_] # N,H*W
|
612 |
+
mask_ = mask_[no_pred_ & have_neg_] # N,H*W
|
613 |
+
N = loss_.shape[0]
|
614 |
+
|
615 |
+
if N > 0:
|
616 |
+
# we want:
|
617 |
+
# in the neg pixels,
|
618 |
+
# set them to the max loss of the pos pixels,
|
619 |
+
# so that they do not contribute to the min
|
620 |
+
loss__ = loss_.reshape(-1)
|
621 |
+
mask__ = mask_.reshape(-1)
|
622 |
+
if torch.sum(mask__)>0:
|
623 |
+
# print('loss_', loss_.shape, 'mask_', mask_.shape, 'loss__', loss__.shape, 'mask__', mask__.shape)
|
624 |
+
mloss__ = loss__.detach()
|
625 |
+
mloss__[mask__==0] = torch.max(loss__[mask__==1])
|
626 |
+
mloss_ = mloss__.reshape(N,H*W)
|
627 |
+
|
628 |
+
# now, in each batch el, take a tiny region around the argmin, so we can boost this region
|
629 |
+
minloss_mask_ = torch.zeros_like(mloss_).scatter(1,mloss_.argmin(1,True),value=1)
|
630 |
+
minloss_mask_ = utils.improc.dilate2d(minloss_mask_.view(N,1,H,W), times=3).reshape(N,H*W)
|
631 |
+
|
632 |
+
loss__ = loss_.reshape(-1)
|
633 |
+
minloss_mask__ = minloss_mask_.reshape(-1)
|
634 |
+
half_loss = loss__[minloss_mask__>0].mean()
|
635 |
+
|
636 |
+
# print('N', N, 'half_loss', half_loss)
|
637 |
+
pos_loss = pos_loss + half_loss
|
638 |
+
|
639 |
+
# if False:
|
640 |
+
# min_pos = 8
|
641 |
+
|
642 |
+
# # only apply the loss when we have some negatives available,
|
643 |
+
# # otherwise it's a whole "ignore" frame, which may mean
|
644 |
+
# # we are unsure if the target is even there
|
645 |
+
# if torch.sum(mask__==0) > 0: # negatives available
|
646 |
+
# # only apply the loss when the halfmask is larger area than
|
647 |
+
# # min_pos (the number of pixels we want to boost),
|
648 |
+
# # so that indexing will work
|
649 |
+
# if torch.all(torch.sum(mask_==1, dim=1) >= min_pos): # topk indexing will work
|
650 |
+
# # in the pixels we will not use,
|
651 |
+
# # set them to the max of the pixels we may use,
|
652 |
+
# # so that they do not contribute to the min
|
653 |
+
# loss__[mask__==0] = torch.max(loss__[mask__==1])
|
654 |
+
# loss_ = loss__.reshape(B,-1)
|
655 |
+
|
656 |
+
# half_loss = torch.mean(torch.topk(loss_, min_pos, dim=1, largest=False)[0], dim=1) # B
|
657 |
+
|
658 |
+
# have_neg = (torch.sum(neg, dim=1)>0).float() # B
|
659 |
+
# pos_loss = pos_loss + half_loss*have_neg
|
660 |
+
|
661 |
+
|
662 |
+
|
663 |
+
|
664 |
+
|
665 |
+
|
666 |
+
# half_loss = []
|
667 |
+
# for b in range(B):
|
668 |
+
# loss_b = loss_[b]
|
669 |
+
# mask_b = mask_[b]
|
670 |
+
# if torch.sum(mask_b):
|
671 |
+
# inds = torch.nonzero(mask_b).reshape(-1)
|
672 |
+
# half_loss.append(torch.min(loss_b[inds]))
|
673 |
+
# if len(half_loss):
|
674 |
+
# # # half_loss_ = half_loss.reshape(-1)
|
675 |
+
# # half_loss = torch.min(half_loss, dim=1)[0] # B
|
676 |
+
# # half_loss = torch.mean(torch.topk(half_loss, 4, dim=1, largest=False)[0], dim=1) # B
|
677 |
+
# pos_loss = pos_loss + torch.stack(half_loss).mean()
|
678 |
+
|
679 |
+
if return_both:
|
680 |
+
return pos_loss, neg_loss
|
681 |
+
balanced_loss = pos_weight*pos_loss + (1-pos_weight)*neg_loss
|
682 |
+
|
683 |
+
return balanced_loss
|
684 |
+
|
685 |
+
def dice_loss(pred, gt):
|
686 |
+
# gt has ignores at 0.5
|
687 |
+
# pred and gt are the same shape
|
688 |
+
for (a,b) in zip(pred.size(), gt.size()):
|
689 |
+
assert(a==b) # some shape mismatch!
|
690 |
+
|
691 |
+
prob = pred.sigmoid()
|
692 |
+
|
693 |
+
# flatten everything except batch
|
694 |
+
prob = prob.flatten(1)
|
695 |
+
gt = gt.flatten(1)
|
696 |
+
|
697 |
+
pos = (gt > 0.95).float()
|
698 |
+
neg = (gt < 0.05).float()
|
699 |
+
valid = (pos+neg).float().clamp(0,1)
|
700 |
+
|
701 |
+
numerator = 2 * (prob * pos * valid).sum(1)
|
702 |
+
denominator = (prob*valid).sum(1) + (pos*valid).sum(1)
|
703 |
+
loss = 1 - (numerator + 1) / (denominator + 1)
|
704 |
+
return loss
|
705 |
+
|
706 |
+
def sigmoid_focal_loss(pred, gt, alpha=0.25, gamma=2):#, use_halfmask=False):
|
707 |
+
# gt has ignores at 0.5
|
708 |
+
# pred and gt are the same shape
|
709 |
+
for (a,b) in zip(pred.size(), gt.size()):
|
710 |
+
assert(a==b) # some shape mismatch!
|
711 |
+
|
712 |
+
# flatten everything except batch
|
713 |
+
pred = pred.flatten(1)
|
714 |
+
gt = gt.flatten(1)
|
715 |
+
|
716 |
+
pos = (gt > 0.95).float()
|
717 |
+
neg = (gt < 0.05).float()
|
718 |
+
# if use_halfmask:
|
719 |
+
# pos_wide = (gt >= 0.5).float()
|
720 |
+
# halfmask = (gt == 0.5).float()
|
721 |
+
# else:
|
722 |
+
# pos_wide = pos
|
723 |
+
valid = (pos+neg).float().clamp(0,1)
|
724 |
+
|
725 |
+
prob = pred.sigmoid()
|
726 |
+
ce_loss = F.binary_cross_entropy_with_logits(pred, pos, reduction="none")
|
727 |
+
p_t = prob * pos + (1 - prob) * (1 - pos)
|
728 |
+
loss = ce_loss * ((1 - p_t) ** gamma)
|
729 |
+
|
730 |
+
if alpha >= 0:
|
731 |
+
alpha_t = alpha * pos + (1 - alpha) * (1 - pos)
|
732 |
+
loss = alpha_t * loss
|
733 |
+
|
734 |
+
loss = (loss*valid).sum(1) / (1 + valid.sum(1))
|
735 |
+
return loss
|
736 |
+
|
737 |
+
# def dice_loss(inputs, targets, normalizer=1):
|
738 |
+
# inputs = inputs.sigmoid()
|
739 |
+
# inputs = inputs.flatten(1)
|
740 |
+
# numerator = 2 * (inputs * targets).sum(1)
|
741 |
+
# denominator = inputs.sum(-1) + targets.sum(-1)
|
742 |
+
# loss = 1 - (numerator + 1) / (denominator + 1)
|
743 |
+
# return loss.sum() / normalizer
|
744 |
+
|
745 |
+
# def sigmoid_focal_loss(inputs, targets, normalizer=1, alpha=0.25, gamma=2):
|
746 |
+
# prob = inputs.sigmoid()
|
747 |
+
# ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
748 |
+
# p_t = prob * targets + (1 - prob) * (1 - targets)
|
749 |
+
# loss = ce_loss * ((1 - p_t) ** gamma)
|
750 |
+
|
751 |
+
# if alpha >= 0:
|
752 |
+
# alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
753 |
+
# loss = alpha_t * loss
|
754 |
+
|
755 |
+
# return loss.mean(1).sum() / normalizer
|
756 |
+
|
757 |
+
def data_replace_with_nearest(xys, valids):
|
758 |
+
# replace invalid xys with nearby ones
|
759 |
+
invalid_idx = np.where(valids==0)[0]
|
760 |
+
valid_idx = np.where(valids==1)[0]
|
761 |
+
for idx in invalid_idx:
|
762 |
+
nearest = valid_idx[np.argmin(np.abs(valid_idx - idx))]
|
763 |
+
xys[idx] = xys[nearest]
|
764 |
+
return xys
|
765 |
+
|
766 |
+
def data_get_traj_from_masks(masks):
|
767 |
+
if masks.ndim==4:
|
768 |
+
masks = masks[...,0]
|
769 |
+
S, H, W = masks.shape
|
770 |
+
masks = (masks > 0.1).astype(np.float32)
|
771 |
+
fills = np.zeros((S))
|
772 |
+
xy_means = np.zeros((S,2))
|
773 |
+
xy_rands = np.zeros((S,2))
|
774 |
+
valids = np.zeros((S))
|
775 |
+
for si, mask in enumerate(masks):
|
776 |
+
if np.sum(mask) > 0:
|
777 |
+
ys, xs = np.where(mask)
|
778 |
+
inds = np.random.permutation(len(xs))
|
779 |
+
xs, ys = xs[inds], ys[inds]
|
780 |
+
x0, x1 = np.min(xs), np.max(xs)+1
|
781 |
+
y0, y1 = np.min(ys), np.max(ys)+1
|
782 |
+
# if (x1-x0)>0 and (y1-y0)>0:
|
783 |
+
xy_means[si] = np.array([xs.mean(), ys.mean()])
|
784 |
+
xy_rands[si] = np.array([xs[0], ys[0]])
|
785 |
+
valids[si] = 1
|
786 |
+
crop = mask[y0:y1, x0:x1]
|
787 |
+
fill = np.mean(crop)
|
788 |
+
fills[si] = fill
|
789 |
+
# print('fills', fills)
|
790 |
+
return xy_means, xy_rands, valids, fills
|
791 |
+
|
792 |
+
def data_zoom(zoom, xys, visibs, rgbs, valids=None, masks=None, masks2=None, masks3=None, masks4=None):
|
793 |
+
S, H, W, C = rgbs.shape
|
794 |
+
S,N,D = xys.shape
|
795 |
+
|
796 |
+
_, H, W, C = rgbs.shape
|
797 |
+
assert(C==3)
|
798 |
+
crop_W = int(W//zoom)
|
799 |
+
crop_H = int(H//zoom)
|
800 |
+
|
801 |
+
if np.random.rand() < 0.25: # follow-crop
|
802 |
+
# start with xy traj
|
803 |
+
# smooth_xys = xys.copy()
|
804 |
+
smooth_xys = xys[:,np.random.randint(N)].reshape(S,1,2)
|
805 |
+
# make it inbounds
|
806 |
+
smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2])
|
807 |
+
|
808 |
+
# smooth it out, to remove info about the traj, and simulate camera motion
|
809 |
+
for _ in range(S*3):
|
810 |
+
for ii in range(S):
|
811 |
+
if ii==0:
|
812 |
+
smooth_xys[ii] = (smooth_xys[ii] + smooth_xys[ii+1])/2.0
|
813 |
+
elif ii==S-1:
|
814 |
+
smooth_xys[ii] = (smooth_xys[ii-1] + smooth_xys[ii])/2.0
|
815 |
+
else:
|
816 |
+
smooth_xys[ii] = (smooth_xys[ii-1] + smooth_xys[ii] + smooth_xys[ii+1])/3.0
|
817 |
+
else: # static (no-hint) crop
|
818 |
+
# zero-vel on random available coordinate
|
819 |
+
|
820 |
+
if valids is not None:
|
821 |
+
visval = visibs*valids # S,N
|
822 |
+
visval = np.sum(visval, axis=1) # S
|
823 |
+
else:
|
824 |
+
visval = np.sum(visibs, axis=1) # S
|
825 |
+
|
826 |
+
anchor_inds = np.nonzero(visval >= np.mean(visval))[0]
|
827 |
+
ind = anchor_inds[np.random.randint(len(anchor_inds))]
|
828 |
+
# print('ind', ind)
|
829 |
+
smooth_xys = xys[ind:ind+1].repeat(S,axis=0)
|
830 |
+
smooth_xys = smooth_xys.mean(axis=1, keepdims=True)
|
831 |
+
# xmid = np.random.randint(crop_W//2, W-crop_W//2)
|
832 |
+
# ymid = np.random.randint(crop_H//2, H-crop_H//2)
|
833 |
+
# smooth_xys = np.stack([xmid, ymid], axis=-1).reshape(1,1,2).repeat(S, axis=0) # S,1,2
|
834 |
+
smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2])
|
835 |
+
|
836 |
+
if np.random.rand() < 0.5:
|
837 |
+
# add a random alternate trajectory, to help push us off center
|
838 |
+
alt_xys = np.random.randint(-crop_H//8, crop_H//8, (S,1,2))
|
839 |
+
for _ in range(4): # smooth out
|
840 |
+
for ii in range(S):
|
841 |
+
if ii==0:
|
842 |
+
alt_xys[ii] = (alt_xys[ii] + alt_xys[ii+1])/2.0
|
843 |
+
elif ii==S-1:
|
844 |
+
alt_xys[ii] = (alt_xys[ii-1] + alt_xys[ii])/2.0
|
845 |
+
else:
|
846 |
+
alt_xys[ii] = (alt_xys[ii-1] + alt_xys[ii] + alt_xys[ii+1])/3.0
|
847 |
+
smooth_xys = smooth_xys + alt_xys
|
848 |
+
|
849 |
+
smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2])
|
850 |
+
|
851 |
+
rgbs_crop = []
|
852 |
+
if masks is not None:
|
853 |
+
masks_crop = []
|
854 |
+
if masks2 is not None:
|
855 |
+
masks2_crop = []
|
856 |
+
if masks3 is not None:
|
857 |
+
masks3_crop = []
|
858 |
+
if masks4 is not None:
|
859 |
+
masks4_crop = []
|
860 |
+
|
861 |
+
offsets = []
|
862 |
+
for si in range(S):
|
863 |
+
xy_mid = smooth_xys[si].squeeze(0).round().astype(np.int32) # 2
|
864 |
+
|
865 |
+
xmid, ymid = xy_mid[0], xy_mid[1]
|
866 |
+
|
867 |
+
x0, x1 = np.clip(xmid-crop_W//2, 0, W), np.clip(xmid+crop_W//2, 0, W)
|
868 |
+
y0, y1 = np.clip(ymid-crop_H//2, 0, H), np.clip(ymid+crop_H//2, 0, H)
|
869 |
+
offset = np.array([x0, y0]).reshape(1,2)
|
870 |
+
|
871 |
+
rgbs_crop.append(rgbs[si,y0:y1,x0:x1])
|
872 |
+
if masks is not None:
|
873 |
+
masks_crop.append(masks[si,y0:y1,x0:x1])
|
874 |
+
if masks2 is not None:
|
875 |
+
masks2_crop.append(masks2[si,y0:y1,x0:x1])
|
876 |
+
if masks3 is not None:
|
877 |
+
masks3_crop.append(masks3[si,y0:y1,x0:x1])
|
878 |
+
if masks4 is not None:
|
879 |
+
masks4_crop.append(masks4[si,y0:y1,x0:x1])
|
880 |
+
xys[si] -= offset
|
881 |
+
|
882 |
+
offsets.append(offset)
|
883 |
+
|
884 |
+
rgbs = np.stack(rgbs_crop, axis=0)
|
885 |
+
if masks is not None:
|
886 |
+
masks = np.stack(masks_crop, axis=0)
|
887 |
+
if masks2 is not None:
|
888 |
+
masks2 = np.stack(masks2_crop, axis=0)
|
889 |
+
if masks3 is not None:
|
890 |
+
masks3 = np.stack(masks3_crop, axis=0)
|
891 |
+
if masks4 is not None:
|
892 |
+
masks4 = np.stack(masks4_crop, axis=0)
|
893 |
+
|
894 |
+
# update visibility annotations
|
895 |
+
for si in range(S):
|
896 |
+
oob_inds = np.logical_or(
|
897 |
+
np.logical_or(xys[si,:,0] < 0, xys[si,:,0] > crop_W-1),
|
898 |
+
np.logical_or(xys[si,:,1] < 0, xys[si,:,1] > crop_H-1))
|
899 |
+
visibs[si,oob_inds] = 0
|
900 |
+
|
901 |
+
# if masks4 is not None:
|
902 |
+
# return xys, visibs, valids, rgbs, masks, masks2, masks3, masks4
|
903 |
+
# if masks3 is not None:
|
904 |
+
# return xys, visibs, valids, rgbs, masks, masks2, masks3
|
905 |
+
# if masks2 is not None:
|
906 |
+
# return xys, visibs, valids, rgbs, masks, masks2
|
907 |
+
# if masks is not None:
|
908 |
+
# return xys, visibs, valids, rgbs, masks
|
909 |
+
# else:
|
910 |
+
# return xys, visibs, valids, rgbs
|
911 |
+
if valids is not None:
|
912 |
+
return xys, visibs, rgbs, valids
|
913 |
+
else:
|
914 |
+
return xys, visibs, rgbs
|
915 |
+
|
916 |
+
|
917 |
+
def data_zoom_bbox(zoom, bboxes, visibs, rgbs):#, valids=None):
|
918 |
+
S, H, W, C = rgbs.shape
|
919 |
+
|
920 |
+
_, H, W, C = rgbs.shape
|
921 |
+
assert(C==3)
|
922 |
+
crop_W = int(W//zoom)
|
923 |
+
crop_H = int(H//zoom)
|
924 |
+
|
925 |
+
xys = bboxes[:,0:2]*0.5 + bboxes[:,2:4]*0.5
|
926 |
+
|
927 |
+
if np.random.rand() < 0.25: # follow-crop
|
928 |
+
# start with xy traj
|
929 |
+
smooth_xys = xys.copy()
|
930 |
+
# make it inbounds
|
931 |
+
smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2])
|
932 |
+
# smooth it out, to remove info about the traj, and simulate camera motion
|
933 |
+
for _ in range(S*3):
|
934 |
+
for ii in range(1,S-1):
|
935 |
+
smooth_xys[ii] = (smooth_xys[ii-1] + smooth_xys[ii] + smooth_xys[ii+1])/3.0
|
936 |
+
else: # static (no-hint) crop
|
937 |
+
# zero-vel on random available coordinate
|
938 |
+
anchor_inds = np.nonzero(visibs.reshape(-1)>0.5)[0]
|
939 |
+
ind = anchor_inds[np.random.randint(len(anchor_inds))]
|
940 |
+
smooth_xys = xys[ind:ind+1].repeat(S,axis=0)
|
941 |
+
# xmid = np.random.randint(crop_W//2, W-crop_W//2)
|
942 |
+
# ymid = np.random.randint(crop_H//2, H-crop_H//2)
|
943 |
+
# smooth_xys = np.stack([xmid, ymid], axis=-1).reshape(1,1,2).repeat(S, axis=0) # S,1,2
|
944 |
+
smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2])
|
945 |
+
# print('xys', xys)
|
946 |
+
# print('smooth_xys', smooth_xys)
|
947 |
+
|
948 |
+
if np.random.rand() < 0.5:
|
949 |
+
# add a random alternate trajectory, to help push us off center
|
950 |
+
alt_xys = np.random.randint(-crop_H//8, crop_H//8, (S,2))
|
951 |
+
for _ in range(3):
|
952 |
+
for ii in range(1,S-1):
|
953 |
+
alt_xys[ii] = (alt_xys[ii-1] + alt_xys[ii] + alt_xys[ii+1])/3.0
|
954 |
+
smooth_xys = smooth_xys + alt_xys
|
955 |
+
|
956 |
+
smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2])
|
957 |
+
|
958 |
+
rgbs_crop = []
|
959 |
+
|
960 |
+
offsets = []
|
961 |
+
for si in range(S):
|
962 |
+
xy_mid = smooth_xys[si].round().astype(np.int32)
|
963 |
+
xmid, ymid = xy_mid[0], xy_mid[1]
|
964 |
+
|
965 |
+
x0, x1 = np.clip(xmid-crop_W//2, 0, W), np.clip(xmid+crop_W//2, 0, W)
|
966 |
+
y0, y1 = np.clip(ymid-crop_H//2, 0, H), np.clip(ymid+crop_H//2, 0, H)
|
967 |
+
offset = np.array([x0, y0]).reshape(2)
|
968 |
+
|
969 |
+
rgbs_crop.append(rgbs[si,y0:y1,x0:x1])
|
970 |
+
xys[si] -= offset
|
971 |
+
bboxes[si,0:2] -= offset
|
972 |
+
bboxes[si,2:4] -= offset
|
973 |
+
|
974 |
+
offsets.append(offset)
|
975 |
+
|
976 |
+
rgbs = np.stack(rgbs_crop, axis=0)
|
977 |
+
|
978 |
+
# update visibility annotations
|
979 |
+
for si in range(S):
|
980 |
+
# avoid 1px edge
|
981 |
+
oob_inds = np.logical_or(
|
982 |
+
np.logical_or(xys[si,0] < 1, xys[si,0] > W-2),
|
983 |
+
np.logical_or(xys[si,1] < 1, xys[si,1] > H-2))
|
984 |
+
visibs[si,oob_inds] = 0
|
985 |
+
|
986 |
+
# clamp to image bounds
|
987 |
+
xys0 = np.minimum(np.maximum(bboxes[:,0:2], np.zeros((2,), dtype=int)), np.array([W, H]) - 1) # S,2
|
988 |
+
xys1 = np.minimum(np.maximum(bboxes[:,2:4], np.zeros((2,), dtype=int)), np.array([W, H]) - 1) # S,2
|
989 |
+
bboxes = np.concatenate([xys0, xys1], axis=1)
|
990 |
+
return bboxes, visibs, rgbs
|
991 |
+
|
992 |
+
|
993 |
+
def data_pad_if_necessary(rgbs, masks, masks2=None):
|
994 |
+
S,H,W,C = rgbs.shape
|
995 |
+
|
996 |
+
mask_areas = (masks > 0).reshape(S,-1).sum(axis=1)
|
997 |
+
mask_areas_norm = mask_areas / np.max(mask_areas)
|
998 |
+
visibs = mask_areas_norm
|
999 |
+
|
1000 |
+
bboxes = np.stack([mask2bbox(mask) for mask in masks])
|
1001 |
+
whs = bboxes[:,2:4] - bboxes[:,0:2]
|
1002 |
+
whs = whs[visibs > 0.5]
|
1003 |
+
# print('mean wh', np.mean(whs[:,0]), np.mean(whs[:,1]))
|
1004 |
+
if np.mean(whs[:,0]) >= W/2:
|
1005 |
+
# print('padding w')
|
1006 |
+
pad = ((0,0),(0,0),(W//4,W//4),(0,0))
|
1007 |
+
rgbs = np.pad(rgbs, pad, mode="constant")
|
1008 |
+
masks = np.pad(masks, pad[:3], mode="constant")
|
1009 |
+
if masks2 is not None:
|
1010 |
+
masks2 = np.pad(masks2, pad[:3], mode="constant")
|
1011 |
+
# print('rgbs', rgbs.shape)
|
1012 |
+
# print('masks', masks.shape)
|
1013 |
+
if np.mean(whs[:,1]) >= H/2:
|
1014 |
+
# print('padding h')
|
1015 |
+
pad = ((0,0),(H//4,H//4),(0,0),(0,0))
|
1016 |
+
rgbs = np.pad(rgbs, pad, mode="constant")
|
1017 |
+
masks = np.pad(masks, pad[:3], mode="constant")
|
1018 |
+
if masks2 is not None:
|
1019 |
+
masks2 = np.pad(masks2, pad[:3], mode="constant", constant_values=0.5)
|
1020 |
+
|
1021 |
+
if masks2 is not None:
|
1022 |
+
return rgbs, masks, masks2
|
1023 |
+
return rgbs, masks
|
1024 |
+
|
1025 |
+
def data_pad_if_necessary_b(rgbs, bboxes, visibs):
|
1026 |
+
S,H,W,C = rgbs.shape
|
1027 |
+
whs = bboxes[:,2:4] - bboxes[:,0:2]
|
1028 |
+
whs = whs[visibs > 0.5]
|
1029 |
+
if np.mean(whs[:,0]) >= W/2:
|
1030 |
+
pad = ((0,0),(0,0),(W//4,W//4),(0,0))
|
1031 |
+
rgbs = np.pad(rgbs, pad, mode="constant")
|
1032 |
+
bboxes[:,0] += W//4
|
1033 |
+
bboxes[:,2] += W//4
|
1034 |
+
if np.mean(whs[:,1]) >= H/2:
|
1035 |
+
pad = ((0,0),(H//4,H//4),(0,0),(0,0))
|
1036 |
+
rgbs = np.pad(rgbs, pad, mode="constant")
|
1037 |
+
bboxes[:,1] += H//4
|
1038 |
+
bboxes[:,3] += H//4
|
1039 |
+
return rgbs, bboxes
|
1040 |
+
|
1041 |
+
def posenc(x, min_deg, max_deg):
|
1042 |
+
"""Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
|
1043 |
+
Instead of computing [sin(x), cos(x)], we use the trig identity
|
1044 |
+
cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
|
1045 |
+
Args:
|
1046 |
+
x: torch.Tensor, variables to be encoded. Note that x should be in [-pi, pi].
|
1047 |
+
min_deg: int, the minimum (inclusive) degree of the encoding.
|
1048 |
+
max_deg: int, the maximum (exclusive) degree of the encoding.
|
1049 |
+
legacy_posenc_order: bool, keep the same ordering as the original tf code.
|
1050 |
+
Returns:
|
1051 |
+
encoded: torch.Tensor, encoded variables.
|
1052 |
+
"""
|
1053 |
+
if min_deg == max_deg:
|
1054 |
+
return x
|
1055 |
+
scales = torch.tensor(
|
1056 |
+
[2**i for i in range(min_deg, max_deg)], dtype=x.dtype, device=x.device
|
1057 |
+
)
|
1058 |
+
|
1059 |
+
xb = (x[..., None, :] * scales[:, None]).reshape(list(x.shape[:-1]) + [-1])
|
1060 |
+
four_feat = torch.sin(torch.cat([xb, xb + 0.5 * torch.pi], dim=-1))
|
1061 |
+
return torch.cat([x] + [four_feat], dim=-1)
|
1062 |
+
|
utils/py.py
ADDED
@@ -0,0 +1,755 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob, math
|
2 |
+
import numpy as np
|
3 |
+
# from scipy import misc
|
4 |
+
# from scipy import linalg
|
5 |
+
from PIL import Image
|
6 |
+
import io
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
EPS = 1e-6
|
9 |
+
|
10 |
+
|
11 |
+
XMIN = -64.0 # right (neg is left)
|
12 |
+
XMAX = 64.0 # right
|
13 |
+
YMIN = -64.0 # down (neg is up)
|
14 |
+
YMAX = 64.0 # down
|
15 |
+
ZMIN = -64.0 # forward
|
16 |
+
ZMAX = 64.0 # forward
|
17 |
+
|
18 |
+
def print_stats(name, tensor):
|
19 |
+
tensor = tensor.astype(np.float32)
|
20 |
+
print('%s min = %.2f, mean = %.2f, max = %.2f' % (name, np.min(tensor), np.mean(tensor), np.max(tensor)), tensor.shape)
|
21 |
+
|
22 |
+
def reduce_masked_mean(x, mask, axis=None, keepdims=False):
|
23 |
+
# x and mask are the same shape
|
24 |
+
# returns shape-1
|
25 |
+
# axis can be a list of axes
|
26 |
+
prod = x*mask
|
27 |
+
numer = np.sum(prod, axis=axis, keepdims=keepdims)
|
28 |
+
denom = EPS+np.sum(mask, axis=axis, keepdims=keepdims)
|
29 |
+
mean = numer/denom
|
30 |
+
return mean
|
31 |
+
|
32 |
+
def reduce_masked_sum(x, mask, axis=None, keepdims=False):
|
33 |
+
# x and mask are the same shape
|
34 |
+
# returns shape-1
|
35 |
+
# axis can be a list of axes
|
36 |
+
prod = x*mask
|
37 |
+
numer = np.sum(prod, axis=axis, keepdims=keepdims)
|
38 |
+
return numer
|
39 |
+
|
40 |
+
def reduce_masked_median(x, mask, keep_batch=False):
|
41 |
+
# x and mask are the same shape
|
42 |
+
# returns shape-1
|
43 |
+
# axis can be a list of axes
|
44 |
+
|
45 |
+
if not (x.shape == mask.shape):
|
46 |
+
print('reduce_masked_median: these shapes should match:', x.shape, mask.shape)
|
47 |
+
assert(False)
|
48 |
+
# assert(x.shape == mask.shape)
|
49 |
+
|
50 |
+
B = list(x.shape)[0]
|
51 |
+
|
52 |
+
if keep_batch:
|
53 |
+
x = np.reshape(x, [B, -1])
|
54 |
+
mask = np.reshape(mask, [B, -1])
|
55 |
+
meds = np.zeros([B], np.float32)
|
56 |
+
for b in list(range(B)):
|
57 |
+
xb = x[b]
|
58 |
+
mb = mask[b]
|
59 |
+
if np.sum(mb) > 0:
|
60 |
+
xb = xb[mb > 0]
|
61 |
+
meds[b] = np.median(xb)
|
62 |
+
else:
|
63 |
+
meds[b] = np.nan
|
64 |
+
return meds
|
65 |
+
else:
|
66 |
+
x = np.reshape(x, [-1])
|
67 |
+
mask = np.reshape(mask, [-1])
|
68 |
+
if np.sum(mask) > 0:
|
69 |
+
x = x[mask > 0]
|
70 |
+
med = np.median(x)
|
71 |
+
else:
|
72 |
+
med = np.nan
|
73 |
+
med = np.array([med], np.float32)
|
74 |
+
return med
|
75 |
+
|
76 |
+
def get_nFiles(path):
|
77 |
+
return len(glob.glob(path))
|
78 |
+
|
79 |
+
def get_file_list(path):
|
80 |
+
return glob.glob(path)
|
81 |
+
|
82 |
+
def rotm2eul(R):
|
83 |
+
# R is 3x3
|
84 |
+
sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0])
|
85 |
+
if sy > 1e-6: # singular
|
86 |
+
x = math.atan2(R[2,1] , R[2,2])
|
87 |
+
y = math.atan2(-R[2,0], sy)
|
88 |
+
z = math.atan2(R[1,0], R[0,0])
|
89 |
+
else:
|
90 |
+
x = math.atan2(-R[1,2], R[1,1])
|
91 |
+
y = math.atan2(-R[2,0], sy)
|
92 |
+
z = 0
|
93 |
+
return x, y, z
|
94 |
+
|
95 |
+
def rad2deg(rad):
|
96 |
+
return rad*180.0/np.pi
|
97 |
+
|
98 |
+
def deg2rad(deg):
|
99 |
+
return deg/180.0*np.pi
|
100 |
+
|
101 |
+
def eul2rotm(rx, ry, rz):
|
102 |
+
# copy of matlab, but order of inputs is different
|
103 |
+
# R = [ cy*cz sy*sx*cz-sz*cx sy*cx*cz+sz*sx
|
104 |
+
# cy*sz sy*sx*sz+cz*cx sy*cx*sz-cz*sx
|
105 |
+
# -sy cy*sx cy*cx]
|
106 |
+
sinz = np.sin(rz)
|
107 |
+
siny = np.sin(ry)
|
108 |
+
sinx = np.sin(rx)
|
109 |
+
cosz = np.cos(rz)
|
110 |
+
cosy = np.cos(ry)
|
111 |
+
cosx = np.cos(rx)
|
112 |
+
r11 = cosy*cosz
|
113 |
+
r12 = sinx*siny*cosz - cosx*sinz
|
114 |
+
r13 = cosx*siny*cosz + sinx*sinz
|
115 |
+
r21 = cosy*sinz
|
116 |
+
r22 = sinx*siny*sinz + cosx*cosz
|
117 |
+
r23 = cosx*siny*sinz - sinx*cosz
|
118 |
+
r31 = -siny
|
119 |
+
r32 = sinx*cosy
|
120 |
+
r33 = cosx*cosy
|
121 |
+
r1 = np.stack([r11,r12,r13],axis=-1)
|
122 |
+
r2 = np.stack([r21,r22,r23],axis=-1)
|
123 |
+
r3 = np.stack([r31,r32,r33],axis=-1)
|
124 |
+
r = np.stack([r1,r2,r3],axis=0)
|
125 |
+
return r
|
126 |
+
|
127 |
+
def wrap2pi(rad_angle):
|
128 |
+
# puts the angle into the range [-pi, pi]
|
129 |
+
return np.arctan2(np.sin(rad_angle), np.cos(rad_angle))
|
130 |
+
|
131 |
+
def rot2view(rx,ry,rz,x,y,z):
|
132 |
+
# takes rot angles and 3d position as input
|
133 |
+
# returns viewpoint angles as output
|
134 |
+
# (all in radians)
|
135 |
+
# it will perform strangely if z <= 0
|
136 |
+
az = wrap2pi(ry - (-np.arctan2(z, x) - 1.5*np.pi))
|
137 |
+
el = -wrap2pi(rx - (-np.arctan2(z, y) - 1.5*np.pi))
|
138 |
+
th = -rz
|
139 |
+
return az, el, th
|
140 |
+
|
141 |
+
def invAxB(a,b):
|
142 |
+
"""
|
143 |
+
Compute the relative 3d transformation between a and b.
|
144 |
+
|
145 |
+
Input:
|
146 |
+
a -- first pose (homogeneous 4x4 matrix)
|
147 |
+
b -- second pose (homogeneous 4x4 matrix)
|
148 |
+
|
149 |
+
Output:
|
150 |
+
Relative 3d transformation from a to b.
|
151 |
+
"""
|
152 |
+
return np.dot(np.linalg.inv(a),b)
|
153 |
+
|
154 |
+
def merge_rt(r, t):
|
155 |
+
# r is 3 x 3
|
156 |
+
# t is 3 or maybe 3 x 1
|
157 |
+
t = np.reshape(t, [3, 1])
|
158 |
+
rt = np.concatenate((r,t), axis=1)
|
159 |
+
# rt is 3 x 4
|
160 |
+
br = np.reshape(np.array([0,0,0,1], np.float32), [1, 4])
|
161 |
+
# br is 1 x 4
|
162 |
+
rt = np.concatenate((rt, br), axis=0)
|
163 |
+
# rt is 4 x 4
|
164 |
+
return rt
|
165 |
+
|
166 |
+
def split_rt(rt):
|
167 |
+
r = rt[:3,:3]
|
168 |
+
t = rt[:3,3]
|
169 |
+
r = np.reshape(r, [3, 3])
|
170 |
+
t = np.reshape(t, [3, 1])
|
171 |
+
return r, t
|
172 |
+
|
173 |
+
def split_intrinsics(K):
|
174 |
+
# K is 3 x 4 or 4 x 4
|
175 |
+
fx = K[0,0]
|
176 |
+
fy = K[1,1]
|
177 |
+
x0 = K[0,2]
|
178 |
+
y0 = K[1,2]
|
179 |
+
return fx, fy, x0, y0
|
180 |
+
|
181 |
+
def merge_intrinsics(fx, fy, x0, y0):
|
182 |
+
# inputs are shaped []
|
183 |
+
K = np.eye(4)
|
184 |
+
K[0,0] = fx
|
185 |
+
K[1,1] = fy
|
186 |
+
K[0,2] = x0
|
187 |
+
K[1,2] = y0
|
188 |
+
# K is shaped 4 x 4
|
189 |
+
return K
|
190 |
+
|
191 |
+
def scale_intrinsics(K, sx, sy):
|
192 |
+
fx, fy, x0, y0 = split_intrinsics(K)
|
193 |
+
fx *= sx
|
194 |
+
fy *= sy
|
195 |
+
x0 *= sx
|
196 |
+
y0 *= sy
|
197 |
+
return merge_intrinsics(fx, fy, x0, y0)
|
198 |
+
|
199 |
+
# def meshgrid(H, W):
|
200 |
+
# x = np.linspace(0, W-1, W)
|
201 |
+
# y = np.linspace(0, H-1, H)
|
202 |
+
# xv, yv = np.meshgrid(x, y)
|
203 |
+
# return xv, yv
|
204 |
+
|
205 |
+
def compute_distance(transform):
|
206 |
+
"""
|
207 |
+
Compute the distance of the translational component of a 4x4 homogeneous matrix.
|
208 |
+
"""
|
209 |
+
return numpy.linalg.norm(transform[0:3,3])
|
210 |
+
|
211 |
+
def radian_l1_dist(e, g):
|
212 |
+
# if our angles are in [0, 360] we can follow this stack overflow answer:
|
213 |
+
# https://gamedev.stackexchange.com/questions/4467/comparing-angles-and-working-out-the-difference
|
214 |
+
# wrap2pi brings the angles to [-180, 180]; adding pi puts them in [0, 360]
|
215 |
+
e = wrap2pi(e)+np.pi
|
216 |
+
g = wrap2pi(g)+np.pi
|
217 |
+
l = np.abs(np.pi - np.abs(np.abs(e-g) - np.pi))
|
218 |
+
return l
|
219 |
+
|
220 |
+
def apply_pix_T_cam(pix_T_cam, xyz):
|
221 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
222 |
+
# xyz is shaped B x H*W x 3
|
223 |
+
# returns xy, shaped B x H*W x 2
|
224 |
+
N, C = xyz.shape
|
225 |
+
x, y, z = np.split(xyz, 3, axis=-1)
|
226 |
+
EPS = 1e-4
|
227 |
+
z = np.clip(z, EPS, None)
|
228 |
+
x = (x*fx)/(z)+x0
|
229 |
+
y = (y*fy)/(z)+y0
|
230 |
+
xy = np.concatenate([x, y], axis=-1)
|
231 |
+
return xy
|
232 |
+
|
233 |
+
def apply_4x4(RT, XYZ):
|
234 |
+
# RT is 4 x 4
|
235 |
+
# XYZ is N x 3
|
236 |
+
|
237 |
+
# put into homogeneous coords
|
238 |
+
X, Y, Z = np.split(XYZ, 3, axis=1)
|
239 |
+
ones = np.ones_like(X)
|
240 |
+
XYZ1 = np.concatenate([X, Y, Z, ones], axis=1)
|
241 |
+
# XYZ1 is N x 4
|
242 |
+
|
243 |
+
XYZ1_t = np.transpose(XYZ1)
|
244 |
+
# this is 4 x N
|
245 |
+
|
246 |
+
XYZ2_t = np.dot(RT, XYZ1_t)
|
247 |
+
# this is 4 x N
|
248 |
+
|
249 |
+
XYZ2 = np.transpose(XYZ2_t)
|
250 |
+
# this is N x 4
|
251 |
+
|
252 |
+
XYZ2 = XYZ2[:,:3]
|
253 |
+
# this is N x 3
|
254 |
+
|
255 |
+
return XYZ2
|
256 |
+
|
257 |
+
def Ref2Mem(xyz, Z, Y, X):
|
258 |
+
# xyz is N x 3, in ref coordinates
|
259 |
+
# transforms ref coordinates into mem coordinates
|
260 |
+
N, C = xyz.shape
|
261 |
+
assert(C==3)
|
262 |
+
mem_T_ref = get_mem_T_ref(Z, Y, X)
|
263 |
+
xyz = apply_4x4(mem_T_ref, xyz)
|
264 |
+
return xyz
|
265 |
+
|
266 |
+
# def Mem2Ref(xyz_mem, MH, MW, MD):
|
267 |
+
# # xyz is B x N x 3, in mem coordinates
|
268 |
+
# # transforms mem coordinates into ref coordinates
|
269 |
+
# B, N, C = xyz_mem.get_shape().as_list()
|
270 |
+
# ref_T_mem = get_ref_T_mem(B, MH, MW, MD)
|
271 |
+
# xyz_ref = utils_geom.apply_4x4(ref_T_mem, xyz_mem)
|
272 |
+
# return xyz_ref
|
273 |
+
|
274 |
+
def get_mem_T_ref(Z, Y, X):
|
275 |
+
# sometimes we want the mat itself
|
276 |
+
# note this is not a rigid transform
|
277 |
+
|
278 |
+
# for interpretability, let's construct this in two steps...
|
279 |
+
|
280 |
+
# translation
|
281 |
+
center_T_ref = np.eye(4, dtype=np.float32)
|
282 |
+
center_T_ref[0,3] = -XMIN
|
283 |
+
center_T_ref[1,3] = -YMIN
|
284 |
+
center_T_ref[2,3] = -ZMIN
|
285 |
+
|
286 |
+
VOX_SIZE_X = (XMAX-XMIN)/float(X)
|
287 |
+
VOX_SIZE_Y = (YMAX-YMIN)/float(Y)
|
288 |
+
VOX_SIZE_Z = (ZMAX-ZMIN)/float(Z)
|
289 |
+
|
290 |
+
# scaling
|
291 |
+
mem_T_center = np.eye(4, dtype=np.float32)
|
292 |
+
mem_T_center[0,0] = 1./VOX_SIZE_X
|
293 |
+
mem_T_center[1,1] = 1./VOX_SIZE_Y
|
294 |
+
mem_T_center[2,2] = 1./VOX_SIZE_Z
|
295 |
+
|
296 |
+
mem_T_ref = np.dot(mem_T_center, center_T_ref)
|
297 |
+
return mem_T_ref
|
298 |
+
|
299 |
+
def safe_inverse(a):
|
300 |
+
r, t = split_rt(a)
|
301 |
+
t = np.reshape(t, [3, 1])
|
302 |
+
r_transpose = r.T
|
303 |
+
inv = np.concatenate([r_transpose, -np.matmul(r_transpose, t)], 1)
|
304 |
+
bottom_row = a[3:4, :] # this is [0, 0, 0, 1]
|
305 |
+
inv = np.concatenate([inv, bottom_row], 0)
|
306 |
+
return inv
|
307 |
+
|
308 |
+
def get_ref_T_mem(Z, Y, X):
|
309 |
+
mem_T_ref = get_mem_T_ref(X, Y, X)
|
310 |
+
# note safe_inverse is inapplicable here,
|
311 |
+
# since the transform is nonrigid
|
312 |
+
ref_T_mem = np.linalg.inv(mem_T_ref)
|
313 |
+
return ref_T_mem
|
314 |
+
|
315 |
+
def voxelize_xyz(xyz_ref, Z, Y, X):
|
316 |
+
# xyz_ref is N x 3
|
317 |
+
xyz_mem = Ref2Mem(xyz_ref, Z, Y, X)
|
318 |
+
# this is N x 3
|
319 |
+
voxels = get_occupancy(xyz_mem, Z, Y, X)
|
320 |
+
voxels = np.reshape(voxels, [Z, Y, X, 1])
|
321 |
+
return voxels
|
322 |
+
|
323 |
+
def get_inbounds(xyz, Z, Y, X, already_mem=False):
|
324 |
+
# xyz is H*W x 3
|
325 |
+
|
326 |
+
if not already_mem:
|
327 |
+
xyz = Ref2Mem(xyz, Z, Y, X)
|
328 |
+
|
329 |
+
x_valid = np.logical_and(
|
330 |
+
np.greater_equal(xyz[:,0], -0.5),
|
331 |
+
np.less(xyz[:,0], float(X)-0.5))
|
332 |
+
y_valid = np.logical_and(
|
333 |
+
np.greater_equal(xyz[:,1], -0.5),
|
334 |
+
np.less(xyz[:,1], float(Y)-0.5))
|
335 |
+
z_valid = np.logical_and(
|
336 |
+
np.greater_equal(xyz[:,2], -0.5),
|
337 |
+
np.less(xyz[:,2], float(Z)-0.5))
|
338 |
+
inbounds = np.logical_and(np.logical_and(x_valid, y_valid), z_valid)
|
339 |
+
return inbounds
|
340 |
+
|
341 |
+
def sub2ind3d_zyx(depth, height, width, d, h, w):
|
342 |
+
# same as sub2ind3d, but inputs in zyx order
|
343 |
+
# when gathering/scattering with these inds, the tensor should be Z x Y x X
|
344 |
+
return d*height*width + h*width + w
|
345 |
+
|
346 |
+
def sub2ind3d_yxz(height, width, depth, h, w, d):
|
347 |
+
return h*width*depth + w*depth + d
|
348 |
+
|
349 |
+
# def ind2sub(height, width, ind):
|
350 |
+
# # int input
|
351 |
+
# y = int(ind / height)
|
352 |
+
# x = ind % height
|
353 |
+
# return y, x
|
354 |
+
|
355 |
+
def get_occupancy(xyz_mem, Z, Y, X):
|
356 |
+
# xyz_mem is N x 3
|
357 |
+
# we want to fill a voxel tensor with 1's at these inds
|
358 |
+
|
359 |
+
inbounds = get_inbounds(xyz_mem, Z, Y, X, already_mem=True)
|
360 |
+
inds = np.where(inbounds)
|
361 |
+
|
362 |
+
xyz_mem = np.reshape(xyz_mem[inds], [-1, 3])
|
363 |
+
# xyz_mem is N x 3
|
364 |
+
|
365 |
+
# this is more accurate than a cast/floor, but runs into issues when Y==0
|
366 |
+
xyz_mem = np.round(xyz_mem).astype(np.int32)
|
367 |
+
x = xyz_mem[:,0]
|
368 |
+
y = xyz_mem[:,1]
|
369 |
+
z = xyz_mem[:,2]
|
370 |
+
|
371 |
+
voxels = np.zeros([Z, Y, X], np.float32)
|
372 |
+
voxels[z, y, x] = 1.0
|
373 |
+
|
374 |
+
return voxels
|
375 |
+
|
376 |
+
def pixels2camera(x,y,z,fx,fy,x0,y0):
|
377 |
+
# x and y are locations in pixel coordinates, z is a depth image in meters
|
378 |
+
# their shapes are H x W
|
379 |
+
# fx, fy, x0, y0 are scalar camera intrinsics
|
380 |
+
# returns xyz, sized [B,H*W,3]
|
381 |
+
|
382 |
+
H, W = z.shape
|
383 |
+
|
384 |
+
fx = np.reshape(fx, [1,1])
|
385 |
+
fy = np.reshape(fy, [1,1])
|
386 |
+
x0 = np.reshape(x0, [1,1])
|
387 |
+
y0 = np.reshape(y0, [1,1])
|
388 |
+
|
389 |
+
# unproject
|
390 |
+
x = ((z+EPS)/fx)*(x-x0)
|
391 |
+
y = ((z+EPS)/fy)*(y-y0)
|
392 |
+
|
393 |
+
x = np.reshape(x, [-1])
|
394 |
+
y = np.reshape(y, [-1])
|
395 |
+
z = np.reshape(z, [-1])
|
396 |
+
xyz = np.stack([x,y,z], axis=1)
|
397 |
+
return xyz
|
398 |
+
|
399 |
+
def depth2pointcloud(z, pix_T_cam):
|
400 |
+
H = z.shape[0]
|
401 |
+
W = z.shape[1]
|
402 |
+
y, x = meshgrid2d(H, W)
|
403 |
+
z = np.reshape(z, [H, W])
|
404 |
+
|
405 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
406 |
+
xyz = pixels2camera(x, y, z, fx, fy, x0, y0)
|
407 |
+
return xyz
|
408 |
+
|
409 |
+
def meshgrid2d(Y, X):
|
410 |
+
grid_y = np.linspace(0.0, Y-1, Y)
|
411 |
+
grid_y = np.reshape(grid_y, [Y, 1])
|
412 |
+
grid_y = np.tile(grid_y, [1, X])
|
413 |
+
|
414 |
+
grid_x = np.linspace(0.0, X-1, X)
|
415 |
+
grid_x = np.reshape(grid_x, [1, X])
|
416 |
+
grid_x = np.tile(grid_x, [Y, 1])
|
417 |
+
|
418 |
+
# outputs are Y x X
|
419 |
+
return grid_y, grid_x
|
420 |
+
|
421 |
+
def gridcloud3d(Y, X, Z):
|
422 |
+
x_ = np.linspace(0, X-1, X)
|
423 |
+
y_ = np.linspace(0, Y-1, Y)
|
424 |
+
z_ = np.linspace(0, Z-1, Z)
|
425 |
+
y, x, z = np.meshgrid(y_, x_, z_, indexing='ij')
|
426 |
+
x = np.reshape(x, [-1])
|
427 |
+
y = np.reshape(y, [-1])
|
428 |
+
z = np.reshape(z, [-1])
|
429 |
+
xyz = np.stack([x,y,z], axis=1).astype(np.float32)
|
430 |
+
return xyz
|
431 |
+
|
432 |
+
def gridcloud2d(Y, X):
|
433 |
+
x_ = np.linspace(0, X-1, X)
|
434 |
+
y_ = np.linspace(0, Y-1, Y)
|
435 |
+
y, x = np.meshgrid(y_, x_, indexing='ij')
|
436 |
+
x = np.reshape(x, [-1])
|
437 |
+
y = np.reshape(y, [-1])
|
438 |
+
xy = np.stack([x,y], axis=1).astype(np.float32)
|
439 |
+
return xy
|
440 |
+
|
441 |
+
def normalize(im):
|
442 |
+
im = im - np.min(im)
|
443 |
+
im = im / np.max(im)
|
444 |
+
return im
|
445 |
+
|
446 |
+
def wrap2pi(rad_angle):
|
447 |
+
# rad_angle can be any shape
|
448 |
+
# puts the angle into the range [-pi, pi]
|
449 |
+
return np.arctan2(np.sin(rad_angle), np.cos(rad_angle))
|
450 |
+
|
451 |
+
def convert_occ_to_height(occ):
|
452 |
+
Z, Y, X, C = occ.shape
|
453 |
+
assert(C==1)
|
454 |
+
|
455 |
+
height = np.linspace(float(Y), 1.0, Y)
|
456 |
+
height = np.reshape(height, [1, Y, 1, 1])
|
457 |
+
height = np.max(occ*height, axis=1)/float(Y)
|
458 |
+
height = np.reshape(height, [Z, X, C])
|
459 |
+
return height
|
460 |
+
|
461 |
+
def create_depth_image(xy, Z, H, W):
|
462 |
+
|
463 |
+
# turn the xy coordinates into image inds
|
464 |
+
xy = np.round(xy)
|
465 |
+
|
466 |
+
# lidar reports a sphere of measurements
|
467 |
+
# only use the inds that are within the image bounds
|
468 |
+
# also, only use forward-pointing depths (Z > 0)
|
469 |
+
valid = (xy[:,0] < W-1) & (xy[:,1] < H-1) & (xy[:,0] >= 0) & (xy[:,1] >= 0) & (Z[:] > 0)
|
470 |
+
|
471 |
+
# gather these up
|
472 |
+
xy = xy[valid]
|
473 |
+
Z = Z[valid]
|
474 |
+
|
475 |
+
inds = sub2ind(H,W,xy[:,1],xy[:,0])
|
476 |
+
depth = np.zeros((H*W), np.float32)
|
477 |
+
|
478 |
+
for (index, replacement) in zip(inds, Z):
|
479 |
+
depth[index] = replacement
|
480 |
+
depth[np.where(depth == 0.0)] = 70.0
|
481 |
+
depth = np.reshape(depth, [H, W])
|
482 |
+
|
483 |
+
return depth
|
484 |
+
|
485 |
+
def vis_depth(depth, maxdepth=80.0, log_vis=True):
|
486 |
+
depth[depth<=0.0] = maxdepth
|
487 |
+
if log_vis:
|
488 |
+
depth = np.log(depth)
|
489 |
+
depth = np.clip(depth, 0, np.log(maxdepth))
|
490 |
+
else:
|
491 |
+
depth = np.clip(depth, 0, maxdepth)
|
492 |
+
depth = (depth*255.0).astype(np.uint8)
|
493 |
+
return depth
|
494 |
+
|
495 |
+
def preprocess_color(x):
|
496 |
+
return x.astype(np.float32) * 1./255 - 0.5
|
497 |
+
|
498 |
+
def convert_box_to_ref_T_obj(boxes):
|
499 |
+
shape = boxes.shape
|
500 |
+
boxes = boxes.reshape(-1,9)
|
501 |
+
rots = [eul2rotm(rx,ry,rz)
|
502 |
+
for rx,ry,rz in boxes[:,6:]]
|
503 |
+
rots = np.stack(rots,axis=0)
|
504 |
+
trans = boxes[:,:3]
|
505 |
+
ref_T_objs = [merge_rt(rot,tran)
|
506 |
+
for rot,tran in zip(rots,trans)]
|
507 |
+
ref_T_objs = np.stack(ref_T_objs,axis=0)
|
508 |
+
ref_T_objs = ref_T_objs.reshape(shape[:-1]+(4,4))
|
509 |
+
ref_T_objs = ref_T_objs.astype(np.float32)
|
510 |
+
return ref_T_objs
|
511 |
+
|
512 |
+
def get_rot_from_delta(delta, yaw_only=False):
|
513 |
+
dx = delta[:,0]
|
514 |
+
dy = delta[:,1]
|
515 |
+
dz = delta[:,2]
|
516 |
+
|
517 |
+
bot_hyp = np.sqrt(dz**2 + dx**2)
|
518 |
+
# top_hyp = np.sqrt(bot_hyp**2 + dy**2)
|
519 |
+
|
520 |
+
pitch = -np.arctan2(dy, bot_hyp)
|
521 |
+
yaw = np.arctan2(dz, dx)
|
522 |
+
|
523 |
+
if yaw_only:
|
524 |
+
rot = [eul2rotm(0,y,0) for y in yaw]
|
525 |
+
else:
|
526 |
+
rot = [eul2rotm(0,y,p) for (p,y) in zip(pitch,yaw)]
|
527 |
+
|
528 |
+
rot = np.stack(rot)
|
529 |
+
# rot is B x 3 x 3
|
530 |
+
return rot
|
531 |
+
|
532 |
+
def im2col(im, psize):
|
533 |
+
n_channels = 1 if len(im.shape) == 2 else im.shape[0]
|
534 |
+
(n_channels, rows, cols) = (1,) * (3 - len(im.shape)) + im.shape
|
535 |
+
|
536 |
+
im_pad = np.zeros((n_channels,
|
537 |
+
int(math.ceil(1.0 * rows / psize) * psize),
|
538 |
+
int(math.ceil(1.0 * cols / psize) * psize)))
|
539 |
+
im_pad[:, 0:rows, 0:cols] = im
|
540 |
+
|
541 |
+
final = np.zeros((im_pad.shape[1], im_pad.shape[2], n_channels,
|
542 |
+
psize, psize))
|
543 |
+
for c in np.arange(n_channels):
|
544 |
+
for x in np.arange(psize):
|
545 |
+
for y in np.arange(psize):
|
546 |
+
im_shift = np.vstack(
|
547 |
+
(im_pad[c, x:], im_pad[c, :x]))
|
548 |
+
im_shift = np.column_stack(
|
549 |
+
(im_shift[:, y:], im_shift[:, :y]))
|
550 |
+
final[x::psize, y::psize, c] = np.swapaxes(
|
551 |
+
im_shift.reshape(int(im_pad.shape[1] / psize), psize,
|
552 |
+
int(im_pad.shape[2] / psize), psize), 1, 2)
|
553 |
+
|
554 |
+
return np.squeeze(final[0:rows - psize + 1, 0:cols - psize + 1])
|
555 |
+
|
556 |
+
def filter_discontinuities(depth, filter_size=9, thresh=10):
|
557 |
+
H, W = list(depth.shape)
|
558 |
+
|
559 |
+
# Ensure that filter sizes are okay
|
560 |
+
assert filter_size % 2 == 1, "Can only use odd filter sizes."
|
561 |
+
|
562 |
+
# Compute discontinuities
|
563 |
+
offset = int((filter_size - 1) / 2)
|
564 |
+
patches = 1.0 * im2col(depth, filter_size)
|
565 |
+
mids = patches[:, :, offset, offset]
|
566 |
+
mins = np.min(patches, axis=(2, 3))
|
567 |
+
maxes = np.max(patches, axis=(2, 3))
|
568 |
+
|
569 |
+
discont = np.maximum(np.abs(mins - mids),
|
570 |
+
np.abs(maxes - mids))
|
571 |
+
mark = discont > thresh
|
572 |
+
|
573 |
+
# Account for offsets
|
574 |
+
final_mark = np.zeros((H, W), dtype=np.uint16)
|
575 |
+
final_mark[offset:offset + mark.shape[0],
|
576 |
+
offset:offset + mark.shape[1]] = mark
|
577 |
+
|
578 |
+
return depth * (1 - final_mark)
|
579 |
+
|
580 |
+
def argmax2d(tensor):
|
581 |
+
Y, X = list(tensor.shape)
|
582 |
+
# flatten the Tensor along the height and width axes
|
583 |
+
flat_tensor = tensor.reshape(-1)
|
584 |
+
# argmax of the flat tensor
|
585 |
+
argmax = np.argmax(flat_tensor)
|
586 |
+
|
587 |
+
# convert the indices into 2d coordinates
|
588 |
+
argmax_y = argmax // X # row
|
589 |
+
argmax_x = argmax % X # col
|
590 |
+
|
591 |
+
return argmax_y, argmax_x
|
592 |
+
|
593 |
+
def plot_traj_3d(traj):
|
594 |
+
# traj is S x 3
|
595 |
+
|
596 |
+
# print('traj', traj.shape)
|
597 |
+
S, C = list(traj.shape)
|
598 |
+
assert(C==3)
|
599 |
+
|
600 |
+
fig = plt.figure()
|
601 |
+
ax = fig.add_subplot(111, projection='3d')
|
602 |
+
|
603 |
+
colors = [plt.cm.RdYlBu(i) for i in np.linspace(0,1,S)]
|
604 |
+
# print('colors', colors)
|
605 |
+
|
606 |
+
xs = traj[:,0]
|
607 |
+
ys = -traj[:,1]
|
608 |
+
zs = traj[:,2]
|
609 |
+
|
610 |
+
ax.scatter(xs, zs, ys, s=30, c=colors, marker='o', alpha=1.0, edgecolors=(0,0,0))#, color=color_map[n])
|
611 |
+
|
612 |
+
ax.set_xlabel('X')
|
613 |
+
ax.set_ylabel('Z')
|
614 |
+
ax.set_zlabel('Y')
|
615 |
+
|
616 |
+
ax.set_xlim(0,1)
|
617 |
+
ax.set_ylim(0,1) # this is really Z
|
618 |
+
ax.set_zlim(-1,0) # this is really Y
|
619 |
+
|
620 |
+
buf = io.BytesIO()
|
621 |
+
plt.savefig(buf, format='png')
|
622 |
+
buf.seek(0)
|
623 |
+
image = np.array(Image.open(buf)) # H x W x 4
|
624 |
+
image = image[:,:,:3]
|
625 |
+
|
626 |
+
plt.close()
|
627 |
+
return image
|
628 |
+
|
629 |
+
def camera2pixels(xyz, pix_T_cam):
|
630 |
+
# xyz is shaped N x 3
|
631 |
+
# returns xy, shaped N x 2
|
632 |
+
|
633 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
634 |
+
x, y, z = xyz[:,0], xyz[:,1], xyz[:,2]
|
635 |
+
|
636 |
+
EPS = 1e-4
|
637 |
+
z = np.clip(z, EPS, None)
|
638 |
+
x = (x*fx)/z + x0
|
639 |
+
y = (y*fy)/z + y0
|
640 |
+
xy = np.stack([x, y], axis=-1)
|
641 |
+
return xy
|
642 |
+
|
643 |
+
def make_colorwheel():
|
644 |
+
"""
|
645 |
+
Generates a color wheel for optical flow visualization as presented in:
|
646 |
+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
647 |
+
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
648 |
+
|
649 |
+
Code follows the original C++ source code of Daniel Scharstein.
|
650 |
+
Code follows the the Matlab source code of Deqing Sun.
|
651 |
+
|
652 |
+
Returns:
|
653 |
+
np.ndarray: Color wheel
|
654 |
+
"""
|
655 |
+
|
656 |
+
RY = 15
|
657 |
+
YG = 6
|
658 |
+
GC = 4
|
659 |
+
CB = 11
|
660 |
+
BM = 13
|
661 |
+
MR = 6
|
662 |
+
|
663 |
+
ncols = RY + YG + GC + CB + BM + MR
|
664 |
+
colorwheel = np.zeros((ncols, 3))
|
665 |
+
col = 0
|
666 |
+
|
667 |
+
# RY
|
668 |
+
colorwheel[0:RY, 0] = 255
|
669 |
+
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
|
670 |
+
col = col+RY
|
671 |
+
# YG
|
672 |
+
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
|
673 |
+
colorwheel[col:col+YG, 1] = 255
|
674 |
+
col = col+YG
|
675 |
+
# GC
|
676 |
+
colorwheel[col:col+GC, 1] = 255
|
677 |
+
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
|
678 |
+
col = col+GC
|
679 |
+
# CB
|
680 |
+
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
|
681 |
+
colorwheel[col:col+CB, 2] = 255
|
682 |
+
col = col+CB
|
683 |
+
# BM
|
684 |
+
colorwheel[col:col+BM, 2] = 255
|
685 |
+
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
|
686 |
+
col = col+BM
|
687 |
+
# MR
|
688 |
+
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
|
689 |
+
colorwheel[col:col+MR, 0] = 255
|
690 |
+
return colorwheel
|
691 |
+
|
692 |
+
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
693 |
+
"""
|
694 |
+
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
695 |
+
|
696 |
+
According to the C++ source code of Daniel Scharstein
|
697 |
+
According to the Matlab source code of Deqing Sun
|
698 |
+
|
699 |
+
Args:
|
700 |
+
u (np.ndarray): Input horizontal flow of shape [H,W]
|
701 |
+
v (np.ndarray): Input vertical flow of shape [H,W]
|
702 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
703 |
+
|
704 |
+
Returns:
|
705 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
706 |
+
"""
|
707 |
+
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
708 |
+
colorwheel = make_colorwheel() # shape [55x3]
|
709 |
+
ncols = colorwheel.shape[0]
|
710 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
711 |
+
a = np.arctan2(-v, -u)/np.pi
|
712 |
+
fk = (a+1) / 2*(ncols-1)
|
713 |
+
k0 = np.floor(fk).astype(np.int32)
|
714 |
+
k1 = k0 + 1
|
715 |
+
k1[k1 == ncols] = 0
|
716 |
+
f = fk - k0
|
717 |
+
for i in range(colorwheel.shape[1]):
|
718 |
+
tmp = colorwheel[:,i]
|
719 |
+
col0 = tmp[k0] / 255.0
|
720 |
+
col1 = tmp[k1] / 255.0
|
721 |
+
col = (1-f)*col0 + f*col1
|
722 |
+
idx = (rad <= 1)
|
723 |
+
col[idx] = 1 - rad[idx] * (1-col[idx])
|
724 |
+
col[~idx] = col[~idx] * 0.75 # out of range
|
725 |
+
# Note the 2-i => BGR instead of RGB
|
726 |
+
ch_idx = 2-i if convert_to_bgr else i
|
727 |
+
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
728 |
+
return flow_image
|
729 |
+
|
730 |
+
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
731 |
+
"""
|
732 |
+
Expects a two dimensional flow image of shape.
|
733 |
+
|
734 |
+
Args:
|
735 |
+
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
736 |
+
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
737 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
738 |
+
|
739 |
+
Returns:
|
740 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
741 |
+
"""
|
742 |
+
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
743 |
+
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
744 |
+
if clip_flow is not None:
|
745 |
+
flow_uv = np.clip(flow_uv, -clip_flow, clip_flow) / clip_flow
|
746 |
+
# flow_uv = np.clamp(flow, -clip, clip)/clip
|
747 |
+
|
748 |
+
u = flow_uv[:,:,0]
|
749 |
+
v = flow_uv[:,:,1]
|
750 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
751 |
+
rad_max = np.max(rad)
|
752 |
+
epsilon = 1e-5
|
753 |
+
u = u / (rad_max + epsilon)
|
754 |
+
v = v / (rad_max + epsilon)
|
755 |
+
return flow_uv_to_colors(u, v, convert_to_bgr)
|
utils/samp.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import utils.basic
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
6 |
+
r"""Sample a tensor using bilinear interpolation
|
7 |
+
|
8 |
+
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
9 |
+
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
10 |
+
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
11 |
+
convention.
|
12 |
+
|
13 |
+
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
14 |
+
:math:`B` is the batch size, :math:`C` is the number of channels,
|
15 |
+
:math:`H` is the height of the image, and :math:`W` is the width of the
|
16 |
+
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
17 |
+
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
18 |
+
|
19 |
+
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
20 |
+
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
21 |
+
that in this case the order of the components is slightly different
|
22 |
+
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
23 |
+
|
24 |
+
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
25 |
+
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
26 |
+
left-most image pixel :math:`W-1` to the center of the right-most
|
27 |
+
pixel.
|
28 |
+
|
29 |
+
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
30 |
+
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
31 |
+
the left-most pixel :math:`W` to the right edge of the right-most
|
32 |
+
pixel.
|
33 |
+
|
34 |
+
Similar conventions apply to the :math:`y` for the range
|
35 |
+
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
36 |
+
:math:`[0,T-1]` and :math:`[0,T]`.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
input (Tensor): batch of input images.
|
40 |
+
coords (Tensor): batch of coordinates.
|
41 |
+
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
42 |
+
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Tensor: sampled points.
|
46 |
+
"""
|
47 |
+
|
48 |
+
sizes = input.shape[2:]
|
49 |
+
|
50 |
+
assert len(sizes) in [2, 3]
|
51 |
+
|
52 |
+
if len(sizes) == 3:
|
53 |
+
# t x y -> x y t to match dimensions T H W in grid_sample
|
54 |
+
coords = coords[..., [1, 2, 0]]
|
55 |
+
|
56 |
+
if align_corners:
|
57 |
+
coords = coords * torch.tensor(
|
58 |
+
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
|
59 |
+
)
|
60 |
+
else:
|
61 |
+
coords = coords * torch.tensor(
|
62 |
+
[2 / size for size in reversed(sizes)], device=coords.device
|
63 |
+
)
|
64 |
+
|
65 |
+
coords -= 1
|
66 |
+
|
67 |
+
return F.grid_sample(
|
68 |
+
input, coords, align_corners=align_corners, padding_mode=padding_mode
|
69 |
+
)
|
70 |
+
|
71 |
+
|
72 |
+
def sample_features4d(input, coords):
|
73 |
+
r"""Sample spatial features
|
74 |
+
|
75 |
+
`sample_features4d(input, coords)` samples the spatial features
|
76 |
+
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
77 |
+
|
78 |
+
The field is sampled at coordinates :attr:`coords` using bilinear
|
79 |
+
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
80 |
+
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
81 |
+
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
82 |
+
|
83 |
+
The output tensor has one feature per point, and has shape :math:`(B,
|
84 |
+
R, C)`.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
input (Tensor): spatial features.
|
88 |
+
coords (Tensor): points.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
Tensor: sampled features.
|
92 |
+
"""
|
93 |
+
|
94 |
+
B, _, _, _ = input.shape
|
95 |
+
|
96 |
+
# B R 2 -> B R 1 2
|
97 |
+
coords = coords.unsqueeze(2)
|
98 |
+
|
99 |
+
# B C R 1
|
100 |
+
feats = bilinear_sampler(input, coords)
|
101 |
+
|
102 |
+
return feats.permute(0, 2, 1, 3).view(
|
103 |
+
B, -1, feats.shape[1] * feats.shape[3]
|
104 |
+
) # B C R 1 -> B R C
|
105 |
+
|
106 |
+
|
107 |
+
def sample_features5d(input, coords):
|
108 |
+
r"""Sample spatio-temporal features
|
109 |
+
|
110 |
+
`sample_features5d(input, coords)` works in the same way as
|
111 |
+
:func:`sample_features4d` but for spatio-temporal features and points:
|
112 |
+
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
|
113 |
+
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
|
114 |
+
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
input (Tensor): spatio-temporal features.
|
118 |
+
coords (Tensor): spatio-temporal points.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
Tensor: sampled features.
|
122 |
+
"""
|
123 |
+
|
124 |
+
B, T, _, _, _ = input.shape
|
125 |
+
|
126 |
+
# B T C H W -> B C T H W
|
127 |
+
input = input.permute(0, 2, 1, 3, 4)
|
128 |
+
|
129 |
+
# B R1 R2 3 -> B R1 R2 1 3
|
130 |
+
coords = coords.unsqueeze(3)
|
131 |
+
|
132 |
+
# B C R1 R2 1
|
133 |
+
feats = bilinear_sampler(input, coords)
|
134 |
+
|
135 |
+
return feats.permute(0, 2, 3, 1, 4).view(
|
136 |
+
B, feats.shape[2], feats.shape[3], feats.shape[1]
|
137 |
+
) # B C R1 R2 1 -> B R1 R2 C
|
138 |
+
|
139 |
+
|
140 |
+
def bilinear_sample2d(im, x, y, return_inbounds=False):
|
141 |
+
# x and y are each B, N
|
142 |
+
# output is B, C, N
|
143 |
+
B, C, H, W = list(im.shape)
|
144 |
+
N = list(x.shape)[1]
|
145 |
+
|
146 |
+
x = x.float()
|
147 |
+
y = y.float()
|
148 |
+
H_f = torch.tensor(H, dtype=torch.float32)
|
149 |
+
W_f = torch.tensor(W, dtype=torch.float32)
|
150 |
+
|
151 |
+
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
|
152 |
+
|
153 |
+
max_y = (H_f - 1).int()
|
154 |
+
max_x = (W_f - 1).int()
|
155 |
+
|
156 |
+
x0 = torch.floor(x).int()
|
157 |
+
x1 = x0 + 1
|
158 |
+
y0 = torch.floor(y).int()
|
159 |
+
y1 = y0 + 1
|
160 |
+
|
161 |
+
x0_clip = torch.clamp(x0, 0, max_x)
|
162 |
+
x1_clip = torch.clamp(x1, 0, max_x)
|
163 |
+
y0_clip = torch.clamp(y0, 0, max_y)
|
164 |
+
y1_clip = torch.clamp(y1, 0, max_y)
|
165 |
+
dim2 = W
|
166 |
+
dim1 = W * H
|
167 |
+
|
168 |
+
base = torch.arange(0, B, dtype=torch.int64, device=x.device)*dim1
|
169 |
+
base = torch.reshape(base, [B, 1]).repeat([1, N])
|
170 |
+
|
171 |
+
base_y0 = base + y0_clip * dim2
|
172 |
+
base_y1 = base + y1_clip * dim2
|
173 |
+
|
174 |
+
idx_y0_x0 = base_y0 + x0_clip
|
175 |
+
idx_y0_x1 = base_y0 + x1_clip
|
176 |
+
idx_y1_x0 = base_y1 + x0_clip
|
177 |
+
idx_y1_x1 = base_y1 + x1_clip
|
178 |
+
|
179 |
+
# use the indices to lookup pixels in the flat image
|
180 |
+
# im is B x C x H x W
|
181 |
+
# move C out to last dim
|
182 |
+
im_flat = (im.permute(0, 2, 3, 1)).reshape(B*H*W, C)
|
183 |
+
i_y0_x0 = im_flat[idx_y0_x0.long()]
|
184 |
+
i_y0_x1 = im_flat[idx_y0_x1.long()]
|
185 |
+
i_y1_x0 = im_flat[idx_y1_x0.long()]
|
186 |
+
i_y1_x1 = im_flat[idx_y1_x1.long()]
|
187 |
+
|
188 |
+
# Finally calculate interpolated values.
|
189 |
+
x0_f = x0.float()
|
190 |
+
x1_f = x1.float()
|
191 |
+
y0_f = y0.float()
|
192 |
+
y1_f = y1.float()
|
193 |
+
|
194 |
+
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
|
195 |
+
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
|
196 |
+
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
|
197 |
+
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
|
198 |
+
|
199 |
+
output = w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + \
|
200 |
+
w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
|
201 |
+
# output is B*N x C
|
202 |
+
output = output.view(B, -1, C)
|
203 |
+
output = output.permute(0, 2, 1)
|
204 |
+
# output is B x C x N
|
205 |
+
|
206 |
+
if return_inbounds:
|
207 |
+
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
|
208 |
+
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
|
209 |
+
inbounds = (x_valid & y_valid).float()
|
210 |
+
inbounds = inbounds.reshape(B, N) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
|
211 |
+
return output, inbounds
|
212 |
+
|
213 |
+
return output # B, C, N
|
utils/saveload.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
|
5 |
+
def save(ckpt_dir, module, optimizer, scheduler, global_step, keep_latest=2, model_name='model'):
|
6 |
+
pathlib.Path(ckpt_dir).mkdir(exist_ok=True, parents=True)
|
7 |
+
prev_ckpts = list(pathlib.Path(ckpt_dir).glob('%s-*pth' % model_name))
|
8 |
+
prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
|
9 |
+
if len(prev_ckpts) > keep_latest-1:
|
10 |
+
for f in prev_ckpts[keep_latest-1:]:
|
11 |
+
f.unlink()
|
12 |
+
save_path = '%s/%s-%09d.pth' % (ckpt_dir, model_name, global_step)
|
13 |
+
save_dict = {
|
14 |
+
"model": module.state_dict(),
|
15 |
+
"optimizer": optimizer.state_dict(),
|
16 |
+
"global_step": global_step,
|
17 |
+
}
|
18 |
+
if scheduler is not None:
|
19 |
+
save_dict['scheduler'] = scheduler.state_dict()
|
20 |
+
print(f"saving {save_path}")
|
21 |
+
torch.save(save_dict, save_path)
|
22 |
+
return False
|
23 |
+
|
24 |
+
def load(fabric, ckpt_dir, model, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model', ignore_load=None, strict=True, verbose=True, weights_only=False):
|
25 |
+
if verbose:
|
26 |
+
print('reading ckpt from %s' % ckpt_dir)
|
27 |
+
if not os.path.exists(ckpt_dir):
|
28 |
+
print('...there is no full checkpoint in %s' % ckpt_dir)
|
29 |
+
print('-- note this function no longer appends "saved_checkpoints/" before the ckpt_dir --')
|
30 |
+
assert(False)
|
31 |
+
else:
|
32 |
+
prev_ckpts = list(pathlib.Path(ckpt_dir).glob('%s-*pth' % model_name))
|
33 |
+
prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
|
34 |
+
if len(prev_ckpts):
|
35 |
+
path = prev_ckpts[0]
|
36 |
+
# e.g., './checkpoints/2Ai4_5e-4_base18_1539/model-000050000.pth'
|
37 |
+
step = int(str(path).split('-')[-1].split('.')[0])
|
38 |
+
if verbose:
|
39 |
+
print('...found checkpoint %s; (parsed step %d from path)' % (path, step))
|
40 |
+
if fabric is not None:
|
41 |
+
checkpoint = fabric.load(path)
|
42 |
+
else:
|
43 |
+
checkpoint = torch.load(path, weights_only=weights_only)
|
44 |
+
if optimizer is not None:
|
45 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
46 |
+
if scheduler is not None:
|
47 |
+
scheduler.load_state_dict(checkpoint['scheduler'])
|
48 |
+
assert ignore_load is None # not ready yet
|
49 |
+
if 'model' in checkpoint:
|
50 |
+
state_dict = checkpoint['model']
|
51 |
+
else:
|
52 |
+
state_dict = checkpoint
|
53 |
+
model.load_state_dict(state_dict, strict=strict)
|
54 |
+
else:
|
55 |
+
print('...there is no full checkpoint here!')
|
56 |
+
return step
|
57 |
+
|
58 |
+
|
59 |
+
|
utils/test.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import cv2
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
def prep_frame_for_dino(img, scale_size=[192]):
|
7 |
+
"""
|
8 |
+
read a single frame & preprocess
|
9 |
+
"""
|
10 |
+
ori_h, ori_w, _ = img.shape
|
11 |
+
if len(scale_size) == 1:
|
12 |
+
if(ori_h > ori_w):
|
13 |
+
tw = scale_size[0]
|
14 |
+
th = (tw * ori_h) / ori_w
|
15 |
+
th = int((th // 64) * 64)
|
16 |
+
else:
|
17 |
+
th = scale_size[0]
|
18 |
+
tw = (th * ori_w) / ori_h
|
19 |
+
tw = int((tw // 64) * 64)
|
20 |
+
else:
|
21 |
+
th, tw = scale_size
|
22 |
+
img = cv2.resize(img, (tw, th))
|
23 |
+
img = img.astype(np.float32)
|
24 |
+
img = img / 255.0
|
25 |
+
img = img[:, :, ::-1]
|
26 |
+
img = np.transpose(img.copy(), (2, 0, 1))
|
27 |
+
img = torch.from_numpy(img).float()
|
28 |
+
|
29 |
+
def color_normalize(x, mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]):
|
30 |
+
for t, m, s in zip(x, mean, std):
|
31 |
+
t.sub_(m)
|
32 |
+
t.div_(s)
|
33 |
+
return x
|
34 |
+
|
35 |
+
img = color_normalize(img)
|
36 |
+
return img, ori_h, ori_w
|
37 |
+
|
38 |
+
def get_feats_from_dino(model, frame):
|
39 |
+
# batch version of the other func
|
40 |
+
B = frame.shape[0]
|
41 |
+
patch_size = model.patch_embed.patch_size
|
42 |
+
h, w = int(frame.shape[2] / patch_size), int(frame.shape[3] / patch_size)
|
43 |
+
out = model.get_intermediate_layers(frame.cuda(), n=1)[0] # B, 1+h*w, dim
|
44 |
+
dim = out.shape[-1]
|
45 |
+
out = out[:, 1:, :] # discard the [CLS] token
|
46 |
+
outmap = out.permute(0, 2, 1).reshape(B, dim, h, w)
|
47 |
+
return out, outmap, h, w
|
48 |
+
|
49 |
+
def restrict_neighborhood(h, w):
|
50 |
+
size_mask_neighborhood = 12
|
51 |
+
# We restrict the set of source nodes considered to a spatial neighborhood of the query node (i.e. ``local attention'')
|
52 |
+
mask = torch.zeros(h, w, h, w)
|
53 |
+
for i in range(h):
|
54 |
+
for j in range(w):
|
55 |
+
for p in range(2 * size_mask_neighborhood + 1):
|
56 |
+
for q in range(2 * size_mask_neighborhood + 1):
|
57 |
+
if i - size_mask_neighborhood + p < 0 or i - size_mask_neighborhood + p >= h:
|
58 |
+
continue
|
59 |
+
if j - size_mask_neighborhood + q < 0 or j - size_mask_neighborhood + q >= w:
|
60 |
+
continue
|
61 |
+
mask[i, j, i - size_mask_neighborhood + p, j - size_mask_neighborhood + q] = 1
|
62 |
+
|
63 |
+
mask = mask.reshape(h * w, h * w)
|
64 |
+
return mask.cuda(non_blocking=True)
|
65 |
+
|
66 |
+
def label_propagation(h, w, feat_tar, list_frame_feats, list_segs, mask_neighborhood=None):
|
67 |
+
ncontext = len(list_frame_feats)
|
68 |
+
feat_sources = torch.stack(list_frame_feats) # nmb_context x dim x h*w
|
69 |
+
|
70 |
+
feat_tar = F.normalize(feat_tar, dim=1, p=2)
|
71 |
+
feat_sources = F.normalize(feat_sources, dim=1, p=2)
|
72 |
+
|
73 |
+
# print('feat_tar', feat_tar.shape)
|
74 |
+
# print('feat_sources', feat_sources.shape)
|
75 |
+
|
76 |
+
feat_tar = feat_tar.unsqueeze(0).repeat(ncontext, 1, 1)
|
77 |
+
aff = torch.exp(torch.bmm(feat_tar, feat_sources) / 0.1)
|
78 |
+
|
79 |
+
size_mask_neighborhood = 12
|
80 |
+
if size_mask_neighborhood > 0:
|
81 |
+
if mask_neighborhood is None:
|
82 |
+
mask_neighborhood = restrict_neighborhood(h, w)
|
83 |
+
mask_neighborhood = mask_neighborhood.unsqueeze(0).repeat(ncontext, 1, 1)
|
84 |
+
aff *= mask_neighborhood
|
85 |
+
|
86 |
+
aff = aff.transpose(2, 1).reshape(-1, h*w) # nmb_context*h*w (source: keys) x h*w (tar: queries)
|
87 |
+
topk = 5
|
88 |
+
tk_val, _ = torch.topk(aff, dim=0, k=topk)
|
89 |
+
tk_val_min, _ = torch.min(tk_val, dim=0)
|
90 |
+
aff[aff < tk_val_min] = 0
|
91 |
+
|
92 |
+
aff = aff / torch.sum(aff, keepdim=True, axis=0)
|
93 |
+
|
94 |
+
list_segs = [s.cuda() for s in list_segs]
|
95 |
+
segs = torch.cat(list_segs)
|
96 |
+
nmb_context, C, h, w = segs.shape
|
97 |
+
segs = segs.reshape(nmb_context, C, -1).transpose(2, 1).reshape(-1, C).T # C x nmb_context*h*w
|
98 |
+
seg_tar = torch.mm(segs, aff)
|
99 |
+
seg_tar = seg_tar.reshape(1, C, h, w)
|
100 |
+
|
101 |
+
return seg_tar, mask_neighborhood
|
102 |
+
|
103 |
+
def norm_mask(mask):
|
104 |
+
c, h, w = mask.size()
|
105 |
+
for cnt in range(c):
|
106 |
+
mask_cnt = mask[cnt,:,:]
|
107 |
+
if(mask_cnt.max() > 0):
|
108 |
+
mask_cnt = (mask_cnt - mask_cnt.min())
|
109 |
+
mask_cnt = mask_cnt/mask_cnt.max()
|
110 |
+
mask[cnt,:,:] = mask_cnt
|
111 |
+
return mask
|
112 |
+
|
113 |
+
|
114 |
+
def get_dino_output(dino, rgbs, trajs_g, vis_g):
|
115 |
+
B, S, C, H, W = rgbs.shape
|
116 |
+
|
117 |
+
B1, S1, N, D = trajs_g.shape
|
118 |
+
assert(B1==B)
|
119 |
+
assert(S1==S)
|
120 |
+
assert(D==2)
|
121 |
+
|
122 |
+
assert(B==1)
|
123 |
+
xy0 = trajs_g[:,0] # B, N, 2
|
124 |
+
|
125 |
+
# The queue stores the n preceeding frames
|
126 |
+
import queue
|
127 |
+
import copy
|
128 |
+
n_last_frames = 7
|
129 |
+
que = queue.Queue(n_last_frames)
|
130 |
+
|
131 |
+
# run dino
|
132 |
+
prep_rgbs = []
|
133 |
+
for s in range(S):
|
134 |
+
prep_rgb, ori_h, ori_w = prep_frame_for_dino(rgbs[0, s].permute(1,2,0).detach().cpu().numpy(), scale_size=[H])
|
135 |
+
prep_rgbs.append(prep_rgb)
|
136 |
+
prep_rgbs = torch.stack(prep_rgbs, dim=0) # S, 3, H, W
|
137 |
+
with torch.no_grad():
|
138 |
+
bs = 8
|
139 |
+
idx = 0
|
140 |
+
featmaps = []
|
141 |
+
while idx < S:
|
142 |
+
end_id = min(S, idx+bs)
|
143 |
+
_, featmaps_cur, h, w = get_feats_from_dino(dino, prep_rgbs[idx:end_id]) # S, C, h, w
|
144 |
+
idx = end_id
|
145 |
+
featmaps.append(featmaps_cur)
|
146 |
+
featmaps = torch.cat(featmaps, dim=0)
|
147 |
+
C = featmaps.shape[1]
|
148 |
+
featmaps = featmaps.unsqueeze(0) # 1, S, C, h, w
|
149 |
+
# featmaps = F.normalize(featmaps, dim=2, p=2)
|
150 |
+
|
151 |
+
xy0 = trajs_g[:, 0, :] # B, N, 2
|
152 |
+
patch_size = dino.patch_embed.patch_size
|
153 |
+
first_seg = torch.zeros((1, N, H//patch_size, W//patch_size))
|
154 |
+
for n in range(N):
|
155 |
+
first_seg[0, n, (xy0[0, n, 1]/patch_size).long(), (xy0[0, n, 0]/patch_size).long()] = 1
|
156 |
+
|
157 |
+
frame1_feat = featmaps[0, 0].reshape(C, h*w) # dim x h*w
|
158 |
+
mask_neighborhood = None
|
159 |
+
accs = []
|
160 |
+
trajs_e = torch.zeros_like(trajs_g)
|
161 |
+
trajs_e[0,0] = trajs_g[0,0]
|
162 |
+
for cnt in range(1, S):
|
163 |
+
used_frame_feats = [frame1_feat] + [pair[0] for pair in list(que.queue)]
|
164 |
+
used_segs = [first_seg] + [pair[1] for pair in list(que.queue)]
|
165 |
+
|
166 |
+
feat_tar = featmaps[0, cnt].reshape(C, h*w)
|
167 |
+
|
168 |
+
frame_tar_avg, mask_neighborhood = label_propagation(h, w, feat_tar.T, used_frame_feats, used_segs, mask_neighborhood)
|
169 |
+
|
170 |
+
# pop out oldest frame if neccessary
|
171 |
+
if que.qsize() == n_last_frames:
|
172 |
+
que.get()
|
173 |
+
# push current results into queue
|
174 |
+
seg = copy.deepcopy(frame_tar_avg)
|
175 |
+
que.put([feat_tar, seg])
|
176 |
+
|
177 |
+
# upsampling & argmax
|
178 |
+
frame_tar_avg = F.interpolate(frame_tar_avg, scale_factor=patch_size, mode='bilinear', align_corners=False, recompute_scale_factor=False)[0]
|
179 |
+
frame_tar_avg = norm_mask(frame_tar_avg)
|
180 |
+
_, frame_tar_seg = torch.max(frame_tar_avg, dim=0)
|
181 |
+
|
182 |
+
for n in range(N):
|
183 |
+
vis = vis_g[0,cnt,n]
|
184 |
+
if len(torch.nonzero(frame_tar_avg[n])) > 0:
|
185 |
+
# weighted average
|
186 |
+
nz = torch.nonzero(frame_tar_avg[n])
|
187 |
+
coord_e = torch.sum(frame_tar_avg[n][nz[:,0], nz[:,1]].reshape(-1,1) * nz.float(), 0) / frame_tar_avg[n][nz[:,0], nz[:,1]].sum() # 2
|
188 |
+
coord_e = coord_e[[1,0]]
|
189 |
+
else:
|
190 |
+
# stay where it was
|
191 |
+
coord_e = trajs_e[0,cnt-1,n]
|
192 |
+
|
193 |
+
trajs_e[0, cnt, n] = coord_e
|
194 |
+
return trajs_e
|
utils/visualizer.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import numpy as np
|
9 |
+
import imageio
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from matplotlib import cm
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import torchvision.transforms as transforms
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
from PIL import Image, ImageDraw
|
17 |
+
|
18 |
+
|
19 |
+
def read_video_from_path(path):
|
20 |
+
try:
|
21 |
+
reader = imageio.get_reader(path)
|
22 |
+
except Exception as e:
|
23 |
+
print("Error opening video file: ", e)
|
24 |
+
return None
|
25 |
+
frames = []
|
26 |
+
for i, im in enumerate(reader):
|
27 |
+
frames.append(np.array(im))
|
28 |
+
return np.stack(frames)
|
29 |
+
|
30 |
+
|
31 |
+
def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True, color_alpha=None):
|
32 |
+
# Create a draw object
|
33 |
+
draw = ImageDraw.Draw(rgb)
|
34 |
+
# Calculate the bounding box of the circle
|
35 |
+
left_up_point = (coord[0] - radius, coord[1] - radius)
|
36 |
+
right_down_point = (coord[0] + radius, coord[1] + radius)
|
37 |
+
# Draw the circle
|
38 |
+
color = tuple(list(color) + [color_alpha if color_alpha is not None else 255])
|
39 |
+
|
40 |
+
draw.ellipse(
|
41 |
+
[left_up_point, right_down_point],
|
42 |
+
fill=tuple(color) if visible else None,
|
43 |
+
outline=tuple(color),
|
44 |
+
)
|
45 |
+
return rgb
|
46 |
+
|
47 |
+
|
48 |
+
def draw_line(rgb, coord_y, coord_x, color, linewidth):
|
49 |
+
draw = ImageDraw.Draw(rgb)
|
50 |
+
draw.line(
|
51 |
+
(coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
|
52 |
+
fill=tuple(color),
|
53 |
+
width=linewidth,
|
54 |
+
)
|
55 |
+
return rgb
|
56 |
+
|
57 |
+
|
58 |
+
def add_weighted(rgb, alpha, original, beta, gamma):
|
59 |
+
return (rgb * alpha + original * beta + gamma).astype("uint8")
|
60 |
+
|
61 |
+
|
62 |
+
class Visualizer:
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
save_dir: str = "./results",
|
66 |
+
grayscale: bool = False,
|
67 |
+
pad_value: int = 0,
|
68 |
+
fps: int = 10,
|
69 |
+
mode: str = "rainbow", # 'cool', 'optical_flow'
|
70 |
+
linewidth: int = 2,
|
71 |
+
show_first_frame: int = 10,
|
72 |
+
tracks_leave_trace: int = 0, # -1 for infinite
|
73 |
+
):
|
74 |
+
self.mode = mode
|
75 |
+
self.save_dir = save_dir
|
76 |
+
if mode == "rainbow":
|
77 |
+
self.color_map = cm.get_cmap("gist_rainbow")
|
78 |
+
elif mode == "cool":
|
79 |
+
self.color_map = cm.get_cmap(mode)
|
80 |
+
self.show_first_frame = show_first_frame
|
81 |
+
self.grayscale = grayscale
|
82 |
+
self.tracks_leave_trace = tracks_leave_trace
|
83 |
+
self.pad_value = pad_value
|
84 |
+
self.linewidth = linewidth
|
85 |
+
self.fps = fps
|
86 |
+
|
87 |
+
def visualize(
|
88 |
+
self,
|
89 |
+
video: torch.Tensor, # (B,T,C,H,W)
|
90 |
+
tracks: torch.Tensor, # (B,T,N,2)
|
91 |
+
visibility: torch.Tensor = None, # (B, T, N, 1) bool
|
92 |
+
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
93 |
+
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
94 |
+
filename: str = "video",
|
95 |
+
writer=None, # tensorboard Summary Writer, used for visualization during training
|
96 |
+
step: int = 0,
|
97 |
+
query_frame=0,
|
98 |
+
save_video: bool = True,
|
99 |
+
compensate_for_camera_motion: bool = False,
|
100 |
+
opacity: float = 1.0,
|
101 |
+
):
|
102 |
+
if compensate_for_camera_motion:
|
103 |
+
assert segm_mask is not None
|
104 |
+
if segm_mask is not None:
|
105 |
+
coords = tracks[0, query_frame].round().long()
|
106 |
+
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
|
107 |
+
if (segm_mask <= 0).all():
|
108 |
+
segm_mask = None
|
109 |
+
|
110 |
+
video = F.pad(
|
111 |
+
video,
|
112 |
+
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
|
113 |
+
"constant",
|
114 |
+
255,
|
115 |
+
)
|
116 |
+
color_alpha = int(opacity * 255)
|
117 |
+
tracks = tracks + self.pad_value
|
118 |
+
|
119 |
+
if self.grayscale:
|
120 |
+
transform = transforms.Grayscale()
|
121 |
+
video = transform(video)
|
122 |
+
video = video.repeat(1, 1, 3, 1, 1)
|
123 |
+
|
124 |
+
res_video = self.draw_tracks_on_video(
|
125 |
+
video=video,
|
126 |
+
tracks=tracks,
|
127 |
+
visibility=visibility,
|
128 |
+
segm_mask=segm_mask,
|
129 |
+
gt_tracks=gt_tracks,
|
130 |
+
query_frame=query_frame,
|
131 |
+
compensate_for_camera_motion=compensate_for_camera_motion,
|
132 |
+
color_alpha=color_alpha,
|
133 |
+
)
|
134 |
+
if save_video:
|
135 |
+
self.save_video(res_video, filename=filename, writer=writer, step=step)
|
136 |
+
return res_video
|
137 |
+
|
138 |
+
def save_video(self, video, filename, writer=None, step=0):
|
139 |
+
if writer is not None:
|
140 |
+
writer.add_video(
|
141 |
+
filename,
|
142 |
+
video.to(torch.uint8),
|
143 |
+
global_step=step,
|
144 |
+
fps=self.fps,
|
145 |
+
)
|
146 |
+
else:
|
147 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
148 |
+
wide_list = list(video.unbind(1))
|
149 |
+
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
|
150 |
+
|
151 |
+
# Prepare the video file path
|
152 |
+
save_path = os.path.join(self.save_dir, f"{filename}.mp4")
|
153 |
+
|
154 |
+
# Create a writer object
|
155 |
+
video_writer = imageio.get_writer(save_path, fps=self.fps)
|
156 |
+
|
157 |
+
# Write frames to the video file
|
158 |
+
for frame in wide_list[2:-1]:
|
159 |
+
video_writer.append_data(frame)
|
160 |
+
|
161 |
+
video_writer.close()
|
162 |
+
|
163 |
+
print(f"Video saved to {save_path}")
|
164 |
+
|
165 |
+
def draw_tracks_on_video(
|
166 |
+
self,
|
167 |
+
video: torch.Tensor,
|
168 |
+
tracks: torch.Tensor,
|
169 |
+
visibility: torch.Tensor = None,
|
170 |
+
segm_mask: torch.Tensor = None,
|
171 |
+
gt_tracks=None,
|
172 |
+
query_frame=0,
|
173 |
+
compensate_for_camera_motion=False,
|
174 |
+
color_alpha: int = 255,
|
175 |
+
):
|
176 |
+
B, T, C, H, W = video.shape
|
177 |
+
_, _, N, D = tracks.shape
|
178 |
+
|
179 |
+
assert D == 2
|
180 |
+
assert C == 3
|
181 |
+
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
|
182 |
+
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
|
183 |
+
if gt_tracks is not None:
|
184 |
+
gt_tracks = gt_tracks[0].detach().cpu().numpy()
|
185 |
+
|
186 |
+
res_video = []
|
187 |
+
|
188 |
+
# process input video
|
189 |
+
for rgb in video:
|
190 |
+
res_video.append(rgb.copy())
|
191 |
+
vector_colors = np.zeros((T, N, 3))
|
192 |
+
|
193 |
+
if self.mode == "optical_flow":
|
194 |
+
import flow_vis
|
195 |
+
|
196 |
+
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
|
197 |
+
elif segm_mask is None:
|
198 |
+
if self.mode == "rainbow":
|
199 |
+
y_min, y_max = (
|
200 |
+
tracks[query_frame, :, 1].min(),
|
201 |
+
tracks[query_frame, :, 1].max(),
|
202 |
+
)
|
203 |
+
norm = plt.Normalize(y_min, y_max)
|
204 |
+
for n in range(N):
|
205 |
+
if isinstance(query_frame, torch.Tensor):
|
206 |
+
query_frame_ = query_frame[n]
|
207 |
+
else:
|
208 |
+
query_frame_ = query_frame
|
209 |
+
color = self.color_map(norm(tracks[query_frame_, n, 1]))
|
210 |
+
color = np.array(color[:3])[None] * 255
|
211 |
+
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
212 |
+
else:
|
213 |
+
# color changes with time
|
214 |
+
for t in range(T):
|
215 |
+
color = np.array(self.color_map(t / T)[:3])[None] * 255
|
216 |
+
vector_colors[t] = np.repeat(color, N, axis=0)
|
217 |
+
else:
|
218 |
+
if self.mode == "rainbow":
|
219 |
+
vector_colors[:, segm_mask <= 0, :] = 255
|
220 |
+
|
221 |
+
y_min, y_max = (
|
222 |
+
tracks[0, segm_mask > 0, 1].min(),
|
223 |
+
tracks[0, segm_mask > 0, 1].max(),
|
224 |
+
)
|
225 |
+
norm = plt.Normalize(y_min, y_max)
|
226 |
+
for n in range(N):
|
227 |
+
if segm_mask[n] > 0:
|
228 |
+
color = self.color_map(norm(tracks[0, n, 1]))
|
229 |
+
color = np.array(color[:3])[None] * 255
|
230 |
+
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
231 |
+
|
232 |
+
else:
|
233 |
+
# color changes with segm class
|
234 |
+
segm_mask = segm_mask.cpu()
|
235 |
+
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
|
236 |
+
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
|
237 |
+
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
|
238 |
+
vector_colors = np.repeat(color[None], T, axis=0)
|
239 |
+
|
240 |
+
# draw tracks
|
241 |
+
if self.tracks_leave_trace != 0:
|
242 |
+
for t in range(query_frame + 1, T):
|
243 |
+
first_ind = (
|
244 |
+
max(0, t - self.tracks_leave_trace)
|
245 |
+
if self.tracks_leave_trace >= 0
|
246 |
+
else 0
|
247 |
+
)
|
248 |
+
curr_tracks = tracks[first_ind : t + 1]
|
249 |
+
curr_colors = vector_colors[first_ind : t + 1]
|
250 |
+
if compensate_for_camera_motion:
|
251 |
+
diff = (
|
252 |
+
tracks[first_ind : t + 1, segm_mask <= 0]
|
253 |
+
- tracks[t : t + 1, segm_mask <= 0]
|
254 |
+
).mean(1)[:, None]
|
255 |
+
|
256 |
+
curr_tracks = curr_tracks - diff
|
257 |
+
curr_tracks = curr_tracks[:, segm_mask > 0]
|
258 |
+
curr_colors = curr_colors[:, segm_mask > 0]
|
259 |
+
|
260 |
+
res_video[t] = self._draw_pred_tracks(
|
261 |
+
res_video[t],
|
262 |
+
curr_tracks,
|
263 |
+
curr_colors,
|
264 |
+
)
|
265 |
+
if gt_tracks is not None:
|
266 |
+
res_video[t] = self._draw_gt_tracks(
|
267 |
+
res_video[t], gt_tracks[first_ind : t + 1]
|
268 |
+
)
|
269 |
+
|
270 |
+
# draw points
|
271 |
+
for t in range(T):
|
272 |
+
img = Image.fromarray(np.uint8(res_video[t]))
|
273 |
+
for i in range(N):
|
274 |
+
coord = (tracks[t, i, 0], tracks[t, i, 1])
|
275 |
+
visibile = True
|
276 |
+
if visibility is not None:
|
277 |
+
visibile = visibility[0, t, i]
|
278 |
+
if coord[0] != 0 and coord[1] != 0:
|
279 |
+
if not compensate_for_camera_motion or (
|
280 |
+
compensate_for_camera_motion and segm_mask[i] > 0
|
281 |
+
):
|
282 |
+
img = draw_circle(
|
283 |
+
img,
|
284 |
+
coord=coord,
|
285 |
+
radius=int(self.linewidth * 2),
|
286 |
+
color=vector_colors[t, i].astype(int),
|
287 |
+
visible=visibile,
|
288 |
+
color_alpha=color_alpha,
|
289 |
+
)
|
290 |
+
res_video[t] = np.array(img)
|
291 |
+
|
292 |
+
# construct the final rgb sequence
|
293 |
+
if self.show_first_frame > 0:
|
294 |
+
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
|
295 |
+
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
|
296 |
+
|
297 |
+
def _draw_pred_tracks(
|
298 |
+
self,
|
299 |
+
rgb: np.ndarray, # H x W x 3
|
300 |
+
tracks: np.ndarray, # T x 2
|
301 |
+
vector_colors: np.ndarray,
|
302 |
+
alpha: float = 0.5,
|
303 |
+
):
|
304 |
+
T, N, _ = tracks.shape
|
305 |
+
rgb = Image.fromarray(np.uint8(rgb))
|
306 |
+
for s in range(T - 1):
|
307 |
+
vector_color = vector_colors[s]
|
308 |
+
original = rgb.copy()
|
309 |
+
alpha = (s / T) ** 2
|
310 |
+
for i in range(N):
|
311 |
+
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
|
312 |
+
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
|
313 |
+
if coord_y[0] != 0 and coord_y[1] != 0:
|
314 |
+
rgb = draw_line(
|
315 |
+
rgb,
|
316 |
+
coord_y,
|
317 |
+
coord_x,
|
318 |
+
vector_color[i].astype(int),
|
319 |
+
self.linewidth,
|
320 |
+
)
|
321 |
+
if self.tracks_leave_trace > 0:
|
322 |
+
rgb = Image.fromarray(
|
323 |
+
np.uint8(
|
324 |
+
add_weighted(
|
325 |
+
np.array(rgb), alpha, np.array(original), 1 - alpha, 0
|
326 |
+
)
|
327 |
+
)
|
328 |
+
)
|
329 |
+
rgb = np.array(rgb)
|
330 |
+
return rgb
|
331 |
+
|
332 |
+
def _draw_gt_tracks(
|
333 |
+
self,
|
334 |
+
rgb: np.ndarray, # H x W x 3,
|
335 |
+
gt_tracks: np.ndarray, # T x 2
|
336 |
+
):
|
337 |
+
T, N, _ = gt_tracks.shape
|
338 |
+
color = np.array((211, 0, 0))
|
339 |
+
rgb = Image.fromarray(np.uint8(rgb))
|
340 |
+
for t in range(T):
|
341 |
+
for i in range(N):
|
342 |
+
gt_tracks = gt_tracks[t][i]
|
343 |
+
# draw a red cross
|
344 |
+
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
|
345 |
+
length = self.linewidth * 3
|
346 |
+
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
|
347 |
+
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
|
348 |
+
rgb = draw_line(
|
349 |
+
rgb,
|
350 |
+
coord_y,
|
351 |
+
coord_x,
|
352 |
+
color,
|
353 |
+
self.linewidth,
|
354 |
+
)
|
355 |
+
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
|
356 |
+
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
|
357 |
+
rgb = draw_line(
|
358 |
+
rgb,
|
359 |
+
coord_y,
|
360 |
+
coord_x,
|
361 |
+
color,
|
362 |
+
self.linewidth,
|
363 |
+
)
|
364 |
+
rgb = np.array(rgb)
|
365 |
+
return rgb
|