Yang You commited on
Commit
2cabcdf
·
1 Parent(s): b9f276b

upload files

Browse files
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: AllTracker PointVersion
3
- emoji: 🔥
4
- colorFrom: purple
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.38.2
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: DenseTrack
3
+ emoji: 🏃
4
+ colorFrom: pink
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.21.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import datetime
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import imageio
9
+ from pathlib import Path
10
+ from tqdm import tqdm
11
+ import gradio as gr
12
+
13
+ # Import your custom modules
14
+ import utils.loss
15
+ import utils.samp
16
+ import utils.data
17
+ import utils.improc
18
+ import utils.misc
19
+ import utils.saveload
20
+ from nets.blocks import InputPadder
21
+ from nets.net34 import Net
22
+ from tensorboardX import SummaryWriter
23
+ import imageio
24
+ from demo_dense_visualize import Tracker
25
+ import spaces
26
+
27
+ # -------------------- Utility Functions -------------------- #
28
+ def count_parameters(model):
29
+ total_params = 0
30
+ for name, parameter in model.named_parameters():
31
+ if not parameter.requires_grad:
32
+ continue
33
+ total_params += parameter.numel()
34
+ print('Total params: %.2f M' % (total_params/1e6))
35
+ return total_params
36
+
37
+ def seed_everything(seed: int):
38
+ random.seed(seed)
39
+ os.environ["PYTHONHASHSEED"] = str(seed)
40
+ np.random.seed(seed)
41
+ torch.manual_seed(seed)
42
+ torch.cuda.manual_seed(seed)
43
+ torch.backends.cudnn.deterministic = True
44
+ torch.backends.cudnn.benchmark = False
45
+
46
+ # -------------------- Step 1: Extract the First Frame -------------------- #
47
+ def extract_first_frame(video_path, _, tracker):
48
+ """
49
+ Opens the video, extracts the first frame, resizes it (largest dimension 1024),
50
+ and returns:
51
+ - the frame for display (to be annotated),
52
+ - the video file path (to store in state),
53
+ - a copy of the original first frame (to be used when adding points)
54
+ """
55
+ cap = cv2.VideoCapture(video_path)
56
+ if not cap.isOpened():
57
+ return None, None, None
58
+ ret, frame = cap.read()
59
+ cap.release()
60
+ if not ret:
61
+ return None, video_path, None
62
+ # Convert BGR to RGB
63
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
64
+ scale = min(tracker.target_res / frame_rgb.shape[0], tracker.target_res / frame_rgb.shape[1])
65
+ frame_resized = cv2.resize(frame_rgb, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
66
+ # Return the displayed frame, the video file path, and a copy of the original frame for point drawing.
67
+ return frame_resized, video_path, frame_resized.copy(), []
68
+
69
+ # -------------------- Callback to Add a Clicked Point -------------------- #
70
+ def add_point(orig_frame, points, evt: gr.SelectData):
71
+ """
72
+ Called when the user clicks on the displayed first frame.
73
+ - orig_frame: The original first frame image (numpy array).
74
+ - points: The current list of point coordinates.
75
+ - evt: Event data from the image click (expects evt.index as (x, y)).
76
+
77
+ Returns the updated image (with circles drawn at all points)
78
+ and the updated list of points.
79
+ """
80
+ if points is None:
81
+ points = []
82
+ # evt.index contains the (x, y) coordinates of the click.
83
+ x, y = evt.index
84
+ new_points = points.copy()
85
+ new_points.append([x, y])
86
+ # Draw circles on a copy of the original image.
87
+ updated_frame = orig_frame.copy()
88
+ for (px, py) in new_points:
89
+ cv2.circle(updated_frame, (int(round(px)), int(round(py))), radius=5, color=(0,255,0), thickness=-1)
90
+
91
+ # print(updated_frame.shape)
92
+ return updated_frame, new_points
93
+
94
+ # -------------------- Step 2: Process Video & Track Points -------------------- #
95
+ @torch.no_grad()
96
+ @spaces.GPU
97
+ def process_video_with_points(video_path, click_points, tracker):
98
+ """
99
+ Runs the dense flow prediction over the entire video, tracking the user-selected points.
100
+ Args:
101
+ video_path: Path to the uploaded video.
102
+ click_points: List of [x, y] coordinates selected on the first frame.
103
+ (Coordinates are in the same (resized) space as the displayed first frame.)
104
+ tracker: The tracker instance to use for processing.
105
+ Returns:
106
+ A path to the output video with tracked points overlaid.
107
+ """
108
+ if len(click_points) == 0:
109
+ print("No points selected for tracking.")
110
+ return "Error: No points selected for tracking."
111
+
112
+ # Open the video.
113
+ cap = cv2.VideoCapture(video_path)
114
+ if not cap.isOpened():
115
+ return "Error: Could not open video."
116
+ fps = cap.get(cv2.CAP_PROP_FPS)
117
+
118
+ # List to store frames with overlaid points.
119
+ output_frames = []
120
+ # Initialize the points with those selected on the first frame.
121
+
122
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
123
+ pbar = tqdm(total=total_frames, desc="Processing video")
124
+
125
+ tracker.reset()
126
+
127
+ frame_disps = []
128
+ try:
129
+ while True:
130
+ torch.cuda.empty_cache()
131
+ ret, frame = cap.read()
132
+ if not ret:
133
+ break
134
+
135
+ # Convert frame from BGR to RGB and resize as in your original code.
136
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
137
+ scale = min(tracker.target_res / frame_rgb.shape[0], tracker.target_res / frame_rgb.shape[1])
138
+ frame_disp = cv2.resize(frame_rgb, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
139
+ frame_disps.append(frame_disp)
140
+
141
+ flows = tracker.track(frame_rgb)
142
+
143
+ if flows is not None:
144
+ flows_np = flows[0].cpu().numpy()
145
+
146
+ for i, flow_np in enumerate(flows_np):
147
+ # --- Update tracked points using the flow ---
148
+ current_points = []
149
+ for (x, y) in click_points:
150
+ xi = int(round(x))
151
+ yi = int(round(y))
152
+ # print('xi, yi', xi, yi)
153
+ if 0 <= yi < flow_np.shape[1] and 0 <= xi < flow_np.shape[2]:
154
+ dx = flow_np[0, yi, xi]
155
+ dy = flow_np[1, yi, xi]
156
+ # print('dx, dy', dx, dy)
157
+ else:
158
+ dx, dy = 0.0, 0.0
159
+ current_points.append([x + dx, y + dy])
160
+
161
+ # Draw the updated points on the frame.
162
+ for (x, y) in current_points:
163
+ cv2.circle(frame_disps[i], (int(round(x)), int(round(y))), radius=5, color=(0,255,0), thickness=-1)
164
+ output_frames.append(frame_disps[i])
165
+ frame_disps = []
166
+ pbar.update(1)
167
+
168
+ except RuntimeError as e:
169
+ # Check if the error message indicates an OOM error.
170
+ if "out of memory" in str(e).lower():
171
+ torch.cuda.empty_cache()
172
+ pbar.close()
173
+ cap.release()
174
+ print("Error: Out of Memory during video processing.")
175
+ return "Error: Out of Memory during video processing. Please try a smaller video or lower resolution."
176
+ else:
177
+ # Re-raise if it's another type of error.
178
+ raise e
179
+ pbar.close()
180
+ cap.release()
181
+
182
+ # -------------------- Save the Output Video -------------------- #
183
+ output_path = "tracked_output.mp4"
184
+ print(len(output_frames), output_frames[0].shape)
185
+ imageio.mimwrite(output_path, output_frames, fps=fps)
186
+
187
+ return output_path
188
+
189
+ # -------------------- Model Initialization -------------------- #
190
+ def load_model():
191
+ """Initialize and load the tracking model."""
192
+ # Adjust these paths as needed.
193
+ init_dir = '648Ai4i4i3n4s_1e-5m_c5c_stage3_from_kub_ns_wa_kk_lsh_dyk_46470'
194
+ ckpt_dir = 'checkpoints'
195
+ load_dir = os.path.join(ckpt_dir, init_dir)
196
+
197
+ # Create the model and load weights.
198
+ model = Net(16)
199
+ count_parameters(model)
200
+ _ = utils.saveload.load(
201
+ None,
202
+ load_dir,
203
+ model,
204
+ optimizer=None,
205
+ scheduler=None,
206
+ ignore_load=None,
207
+ strict=True,
208
+ verbose=False,
209
+ weights_only=False,
210
+ )
211
+ model.cuda()
212
+ for n, p in model.named_parameters():
213
+ p.requires_grad = False
214
+ model.eval()
215
+
216
+ return model
217
+
218
+ def create_tracker(model):
219
+ """Create tracker instance with the loaded model."""
220
+ return Tracker(
221
+ model=model,
222
+ mean=torch.tensor([0.485, 0.456, 0.406]).cuda().reshape(1, 3, 1, 1),
223
+ std=torch.tensor([0.229, 0.224, 0.225]).cuda().reshape(1, 3, 1, 1),
224
+ S=16,
225
+ stride=8,
226
+ inference_iters=4,
227
+ target_res=1024,
228
+ )
229
+
230
+ # -------------------- Wrappers to Update Tracker Based on UI Settings -------------------- #
231
+ def extract_with_config(video_path, points, resolution, window_index, tracker):
232
+ """
233
+ Update the tracker configuration using the slider values, then extract the first frame.
234
+ - resolution: Target resolution from slider (e.g., 512, 768, 1024).
235
+ - window_index: An index (0–3) to be mapped to sliding window lengths {0:2, 1:4, 2:8, 3:16}.
236
+ - tracker: The tracker instance to configure.
237
+ """
238
+ tracker.target_res = resolution
239
+ mapping = {0: 2, 1: 4, 2: 8, 3: 16}
240
+ tracker.S = mapping.get(int(window_index), 16)
241
+ return extract_first_frame(video_path, points, tracker)
242
+
243
+ @torch.no_grad()
244
+ @spaces.GPU
245
+ def process_with_config(video_path, click_points, resolution, window_index, tracker):
246
+ """
247
+ Update the tracker configuration using the slider values, then process the video.
248
+ """
249
+ tracker.target_res = resolution
250
+ mapping = {0: 2, 1: 4, 2: 8, 3: 16}
251
+ tracker.S = mapping.get(int(window_index), 16)
252
+ return process_video_with_points(video_path, click_points, tracker)
253
+
254
+ def main():
255
+ """Main function that initializes the model and runs the Gradio interface."""
256
+ # Set torch matmul precision
257
+ torch.set_float32_matmul_precision('medium')
258
+
259
+ # Initialize random seeds
260
+ seed_everything(42)
261
+ torch.set_grad_enabled(False)
262
+
263
+ # Load model and create tracker
264
+ print("Loading model...")
265
+ model = load_model()
266
+ tracker = create_tracker(model)
267
+ print("Model loaded successfully!")
268
+
269
+ # -------------------- Gradio Interface -------------------- #
270
+ # The interface is built in two steps:
271
+ # 1. Upload a video and extract the first frame.
272
+ # 2. Annotate the first frame with multiple points (using gr.Points),
273
+ # then run tracking on the video.
274
+ with gr.Blocks() as demo:
275
+ gr.Markdown("## Dense Flow Tracking with Clickable Points")
276
+
277
+ with gr.Row():
278
+ with gr.Column():
279
+ video_input = gr.Video(label="Upload Video", value="data/244754_medium.mp4")
280
+ extract_btn = gr.Button("Extract First Frame")
281
+ # Add sliders for resolution and sliding window length.
282
+ resolution_slider = gr.Slider(minimum=512, maximum=1024, step=256, value=1024, label="Target Resolution")
283
+ # The slider below outputs an index 0-3; we'll map it to {0:2, 1:4, 2:8, 3:16}
284
+ window_slider = gr.Slider(minimum=0, maximum=3, step=1, value=3, label="Sliding Window Length (Index: 0->2, 1->4, 2->8, 3->16)")
285
+ with gr.Column():
286
+ # This image will display the first frame and be interactive.
287
+ first_frame_display = gr.Image(label="First Frame (Click to add points)", interactive=True)
288
+ clear_pts_btn = gr.Button("Clear Points")
289
+
290
+ # Hidden states: video file path, original first frame, and accumulated click points.
291
+ video_state = gr.State(None)
292
+ orig_frame_state = gr.State(None)
293
+ points_state = gr.State([])
294
+
295
+ track_btn = gr.Button("Track Points")
296
+ output_video = gr.Video(label="Tracked Video")
297
+
298
+ # When "Extract First Frame" is clicked, extract and display the first frame.
299
+ extract_btn.click(
300
+ fn=lambda *args: extract_with_config(*args, tracker),
301
+ inputs=[video_input, points_state, resolution_slider, window_slider],
302
+ outputs=[first_frame_display, video_state, orig_frame_state, points_state]
303
+ )
304
+
305
+ clear_pts_btn.click(
306
+ fn=lambda orig_frame, points: (orig_frame, []),
307
+ inputs=[orig_frame_state, points_state],
308
+ outputs=[first_frame_display, points_state]
309
+ )
310
+
311
+ # When the user clicks on the image, add a point.
312
+ first_frame_display.select(
313
+ fn=add_point,
314
+ inputs=[orig_frame_state, points_state],
315
+ outputs=[first_frame_display, points_state]
316
+ )
317
+
318
+ # When "Track Points" is clicked, process the video using the accumulated points.
319
+ track_btn.click(
320
+ fn=lambda *args: process_with_config(*args, tracker),
321
+ inputs=[video_state, points_state, resolution_slider, window_slider],
322
+ outputs=output_video
323
+ )
324
+
325
+ demo.launch()
326
+
327
+ if __name__ == '__main__':
328
+ main()
checkpoints/648Ai4i4i3n4s_1e-5m_c5c_stage3_from_kub_ns_wa_kk_lsh_dyk_46470/model-000600000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15b029e1396feb453d361988afafab02d94993f4b43191b4dfcc44aac10fffa5
3
+ size 198027808
data/244754_medium.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a26ffc0ea63adfb3e2b44f5f6ea384a6071f74cd8453e7b6d386a829d8c11fcb
3
+ size 42948491
demo_dense_visualize.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import torch
4
+ import signal
5
+ import socket
6
+ import sys
7
+ import json
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import argparse
11
+ from pathlib import Path
12
+ import torch.optim as optim
13
+ from torch.cuda.amp import GradScaler
14
+ from lightning_fabric import Fabric
15
+
16
+ import utils.loss
17
+ import utils.samp
18
+ import utils.data
19
+ import utils.improc
20
+ import utils.misc
21
+ import utils.saveload
22
+ from tensorboardX import SummaryWriter
23
+ import datetime
24
+ import time
25
+ import cv2
26
+ import imageio
27
+ from nets.blocks import InputPadder
28
+ from tqdm import tqdm
29
+ # from pytorch_lightning.callbacks import BaseFinetuning
30
+ from utils.visualizer import Visualizer
31
+ from torchvision.transforms.functional import resize
32
+
33
+ import torch
34
+ import requests
35
+ from PIL import Image, ImageDraw
36
+ from transformers import AutoProcessor, AutoModelForCausalLM
37
+ import numpy as np
38
+
39
+
40
+ torch.set_float32_matmul_precision('medium')
41
+
42
+ def run_example(processor, model, task_prompt, image, text_input=None):
43
+ if text_input is None:
44
+ prompt = task_prompt
45
+ else:
46
+ prompt = task_prompt + text_input
47
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.float32)
48
+ generated_ids = model.generate(
49
+ input_ids=inputs["input_ids"],
50
+ pixel_values=inputs["pixel_values"],
51
+ max_new_tokens=1024,
52
+ num_beams=3
53
+ )
54
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
55
+
56
+ parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
57
+
58
+ return parsed_answer
59
+
60
+
61
+ def polygons_to_mask(image, prediction, fill_value=255):
62
+ """
63
+ Converts polygons into a mask.
64
+
65
+ Parameters:
66
+ - image: A PIL Image instance whose size will be used for the mask.
67
+ - prediction: Dictionary containing 'polygons' and 'labels'.
68
+ 'polygons' is a list where each element is a list of sub-polygons.
69
+ - fill_value: The pixel value used to fill the polygon areas (default 255 for a binary mask).
70
+
71
+ Returns:
72
+ - A NumPy array representing the mask (same width and height as the input image).
73
+ """
74
+ # Create a blank grayscale mask image with the same size as the original image.
75
+ mask = Image.new('L', image.size, 0)
76
+ draw = ImageDraw.Draw(mask)
77
+
78
+ # Iterate over each set of polygons
79
+ for polygons in prediction['polygons']:
80
+ # Each element in "polygons" can be a sub-polygon
81
+ for poly in polygons:
82
+ # Ensure the polygon is in the right shape and has at least 3 points.
83
+ poly_arr = np.array(poly).reshape(-1, 2)
84
+ if poly_arr.shape[0] < 3:
85
+ print('Skipping invalid polygon:', poly_arr)
86
+ continue
87
+ # Convert the polygon vertices into a list for drawing.
88
+ poly_list = poly_arr.reshape(-1).tolist()
89
+ # Draw the polygon on the mask with the fill_value.
90
+ draw.polygon(poly_list, fill=fill_value)
91
+
92
+ # Convert the PIL mask image to a NumPy array and return it.
93
+ return np.array(mask)
94
+
95
+
96
+ class Tracker:
97
+ def __init__(self, model, mean, std, S, stride, inference_iters, target_res, device='cuda'):
98
+ """
99
+ Initializes the Tracker.
100
+
101
+ Args:
102
+ model: The model used to compute feature maps and forward window flow.
103
+ mean: Tensor or value used for normalizing the input.
104
+ std: Tensor or value used for normalizing the input.
105
+ S: Window size for the tracker.
106
+ stride: The stride used when updating the window.
107
+ inference_iters: Number of inference iterations.
108
+ device: Torch device, defaults to 'cuda'.
109
+ """
110
+ self.model = model
111
+ self.mean = mean
112
+ self.std = std
113
+ self.S = S
114
+ self.stride = stride
115
+ self.inference_iters = inference_iters
116
+ self.device = device
117
+ self.target_res = target_res
118
+
119
+ self.padder = None
120
+ self.cnt = 0
121
+ self.fmap_anchor = None
122
+ self.fmaps2 = None
123
+ self.flows8 = None
124
+ self.visconfs8 = None
125
+ self.flows = [] # List to store computed flows
126
+ self.visibs = [] # List to store visibility confidences
127
+ self.rgbs = [] # List to store RGB frames
128
+
129
+ def reset(self):
130
+ """Reset the tracker state."""
131
+ self.padder = None
132
+ self.cnt = 0
133
+ self.fmap_anchor = None
134
+ self.fmaps2 = None
135
+ self.flows8 = None
136
+ self.visconfs8 = None
137
+ self.flows = []
138
+ self.visibs = []
139
+ self.rgbs = []
140
+
141
+ def preprocess(self, rgb_frame):
142
+ # Resize frame (scale to keep maximum dimension ~1024)
143
+ scale = min(self.target_res / rgb_frame.shape[0], self.target_res / rgb_frame.shape[1])
144
+ rgb_resized = cv2.resize(rgb_frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
145
+
146
+ # Convert to tensor, normalize and move to device.
147
+ rgb_tensor = torch.from_numpy(rgb_resized).permute(2, 0, 1).float().unsqueeze(0).to(self.device)
148
+ rgb_tensor = rgb_tensor / 255.0
149
+
150
+ self.rgbs.append(rgb_tensor.cpu())
151
+
152
+ # import pdb; pdb.set_trace()
153
+ rgb_tensor = (rgb_tensor - self.mean) / self.std
154
+ return rgb_tensor
155
+
156
+ @torch.no_grad()
157
+ def track(self, rgb_frame):
158
+ """
159
+ Process a single RGB frame and return the computed flow when available.
160
+
161
+ Args:
162
+ rgb_frame: A NumPy array containing the RGB frame.
163
+ (Assumed to be in RGB; if coming from OpenCV, convert it before passing.)
164
+
165
+ Returns:
166
+ flow_predictions: The predicted flow for the current frame (or None if not enough frames have been processed).
167
+ """
168
+ torch.cuda.empty_cache()
169
+
170
+ rgb_tensor = self.preprocess(rgb_frame)
171
+
172
+ # Initialize padder on the first frame.
173
+ if self.cnt == 0:
174
+ self.padder = InputPadder(rgb_tensor.shape)
175
+ rgb_padded = self.padder.pad(rgb_tensor)[0]
176
+ _, _, H_pad, W_pad = rgb_padded.shape
177
+ C = 256 # Feature map channel dimension (could be parameterized if needed)
178
+ H8, W8 = H_pad // 8, W_pad // 8
179
+
180
+ # Accumulate feature maps until the window is full.
181
+ if self.cnt == 0:
182
+ self.fmap_anchor = self.model.get_fmaps(rgb_padded, 1, 1, None, False, False).reshape(1, C, H8, W8)
183
+ self.fmaps2 = self.fmap_anchor[:, None]
184
+ self.cnt += 1
185
+ return None
186
+
187
+ new_fmap = self.model.get_fmaps(rgb_padded, 1, 1, None, False, False).reshape(1, 1, C, H8, W8)
188
+ self.fmaps2 = torch.cat([self.fmaps2[:, (1 if self.fmaps2.shape[1] >= self.S else 0):].detach().clone(), new_fmap], dim=1)
189
+
190
+ # need to track
191
+ if self.cnt - self.S + 1 >= 0 and (self.cnt - self.S + 1) % self.stride == 0:
192
+ # Initialize or update temporary flow buffers.
193
+ iter_num = self.inference_iters
194
+ if self.flows8 is None:
195
+ self.flows8 = torch.zeros((self.S, 2, H_pad // 8, W_pad // 8), device=self.device)
196
+ self.visconfs8 = torch.zeros((self.S, 2, H_pad // 8, W_pad // 8), device=self.device)
197
+ # iter_num = self.inference_iters
198
+ else:
199
+ self.flows8 = torch.cat([
200
+ self.flows8[self.stride:self.stride + self.S // 2].detach().clone(),
201
+ self.flows8[self.stride + self.S // 2 - 1:self.stride + self.S // 2].detach().clone().repeat(self.S // 2, 1, 1, 1)
202
+ ])
203
+ self.visconfs8 = torch.cat([
204
+ self.visconfs8[self.stride:self.stride + self.S // 2].detach().clone(),
205
+ self.visconfs8[self.stride + self.S // 2 - 1:self.stride + self.S // 2].detach().clone().repeat(self.S // 2, 1, 1, 1)
206
+ ])
207
+
208
+ # import pdb; pdb.set_trace()
209
+ # Compute flow predictions using the model's forward window.
210
+ flow_predictions, visconf_predictions, self.flows8, self.visconfs8, _ = self.model.forward_window(
211
+ self.fmap_anchor,
212
+ self.fmaps2,
213
+ self.visconfs8,
214
+ iters=iter_num,
215
+ flowfeat=None,
216
+ flows8=self.flows8,
217
+ is_training=False
218
+ )
219
+ flow_predictions = self.padder.unpad(flow_predictions[-1][0 if self.cnt == self.S - 1 else -self.stride:])
220
+ visconf_predictions = self.padder.unpad(torch.sigmoid(visconf_predictions[-1][0 if self.cnt == self.S - 1 else -self.stride:]))
221
+
222
+ self.cnt += 1
223
+ self.flows.append(flow_predictions.cpu())
224
+ self.visibs.append(visconf_predictions.cpu())
225
+
226
+ return flow_predictions, visconf_predictions
227
+
228
+ self.cnt += 1
229
+ return None
nets/blocks.py ADDED
@@ -0,0 +1,1337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import nn, Tensor
5
+ from itertools import repeat
6
+ import collections
7
+ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
8
+ from functools import partial
9
+ import einops
10
+ import math
11
+ from torchvision.ops.misc import Conv2dNormActivation, Permute
12
+ from torchvision.ops.stochastic_depth import StochasticDepth
13
+
14
+ def _ntuple(n):
15
+ def parse(x):
16
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
17
+ return tuple(x)
18
+ return tuple(repeat(x, n))
19
+ return parse
20
+
21
+ def exists(val):
22
+ return val is not None
23
+
24
+ def default(val, d):
25
+ return val if exists(val) else d
26
+
27
+ to_2tuple = _ntuple(2)
28
+
29
+ class InputPadder:
30
+ """ Pads images such that dimensions are divisible by a certain stride """
31
+ def __init__(self, dims, mode='sintel'):
32
+ self.ht, self.wd = dims[-2:]
33
+ pad_ht = (((self.ht // 64) + 1) * 64 - self.ht) % 64
34
+ pad_wd = (((self.wd // 64) + 1) * 64 - self.wd) % 64
35
+ if mode == 'sintel':
36
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
37
+ else:
38
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
39
+
40
+ def pad(self, *inputs):
41
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
42
+
43
+ def unpad(self, x):
44
+ ht, wd = x.shape[-2:]
45
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
46
+ return x[..., c[0]:c[1], c[2]:c[3]]
47
+
48
+ def bilinear_sampler(
49
+ input, coords,
50
+ align_corners=True,
51
+ padding_mode="border",
52
+ normalize_coords=True):
53
+ # func from mattie (oct9)
54
+ if input.ndim not in [4, 5]:
55
+ raise ValueError("input must be 4D or 5D.")
56
+
57
+ if input.ndim == 4 and not coords.ndim == 4:
58
+ raise ValueError("input is 4D, but coords is not 4D.")
59
+
60
+ if input.ndim == 5 and not coords.ndim == 5:
61
+ raise ValueError("input is 5D, but coords is not 5D.")
62
+
63
+ if coords.ndim == 5:
64
+ coords = coords[..., [1, 2, 0]] # t x y -> x y t to match what grid_sample() expects.
65
+
66
+ if normalize_coords:
67
+ if align_corners:
68
+ # Normalize coordinates from [0, W/H - 1] to [-1, 1].
69
+ coords = (
70
+ coords
71
+ * torch.tensor([2 / max(size - 1, 1) for size in reversed(input.shape[2:])], device=coords.device)
72
+ - 1
73
+ )
74
+ else:
75
+ # Normalize coordinates from [0, W/H] to [-1, 1].
76
+ coords = coords * torch.tensor([2 / size for size in reversed(input.shape[2:])], device=coords.device) - 1
77
+
78
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
79
+
80
+
81
+ class CorrBlock:
82
+ def __init__(self, fmap1, fmap2, corr_levels, corr_radius):
83
+ self.num_levels = corr_levels
84
+ self.radius = corr_radius
85
+ self.corr_pyramid = []
86
+ # all pairs correlation
87
+ for i in range(self.num_levels):
88
+ corr = CorrBlock.corr(fmap1, fmap2, 1)
89
+ batch, h1, w1, dim, h2, w2 = corr.shape
90
+ corr = corr.reshape(batch*h1*w1, dim, h2, w2)
91
+ fmap2 = F.interpolate(fmap2, scale_factor=0.5, mode='area')
92
+ # print('corr', corr.shape)
93
+ self.corr_pyramid.append(corr)
94
+
95
+ def __call__(self, coords, dilation=None):
96
+ r = self.radius
97
+ coords = coords.permute(0, 2, 3, 1)
98
+ batch, h1, w1, _ = coords.shape
99
+
100
+ if dilation is None:
101
+ dilation = torch.ones(batch, 1, h1, w1, device=coords.device)
102
+
103
+ # print(dilation.max(), dilation.mean(), dilation.min())
104
+ out_pyramid = []
105
+ for i in range(self.num_levels):
106
+ corr = self.corr_pyramid[i]
107
+ device = coords.device
108
+ dx = torch.linspace(-r, r, 2*r+1, device=device)
109
+ dy = torch.linspace(-r, r, 2*r+1, device=device)
110
+ delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
111
+ delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
112
+ delta_lvl = delta_lvl * dilation.view(batch * h1 * w1, 1, 1, 1)
113
+ centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
114
+ coords_lvl = centroid_lvl + delta_lvl
115
+ corr = bilinear_sampler(corr, coords_lvl)
116
+ corr = corr.view(batch, h1, w1, -1)
117
+ out_pyramid.append(corr)
118
+
119
+ out = torch.cat(out_pyramid, dim=-1)
120
+ out = out.permute(0, 3, 1, 2).contiguous().float()
121
+ return out
122
+
123
+ @staticmethod
124
+ def corr(fmap1, fmap2, num_head):
125
+ batch, dim, h1, w1 = fmap1.shape
126
+ h2, w2 = fmap2.shape[2:]
127
+ fmap1 = fmap1.view(batch, num_head, dim // num_head, h1*w1)
128
+ fmap2 = fmap2.view(batch, num_head, dim // num_head, h2*w2)
129
+ corr = fmap1.transpose(2, 3) @ fmap2
130
+ corr = corr.reshape(batch, num_head, h1, w1, h2, w2).permute(0, 2, 3, 1, 4, 5)
131
+ return corr / torch.sqrt(torch.tensor(dim).float())
132
+
133
+ def conv1x1(in_planes, out_planes, stride=1):
134
+ """1x1 convolution without padding"""
135
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0)
136
+
137
+ def conv3x3(in_planes, out_planes, stride=1):
138
+ """3x3 convolution with padding"""
139
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1)
140
+
141
+ class LayerNorm2d(nn.LayerNorm):
142
+ def forward(self, x: Tensor) -> Tensor:
143
+ x = x.permute(0, 2, 3, 1)
144
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
145
+ x = x.permute(0, 3, 1, 2)
146
+ return x
147
+
148
+ class CNBlock1d(nn.Module):
149
+ def __init__(
150
+ self,
151
+ dim,
152
+ output_dim,
153
+ layer_scale: float = 1e-6,
154
+ stochastic_depth_prob: float = 0,
155
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
156
+ dense=True,
157
+ use_attn=True,
158
+ use_mixer=False,
159
+ use_conv=False,
160
+ use_convb=False,
161
+ use_layer_scale=True,
162
+ ) -> None:
163
+ super().__init__()
164
+ self.dense = dense
165
+ self.use_attn = use_attn
166
+ self.use_mixer = use_mixer
167
+ self.use_conv = use_conv
168
+ self.use_layer_scale = use_layer_scale
169
+
170
+ if use_attn:
171
+ assert not use_mixer
172
+ assert not use_conv
173
+ assert not use_convb
174
+
175
+ if norm_layer is None:
176
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
177
+
178
+ if use_attn:
179
+ num_heads = 8
180
+ self.block = AttnBlock(
181
+ hidden_size=dim,
182
+ num_heads=num_heads,
183
+ mlp_ratio=4,
184
+ attn_class=Attention,
185
+ )
186
+ elif use_mixer:
187
+ self.block = MLPMixerBlock(
188
+ S=16,
189
+ dim=dim,
190
+ depth=1,
191
+ expansion_factor=2,
192
+ )
193
+ elif use_conv:
194
+ self.block = nn.Sequential(
195
+ nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True, padding_mode='zeros'),
196
+ Permute([0, 2, 1]),
197
+ norm_layer(dim),
198
+ nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
199
+ nn.GELU(),
200
+ nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
201
+ Permute([0, 2, 1]),
202
+ )
203
+ elif use_convb:
204
+ self.block = nn.Sequential(
205
+ nn.Conv1d(dim, dim, kernel_size=3, padding=1, bias=True, padding_mode='zeros'),
206
+ Permute([0, 2, 1]),
207
+ norm_layer(dim),
208
+ nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
209
+ nn.GELU(),
210
+ nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
211
+ Permute([0, 2, 1]),
212
+ )
213
+ else:
214
+ assert(False) # choose attn, mixer, or conv please
215
+
216
+ if self.use_layer_scale:
217
+ self.layer_scale = nn.Parameter(torch.ones(dim, 1) * layer_scale)
218
+ else:
219
+ self.layer_scale = 1.0
220
+
221
+ self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
222
+
223
+ if output_dim != dim:
224
+ self.final = nn.Conv1d(dim, output_dim, kernel_size=1, padding=0)
225
+ else:
226
+ self.final = nn.Identity()
227
+
228
+ def forward(self, input, S=None):
229
+ if self.dense:
230
+ assert S is not None
231
+ BS,C,H,W = input.shape
232
+ B = BS//S
233
+
234
+ # if S<7:
235
+ # return input
236
+
237
+ input = einops.rearrange(input, '(b s) c h w -> (b h w) c s', b=B, s=S, c=C, h=H, w=W)
238
+
239
+ if self.use_mixer or self.use_attn:
240
+ # mixer/transformer blocks want B,S,C
241
+ result = self.layer_scale * self.block(input.permute(0,2,1)).permute(0,2,1)
242
+ else:
243
+ result = self.layer_scale * self.block(input)
244
+ result = self.stochastic_depth(result)
245
+ result += input
246
+ result = self.final(result)
247
+
248
+ result = einops.rearrange(result, '(b h w) c s -> (b s) c h w', b=B, s=S, c=C, h=H, w=W)
249
+ else:
250
+ B,S,C = input.shape
251
+
252
+ if S<7:
253
+ return input
254
+
255
+ input = einops.rearrange(input, 'b s c -> b c s', b=B, s=S, c=C)
256
+
257
+ result = self.layer_scale * self.block(input)
258
+ result = self.stochastic_depth(result)
259
+ result += input
260
+
261
+ result = self.final(result)
262
+
263
+ result = einops.rearrange(result, 'b c s -> b s c', b=B, s=S, c=C)
264
+
265
+ return result
266
+
267
+ class CNBlock2d(nn.Module):
268
+ def __init__(
269
+ self,
270
+ dim,
271
+ output_dim,
272
+ layer_scale: float = 1e-6,
273
+ stochastic_depth_prob: float = 0,
274
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
275
+ use_layer_scale=True,
276
+ ) -> None:
277
+ super().__init__()
278
+ self.use_layer_scale = use_layer_scale
279
+ if norm_layer is None:
280
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
281
+
282
+ self.block = nn.Sequential(
283
+ nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True, padding_mode='zeros'),
284
+ Permute([0, 2, 3, 1]),
285
+ norm_layer(dim),
286
+ nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
287
+ nn.GELU(),
288
+ nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
289
+ Permute([0, 3, 1, 2]),
290
+ )
291
+ if self.use_layer_scale:
292
+ self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
293
+ else:
294
+ self.layer_scale = 1.0
295
+ self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
296
+
297
+ if output_dim != dim:
298
+ self.final = nn.Conv2d(dim, output_dim, kernel_size=1, padding=0)
299
+ else:
300
+ self.final = nn.Identity()
301
+
302
+ def forward(self, input, S=None):
303
+ result = self.layer_scale * self.block(input)
304
+ result = self.stochastic_depth(result)
305
+ result += input
306
+ result = self.final(result)
307
+ return result
308
+
309
+ class CNBlockConfig:
310
+ # Stores information listed at Section 3 of the ConvNeXt paper
311
+ def __init__(
312
+ self,
313
+ input_channels: int,
314
+ out_channels: Optional[int],
315
+ num_layers: int,
316
+ downsample: bool,
317
+ ) -> None:
318
+ self.input_channels = input_channels
319
+ self.out_channels = out_channels
320
+ self.num_layers = num_layers
321
+ self.downsample = downsample
322
+
323
+ def __repr__(self) -> str:
324
+ s = self.__class__.__name__ + "("
325
+ s += "input_channels={input_channels}"
326
+ s += ", out_channels={out_channels}"
327
+ s += ", num_layers={num_layers}"
328
+ s += ", downsample={downsample}"
329
+ s += ")"
330
+ return s.format(**self.__dict__)
331
+
332
+ class ConvNeXt(nn.Module):
333
+ def __init__(
334
+ self,
335
+ block_setting: List[CNBlockConfig],
336
+ stochastic_depth_prob: float = 0.0,
337
+ layer_scale: float = 1e-6,
338
+ num_classes: int = 1000,
339
+ block: Optional[Callable[..., nn.Module]] = None,
340
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
341
+ init_weights=True):
342
+ super().__init__()
343
+
344
+ self.init_weights = init_weights
345
+
346
+ if not block_setting:
347
+ raise ValueError("The block_setting should not be empty")
348
+ elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
349
+ raise TypeError("The block_setting should be List[CNBlockConfig]")
350
+
351
+ if block is None:
352
+ block = CNBlock2d
353
+
354
+ if norm_layer is None:
355
+ norm_layer = partial(LayerNorm2d, eps=1e-6)
356
+
357
+ layers: List[nn.Module] = []
358
+
359
+ # Stem
360
+ firstconv_output_channels = block_setting[0].input_channels
361
+ layers.append(
362
+ Conv2dNormActivation(
363
+ 3,
364
+ firstconv_output_channels,
365
+ kernel_size=4,
366
+ stride=4,
367
+ padding=0,
368
+ norm_layer=norm_layer,
369
+ activation_layer=None,
370
+ bias=True,
371
+ )
372
+ )
373
+
374
+ total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
375
+ stage_block_id = 0
376
+ for cnf in block_setting:
377
+ # Bottlenecks
378
+ stage: List[nn.Module] = []
379
+ for _ in range(cnf.num_layers):
380
+ # adjust stochastic depth probability based on the depth of the stage block
381
+ sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
382
+ stage.append(block(cnf.input_channels, cnf.input_channels, layer_scale, sd_prob))
383
+ stage_block_id += 1
384
+ layers.append(nn.Sequential(*stage))
385
+ if cnf.out_channels is not None:
386
+ if cnf.downsample:
387
+ layers.append(
388
+ nn.Sequential(
389
+ norm_layer(cnf.input_channels),
390
+ nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
391
+ )
392
+ )
393
+ else:
394
+ # we convert the 2x2 downsampling layer into a 3x3 with dilation2 and replicate padding.
395
+ # replicate padding compensates for the fact that this kernel never saw zero-padding.
396
+ layers.append(
397
+ nn.Sequential(
398
+ norm_layer(cnf.input_channels),
399
+ nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=3, stride=1, padding=2, dilation=2, padding_mode='zeros'),
400
+ )
401
+ )
402
+
403
+ self.features = nn.Sequential(*layers)
404
+
405
+ # self.final_conv = conv1x1(block_setting[-1].input_channels, output_dim)
406
+
407
+ for m in self.modules():
408
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
409
+ nn.init.trunc_normal_(m.weight, std=0.02)
410
+ if m.bias is not None:
411
+ nn.init.zeros_(m.bias)
412
+
413
+ if self.init_weights:
414
+ from torchvision.models import convnext_tiny, ConvNeXt_Tiny_Weights
415
+ pretrained_dict = convnext_tiny(weights=ConvNeXt_Tiny_Weights.DEFAULT).state_dict()
416
+ # from torchvision.models import convnext_base, ConvNeXt_Base_Weights
417
+ # pretrained_dict = convnext_base(weights=ConvNeXt_Base_Weights.DEFAULT).state_dict()
418
+ model_dict = self.state_dict()
419
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
420
+ # if self.input_dim == 6:
421
+ # for k, v in pretrained_dict.items():
422
+ # if k == 'conv1.weight':
423
+ # pretrained_dict[k] = torch.cat((v, v), dim=1)
424
+
425
+ # del pretrained_dict['features.4.1.weight']
426
+
427
+ for k, v in pretrained_dict.items():
428
+ if k == 'features.4.1.weight': # this is the layer normally in charge of 2x2 downsampling
429
+ # convert to 3x3 filter
430
+ pretrained_dict[k] = F.interpolate(v, (3, 3), mode='bicubic', align_corners=True) * (4/9.0)
431
+
432
+ # print('v0', v[0,0])
433
+ # print('d0', pretrained_dict[k][0,0])
434
+
435
+ model_dict.update(pretrained_dict)
436
+ self.load_state_dict(model_dict, strict=False)
437
+
438
+
439
+ def _forward_impl(self, x: Tensor) -> Tensor:
440
+ # with torch.no_grad():
441
+ x = self.features(x)
442
+ # x = self.final_conv(x)
443
+ return x
444
+
445
+ def forward(self, x: Tensor) -> Tensor:
446
+ return self._forward_impl(x)
447
+
448
+ class Mlp(nn.Module):
449
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
450
+
451
+ def __init__(
452
+ self,
453
+ in_features,
454
+ hidden_features=None,
455
+ out_features=None,
456
+ act_layer=nn.GELU,
457
+ norm_layer=None,
458
+ bias=True,
459
+ drop=0.0,
460
+ use_conv=False,
461
+ ):
462
+ super().__init__()
463
+ out_features = out_features or in_features
464
+ hidden_features = hidden_features or in_features
465
+ bias = to_2tuple(bias)
466
+ drop_probs = to_2tuple(drop)
467
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
468
+
469
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
470
+ self.act = act_layer()
471
+ self.drop1 = nn.Dropout(drop_probs[0])
472
+ self.norm = (
473
+ norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
474
+ )
475
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
476
+ self.drop2 = nn.Dropout(drop_probs[1])
477
+
478
+ def forward(self, x):
479
+ x = self.fc1(x)
480
+ x = self.act(x)
481
+ x = self.drop1(x)
482
+ x = self.fc2(x)
483
+ x = self.drop2(x)
484
+ return x
485
+
486
+ class Attention(nn.Module):
487
+ def __init__(
488
+ self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False
489
+ ):
490
+ super().__init__()
491
+ inner_dim = dim_head * num_heads
492
+ context_dim = default(context_dim, query_dim)
493
+ self.scale = dim_head**-0.5
494
+ self.heads = num_heads
495
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
496
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
497
+ self.to_out = nn.Linear(inner_dim, query_dim)
498
+
499
+ def forward(self, x, context=None, attn_bias=None):
500
+ B, N1, C = x.shape
501
+ H = self.heads
502
+ q = self.to_q(x)
503
+ context = default(context, x)
504
+ k, v = self.to_kv(context).chunk(2, dim=-1)
505
+ q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))
506
+ x = F.scaled_dot_product_attention(q, k, v) # scale default is already dim^-0.5
507
+ x = einops.rearrange(x, 'b h n d -> b n (h d)')
508
+ return self.to_out(x)
509
+
510
+ class CrossAttnBlock(nn.Module):
511
+ def __init__(
512
+ self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs
513
+ ):
514
+ super().__init__()
515
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
516
+ self.norm_context = nn.LayerNorm(hidden_size)
517
+ self.cross_attn = Attention(
518
+ hidden_size,
519
+ context_dim=context_dim,
520
+ num_heads=num_heads,
521
+ qkv_bias=True,
522
+ **block_kwargs
523
+ )
524
+
525
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
526
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
527
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
528
+ self.mlp = Mlp(
529
+ in_features=hidden_size,
530
+ hidden_features=mlp_hidden_dim,
531
+ act_layer=approx_gelu,
532
+ drop=0,
533
+ )
534
+
535
+ def forward(self, x, context, mask=None):
536
+ attn_bias = None
537
+ if mask is not None:
538
+ if mask.shape[1] == x.shape[1]:
539
+ mask = mask[:, None, :, None].expand(
540
+ -1, self.cross_attn.heads, -1, context.shape[1]
541
+ )
542
+ else:
543
+ mask = mask[:, None, None].expand(
544
+ -1, self.cross_attn.heads, x.shape[1], -1
545
+ )
546
+
547
+ max_neg_value = -torch.finfo(x.dtype).max
548
+ attn_bias = (~mask) * max_neg_value
549
+ x = x + self.cross_attn(
550
+ self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias
551
+ )
552
+ x = x + self.mlp(self.norm2(x))
553
+ return x
554
+
555
+ class AttnBlock(nn.Module):
556
+ def __init__(
557
+ self,
558
+ hidden_size,
559
+ num_heads,
560
+ attn_class: Callable[..., nn.Module] = Attention,
561
+ mlp_ratio=4.0,
562
+ **block_kwargs
563
+ ):
564
+ super().__init__()
565
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
566
+ self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, dim_head=hidden_size//num_heads)
567
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
568
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
569
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
570
+ self.mlp = Mlp(
571
+ in_features=hidden_size,
572
+ hidden_features=mlp_hidden_dim,
573
+ act_layer=approx_gelu,
574
+ drop=0,
575
+ )
576
+
577
+ def forward(self, x, mask=None):
578
+ attn_bias = mask
579
+ if mask is not None:
580
+ mask = (
581
+ (mask[:, None] * mask[:, :, None])
582
+ .unsqueeze(1)
583
+ .expand(-1, self.attn.num_heads, -1, -1)
584
+ )
585
+ max_neg_value = -torch.finfo(x.dtype).max
586
+ attn_bias = (~mask) * max_neg_value
587
+
588
+ x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
589
+ x = x + self.mlp(self.norm2(x))
590
+ return x
591
+
592
+
593
+ class ResidualBlock(nn.Module):
594
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
595
+ super(ResidualBlock, self).__init__()
596
+
597
+ self.conv1 = nn.Conv2d(
598
+ in_planes,
599
+ planes,
600
+ kernel_size=3,
601
+ padding=1,
602
+ stride=stride,
603
+ padding_mode="zeros",
604
+ )
605
+ self.conv2 = nn.Conv2d(
606
+ planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
607
+ )
608
+ self.relu = nn.ReLU(inplace=True)
609
+
610
+ num_groups = planes // 8
611
+
612
+ if norm_fn == "group":
613
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
614
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
615
+ if not stride == 1:
616
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
617
+
618
+ elif norm_fn == "batch":
619
+ self.norm1 = nn.BatchNorm2d(planes)
620
+ self.norm2 = nn.BatchNorm2d(planes)
621
+ if not stride == 1:
622
+ self.norm3 = nn.BatchNorm2d(planes)
623
+
624
+ elif norm_fn == "instance":
625
+ self.norm1 = nn.InstanceNorm2d(planes)
626
+ self.norm2 = nn.InstanceNorm2d(planes)
627
+ if not stride == 1:
628
+ self.norm3 = nn.InstanceNorm2d(planes)
629
+
630
+ elif norm_fn == "none":
631
+ self.norm1 = nn.Sequential()
632
+ self.norm2 = nn.Sequential()
633
+ if not stride == 1:
634
+ self.norm3 = nn.Sequential()
635
+
636
+ if stride == 1:
637
+ self.downsample = None
638
+
639
+ else:
640
+ self.downsample = nn.Sequential(
641
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
642
+ )
643
+
644
+ def forward(self, x):
645
+ y = x
646
+ y = self.relu(self.norm1(self.conv1(y)))
647
+ y = self.relu(self.norm2(self.conv2(y)))
648
+
649
+ if self.downsample is not None:
650
+ x = self.downsample(x)
651
+
652
+ return self.relu(x + y)
653
+
654
+
655
+ class BasicEncoder(nn.Module):
656
+ def __init__(self, input_dim=3, output_dim=128, stride=4):
657
+ super(BasicEncoder, self).__init__()
658
+ self.stride = stride
659
+ self.norm_fn = "instance"
660
+ self.in_planes = output_dim // 2
661
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
662
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
663
+
664
+ self.conv1 = nn.Conv2d(
665
+ input_dim,
666
+ self.in_planes,
667
+ kernel_size=7,
668
+ stride=2,
669
+ padding=3,
670
+ padding_mode="zeros",
671
+ )
672
+ self.relu1 = nn.ReLU(inplace=True)
673
+ self.layer1 = self._make_layer(output_dim // 2, stride=1)
674
+ self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
675
+ self.layer3 = self._make_layer(output_dim, stride=2)
676
+ self.layer4 = self._make_layer(output_dim, stride=2)
677
+
678
+ self.conv2 = nn.Conv2d(
679
+ output_dim * 3 + output_dim // 4,
680
+ output_dim * 2,
681
+ kernel_size=3,
682
+ padding=1,
683
+ padding_mode="zeros",
684
+ )
685
+ self.relu2 = nn.ReLU(inplace=True)
686
+ self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
687
+ for m in self.modules():
688
+ if isinstance(m, nn.Conv2d):
689
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
690
+ elif isinstance(m, (nn.InstanceNorm2d)):
691
+ if m.weight is not None:
692
+ nn.init.constant_(m.weight, 1)
693
+ if m.bias is not None:
694
+ nn.init.constant_(m.bias, 0)
695
+
696
+ def _make_layer(self, dim, stride=1):
697
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
698
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
699
+ layers = (layer1, layer2)
700
+
701
+ self.in_planes = dim
702
+ return nn.Sequential(*layers)
703
+
704
+ def forward(self, x):
705
+ _, _, H, W = x.shape
706
+
707
+ # with torch.no_grad():
708
+ x = self.conv1(x)
709
+ x = self.norm1(x)
710
+ x = self.relu1(x)
711
+
712
+ a = self.layer1(x)
713
+ b = self.layer2(a)
714
+ c = self.layer3(b)
715
+ d = self.layer4(c)
716
+
717
+ def _bilinear_intepolate(x):
718
+ return F.interpolate(
719
+ x,
720
+ (H // self.stride, W // self.stride),
721
+ mode="bilinear",
722
+ align_corners=True,
723
+ )
724
+
725
+ a = _bilinear_intepolate(a)
726
+ b = _bilinear_intepolate(b)
727
+ c = _bilinear_intepolate(c)
728
+ d = _bilinear_intepolate(d)
729
+
730
+ x = self.conv2(torch.cat([a, b, c, d], dim=1))
731
+ x = self.norm2(x)
732
+ x = self.relu2(x)
733
+ x = self.conv3(x)
734
+ return x
735
+
736
+ class EfficientUpdateFormer(nn.Module):
737
+ """
738
+ Transformer model that updates track estimates.
739
+ """
740
+
741
+ def __init__(
742
+ self,
743
+ space_depth=6,
744
+ time_depth=6,
745
+ input_dim=320,
746
+ hidden_size=384,
747
+ num_heads=8,
748
+ output_dim=130,
749
+ mlp_ratio=4.0,
750
+ num_virtual_tracks=64,
751
+ add_space_attn=True,
752
+ linear_layer_for_vis_conf=False,
753
+ use_time_conv=False,
754
+ use_time_mixer=False,
755
+ ):
756
+ super().__init__()
757
+ self.out_channels = 2
758
+ self.num_heads = num_heads
759
+ self.hidden_size = hidden_size
760
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
761
+ if linear_layer_for_vis_conf:
762
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True)
763
+ self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
764
+ else:
765
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
766
+ self.num_virtual_tracks = num_virtual_tracks
767
+ self.virual_tracks = nn.Parameter(
768
+ torch.randn(1, num_virtual_tracks, 1, hidden_size)
769
+ )
770
+ self.add_space_attn = add_space_attn
771
+ self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
772
+
773
+ if use_time_conv:
774
+ self.time_blocks = nn.ModuleList(
775
+ [
776
+ CNBlock1d(hidden_size, hidden_size, dense=False)
777
+ for _ in range(time_depth)
778
+ ]
779
+ )
780
+ elif use_time_mixer:
781
+ self.time_blocks = nn.ModuleList(
782
+ [
783
+ MLPMixerBlock(
784
+ S=16,
785
+ dim=hidden_size,
786
+ depth=1,
787
+ )
788
+ for _ in range(time_depth)
789
+ ]
790
+ )
791
+ else:
792
+ self.time_blocks = nn.ModuleList(
793
+ [
794
+ AttnBlock(
795
+ hidden_size,
796
+ num_heads,
797
+ mlp_ratio=mlp_ratio,
798
+ attn_class=Attention,
799
+ )
800
+ for _ in range(time_depth)
801
+ ]
802
+ )
803
+
804
+ if add_space_attn:
805
+ self.space_virtual_blocks = nn.ModuleList(
806
+ [
807
+ AttnBlock(
808
+ hidden_size,
809
+ num_heads,
810
+ mlp_ratio=mlp_ratio,
811
+ attn_class=Attention,
812
+ )
813
+ for _ in range(space_depth)
814
+ ]
815
+ )
816
+ self.space_point2virtual_blocks = nn.ModuleList(
817
+ [
818
+ CrossAttnBlock(
819
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
820
+ )
821
+ for _ in range(space_depth)
822
+ ]
823
+ )
824
+ self.space_virtual2point_blocks = nn.ModuleList(
825
+ [
826
+ CrossAttnBlock(
827
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
828
+ )
829
+ for _ in range(space_depth)
830
+ ]
831
+ )
832
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
833
+ self.initialize_weights()
834
+
835
+ def initialize_weights(self):
836
+ def _basic_init(module):
837
+ if isinstance(module, nn.Linear):
838
+ torch.nn.init.xavier_uniform_(module.weight)
839
+ if module.bias is not None:
840
+ nn.init.constant_(module.bias, 0)
841
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
842
+ if self.linear_layer_for_vis_conf:
843
+ torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001)
844
+
845
+ def _trunc_init(module):
846
+ """ViT weight initialization, original timm impl (for reproducibility)"""
847
+ if isinstance(module, nn.Linear):
848
+ torch.nn.init.trunc_normal_(module.weight, std=0.02)
849
+ if module.bias is not None:
850
+ nn.init.zeros_(module.bias)
851
+
852
+ self.apply(_basic_init)
853
+
854
+ def forward(self, input_tensor, mask=None, add_space_attn=True):
855
+ tokens = self.input_transform(input_tensor)
856
+
857
+ B, _, T, _ = tokens.shape
858
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
859
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
860
+
861
+ _, N, _, _ = tokens.shape
862
+ j = 0
863
+ layers = []
864
+ for i in range(len(self.time_blocks)):
865
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
866
+ time_tokens = self.time_blocks[i](time_tokens)
867
+
868
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
869
+ if (
870
+ add_space_attn
871
+ and hasattr(self, "space_virtual_blocks")
872
+ and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0)
873
+ ):
874
+ space_tokens = (
875
+ tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
876
+ ) # B N T C -> (B T) N C
877
+
878
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
879
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
880
+
881
+ virtual_tokens = self.space_virtual2point_blocks[j](
882
+ virtual_tokens, point_tokens, mask=mask
883
+ )
884
+
885
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
886
+ point_tokens = self.space_point2virtual_blocks[j](
887
+ point_tokens, virtual_tokens, mask=mask
888
+ )
889
+
890
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
891
+ tokens = space_tokens.view(B, T, N, -1).permute(
892
+ 0, 2, 1, 3
893
+ ) # (B T) N C -> B N T C
894
+ j += 1
895
+ tokens = tokens[:, : N - self.num_virtual_tracks]
896
+
897
+ flow = self.flow_head(tokens)
898
+ if self.linear_layer_for_vis_conf:
899
+ vis_conf = self.vis_conf_head(tokens)
900
+ flow = torch.cat([flow, vis_conf], dim=-1)
901
+
902
+ return flow
903
+
904
+
905
+ class MMPreNormResidual(nn.Module):
906
+ def __init__(self, dim, fn):
907
+ super().__init__()
908
+ self.fn = fn
909
+ self.norm = nn.LayerNorm(dim)
910
+
911
+ def forward(self, x):
912
+ return self.fn(self.norm(x)) + x
913
+
914
+ def MMFeedForward(dim, expansion_factor=4, dropout=0., dense=nn.Linear):
915
+ return nn.Sequential(
916
+ dense(dim, dim * expansion_factor),
917
+ nn.GELU(),
918
+ nn.Dropout(dropout),
919
+ dense(dim * expansion_factor, dim),
920
+ nn.Dropout(dropout)
921
+ )
922
+
923
+ def MLPMixer(S, input_dim, dim, output_dim, depth=6, expansion_factor=4, dropout=0., do_reduce=False):
924
+ # input is coming in as B,S,C, as standard for mlp and transformer
925
+ # chan_first treats S as the channel dim, and transforms it to a new S
926
+ # chan_last treats C as the channel dim, and transforms it to a new C
927
+ chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
928
+ if do_reduce:
929
+ return nn.Sequential(
930
+ nn.Linear(input_dim, dim),
931
+ *[nn.Sequential(
932
+ MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
933
+ MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
934
+ ) for _ in range(depth)],
935
+ nn.LayerNorm(dim),
936
+ Reduce('b n c -> b c', 'mean'),
937
+ nn.Linear(dim, output_dim)
938
+ )
939
+ else:
940
+ return nn.Sequential(
941
+ nn.Linear(input_dim, dim),
942
+ *[nn.Sequential(
943
+ MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
944
+ MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
945
+ ) for _ in range(depth)],
946
+ )
947
+
948
+ def MLPMixerBlock(S, dim, depth=1, expansion_factor=4, dropout=0., do_reduce=False):
949
+ # input is coming in as B,S,C, as standard for mlp and transformer
950
+ # chan_first treats S as the channel dim, and transforms it to a new S
951
+ # chan_last treats C as the channel dim, and transforms it to a new C
952
+ chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
953
+ return nn.Sequential(
954
+ *[nn.Sequential(
955
+ MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
956
+ MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
957
+ ) for _ in range(depth)],
958
+ )
959
+
960
+
961
+ class MlpUpdateFormer(nn.Module):
962
+ """
963
+ Transformer model that updates track estimates.
964
+ """
965
+
966
+ def __init__(
967
+ self,
968
+ space_depth=6,
969
+ time_depth=6,
970
+ input_dim=320,
971
+ hidden_size=384,
972
+ num_heads=8,
973
+ output_dim=130,
974
+ mlp_ratio=4.0,
975
+ num_virtual_tracks=64,
976
+ add_space_attn=True,
977
+ linear_layer_for_vis_conf=False,
978
+ ):
979
+ super().__init__()
980
+ self.out_channels = 2
981
+ self.num_heads = num_heads
982
+ self.hidden_size = hidden_size
983
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
984
+ if linear_layer_for_vis_conf:
985
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True)
986
+ self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
987
+ else:
988
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
989
+ self.num_virtual_tracks = num_virtual_tracks
990
+ self.virual_tracks = nn.Parameter(
991
+ torch.randn(1, num_virtual_tracks, 1, hidden_size)
992
+ )
993
+ self.add_space_attn = add_space_attn
994
+ self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
995
+ self.time_blocks = nn.ModuleList(
996
+ [
997
+ MLPMixer(
998
+ S=16,
999
+ input_dim=hidden_size,
1000
+ dim=hidden_size,
1001
+ output_dim=hidden_size,
1002
+ depth=1,
1003
+ )
1004
+ for _ in range(time_depth)
1005
+ ]
1006
+ )
1007
+
1008
+ if add_space_attn:
1009
+ self.space_virtual_blocks = nn.ModuleList(
1010
+ [
1011
+ AttnBlock(
1012
+ hidden_size,
1013
+ num_heads,
1014
+ mlp_ratio=mlp_ratio,
1015
+ attn_class=Attention,
1016
+ )
1017
+ for _ in range(space_depth)
1018
+ ]
1019
+ )
1020
+ self.space_point2virtual_blocks = nn.ModuleList(
1021
+ [
1022
+ CrossAttnBlock(
1023
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
1024
+ )
1025
+ for _ in range(space_depth)
1026
+ ]
1027
+ )
1028
+ self.space_virtual2point_blocks = nn.ModuleList(
1029
+ [
1030
+ CrossAttnBlock(
1031
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
1032
+ )
1033
+ for _ in range(space_depth)
1034
+ ]
1035
+ )
1036
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
1037
+ self.initialize_weights()
1038
+
1039
+ def initialize_weights(self):
1040
+ def _basic_init(module):
1041
+ if isinstance(module, nn.Linear):
1042
+ torch.nn.init.xavier_uniform_(module.weight)
1043
+ if module.bias is not None:
1044
+ nn.init.constant_(module.bias, 0)
1045
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
1046
+ if self.linear_layer_for_vis_conf:
1047
+ torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001)
1048
+
1049
+ def _trunc_init(module):
1050
+ """ViT weight initialization, original timm impl (for reproducibility)"""
1051
+ if isinstance(module, nn.Linear):
1052
+ torch.nn.init.trunc_normal_(module.weight, std=0.02)
1053
+ if module.bias is not None:
1054
+ nn.init.zeros_(module.bias)
1055
+
1056
+ self.apply(_basic_init)
1057
+
1058
+ def forward(self, input_tensor, mask=None, add_space_attn=True):
1059
+ tokens = self.input_transform(input_tensor)
1060
+
1061
+ B, _, T, _ = tokens.shape
1062
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
1063
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
1064
+
1065
+ _, N, _, _ = tokens.shape
1066
+ j = 0
1067
+ layers = []
1068
+ for i in range(len(self.time_blocks)):
1069
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
1070
+ time_tokens = self.time_blocks[i](time_tokens)
1071
+
1072
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
1073
+ if (
1074
+ add_space_attn
1075
+ and hasattr(self, "space_virtual_blocks")
1076
+ and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0)
1077
+ ):
1078
+ space_tokens = (
1079
+ tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
1080
+ ) # B N T C -> (B T) N C
1081
+
1082
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
1083
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
1084
+
1085
+ virtual_tokens = self.space_virtual2point_blocks[j](
1086
+ virtual_tokens, point_tokens, mask=mask
1087
+ )
1088
+
1089
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
1090
+ point_tokens = self.space_point2virtual_blocks[j](
1091
+ point_tokens, virtual_tokens, mask=mask
1092
+ )
1093
+
1094
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
1095
+ tokens = space_tokens.view(B, T, N, -1).permute(
1096
+ 0, 2, 1, 3
1097
+ ) # (B T) N C -> B N T C
1098
+ j += 1
1099
+ tokens = tokens[:, : N - self.num_virtual_tracks]
1100
+
1101
+ flow = self.flow_head(tokens)
1102
+ if self.linear_layer_for_vis_conf:
1103
+ vis_conf = self.vis_conf_head(tokens)
1104
+ flow = torch.cat([flow, vis_conf], dim=-1)
1105
+
1106
+ return flow
1107
+
1108
+ class BasicMotionEncoder(nn.Module):
1109
+ def __init__(self, corr_channel, dim=128, pdim=2):
1110
+ super(BasicMotionEncoder, self).__init__()
1111
+ self.pdim = pdim
1112
+ self.convc1 = nn.Conv2d(corr_channel, dim*4, 1, padding=0)
1113
+ self.convc2 = nn.Conv2d(dim*4, dim+dim//2, 3, padding=1)
1114
+ if pdim==2 or pdim==4:
1115
+ self.convf1 = nn.Conv2d(pdim, dim*2, 5, padding=2)
1116
+ self.convf2 = nn.Conv2d(dim*2, dim//2, 3, padding=1)
1117
+ self.conv = nn.Conv2d(dim*2, dim-pdim, 3, padding=1)
1118
+ else:
1119
+ self.conv = nn.Conv2d(dim+dim//2+pdim, dim, 3, padding=1)
1120
+
1121
+ def forward(self, flow, corr):
1122
+ cor = F.relu(self.convc1(corr))
1123
+ cor = F.relu(self.convc2(cor))
1124
+ if self.pdim==2 or self.pdim==4:
1125
+ flo = F.relu(self.convf1(flow))
1126
+ flo = F.relu(self.convf2(flo))
1127
+ cor_flo = torch.cat([cor, flo], dim=1)
1128
+ out = F.relu(self.conv(cor_flo))
1129
+ return torch.cat([out, flow], dim=1)
1130
+ else:
1131
+ # the flow is already encoded to something nice
1132
+ cor_flo = torch.cat([cor, flow], dim=1)
1133
+ return F.relu(self.conv(cor_flo))
1134
+ # return torch.cat([out, flow], dim=1)
1135
+
1136
+ def conv133_encoder(input_dim, dim, expansion_factor=4):
1137
+ return nn.Sequential(
1138
+ nn.Conv2d(input_dim, dim*expansion_factor, kernel_size=1),
1139
+ nn.GELU(),
1140
+ nn.Conv2d(dim*expansion_factor, dim*expansion_factor, kernel_size=3, padding=1),
1141
+ nn.GELU(),
1142
+ nn.Conv2d(dim*expansion_factor, dim, kernel_size=3, padding=1),
1143
+ )
1144
+
1145
+ class BasicUpdateBlock(nn.Module):
1146
+ def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128):
1147
+ # flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
1148
+ super(BasicUpdateBlock, self).__init__()
1149
+ self.encoder = BasicMotionEncoder(corr_channel, dim=cdim)
1150
+ self.compressor = conv1x1(2*cdim+hdim, hdim)
1151
+
1152
+ self.refine = []
1153
+ for i in range(num_blocks):
1154
+ self.refine.append(CNBlock1d(hdim, hdim))
1155
+ self.refine.append(CNBlock2d(hdim, hdim))
1156
+ self.refine = nn.ModuleList(self.refine)
1157
+
1158
+ def forward(self, flowfeat, ctxfeat, corr, flow, S, upsample=True):
1159
+ BS,C,H,W = flowfeat.shape
1160
+ B = BS//S
1161
+
1162
+ # with torch.no_grad():
1163
+ motion_features = self.encoder(flow, corr)
1164
+ flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features], dim=1))
1165
+
1166
+ for blk in self.refine:
1167
+ flowfeat = blk(flowfeat, S)
1168
+ return flowfeat
1169
+
1170
+ class FullUpdateBlock(nn.Module):
1171
+ def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128, pdim=2, use_attn=False):
1172
+ # flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
1173
+ super(FullUpdateBlock, self).__init__()
1174
+ self.encoder = BasicMotionEncoder(corr_channel, dim=cdim, pdim=pdim)
1175
+
1176
+ # note we have hdim==cdim
1177
+ # compressor chans:
1178
+ # dim for flowfeat
1179
+ # dim for ctxfeat
1180
+ # dim for motion_features
1181
+ # pdim for flow (if p 2, like if we give sincos(relflow))
1182
+ # 2 for visconf
1183
+
1184
+ if pdim==2:
1185
+ # hdim==cdim
1186
+ # dim for flowfeat
1187
+ # dim for ctxfeat
1188
+ # dim for motion_features
1189
+ # 2 for visconf
1190
+ self.compressor = conv1x1(2*cdim+hdim+2, hdim)
1191
+ else:
1192
+ # we concatenate the flow info again, to not lose it (e.g., from the relu)
1193
+ self.compressor = conv1x1(2*cdim+hdim+2+pdim, hdim)
1194
+
1195
+ self.refine = []
1196
+ for i in range(num_blocks):
1197
+ self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn))
1198
+ self.refine.append(CNBlock2d(hdim, hdim))
1199
+ self.refine = nn.ModuleList(self.refine)
1200
+
1201
+ def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
1202
+ BS,C,H,W = flowfeat.shape
1203
+ B = BS//S
1204
+
1205
+ # print('flowfeat', flowfeat.shape)
1206
+ # print('ctxfeat', ctxfeat.shape)
1207
+ # print('visconf', visconf.shape)
1208
+ # print('corr', corr.shape)
1209
+ # print('flow', flow.shape)
1210
+
1211
+ motion_features = self.encoder(flow, corr)
1212
+
1213
+ # print('cat', torch.cat([flowfeat, ctxfeat, motion_features, visconf, flow], dim=1).shape)
1214
+ flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features, visconf], dim=1))
1215
+ for blk in self.refine:
1216
+ flowfeat = blk(flowfeat, S)
1217
+ return flowfeat
1218
+
1219
+ class MixerUpdateBlock(nn.Module):
1220
+ def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128):
1221
+ # flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
1222
+ super(MixerUpdateBlock, self).__init__()
1223
+ self.encoder = BasicMotionEncoder(corr_channel, dim=cdim)
1224
+ self.compressor = conv1x1(2*cdim+hdim, hdim)
1225
+
1226
+ self.refine = []
1227
+ for i in range(num_blocks):
1228
+ self.refine.append(CNBlock1d(hdim, hdim, use_mixer=True))
1229
+ self.refine.append(CNBlock2d(hdim, hdim))
1230
+ self.refine = nn.ModuleList(self.refine)
1231
+
1232
+ def forward(self, flowfeat, ctxfeat, corr, flow, S, upsample=True):
1233
+ BS,C,H,W = flowfeat.shape
1234
+ B = BS//S
1235
+
1236
+ # with torch.no_grad():
1237
+ motion_features = self.encoder(flow, corr)
1238
+ flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features], dim=1))
1239
+
1240
+ for ii, blk in enumerate(self.refine):
1241
+ flowfeat = blk(flowfeat, S)
1242
+ return flowfeat
1243
+
1244
+ class FacUpdateBlock(nn.Module):
1245
+ def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128, pdim=84, use_attn=False):
1246
+ super(FacUpdateBlock, self).__init__()
1247
+ self.corr_encoder = conv133_encoder(corr_channel, cdim)
1248
+ # note we have hdim==cdim
1249
+ # compressor chans:
1250
+ # dim for flowfeat
1251
+ # dim for ctxfeat
1252
+ # dim for corr
1253
+ # pdim for flow
1254
+ # 2 for visconf
1255
+ self.compressor = conv1x1(2*cdim+hdim+2+pdim, hdim)
1256
+ self.refine = []
1257
+ for i in range(num_blocks):
1258
+ self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn))
1259
+ self.refine.append(CNBlock2d(hdim, hdim))
1260
+ self.refine = nn.ModuleList(self.refine)
1261
+
1262
+ def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
1263
+ BS,C,H,W = flowfeat.shape
1264
+ B = BS//S
1265
+
1266
+ # print('flowfeat', flowfeat.shape)
1267
+ # print('ctxfeat', ctxfeat.shape)
1268
+ # print('visconf', visconf.shape)
1269
+ # print('corr', corr.shape)
1270
+ # print('flow', flow.shape)
1271
+ corr = self.corr_encoder(corr)
1272
+ # print('cat', torch.cat([flowfeat, ctxfeat, motion_features, visconf, flow], dim=1).shape)
1273
+ flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, corr, visconf, flow], dim=1))
1274
+ for blk in self.refine:
1275
+ flowfeat = blk(flowfeat, S)
1276
+ return flowfeat
1277
+
1278
+ class CleanUpdateBlock(nn.Module):
1279
+ def __init__(self, corr_channel, num_blocks, cdim=128, hdim=256, pdim=84, use_attn=False, use_layer_scale=True):
1280
+ super(CleanUpdateBlock, self).__init__()
1281
+ self.corr_encoder = conv133_encoder(corr_channel, cdim)
1282
+ # compressor chans:
1283
+ # cdim for flowfeat
1284
+ # cdim for ctxfeat
1285
+ # cdim for corrfeat
1286
+ # pdim for flow
1287
+ # 2 for visconf
1288
+ self.compressor = conv1x1(3*cdim+pdim+2, hdim)
1289
+ self.refine = []
1290
+ for i in range(num_blocks):
1291
+ self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn, use_layer_scale=use_layer_scale))
1292
+ self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale))
1293
+ self.refine = nn.ModuleList(self.refine)
1294
+ self.final_conv = conv1x1(hdim, cdim)
1295
+
1296
+ def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
1297
+ BS,C,H,W = flowfeat.shape
1298
+ B = BS//S
1299
+ corrfeat = self.corr_encoder(corr)
1300
+ flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, corrfeat, flow, visconf], dim=1))
1301
+ for blk in self.refine:
1302
+ flowfeat = blk(flowfeat, S)
1303
+ flowfeat = self.final_conv(flowfeat)
1304
+ return flowfeat
1305
+
1306
+ class RelUpdateBlock(nn.Module):
1307
+ def __init__(self, corr_channel, num_blocks, cdim=128, hdim=128, pdim=4, use_attn=True, use_mixer=False, use_conv=False, use_convb=False, use_layer_scale=True, no_time=False, no_space=False, final_space=False, no_ctx=False):
1308
+ super(RelUpdateBlock, self).__init__()
1309
+ self.motion_encoder = BasicMotionEncoder(corr_channel, dim=hdim, pdim=pdim) # B,hdim,H,W
1310
+ self.no_ctx = no_ctx
1311
+ if no_ctx:
1312
+ self.compressor = conv1x1(cdim+hdim+2, hdim)
1313
+ else:
1314
+ self.compressor = conv1x1(2*cdim+hdim+2, hdim)
1315
+ self.refine = []
1316
+ for i in range(num_blocks):
1317
+ if not no_time:
1318
+ self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn, use_mixer=use_mixer, use_conv=use_conv, use_convb=use_convb, use_layer_scale=use_layer_scale))
1319
+ if not no_space:
1320
+ self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale))
1321
+ if final_space:
1322
+ self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale))
1323
+ self.refine = nn.ModuleList(self.refine)
1324
+ self.final_conv = conv1x1(hdim, cdim)
1325
+
1326
+ def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
1327
+ BS,C,H,W = flowfeat.shape
1328
+ B = BS//S
1329
+ motion_features = self.motion_encoder(flow, corr)
1330
+ if self.no_ctx:
1331
+ flowfeat = self.compressor(torch.cat([flowfeat, motion_features, visconf], dim=1))
1332
+ else:
1333
+ flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features, visconf], dim=1))
1334
+ for blk in self.refine:
1335
+ flowfeat = blk(flowfeat, S)
1336
+ flowfeat = self.final_conv(flowfeat)
1337
+ return flowfeat
nets/net34.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import utils.samp
5
+ import utils.misc
6
+ import numpy as np
7
+
8
+ from nets.blocks import CNBlockConfig, ConvNeXt, conv1x1, RelUpdateBlock, InputPadder, CorrBlock, BasicEncoder
9
+
10
+ # init mostly from raft82 in alltrack
11
+
12
+ class Net(nn.Module):
13
+ def __init__(
14
+ self,
15
+ seqlen,
16
+ noise_level=0,
17
+ use_attn=True,
18
+ use_mixer=False,
19
+ use_conv=False,
20
+ use_convb=False,
21
+ use_basicencoder=False,
22
+ use_sinmotion=False,
23
+ use_relmotion=False,
24
+ use_sinrelmotion=False,
25
+ use_feats8=False,
26
+ no_time=False,
27
+ no_space=False,
28
+ final_space=False,
29
+ no_split=False,
30
+ no_ctx=False,
31
+ full_split=False,
32
+ half_corr=False,
33
+ corr_levels=5,
34
+ corr_radius=4,
35
+ num_blocks=3,
36
+ dim=128,
37
+ hdim=128,
38
+ init_weights=True,
39
+ ):
40
+ super(Net, self).__init__()
41
+
42
+ self.dim = dim
43
+ self.hdim = hdim
44
+
45
+ self.noise_level = noise_level
46
+ self.no_time = no_time
47
+ self.no_space = no_space
48
+ self.final_space = final_space
49
+ self.seqlen = seqlen
50
+ self.corr_levels = corr_levels
51
+ self.corr_radius = corr_radius
52
+ self.corr_channel = self.corr_levels * (self.corr_radius * 2 + 1) ** 2
53
+ self.num_blocks = num_blocks
54
+
55
+ self.use_feats8 = use_feats8
56
+ self.use_basicencoder = use_basicencoder
57
+ self.use_sinmotion = use_sinmotion
58
+ self.use_relmotion = use_relmotion
59
+ self.use_sinrelmotion = use_sinrelmotion
60
+ self.no_split = no_split
61
+ self.no_ctx = no_ctx
62
+ self.full_split = full_split
63
+ self.half_corr = half_corr
64
+
65
+ if use_basicencoder:
66
+ if self.full_split:
67
+ self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8)
68
+ self.cnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8)
69
+ else:
70
+ if self.no_split:
71
+ self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8)
72
+ else:
73
+ self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim*2, stride=8)
74
+ else:
75
+ block_setting = [
76
+ CNBlockConfig(96, 192, 3, True), # 4x
77
+ CNBlockConfig(192, 384, 3, False), # 8x
78
+ CNBlockConfig(384, None, 9, False), # 8x
79
+ ]
80
+ self.cnn = ConvNeXt(block_setting, stochastic_depth_prob=0.0, init_weights=init_weights)
81
+ if self.no_split:
82
+ self.dot_conv = conv1x1(384, dim)
83
+ else:
84
+ self.dot_conv = conv1x1(384, dim*2)
85
+
86
+ # # conv for iter 0 results
87
+ # self.init_conv = conv3x3(2 * dim, 2 * dim)
88
+ self.upsample_weight = nn.Sequential(
89
+ # convex combination of 3x3 patches
90
+ nn.Conv2d(dim, dim * 2, 3, padding=1),
91
+ nn.ReLU(inplace=True),
92
+ nn.Conv2d(dim * 2, 64 * 9, 1, padding=0)
93
+ )
94
+ self.flow_head = nn.Sequential(
95
+ nn.Conv2d(dim, 2*dim, kernel_size=3, padding=1),
96
+ nn.ReLU(inplace=True),
97
+ nn.Conv2d(2*dim, 2, kernel_size=3, padding=1)
98
+ )
99
+ self.visconf_head = nn.Sequential(
100
+ nn.Conv2d(dim, 2*dim, kernel_size=3, padding=1),
101
+ nn.ReLU(inplace=True),
102
+ nn.Conv2d(2*dim, 2, kernel_size=3, padding=1)
103
+ )
104
+
105
+ if self.use_sinrelmotion:
106
+ self.pdim = 84 # 32*2
107
+ elif self.use_relmotion:
108
+ self.pdim = 4
109
+ elif self.use_sinmotion:
110
+ self.pdim = 42
111
+ else:
112
+ self.pdim = 2
113
+
114
+ self.update_block = RelUpdateBlock(self.corr_channel, self.num_blocks, cdim=dim, hdim=hdim, pdim=self.pdim,
115
+ use_attn=use_attn, use_mixer=use_mixer, use_conv=use_conv, use_convb=use_convb,
116
+ use_layer_scale=True, no_time=no_time, no_space=no_space, final_space=final_space,
117
+ no_ctx=no_ctx)
118
+
119
+ time_line = torch.linspace(0, seqlen-1, seqlen).reshape(1, seqlen, 1)
120
+ self.register_buffer("time_emb", utils.misc.get_1d_sincos_pos_embed_from_grid(self.dim, time_line[0])) # 1,S,C
121
+
122
+ def fetch_time_embed(self, t, dtype, is_training=False):
123
+ S = self.time_emb.shape[1]
124
+ # print('fetching time_embed for t', t, '(we have %d)' % (S))
125
+ # print('self.time_emb', self.time_emb.shape)
126
+ if t == S:
127
+ return self.time_emb.to(dtype)
128
+ elif t==1:
129
+ if is_training:
130
+ ind = np.random.choice(S)
131
+ return self.time_emb[:,ind:ind+1].to(dtype)
132
+ else:
133
+ return self.time_emb[:,1:2].to(dtype)
134
+ else:
135
+ time_emb = self.time_emb.float()
136
+ time_emb = F.interpolate(time_emb.permute(0, 2, 1), size=t, mode="linear").permute(0, 2, 1)
137
+ return time_emb.to(dtype)
138
+
139
+ def coords_grid(self, batch, ht, wd, device, dtype):
140
+ coords = torch.meshgrid(torch.arange(ht, device=device, dtype=dtype), torch.arange(wd, device=device, dtype=dtype))
141
+ coords = torch.stack(coords[::-1], dim=0)
142
+ return coords[None].repeat(batch, 1, 1, 1)
143
+
144
+ def initialize_flow(self, img):
145
+ """ Flow is represented as difference between two coordinate grids flow = coords2 - coords1"""
146
+ N, C, H, W = img.shape
147
+ coords1 = self.coords_grid(N, H//8, W//8, device=img.device)
148
+ coords2 = self.coords_grid(N, H//8, W//8, device=img.device)
149
+ return coords1, coords2
150
+
151
+ def upsample_data(self, flow, mask):
152
+ """ Upsample [H/8, W/8, C] -> [H, W, C] using convex combination """
153
+ N, C, H, W = flow.shape
154
+ mask = mask.view(N, 1, 9, 8, 8, H, W)
155
+ mask = torch.softmax(mask, dim=2)
156
+
157
+ up_flow = F.unfold(8 * flow, [3,3], padding=1)
158
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
159
+
160
+ up_flow = torch.sum(mask * up_flow, dim=2)
161
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
162
+
163
+ return up_flow.reshape(N, 2, 8*H, 8*W).to(flow.dtype)
164
+
165
+ def get_T_padded_images(self, images, T, S, is_training, stride=None, pad=True):
166
+ B,T,C,H,W = images.shape
167
+ indices = None
168
+ if T > 2:
169
+ step = S // 2 if stride is None else stride
170
+ # starts = list(range(step,max(T,S),step))
171
+ # indices = [mid-step for mid in mids]
172
+ indices = []
173
+ start = 0
174
+ while start + S < T:
175
+ indices.append(start)
176
+ start += step
177
+ indices.append(start)
178
+ Tpad = indices[-1]+S-T
179
+ # print(indices, Tpad)
180
+ # import pdb; pdb.set_trace()
181
+ if pad:
182
+ if is_training:
183
+ assert Tpad == 0
184
+ else:
185
+ images = images.reshape(B,1,T,C*H*W)
186
+ if Tpad > 0:
187
+ padding_tensor = images[:,:,-1:,:].expand(B,1,Tpad,C*H*W)
188
+ images = torch.cat([images, padding_tensor], dim=2)
189
+ images = images.reshape(B,T+Tpad,C,H,W)
190
+ T = T+Tpad
191
+ else:
192
+ assert T == 2
193
+ return images, T, indices
194
+
195
+ def get_fmaps(self, images_, B, T, sw, is_training, nograd_backbone):
196
+ _, _, H_pad, W_pad = images_.shape # revised HW
197
+
198
+ C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
199
+ if self.no_split:
200
+ C = self.dim
201
+
202
+ fmaps_chunk_size = 64
203
+ if (not is_training) and (T > fmaps_chunk_size):
204
+ images = images_.reshape(B,T,3,H_pad,W_pad)
205
+ fmaps = []
206
+ for t in range(0, T, fmaps_chunk_size):
207
+ images_chunk = images[:, t : t + fmaps_chunk_size]
208
+ images_chunk = images_chunk.cuda()
209
+ if self.use_basicencoder:
210
+ if self.full_split:
211
+ fmaps_chunk1 = self.fnet(images_chunk.reshape(-1, 3, H_pad, W_pad))
212
+ fmaps_chunk2 = self.cnet(images_chunk.reshape(-1, 3, H_pad, W_pad))
213
+ fmaps_chunk = torch.cat([fmaps_chunk1, fmaps_chunk2], axis=1)
214
+ else:
215
+ fmaps_chunk = self.fnet(images_chunk.reshape(-1, 3, H_pad, W_pad))
216
+ else:
217
+ fmaps_chunk = self.cnn(images_chunk.reshape(-1, 3, H_pad, W_pad))
218
+ if t==0 and sw is not None and sw.save_this:
219
+ sw.summ_feat('1_model/fmap_raw', fmaps_chunk[0:1])
220
+ fmaps_chunk = self.dot_conv(fmaps_chunk) # B*T,C,H8,W8
221
+ T_chunk = images_chunk.shape[1]
222
+ fmaps.append(fmaps_chunk.reshape(B, -1, C, H8, W8))
223
+ fmaps_ = torch.cat(fmaps, dim=1).reshape(-1, C, H8, W8)
224
+ else:
225
+ if not is_training:
226
+ # sometimes we need to move things to cuda here
227
+ images_ = images_.cuda()
228
+ if self.use_basicencoder:
229
+ if self.full_split:
230
+ # if self.half_corr:
231
+ fmaps1_ = self.fnet(images_)
232
+ fmaps2_ = self.cnet(images_)
233
+ fmaps_ = torch.cat([fmaps1_, fmaps2_], axis=1)
234
+ else:
235
+ fmaps_ = self.fnet(images_)
236
+ else:
237
+ if nograd_backbone:
238
+ with torch.no_grad():
239
+ fmaps_ = self.cnn(images_)
240
+ else:
241
+ fmaps_ = self.cnn(images_)
242
+ if sw is not None and sw.save_this:
243
+ sw.summ_feat('1_model/fmap_raw', fmaps_[0:1])
244
+ fmaps_ = self.dot_conv(fmaps_) # B*T,C,H8,W8
245
+ return fmaps_
246
+
247
+ def forward(self, images, iters=6, sw=None, nograd_backbone=False, is_training=False, stride=None):
248
+ B,T,C,H,W = images.shape
249
+ S = self.seqlen
250
+ device = images.device
251
+ dtype = images.dtype
252
+ # print('images', images.shape, 'device', device)
253
+
254
+ T_bak = T
255
+ if stride is not None:
256
+ pad = False
257
+ else:
258
+ pad = True
259
+ images, T, indices = self.get_T_padded_images(images, T, S, is_training, stride=stride, pad=pad)
260
+ # print('indices', indices)
261
+
262
+ images = images.contiguous()
263
+ images_ = images.reshape(B*T,3,H,W)
264
+ padder = InputPadder(images_.shape)
265
+ images_ = padder.pad(images_)[0]
266
+
267
+ _, _, H_pad, W_pad = images_.shape # revised HW
268
+ C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
269
+ C2 = C//2
270
+ if self.no_split:
271
+ C = self.dim
272
+ C2 = C
273
+
274
+ fmaps = self.get_fmaps(images_, B, T, sw, is_training, nograd_backbone).reshape(B,T,C,H8,W8)
275
+ # print('fmaps_', fmaps_.shape)
276
+ device = fmaps.device
277
+
278
+ # if sw is not None and sw.save_this:
279
+ # sw.summ_feat('1_model/fmap_dc', fmaps_[0:1])
280
+
281
+ # fmaps = fmaps_.reshape(B,T,C,H8,W8)
282
+ # del fmaps_
283
+ fmap_anchor = fmaps[:,0]
284
+
285
+ # if not is_training:
286
+ # del images
287
+ # del images_
288
+
289
+ if T<=2 or is_training:
290
+ # note: collecting preds can get expensive on a long video
291
+ all_flow_preds = []
292
+ all_visconf_preds = []
293
+ else:
294
+ all_flow_preds = None
295
+ all_visconf_preds = None
296
+
297
+ if T > 2: # multiframe tracking
298
+
299
+ # we will store our final outputs in these tensors
300
+ full_flows = torch.zeros((B,T,2,H,W), dtype=dtype, device=device)
301
+ full_visconfs = torch.zeros((B,T,2,H,W), dtype=dtype, device=device)
302
+ # 1/8 resolution
303
+ full_flows8 = torch.zeros((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
304
+ full_visconfs8 = torch.zeros((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
305
+
306
+ if is_training and self.noise_level and np.random.rand() < 0.1:
307
+ # full_flows8 = 2.0*torch.randn((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
308
+ # print('flows8 += randn4, const on time')
309
+ # full_flows8 = float(self.noise_level)*torch.randn((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
310
+ # print('noise const on time')
311
+ full_flows8 += np.random.rand()*float(self.noise_level)*torch.randn((B,1,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
312
+
313
+ if self.use_feats8:
314
+ full_feats8 = torch.zeros((B,T,C2,H_pad//8,W_pad//8), dtype=dtype, device=device)
315
+ visits = np.zeros((T))
316
+
317
+ for ii, ind in enumerate(indices):
318
+ ara = np.arange(ind,ind+S)
319
+ if ii < len(indices)-1:
320
+ next_ind = indices[ii+1]
321
+ next_ara = np.arange(next_ind,next_ind+S)
322
+
323
+ # print("torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024), 'ara', ara)
324
+ # print('ara', ara)
325
+ fmaps2 = fmaps[:,ara]
326
+ flows8 = full_flows8[:,ara].reshape(B*(S),2,H_pad//8,W_pad//8).detach()
327
+ visconfs8 = full_visconfs8[:,ara].reshape(B*(S),2,H_pad//8,W_pad//8).detach()
328
+
329
+ if self.use_feats8:
330
+ if ind==0:
331
+ feats8 = None
332
+ else:
333
+ feats8 = full_feats8[:,ara].reshape(B*(S),C2,H_pad//8,W_pad//8).detach()
334
+ else:
335
+ feats8 = None
336
+
337
+ flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window(
338
+ fmap_anchor, fmaps2, visconfs8, iters=iters, flowfeat=feats8, flows8=flows8,
339
+ is_training=is_training)
340
+
341
+ unpad_flow_predictions = []
342
+ unpad_visconf_predictions = []
343
+ for i in range(len(flow_predictions)):
344
+ flow_predictions[i] = padder.unpad(flow_predictions[i])
345
+ unpad_flow_predictions.append(flow_predictions[i].reshape(B,S,2,H,W))
346
+ visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
347
+ # print('visconf_predictions[%d]' % i, visconf_predictions[i].shape)
348
+ unpad_visconf_predictions.append(visconf_predictions[i].reshape(B,S,2,H,W))
349
+
350
+ full_flows[:,ara] = unpad_flow_predictions[-1].reshape(B,S,2,H,W)
351
+ full_flows8[:,ara] = flows8.reshape(B,S,2,H_pad//8,W_pad//8)
352
+ full_visconfs[:,ara] = unpad_visconf_predictions[-1].reshape(B,S,2,H,W)
353
+ full_visconfs8[:,ara] = visconfs8.reshape(B,S,2,H_pad//8,W_pad//8)
354
+ if self.use_feats8:
355
+ full_feats8[:,ara] = feats8.reshape(B,S,C2,H_pad//8,W_pad//8)
356
+ visits[ara] += 1
357
+
358
+ if is_training:
359
+ all_flow_preds.append(unpad_flow_predictions)
360
+ all_visconf_preds.append(unpad_visconf_predictions)
361
+ else:
362
+ del unpad_flow_predictions
363
+ del unpad_visconf_predictions
364
+
365
+ # for the next iter, replace empty data with nearest available preds
366
+ invalid_idx = np.where(visits==0)[0]
367
+ valid_idx = np.where(visits>0)[0]
368
+ for idx in invalid_idx:
369
+ nearest = valid_idx[np.argmin(np.abs(valid_idx - idx))]
370
+ # print('replacing %d with %d' % (idx, nearest))
371
+ full_flows8[:,idx] = full_flows8[:,nearest]
372
+ full_visconfs8[:,idx] = full_visconfs8[:,nearest]
373
+ if self.use_feats8:
374
+ full_feats8[:,idx] = full_feats8[:,nearest]
375
+ else: # flow
376
+
377
+ flows8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
378
+ visconfs8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
379
+
380
+ if is_training and self.noise_level and np.random.rand() < 0.1:
381
+ # full_flows8 = 2.0*torch.randn((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
382
+ # print('flows8 += randn4, const on time')
383
+ # full_flows8 = float(self.noise_level)*torch.randn((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
384
+ # print('noise on flow too')
385
+ flows8 += np.random.rand()*float(self.noise_level)*torch.randn((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
386
+
387
+ flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window(
388
+ fmap_anchor, fmaps[:,1:2], visconfs8, iters=iters, flowfeat=None, flows8=flows8,
389
+ is_training=is_training)
390
+ unpad_flow_predictions = []
391
+ unpad_visconf_predictions = []
392
+ for i in range(len(flow_predictions)):
393
+ flow_predictions[i] = padder.unpad(flow_predictions[i])
394
+ all_flow_preds.append(flow_predictions[i].reshape(B,2,H,W))
395
+ visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
396
+ all_visconf_preds.append(visconf_predictions[i].reshape(B,2,H,W))
397
+ full_flows = all_flow_preds[-1].reshape(B,2,H,W)
398
+ full_visconfs = all_visconf_preds[-1].reshape(B,2,H,W)
399
+
400
+ if (not is_training) and (T > 2):
401
+ # print('full_flows', full_flows.shape)
402
+ # print('T_bak', T_bak)
403
+ full_flows = full_flows[:,:T_bak]
404
+ full_visconfs = full_visconfs[:,:T_bak]
405
+ # print('full_flows trim', full_flows.shape)
406
+
407
+ return full_flows, full_visconfs, all_flow_preds, all_visconf_preds#, bak_flows8
408
+
409
+ def forward_sliding(self, images, iters=6, sw=None, nograd_backbone=False, is_training=False, window_len=None, stride=None):
410
+ B,T,C,H,W = images.shape
411
+ S = self.seqlen if window_len is None else window_len
412
+ device = images.device
413
+ dtype = images.dtype
414
+ # print('images', images.shape, 'device', device)
415
+ stride = S // 2 if stride is None else stride
416
+
417
+ T_bak = T
418
+ images, T, indices = self.get_T_padded_images(images, T, S, is_training, stride)
419
+ # print('indices', indices)
420
+ assert stride <= S // 2
421
+
422
+ images = images.contiguous()
423
+ images_ = images.reshape(B*T,3,H,W)
424
+ padder = InputPadder(images_.shape)
425
+ images_ = padder.pad(images_)[0]
426
+
427
+ _, _, H_pad, W_pad = images_.shape # revised HW
428
+ C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
429
+ C2 = C//2
430
+ if self.no_split:
431
+ C = self.dim
432
+ C2 = C
433
+
434
+ all_flow_preds = None
435
+ all_visconf_preds = None
436
+
437
+ if T<=2:
438
+ # note: collecting preds can get expensive on a long video
439
+ all_flow_preds = []
440
+ all_visconf_preds = []
441
+
442
+ fmaps = self.get_fmaps(images_, B, T, sw, is_training, nograd_backbone).reshape(B,T,C,H8,W8)
443
+ device = fmaps.device
444
+
445
+ flows8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
446
+ visconfs8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
447
+
448
+ if is_training and self.noise_level and np.random.rand() < 0.1:
449
+ # full_flows8 = 2.0*torch.randn((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
450
+ # print('flows8 += randn4, const on time')
451
+ # full_flows8 = float(self.noise_level)*torch.randn((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
452
+ # print('noise on flow too')
453
+ flows8 += np.random.rand()*float(self.noise_level)*torch.randn((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
454
+
455
+ fmap_anchor = fmaps[:,0]
456
+
457
+ flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window(
458
+ fmap_anchor, fmaps[:,1:2], visconfs8, iters=iters, flowfeat=None, flows8=flows8,
459
+ is_training=is_training)
460
+ unpad_flow_predictions = []
461
+ unpad_visconf_predictions = []
462
+ for i in range(len(flow_predictions)):
463
+ flow_predictions[i] = padder.unpad(flow_predictions[i])
464
+ all_flow_preds.append(flow_predictions[i].reshape(B,2,H,W))
465
+ visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
466
+ all_visconf_preds.append(visconf_predictions[i].reshape(B,2,H,W))
467
+ full_flows = all_flow_preds[-1].reshape(B,2,H,W).detach().cpu()
468
+ full_visconfs = all_visconf_preds[-1].reshape(B,2,H,W).detach().cpu()
469
+
470
+ return full_flows, full_visconfs, all_flow_preds, all_visconf_preds#, bak_flows8
471
+
472
+ assert T > 2 # multiframe tracking
473
+
474
+ if is_training:
475
+ all_flow_preds = []
476
+ all_visconf_preds = []
477
+
478
+ # # we will store our final outputs in these tensors cpu
479
+ full_flows = torch.zeros((B,T,2,H,W), dtype=dtype, device='cpu')
480
+ full_visconfs = torch.zeros((B,T,2,H,W), dtype=dtype, device='cpu')
481
+
482
+ images_ = images_.reshape(B,T,3,H_pad,W_pad)
483
+ fmap_anchor = self.get_fmaps(images_[:,:1].reshape(-1,3,H_pad,W_pad), B, 1, sw, is_training, nograd_backbone).reshape(B,C,H8,W8)
484
+ device = fmap_anchor.device
485
+ full_visited = torch.zeros((T,), dtype=torch.bool, device=device)
486
+
487
+ for ii, ind in enumerate(indices):
488
+ ara = np.arange(ind,ind+S)
489
+ if ii == 0:
490
+ flows8 = torch.zeros((B,S,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
491
+ visconfs8 = torch.zeros((B,S,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
492
+ fmaps2 = self.get_fmaps(images_[:,ara].reshape(-1,3,H_pad,W_pad), B, S, sw, is_training, nograd_backbone).reshape(B,S,C,H8,W8)
493
+ if is_training and self.noise_level and np.random.rand() < 0.1:
494
+ flows8 += np.random.rand()*float(self.noise_level)*torch.randn((B,1,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
495
+ else:
496
+ # import pdb; pdb.set_trace()
497
+ flows8 = torch.cat([flows8[:,stride:stride+S//2], flows8[:,stride+S//2-1:stride+S//2].repeat(1,S//2,1,1,1)], dim=1)
498
+ visconfs8 = torch.cat([visconfs8[:,stride:stride+S//2], visconfs8[:,stride+S//2-1:stride+S//2].repeat(1,S//2,1,1,1)], dim=1)
499
+ fmaps2 = torch.cat([fmaps2[:,stride:stride+S//2],
500
+ self.get_fmaps(images_[:,np.arange(ind+S//2,ind+S)].reshape(-1,3,H_pad,W_pad), B, S//2, sw, is_training, nograd_backbone).reshape(B,S//2,C,H8,W8)], dim=1)
501
+
502
+ flows8 = flows8.reshape(B*S,2,H_pad//8,W_pad//8).detach()
503
+ visconfs8 = visconfs8.reshape(B*S,2,H_pad//8,W_pad//8).detach()
504
+
505
+ # print("torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024), 'ara', ara)
506
+ # print('ara', ara)
507
+ flow_predictions, visconf_predictions, flows8, visconfs8, _ = self.forward_window(
508
+ fmap_anchor, fmaps2, visconfs8, iters=iters, flowfeat=None, flows8=flows8,
509
+ is_training=is_training)
510
+
511
+ unpad_flow_predictions = []
512
+ unpad_visconf_predictions = []
513
+ for i in range(len(flow_predictions)):
514
+ flow_predictions[i] = padder.unpad(flow_predictions[i])
515
+ unpad_flow_predictions.append(flow_predictions[i].reshape(B,S,2,H,W))
516
+ visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
517
+ # print('visconf_predictions[%d]' % i, visconf_predictions[i].shape)
518
+ unpad_visconf_predictions.append(visconf_predictions[i].reshape(B,S,2,H,W))
519
+
520
+ current_visiting = torch.zeros((T,), dtype=torch.bool, device=device)
521
+ current_visiting[ara] = True
522
+
523
+ to_fill = current_visiting & (~full_visited)
524
+ to_fill_sum = to_fill.sum().item()
525
+ full_flows[:,to_fill] = unpad_flow_predictions[-1].reshape(B,S,2,H,W)[:,-to_fill_sum:].detach().cpu()
526
+ full_visconfs[:,to_fill] = unpad_visconf_predictions[-1].reshape(B,S,2,H,W)[:,-to_fill_sum:].detach().cpu()
527
+ full_visited |= current_visiting
528
+
529
+ if is_training:
530
+ all_flow_preds.append(unpad_flow_predictions)
531
+ all_visconf_preds.append(unpad_visconf_predictions)
532
+ else:
533
+ del unpad_flow_predictions
534
+ del unpad_visconf_predictions
535
+
536
+ flows8 = flows8.reshape(B,S,2,H_pad//8,W_pad//8)
537
+ visconfs8 = visconfs8.reshape(B,S,2,H_pad//8,W_pad//8)
538
+
539
+ if not is_training:
540
+ # print('full_flows', full_flows.shape)
541
+ # print('T_bak', T_bak)
542
+ full_flows = full_flows[:,:T_bak]
543
+ full_visconfs = full_visconfs[:,:T_bak]
544
+ # print('full_flows trim', full_flows.shape)
545
+
546
+ return full_flows, full_visconfs, all_flow_preds, all_visconf_preds#, bak_flows8
547
+
548
+ def forward_window(self, fmap1_single, fmaps2, visconfs8, iters=None, flowfeat=None, flows8=None, sw=None, is_training=False):
549
+ B,S,C,H8,W8 = fmaps2.shape
550
+ device = fmaps2.device
551
+ dtype = fmaps2.dtype
552
+
553
+ flow_predictions = []
554
+ visconf_predictions = []
555
+
556
+ # print('fmap1_single', fmap1_single.shape)
557
+ # print('fmaps2', fmaps2.shape)
558
+
559
+ fmap1 = fmap1_single.unsqueeze(1).repeat(1,S,1,1,1) # B,S,C,H,W
560
+ # print('fmap1', fmap1.shape)
561
+ fmap1 = fmap1.reshape(B*(S),C,H8,W8).contiguous()
562
+ # print('fmap1', fmap1.shape)
563
+
564
+ fmap2 = fmaps2.reshape(B*(S),C,H8,W8).contiguous()
565
+
566
+ visconfs8 = visconfs8.reshape(B*(S),2,H8,W8).contiguous()
567
+
568
+ if not self.half_corr:
569
+ corr_fn = CorrBlock(fmap1, fmap2, self.corr_levels, self.corr_radius)
570
+
571
+ coords1 = self.coords_grid(B*(S), H8, W8, device=fmap1.device, dtype=dtype)
572
+
573
+ if self.no_split:
574
+ flowfeat, ctxfeat = fmap1.clone(), fmap1.clone()
575
+ else:
576
+ if flowfeat is not None:
577
+ _, ctxfeat = torch.split(fmap1, [self.dim, self.dim], dim=1)
578
+ else:
579
+ flowfeat, ctxfeat = torch.split(fmap1, [self.dim, self.dim], dim=1)
580
+
581
+ if self.half_corr:
582
+ # we discard the cnet output for fmap2
583
+ # in general maybe we shouldn't compute it, if half_corr pays off
584
+ flowfeat2, _ = torch.split(fmap2, [self.dim, self.dim], dim=1)
585
+ corr_fn = CorrBlock(flowfeat, flowfeat2, self.corr_levels, self.corr_radius)
586
+
587
+ # add pos emb to ctxfeat (and not flowfeat), since ctxfeat is untouched across iters
588
+ time_emb = self.fetch_time_embed(S, ctxfeat.dtype, is_training).reshape(1,S,self.dim,1,1).repeat(B,1,1,1,1)
589
+ ctxfeat = ctxfeat + time_emb.reshape(B*S,self.dim,1,1)
590
+
591
+ if self.no_ctx:
592
+ flowfeat = flowfeat + time_emb.reshape(B*S,self.dim,1,1)
593
+
594
+ for itr in range(iters):
595
+ _, _, H8, W8 = flows8.shape
596
+ flows8 = flows8.detach()
597
+ coords2 = (coords1 + flows8).detach() # B*S,2,H,W
598
+ corr = corr_fn(coords2).to(dtype)
599
+
600
+ if self.use_relmotion or self.use_sinrelmotion:
601
+ coords_ = coords2.reshape(B,S,2,H8*W8).permute(0,1,3,2) # B,S,H8*W8,2
602
+ rel_coords_forward = coords_[:, :-1] - coords_[:, 1:]
603
+ rel_coords_backward = coords_[:, 1:] - coords_[:, :-1]
604
+ rel_coords_forward = torch.nn.functional.pad(
605
+ rel_coords_forward, (0, 0, 0, 0, 0, 1) # pad the 3rd-last dim (S) by (0,1)
606
+ )
607
+ rel_coords_backward = torch.nn.functional.pad(
608
+ rel_coords_backward, (0, 0, 0, 0, 1, 0) # pad the 3rd-last dim (S) by (1,0)
609
+ )
610
+ rel_coords = torch.cat([rel_coords_forward, rel_coords_backward], dim=-1) # B,S,H8*W8,4
611
+
612
+ if self.use_sinrelmotion:
613
+ rel_pos_emb_input = utils.misc.posenc(
614
+ rel_coords,
615
+ min_deg=0,
616
+ max_deg=10,
617
+ ) # B,S,H*W,pdim
618
+ motion = rel_pos_emb_input.reshape(B*S,H8,W8,self.pdim).permute(0,3,1,2).to(dtype) # B*S,pdim,H8,W8
619
+ else:
620
+ motion = rel_coords.reshape(B*S,H8,W8,4).permute(0,3,1,2).to(dtype) # B*S,4,H8,W8
621
+
622
+ else:
623
+ if self.use_sinmotion:
624
+ pos_emb_input = utils.misc.posenc(
625
+ flows8.reshape(B,S,H8*W8,2),
626
+ min_deg=0,
627
+ max_deg=10,
628
+ ) # B,S,H*W,pdim
629
+ motion = pos_emb_input.reshape(B*S,H8,W8,self.pdim).permute(0,3,1,2).to(dtype) # B*S,pdim,H8,W8
630
+ else:
631
+ motion = flows8
632
+
633
+ flowfeat = self.update_block(flowfeat, ctxfeat, visconfs8, corr, motion, S)
634
+ flow_update = self.flow_head(flowfeat)
635
+ visconf_update = self.visconf_head(flowfeat)
636
+ weight_update = .25 * self.upsample_weight(flowfeat)
637
+ flows8 = flows8 + flow_update
638
+ visconfs8 = visconfs8 + visconf_update
639
+ flow_up = self.upsample_data(flows8, weight_update)
640
+ flow_predictions.append(flow_up)
641
+ visconf_up = self.upsample_data(visconfs8, weight_update)
642
+ visconf_predictions.append(visconf_up)
643
+
644
+ return flow_predictions, visconf_predictions, flows8, visconfs8, flowfeat#, bak_flows8
645
+
646
+
647
+
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ imageio==2.19.3
3
+ imageio-ffmpeg==0.4.7
4
+ gradio
5
+ spaces
6
+ matplotlib
7
+ pillow
8
+ torch==2.2.0
9
+ torchvision==0.17.0
10
+ albumentations
11
+ pytorch-lightning==2.2.5
12
+ opencv-python
13
+ scikit-learn
14
+ scikit-image
15
+ einops
16
+ tensorboardX
17
+ transformers
utils/basic.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from os.path import isfile
4
+ import torch
5
+ import torch.nn.functional as F
6
+ EPS = 1e-6
7
+ import copy
8
+
9
+ def sub2ind(height, width, y, x):
10
+ return y*width + x
11
+
12
+ def ind2sub(height, width, ind):
13
+ y = ind // width
14
+ x = ind % width
15
+ return y, x
16
+
17
+ def get_lr_str(lr):
18
+ lrn = "%.1e" % lr # e.g., 5.0e-04
19
+ lrn = lrn[0] + lrn[3:5] + lrn[-1] # e.g., 5e-4
20
+ return lrn
21
+
22
+ def strnum(x):
23
+ s = '%g' % x
24
+ if '.' in s:
25
+ if x < 1.0:
26
+ s = s[s.index('.'):]
27
+ s = s[:min(len(s),4)]
28
+ return s
29
+
30
+ def assert_same_shape(t1, t2):
31
+ for (x, y) in zip(list(t1.shape), list(t2.shape)):
32
+ assert(x==y)
33
+
34
+ def print_stats(name, tensor):
35
+ shape = tensor.shape
36
+ tensor = tensor.detach().cpu().numpy()
37
+ print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape)
38
+
39
+ def print_stats_py(name, tensor):
40
+ shape = tensor.shape
41
+ print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape)
42
+
43
+ def print_(name, tensor):
44
+ tensor = tensor.detach().cpu().numpy()
45
+ print(name, tensor, tensor.shape)
46
+
47
+ def mkdir(path):
48
+ if not os.path.exists(path):
49
+ os.makedirs(path)
50
+
51
+ def normalize_single(d):
52
+ # d is a whatever shape torch tensor
53
+ dmin = torch.min(d)
54
+ dmax = torch.max(d)
55
+ d = (d-dmin)/(EPS+(dmax-dmin))
56
+ return d
57
+
58
+ def normalize(d):
59
+ # d is B x whatever. normalize within each element of the batch
60
+ out = torch.zeros(d.size(), dtype=d.dtype, device=d.device)
61
+ B = list(d.size())[0]
62
+ for b in list(range(B)):
63
+ out[b] = normalize_single(d[b])
64
+ return out
65
+
66
+ def hard_argmax2d(tensor):
67
+ B, C, Y, X = list(tensor.shape)
68
+ assert(C==1)
69
+
70
+ # flatten the Tensor along the height and width axes
71
+ flat_tensor = tensor.reshape(B, -1)
72
+ # argmax of the flat tensor
73
+ argmax = torch.argmax(flat_tensor, dim=1)
74
+
75
+ # convert the indices into 2d coordinates
76
+ argmax_y = torch.floor(argmax / X) # row
77
+ argmax_x = argmax % X # col
78
+
79
+ argmax_y = argmax_y.reshape(B)
80
+ argmax_x = argmax_x.reshape(B)
81
+ return argmax_y, argmax_x
82
+
83
+ def argmax2d(heat, hard=True):
84
+ B, C, Y, X = list(heat.shape)
85
+ assert(C==1)
86
+
87
+ if hard:
88
+ # hard argmax
89
+ loc_y, loc_x = hard_argmax2d(heat)
90
+ loc_y = loc_y.float()
91
+ loc_x = loc_x.float()
92
+ else:
93
+ heat = heat.reshape(B, Y*X)
94
+ prob = torch.nn.functional.softmax(heat, dim=1)
95
+
96
+ grid_y, grid_x = meshgrid2d(B, Y, X)
97
+
98
+ grid_y = grid_y.reshape(B, -1)
99
+ grid_x = grid_x.reshape(B, -1)
100
+
101
+ loc_y = torch.sum(grid_y*prob, dim=1)
102
+ loc_x = torch.sum(grid_x*prob, dim=1)
103
+ # these are B
104
+
105
+ return loc_y, loc_x
106
+
107
+ def reduce_masked_mean(x, mask, dim=None, keepdim=False, broadcast=False):
108
+ # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
109
+ # returns shape-1
110
+ # axis can be a list of axes
111
+ if not broadcast:
112
+ for (a,b) in zip(x.size(), mask.size()):
113
+ if not a==b:
114
+ print('some shape mismatch:', x.shape, mask.shape)
115
+ assert(a==b) # some shape mismatch!
116
+ # assert(x.size() == mask.size())
117
+ prod = x*mask
118
+ if dim is None:
119
+ numer = torch.sum(prod)
120
+ denom = EPS+torch.sum(mask)
121
+ else:
122
+ numer = torch.sum(prod, dim=dim, keepdim=keepdim)
123
+ denom = EPS+torch.sum(mask, dim=dim, keepdim=keepdim)
124
+ mean = numer/denom
125
+ return mean
126
+
127
+ def reduce_masked_median(x, mask, keep_batch=False):
128
+ # x and mask are the same shape
129
+ assert(x.size() == mask.size())
130
+ device = x.device
131
+
132
+ B = list(x.shape)[0]
133
+ x = x.detach().cpu().numpy()
134
+ mask = mask.detach().cpu().numpy()
135
+
136
+ if keep_batch:
137
+ x = np.reshape(x, [B, -1])
138
+ mask = np.reshape(mask, [B, -1])
139
+ meds = np.zeros([B], np.float32)
140
+ for b in list(range(B)):
141
+ xb = x[b]
142
+ mb = mask[b]
143
+ if np.sum(mb) > 0:
144
+ xb = xb[mb > 0]
145
+ meds[b] = np.median(xb)
146
+ else:
147
+ meds[b] = np.nan
148
+ meds = torch.from_numpy(meds).to(device)
149
+ return meds.float()
150
+ else:
151
+ x = np.reshape(x, [-1])
152
+ mask = np.reshape(mask, [-1])
153
+ if np.sum(mask) > 0:
154
+ x = x[mask > 0]
155
+ med = np.median(x)
156
+ else:
157
+ med = np.nan
158
+ med = np.array([med], np.float32)
159
+ med = torch.from_numpy(med).to(device)
160
+ return med.float()
161
+
162
+ def pack_seqdim(tensor, B):
163
+ shapelist = list(tensor.shape)
164
+ B_, S = shapelist[:2]
165
+ assert(B==B_)
166
+ otherdims = shapelist[2:]
167
+ tensor = torch.reshape(tensor, [B*S]+otherdims)
168
+ return tensor
169
+
170
+ def unpack_seqdim(tensor, B):
171
+ shapelist = list(tensor.shape)
172
+ BS = shapelist[0]
173
+ assert(BS%B==0)
174
+ otherdims = shapelist[1:]
175
+ S = int(BS/B)
176
+ tensor = torch.reshape(tensor, [B,S]+otherdims)
177
+ return tensor
178
+
179
+ def meshgrid2d(B, Y, X, stack=False, norm=False, device='cuda', on_chans=False):
180
+ # returns a meshgrid sized B x Y x X
181
+
182
+ grid_y = torch.linspace(0.0, Y-1, Y, device=torch.device(device))
183
+ grid_y = torch.reshape(grid_y, [1, Y, 1])
184
+ grid_y = grid_y.repeat(B, 1, X)
185
+
186
+ grid_x = torch.linspace(0.0, X-1, X, device=torch.device(device))
187
+ grid_x = torch.reshape(grid_x, [1, 1, X])
188
+ grid_x = grid_x.repeat(B, Y, 1)
189
+
190
+ if norm:
191
+ grid_y, grid_x = normalize_grid2d(
192
+ grid_y, grid_x, Y, X)
193
+
194
+ if stack:
195
+ # note we stack in xy order
196
+ # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
197
+ if on_chans:
198
+ grid = torch.stack([grid_x, grid_y], dim=1)
199
+ else:
200
+ grid = torch.stack([grid_x, grid_y], dim=-1)
201
+ return grid
202
+ else:
203
+ return grid_y, grid_x
204
+
205
+ def meshgrid2d_py(B, Y, X, stack=False, on_chans=False):
206
+ grid_y = np.linspace(0.0, Y-1, Y)
207
+ grid_y = np.reshape(grid_y, [1, Y, 1])
208
+ grid_y = np.tile(grid_y, [B, 1, X])
209
+
210
+ grid_x = np.linspace(0.0, X-1, X)
211
+ grid_x = np.reshape(grid_x, [1, 1, X])
212
+ grid_x = np.tile(grid_x, [B, Y, 1])
213
+
214
+ if stack:
215
+ # note we stack in xy order
216
+ # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
217
+ if on_chans:
218
+ grid = np.stack([grid_x, grid_y], axis=1)
219
+ else:
220
+ grid = np.stack([grid_x, grid_y], axis=-1)
221
+ return grid
222
+ else:
223
+ # outputs are Y x X
224
+ return grid_y, grid_x
225
+
226
+
227
+ def meshgrid3d(B, Z, Y, X, stack=False, norm=False, device='cuda'):
228
+ # returns a meshgrid sized B x Z x Y x X
229
+
230
+ grid_z = torch.linspace(0.0, Z-1, Z, device=device)
231
+ grid_z = torch.reshape(grid_z, [1, Z, 1, 1])
232
+ grid_z = grid_z.repeat(B, 1, Y, X)
233
+
234
+ grid_y = torch.linspace(0.0, Y-1, Y, device=device)
235
+ grid_y = torch.reshape(grid_y, [1, 1, Y, 1])
236
+ grid_y = grid_y.repeat(B, Z, 1, X)
237
+
238
+ grid_x = torch.linspace(0.0, X-1, X, device=device)
239
+ grid_x = torch.reshape(grid_x, [1, 1, 1, X])
240
+ grid_x = grid_x.repeat(B, Z, Y, 1)
241
+
242
+ if norm:
243
+ grid_z, grid_y, grid_x = normalize_grid3d(
244
+ grid_z, grid_y, grid_x, Z, Y, X)
245
+
246
+ if stack:
247
+ # note we stack in xyz order
248
+ # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
249
+ grid = torch.stack([grid_x, grid_y, grid_z], dim=-1)
250
+ return grid
251
+ else:
252
+ return grid_z, grid_y, grid_x
253
+
254
+ def normalize_grid2d(grid_y, grid_x, Y, X, clamp_extreme=True):
255
+ # make things in [-1,1]
256
+ grid_y = 2.0*(grid_y / float(Y-1)) - 1.0
257
+ grid_x = 2.0*(grid_x / float(X-1)) - 1.0
258
+
259
+ if clamp_extreme:
260
+ grid_y = torch.clamp(grid_y, min=-2.0, max=2.0)
261
+ grid_x = torch.clamp(grid_x, min=-2.0, max=2.0)
262
+
263
+ return grid_y, grid_x
264
+
265
+ def normalize_grid3d(grid_z, grid_y, grid_x, Z, Y, X, clamp_extreme=True):
266
+ # make things in [-1,1]
267
+ grid_z = 2.0*(grid_z / float(Z-1)) - 1.0
268
+ grid_y = 2.0*(grid_y / float(Y-1)) - 1.0
269
+ grid_x = 2.0*(grid_x / float(X-1)) - 1.0
270
+
271
+ if clamp_extreme:
272
+ grid_z = torch.clamp(grid_z, min=-2.0, max=2.0)
273
+ grid_y = torch.clamp(grid_y, min=-2.0, max=2.0)
274
+ grid_x = torch.clamp(grid_x, min=-2.0, max=2.0)
275
+
276
+ return grid_z, grid_y, grid_x
277
+
278
+ def gridcloud2d(B, Y, X, norm=False, device='cuda'):
279
+ # we want to sample for each location in the grid
280
+ grid_y, grid_x = meshgrid2d(B, Y, X, norm=norm, device=device)
281
+ x = torch.reshape(grid_x, [B, -1])
282
+ y = torch.reshape(grid_y, [B, -1])
283
+ # these are B x N
284
+ xy = torch.stack([x, y], dim=2)
285
+ # this is B x N x 2
286
+ return xy
287
+
288
+ def gridcloud3d(B, Z, Y, X, norm=False, device='cuda'):
289
+ # we want to sample for each location in the grid
290
+ grid_z, grid_y, grid_x = meshgrid3d(B, Z, Y, X, norm=norm, device=device)
291
+ x = torch.reshape(grid_x, [B, -1])
292
+ y = torch.reshape(grid_y, [B, -1])
293
+ z = torch.reshape(grid_z, [B, -1])
294
+ # these are B x N
295
+ xyz = torch.stack([x, y, z], dim=2)
296
+ # this is B x N x 3
297
+ return xyz
298
+
299
+ # import re
300
+ # def readPFM(file):
301
+ # file = open(file, 'rb')
302
+
303
+ # color = None
304
+ # width = None
305
+ # height = None
306
+ # scale = None
307
+ # endian = None
308
+
309
+ # header = file.readline().rstrip()
310
+ # if header == b'PF':
311
+ # color = True
312
+ # elif header == b'Pf':
313
+ # color = False
314
+ # else:
315
+ # raise Exception('Not a PFM file.')
316
+
317
+ # dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
318
+ # if dim_match:
319
+ # width, height = map(int, dim_match.groups())
320
+ # else:
321
+ # raise Exception('Malformed PFM header.')
322
+
323
+ # scale = float(file.readline().rstrip())
324
+ # if scale < 0: # little-endian
325
+ # endian = '<'
326
+ # scale = -scale
327
+ # else:
328
+ # endian = '>' # big-endian
329
+
330
+ # data = np.fromfile(file, endian + 'f')
331
+ # shape = (height, width, 3) if color else (height, width)
332
+
333
+ # data = np.reshape(data, shape)
334
+ # data = np.flipud(data)
335
+ # return data
336
+
337
+ def normalize_boxlist2d(boxlist2d, H, W):
338
+ boxlist2d = boxlist2d.clone()
339
+ ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
340
+ ymin = ymin / float(H)
341
+ ymax = ymax / float(H)
342
+ xmin = xmin / float(W)
343
+ xmax = xmax / float(W)
344
+ boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
345
+ return boxlist2d
346
+
347
+ def unnormalize_boxlist2d(boxlist2d, H, W):
348
+ boxlist2d = boxlist2d.clone()
349
+ ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
350
+ ymin = ymin * float(H)
351
+ ymax = ymax * float(H)
352
+ xmin = xmin * float(W)
353
+ xmax = xmax * float(W)
354
+ boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
355
+ return boxlist2d
356
+
357
+ def unnormalize_box2d(box2d, H, W):
358
+ return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
359
+
360
+ def normalize_box2d(box2d, H, W):
361
+ return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
362
+
363
+ def get_gaussian_kernel_2d(channels, kernel_size=3, sigma=2.0, mid_one=False, device='cuda'):
364
+ C = channels
365
+ xy_grid = gridcloud2d(C, kernel_size, kernel_size, device=device) # C x N x 2
366
+
367
+ mean = (kernel_size - 1)/2.0
368
+ variance = sigma**2.0
369
+
370
+ gaussian_kernel = (1.0/(2.0*np.pi*variance)**1.5) * torch.exp(-torch.sum((xy_grid - mean)**2.0, dim=-1) / (2.0*variance)) # C X N
371
+ gaussian_kernel = gaussian_kernel.view(C, 1, kernel_size, kernel_size) # C x 1 x 3 x 3
372
+ kernel_sum = torch.sum(gaussian_kernel, dim=(2,3), keepdim=True)
373
+
374
+ gaussian_kernel = gaussian_kernel / kernel_sum # normalize
375
+
376
+ if mid_one:
377
+ # normalize so that the middle element is 1
378
+ maxval = gaussian_kernel[:,:,(kernel_size//2),(kernel_size//2)].reshape(C, 1, 1, 1)
379
+ gaussian_kernel = gaussian_kernel / maxval
380
+
381
+ return gaussian_kernel
382
+
383
+ def gaussian_blur_2d(input, kernel_size=3, sigma=2.0, reflect_pad=False, mid_one=False):
384
+ B, C, Z, X = input.shape
385
+ kernel = get_gaussian_kernel_2d(C, kernel_size, sigma, mid_one=mid_one, device=input.device)
386
+ if reflect_pad:
387
+ pad = (kernel_size - 1)//2
388
+ out = F.pad(input, (pad, pad, pad, pad), mode='reflect')
389
+ out = F.conv2d(out, kernel, padding=0, groups=C)
390
+ else:
391
+ out = F.conv2d(input, kernel, padding=(kernel_size - 1)//2, groups=C)
392
+ return out
393
+
394
+ def gradient2d(x, absolute=False, square=False, return_sum=False):
395
+ # x should be B x C x H x W
396
+ dh = x[:, :, 1:, :] - x[:, :, :-1, :]
397
+ dw = x[:, :, :, 1:] - x[:, :, :, :-1]
398
+
399
+ zeros = torch.zeros_like(x)
400
+ zero_h = zeros[:, :, 0:1, :]
401
+ zero_w = zeros[:, :, :, 0:1]
402
+ dh = torch.cat([dh, zero_h], axis=2)
403
+ dw = torch.cat([dw, zero_w], axis=3)
404
+ if absolute:
405
+ dh = torch.abs(dh)
406
+ dw = torch.abs(dw)
407
+ if square:
408
+ dh = dh ** 2
409
+ dw = dw ** 2
410
+ if return_sum:
411
+ return dh+dw
412
+ else:
413
+ return dh, dw
414
+
415
+ def gradient1d(x, absolute=False, square=False):
416
+ # x should be B,S,C
417
+ dx = x[:, 1:] - x[:, :-1]
418
+ zero = torch.zeros_like(x[:,0:1])
419
+ dx = torch.cat([dx, zero], axis=1)
420
+ if absolute:
421
+ dx = torch.abs(dx)
422
+ if square:
423
+ dx = dx ** 2
424
+ return dx
425
+
426
+ def smart_cat(tensor1, tensor2, dim):
427
+ if tensor1 is None:
428
+ return tensor2
429
+ return torch.cat([tensor1, tensor2], dim=dim)
utils/data.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import dataclasses
3
+ import torch.nn.functional as F
4
+ from dataclasses import dataclass
5
+ from typing import Any, Optional, Dict
6
+
7
+
8
+ @dataclass(eq=False)
9
+ class VideoData:
10
+ """
11
+ Dataclass for storing video tracks data.
12
+ """
13
+
14
+ video: torch.Tensor # B, S, C, H, W
15
+ trajs: torch.Tensor # B, S, N, 2
16
+ visibs: torch.Tensor # B, S, N
17
+ # optional data
18
+ valids: Optional[torch.Tensor] = None # B, S, N
19
+ hards: Optional[torch.Tensor] = None # B, S, N
20
+ segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
21
+ seq_name: Optional[str] = None
22
+ dname: Optional[str] = None
23
+ query_points: Optional[torch.Tensor] = None # TapVID evaluation format
24
+ transforms: Optional[Dict[str, Any]] = None
25
+ aug_video: Optional[torch.Tensor] = None
26
+
27
+
28
+ def collate_fn(batch):
29
+ """
30
+ Collate function for video tracks data.
31
+ """
32
+ video = torch.stack([b.video for b in batch], dim=0)
33
+ trajs = torch.stack([b.trajs for b in batch], dim=0)
34
+ visibs = torch.stack([b.visibs for b in batch], dim=0)
35
+ query_points = segmentation = None
36
+ if batch[0].query_points is not None:
37
+ query_points = torch.stack([b.query_points for b in batch], dim=0)
38
+ if batch[0].segmentation is not None:
39
+ segmentation = torch.stack([b.segmentation for b in batch], dim=0)
40
+ seq_name = [b.seq_name for b in batch]
41
+ dname = [b.dname for b in batch]
42
+
43
+ return VideoData(
44
+ video=video,
45
+ trajs=trajs,
46
+ visibs=visibs,
47
+ segmentation=segmentation,
48
+ seq_name=seq_name,
49
+ dname=dname,
50
+ query_points=query_points,
51
+ )
52
+
53
+
54
+ def collate_fn_train(batch):
55
+ """
56
+ Collate function for video tracks data during training.
57
+ """
58
+ gotit = [gotit for _, gotit in batch]
59
+ video = torch.stack([b.video for b, _ in batch], dim=0)
60
+ trajs = torch.stack([b.trajs for b, _ in batch], dim=0)
61
+ visibs = torch.stack([b.visibs for b, _ in batch], dim=0)
62
+ valids = torch.stack([b.valids for b, _ in batch], dim=0)
63
+ seq_name = [b.seq_name for b, _ in batch]
64
+ dname = [b.dname for b, _ in batch]
65
+ query_points = transforms = aug_video = hards = None
66
+ if batch[0][0].query_points is not None:
67
+ query_points = torch.stack([b.query_points for b, _ in batch], dim=0)
68
+ if batch[0][0].hards is not None:
69
+ hards = torch.stack([b.hards for b, _ in batch], dim=0)
70
+
71
+ if batch[0][0].transforms is not None:
72
+ transforms = [b.transforms for b, _ in batch]
73
+
74
+ if batch[0][0].aug_video is not None:
75
+ aug_video = torch.stack([b.aug_video for b, _ in batch], dim=0)
76
+ return (
77
+ VideoData(
78
+ video=video,
79
+ trajs=trajs,
80
+ visibs=visibs,
81
+ valids=valids,
82
+ hards=hards,
83
+ seq_name=seq_name,
84
+ dname=dname,
85
+ query_points=query_points,
86
+ aug_video=aug_video,
87
+ transforms=transforms,
88
+ ),
89
+ gotit,
90
+ )
91
+
92
+
93
+ def try_to_cuda(t: Any) -> Any:
94
+ """
95
+ Try to move the input variable `t` to a cuda device.
96
+
97
+ Args:
98
+ t: Input.
99
+
100
+ Returns:
101
+ t_cuda: `t` moved to a cuda device, if supported.
102
+ """
103
+ try:
104
+ t = t.float().cuda()
105
+ except AttributeError:
106
+ pass
107
+ return t
108
+
109
+
110
+ def dataclass_to_cuda_(obj):
111
+ """
112
+ Move all contents of a dataclass to cuda inplace if supported.
113
+
114
+ Args:
115
+ batch: Input dataclass.
116
+
117
+ Returns:
118
+ batch_cuda: `batch` moved to a cuda device, if supported.
119
+ """
120
+ for f in dataclasses.fields(obj):
121
+ setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
122
+ return obj
utils/geom.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import utils.basic
3
+ # import utils.box
4
+ # import utils.vox
5
+ import numpy as np
6
+ import torchvision.ops as ops
7
+ from utils.basic import print_
8
+
9
+ def split_intrinsics(K):
10
+ # K is B x 3 x 3 or B x 4 x 4
11
+ fx = K[:,0,0]
12
+ fy = K[:,1,1]
13
+ x0 = K[:,0,2]
14
+ y0 = K[:,1,2]
15
+ return fx, fy, x0, y0
16
+
17
+ # def apply_pix_T_cam_py(pix_T_cam, xyz):
18
+
19
+ # fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
20
+
21
+ # # xyz is shaped B x H*W x 3
22
+ # # returns xy, shaped B x H*W x 2
23
+
24
+ # B, N, C = list(xyz.shape)
25
+ # assert(C==3)
26
+
27
+ # x, y, z = xyz[:,:,0], xyz[:,:,1], xyz[:,:,2]
28
+
29
+ # fx = np.reshape(fx, [B, 1])
30
+ # fy = np.reshape(fy, [B, 1])
31
+ # x0 = np.reshape(x0, [B, 1])
32
+ # y0 = np.reshape(y0, [B, 1])
33
+
34
+ # EPS = 1e-4
35
+ # z = np.clip(z, EPS, None)
36
+ # x = (x*fx)/(z)+x0
37
+ # y = (y*fy)/(z)+y0
38
+ # xy = np.stack([x, y], axis=-1)
39
+ # return xy
40
+
41
+ def apply_pix_T_cam(pix_T_cam, xyz):
42
+
43
+ fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
44
+
45
+ # xyz is shaped B x H*W x 3
46
+ # returns xy, shaped B x H*W x 2
47
+
48
+ B, N, C = list(xyz.shape)
49
+ assert(C==3)
50
+
51
+ x, y, z = torch.unbind(xyz, axis=-1)
52
+
53
+ fx = torch.reshape(fx, [B, 1])
54
+ fy = torch.reshape(fy, [B, 1])
55
+ x0 = torch.reshape(x0, [B, 1])
56
+ y0 = torch.reshape(y0, [B, 1])
57
+
58
+ EPS = 1e-4
59
+ z = torch.clamp(z, min=EPS)
60
+ x = (x*fx)/(z)+x0
61
+ y = (y*fy)/(z)+y0
62
+ xy = torch.stack([x, y], axis=-1)
63
+ return xy
64
+
65
+ # def apply_4x4_py(RT, xyz):
66
+ # # print('RT', RT.shape)
67
+ # B, N, _ = list(xyz.shape)
68
+ # ones = np.ones_like(xyz[:,:,0:1])
69
+ # xyz1 = np.concatenate([xyz, ones], 2)
70
+ # # print('xyz1', xyz1.shape)
71
+ # xyz1_t = xyz1.transpose(0,2,1)
72
+ # # print('xyz1_t', xyz1_t.shape)
73
+ # # this is B x 4 x N
74
+ # xyz2_t = np.matmul(RT, xyz1_t)
75
+ # # print('xyz2_t', xyz2_t.shape)
76
+ # xyz2 = xyz2_t.transpose(0,2,1)
77
+ # # print('xyz2', xyz2.shape)
78
+ # xyz2 = xyz2[:,:,:3]
79
+ # return xyz2
80
+
81
+ def apply_4x4(RT, xyz):
82
+ B, N, _ = list(xyz.shape)
83
+ ones = torch.ones_like(xyz[:,:,0:1])
84
+ xyz1 = torch.cat([xyz, ones], 2)
85
+ xyz1_t = torch.transpose(xyz1, 1, 2)
86
+ # this is B x 4 x N
87
+ xyz2_t = torch.matmul(RT, xyz1_t)
88
+ xyz2 = torch.transpose(xyz2_t, 1, 2)
89
+ xyz2 = xyz2[:,:,:3]
90
+ return xyz2
91
+
92
+ def apply_3x3(RT, xy):
93
+ B, N, _ = list(xy.shape)
94
+ ones = torch.ones_like(xy[:,:,0:1])
95
+ xy1 = torch.cat([xy, ones], 2)
96
+ xy1_t = torch.transpose(xy1, 1, 2)
97
+ # this is B x 4 x N
98
+ xy2_t = torch.matmul(RT, xy1_t)
99
+ xy2 = torch.transpose(xy2_t, 1, 2)
100
+ xy2 = xy2[:,:,:2]
101
+ return xy2
102
+
103
+ def generate_polygon(ctr_x, ctr_y, avg_r, irregularity, spikiness, num_verts):
104
+ '''
105
+ Start with the center of the polygon at ctr_x, ctr_y,
106
+ Then creates the polygon by sampling points on a circle around the center.
107
+ Random noise is added by varying the angular spacing between sequential points,
108
+ and by varying the radial distance of each point from the centre.
109
+
110
+ Params:
111
+ ctr_x, ctr_y - coordinates of the "centre" of the polygon
112
+ avg_r - in px, the average radius of this polygon, this roughly controls how large the polygon is, really only useful for order of magnitude.
113
+ irregularity - [0,1] indicating how much variance there is in the angular spacing of vertices. [0,1] will map to [0, 2pi/numberOfVerts]
114
+ spikiness - [0,1] indicating how much variance there is in each vertex from the circle of radius avg_r. [0,1] will map to [0, avg_r]
115
+ num_verts
116
+
117
+ Returns:
118
+ np.array [num_verts, 2] - CCW order.
119
+ '''
120
+ # spikiness
121
+ spikiness = np.clip(spikiness, 0, 1) * avg_r
122
+
123
+ # generate n angle steps
124
+ irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / num_verts
125
+ lower = (2*np.pi / num_verts) - irregularity
126
+ upper = (2*np.pi / num_verts) + irregularity
127
+
128
+ # angle steps
129
+ angle_steps = np.random.uniform(lower, upper, num_verts)
130
+ sc = (2 * np.pi) / angle_steps.sum()
131
+ angle_steps *= sc
132
+
133
+ # get all radii
134
+ angle = np.random.uniform(0, 2*np.pi)
135
+ radii = np.clip(np.random.normal(avg_r, spikiness, num_verts), 0, 2 * avg_r)
136
+
137
+ # compute all points
138
+ points = []
139
+ for i in range(num_verts):
140
+ x = ctr_x + radii[i] * np.cos(angle)
141
+ y = ctr_y + radii[i] * np.sin(angle)
142
+ points.append([x, y])
143
+ angle += angle_steps[i]
144
+
145
+ return np.array(points).astype(int)
146
+
147
+
148
+ def get_random_affine_2d(B, rot_min=-5.0, rot_max=5.0, tx_min=-0.1, tx_max=0.1, ty_min=-0.1, ty_max=0.1, sx_min=-0.05, sx_max=0.05, sy_min=-0.05, sy_max=0.05, shx_min=-0.05, shx_max=0.05, shy_min=-0.05, shy_max=0.05):
149
+ '''
150
+ Params:
151
+ rot_min: rotation amount min
152
+ rot_max: rotation amount max
153
+
154
+ tx_min: translation x min
155
+ tx_max: translation x max
156
+
157
+ ty_min: translation y min
158
+ ty_max: translation y max
159
+
160
+ sx_min: scaling x min
161
+ sx_max: scaling x max
162
+
163
+ sy_min: scaling y min
164
+ sy_max: scaling y max
165
+
166
+ shx_min: shear x min
167
+ shx_max: shear x max
168
+
169
+ shy_min: shear y min
170
+ shy_max: shear y max
171
+
172
+ Returns:
173
+ transformation matrix: (B, 3, 3)
174
+ '''
175
+ # rotation
176
+ if rot_max - rot_min != 0:
177
+ rot_amount = np.random.uniform(low=rot_min, high=rot_max, size=B)
178
+ rot_amount = np.pi/180.0*rot_amount
179
+ else:
180
+ rot_amount = rot_min
181
+ rotation = np.zeros((B, 3, 3)) # B, 3, 3
182
+ rotation[:, 2, 2] = 1
183
+ rotation[:, 0, 0] = np.cos(rot_amount)
184
+ rotation[:, 0, 1] = -np.sin(rot_amount)
185
+ rotation[:, 1, 0] = np.sin(rot_amount)
186
+ rotation[:, 1, 1] = np.cos(rot_amount)
187
+
188
+ # translation
189
+ translation = np.zeros((B, 3, 3)) # B, 3, 3
190
+ translation[:, [0,1,2], [0,1,2]] = 1
191
+ if (tx_max - tx_min) > 0:
192
+ trans_x = np.random.uniform(low=tx_min, high=tx_max, size=B)
193
+ translation[:, 0, 2] = trans_x
194
+ # else:
195
+ # translation[:, 0, 2] = tx_max
196
+ if ty_max - ty_min != 0:
197
+ trans_y = np.random.uniform(low=ty_min, high=ty_max, size=B)
198
+ translation[:, 1, 2] = trans_y
199
+ # else:
200
+ # translation[:, 1, 2] = ty_max
201
+
202
+ # scaling
203
+ scaling = np.zeros((B, 3, 3)) # B, 3, 3
204
+ scaling[:, [0,1,2], [0,1,2]] = 1
205
+ if (sx_max - sx_min) > 0:
206
+ scale_x = 1 + np.random.uniform(low=sx_min, high=sx_max, size=B)
207
+ scaling[:, 0, 0] = scale_x
208
+ # else:
209
+ # scaling[:, 0, 0] = sx_max
210
+ if (sy_max - sy_min) > 0:
211
+ scale_y = 1 + np.random.uniform(low=sy_min, high=sy_max, size=B)
212
+ scaling[:, 1, 1] = scale_y
213
+ # else:
214
+ # scaling[:, 1, 1] = sy_max
215
+
216
+ # shear
217
+ shear = np.zeros((B, 3, 3)) # B, 3, 3
218
+ shear[:, [0,1,2], [0,1,2]] = 1
219
+ if (shx_max - shx_min) > 0:
220
+ shear_x = np.random.uniform(low=shx_min, high=shx_max, size=B)
221
+ shear[:, 0, 1] = shear_x
222
+ # else:
223
+ # shear[:, 0, 1] = shx_max
224
+ if (shy_max - shy_min) > 0:
225
+ shear_y = np.random.uniform(low=shy_min, high=shy_max, size=B)
226
+ shear[:, 1, 0] = shear_y
227
+ # else:
228
+ # shear[:, 1, 0] = shy_max
229
+
230
+ # compose all those
231
+ rt = np.einsum("ijk,ikl->ijl", rotation, translation)
232
+ ss = np.einsum("ijk,ikl->ijl", scaling, shear)
233
+ trans = np.einsum("ijk,ikl->ijl", rt, ss)
234
+
235
+ return trans
236
+
237
+ def get_centroid_from_box2d(box2d):
238
+ ymin = box2d[:,0]
239
+ xmin = box2d[:,1]
240
+ ymax = box2d[:,2]
241
+ xmax = box2d[:,3]
242
+ x = (xmin+xmax)/2.0
243
+ y = (ymin+ymax)/2.0
244
+ return y, x
245
+
246
+ def normalize_boxlist2d(boxlist2d, H, W):
247
+ boxlist2d = boxlist2d.clone()
248
+ ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
249
+ ymin = ymin / float(H)
250
+ ymax = ymax / float(H)
251
+ xmin = xmin / float(W)
252
+ xmax = xmax / float(W)
253
+ boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
254
+ return boxlist2d
255
+
256
+ def unnormalize_boxlist2d(boxlist2d, H, W):
257
+ boxlist2d = boxlist2d.clone()
258
+ ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
259
+ ymin = ymin * float(H)
260
+ ymax = ymax * float(H)
261
+ xmin = xmin * float(W)
262
+ xmax = xmax * float(W)
263
+ boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
264
+ return boxlist2d
265
+
266
+ def unnormalize_box2d(box2d, H, W):
267
+ return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
268
+
269
+ def normalize_box2d(box2d, H, W):
270
+ return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
271
+
272
+ def get_size_from_box2d(box2d):
273
+ ymin = box2d[:,0]
274
+ xmin = box2d[:,1]
275
+ ymax = box2d[:,2]
276
+ xmax = box2d[:,3]
277
+ height = ymax-ymin
278
+ width = xmax-xmin
279
+ return height, width
280
+
281
+ def crop_and_resize(im, boxlist, PH, PW, boxlist_is_normalized=False):
282
+ B, C, H, W = im.shape
283
+ B2, N, D = boxlist.shape
284
+ assert(B==B2)
285
+ assert(D==4)
286
+ # PH, PW is the size to resize to
287
+
288
+ # print('im', im.shape)
289
+
290
+ # output is B,N,C,PH,PW
291
+
292
+ # pt wants xy xy, unnormalized
293
+ if boxlist_is_normalized:
294
+ boxlist_unnorm = unnormalize_boxlist2d(boxlist, H, W)
295
+ else:
296
+ boxlist_unnorm = boxlist
297
+
298
+ ymin, xmin, ymax, xmax = boxlist_unnorm.unbind(2)
299
+ # boxlist_pt = torch.stack([boxlist_unnorm[:,1], boxlist_unnorm[:,0], boxlist_unnorm[:,3], boxlist_unnorm[:,2]], dim=1)
300
+ boxlist_pt = torch.stack([xmin, ymin, xmax, ymax], dim=2)
301
+ # we want a B-len list of K x 4 arrays
302
+
303
+ # print('im', im.shape)
304
+ # print('boxlist', boxlist.shape)
305
+ # print('boxlist_pt', boxlist_pt.shape)
306
+
307
+ # boxlist_pt = list(boxlist_pt.unbind(0))
308
+
309
+ # crops = []
310
+ # for b in range(B):
311
+ # crops_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW))
312
+ # crops.append(crops_b)
313
+ # # # crops = im
314
+
315
+ # inds = torch.arange(B).reshape(B,1)
316
+ # # boxlist_pt = torch.cat([inds, boxlist_
317
+ # boxlist_pt = list(boxlist_pt.unbind(0)
318
+
319
+ # crops = ops.roi_align(im, [boxlist_pt[b]], output_size=(PH, PW))
320
+ crops = ops.roi_align(im, list(boxlist_pt.unbind(0)), output_size=(PH, PW))
321
+
322
+
323
+ # print('crops', crops.shape)
324
+ # crops = crops.reshape(B,N,C,PH,PW)
325
+
326
+
327
+ # crops = []
328
+ # for b in range(B):
329
+ # crop_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW))
330
+ # print('crop_b', crop_b.shape)
331
+ # crops.append(crop_b)
332
+ # crops = torch.stack(crops, dim=0)
333
+
334
+ # print('crops', crops.shape)
335
+ # boxlist_list = boxlist_pt.unbind(0)
336
+ # print('rgb_crop', rgb_crop.shape)
337
+
338
+ return crops
339
+
340
+
341
+ # def get_boxlist_from_centroid_and_size(cy, cx, h, w, clip=True):
342
+ # # cy,cx are both B,N
343
+ # ymin = cy - h/2
344
+ # ymax = cy + h/2
345
+ # xmin = cx - w/2
346
+ # xmax = cx + w/2
347
+
348
+ # box = torch.stack([ymin, xmin, ymax, xmax], dim=-1)
349
+ # if clip:
350
+ # box = torch.clamp(box, 0, 1)
351
+ # return box
352
+
353
+
354
+ def get_boxlist_from_centroid_and_size(cy, cx, h, w):#, clip=False):
355
+ # cy,cx are the same shape
356
+ ymin = cy - h/2
357
+ ymax = cy + h/2
358
+ xmin = cx - w/2
359
+ xmax = cx + w/2
360
+
361
+ # if clip:
362
+ # ymin = torch.clamp(ymin, 0, H-1)
363
+ # ymax = torch.clamp(ymax, 0, H-1)
364
+ # xmin = torch.clamp(xmin, 0, W-1)
365
+ # xmax = torch.clamp(xmax, 0, W-1)
366
+
367
+ box = torch.stack([ymin, xmin, ymax, xmax], dim=-1)
368
+ return box
369
+
370
+
371
+ def get_box2d_from_mask(mask, normalize=False):
372
+ # mask is B, 1, H, W
373
+
374
+ B, C, H, W = mask.shape
375
+ assert(C==1)
376
+ xy = utils.basic.gridcloud2d(B, H, W, norm=False, device=mask.device) # B, H*W, 2
377
+
378
+ box = torch.zeros((B, 4), dtype=torch.float32, device=mask.device)
379
+ for b in range(B):
380
+ xy_b = xy[b] # H*W, 2
381
+ mask_b = mask[b].reshape(H*W)
382
+ xy_ = xy_b[mask_b > 0]
383
+ x_ = xy_[:,0]
384
+ y_ = xy_[:,1]
385
+ ymin = torch.min(y_)
386
+ ymax = torch.max(y_)
387
+ xmin = torch.min(x_)
388
+ xmax = torch.max(x_)
389
+ box[b] = torch.stack([ymin, xmin, ymax, xmax], dim=0)
390
+ if normalize:
391
+ box = normalize_boxlist2d(box.unsqueeze(1), H, W).squeeze(1)
392
+ return box
393
+
394
+ def convert_box2d_to_intrinsics(box2d, pix_T_cam, H, W, use_image_aspect_ratio=True, mult_padding=1.0):
395
+ # box2d is B x 4, with ymin, xmin, ymax, xmax in normalized coords
396
+ # ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1)
397
+ # H, W is the original size of the image
398
+ # mult_padding is relative to object size in pixels
399
+
400
+ # i assume we're rendering an image the same size as the original (H, W)
401
+
402
+ if not mult_padding==1.0:
403
+ y, x = get_centroid_from_box2d(box2d)
404
+ h, w = get_size_from_box2d(box2d)
405
+ box2d = get_box2d_from_centroid_and_size(
406
+ y, x, h*mult_padding, w*mult_padding, clip=False)
407
+
408
+ if use_image_aspect_ratio:
409
+ h, w = get_size_from_box2d(box2d)
410
+ y, x = get_centroid_from_box2d(box2d)
411
+
412
+ # note h,w are relative right now
413
+ # we need to undo this, to see the real ratio
414
+
415
+ h = h*float(H)
416
+ w = w*float(W)
417
+ box_ratio = h/w
418
+ im_ratio = H/float(W)
419
+
420
+ # print('box_ratio:', box_ratio)
421
+ # print('im_ratio:', im_ratio)
422
+
423
+ if box_ratio >= im_ratio:
424
+ w = h/im_ratio
425
+ # print('setting w:', h/im_ratio)
426
+ else:
427
+ h = w*im_ratio
428
+ # print('setting h:', w*im_ratio)
429
+
430
+ box2d = get_box2d_from_centroid_and_size(
431
+ y, x, h/float(H), w/float(W), clip=False)
432
+
433
+ assert(h > 1e-4)
434
+ assert(w > 1e-4)
435
+
436
+ ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1)
437
+
438
+ fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
439
+
440
+ # the topleft of the new image will now have a different offset from the center of projection
441
+
442
+ new_x0 = x0 - xmin*W
443
+ new_y0 = y0 - ymin*H
444
+
445
+ pix_T_cam = pack_intrinsics(fx, fy, new_x0, new_y0)
446
+ # this alone will give me an image in original resolution,
447
+ # with its topleft at the box corner
448
+
449
+ box_h, box_w = get_size_from_box2d(box2d)
450
+ # these are normalized, and shaped B. (e.g., [0.4], [0.3])
451
+
452
+ # we are going to scale the image by the inverse of this,
453
+ # since we are zooming into this area
454
+
455
+ sy = 1./box_h
456
+ sx = 1./box_w
457
+
458
+ pix_T_cam = scale_intrinsics(pix_T_cam, sx, sy)
459
+ return pix_T_cam, box2d
460
+
461
+ def generatePolygon(ctr_x, ctr_y, avg_r, irregularity, spikiness, num_verts):
462
+ '''
463
+ Start with the center of the polygon at ctr_x, ctr_y,
464
+ Then creates the polygon by sampling points on a circle around the center.
465
+ Random noise is added by varying the angular spacing between sequential points,
466
+ and by varying the radial distance of each point from the centre.
467
+
468
+ Params:
469
+ ctr_x, ctr_y - coordinates of the "centre" of the polygon
470
+ avg_r - in px, the average radius of this polygon, this roughly controls how large the polygon is, really only useful for order of magnitude.
471
+ irregularity - [0,1] indicating how much variance there is in the angular spacing of vertices. [0,1] will map to [0, 2pi/numberOfVerts]
472
+ spikiness - [0,1] indicating how much variance there is in each vertex from the circle of radius avg_r. [0,1] will map to [0, avg_r]
473
+ num_verts
474
+
475
+ Returns:
476
+ np.array [num_verts, 2] - CCW order.
477
+ '''
478
+ # spikiness
479
+ spikiness = np.clip(spikiness, 0, 1) * avg_r
480
+
481
+ # generate n angle steps
482
+ irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / num_verts
483
+ lower = (2*np.pi / num_verts) - irregularity
484
+ upper = (2*np.pi / num_verts) + irregularity
485
+
486
+ # angle steps
487
+ angle_steps = np.random.uniform(lower, upper, num_verts)
488
+ sc = (2 * np.pi) / angle_steps.sum()
489
+ angle_steps *= sc
490
+
491
+ # get all radii
492
+ angle = np.random.uniform(0, 2*np.pi)
493
+ radii = np.clip(np.random.normal(avg_r, spikiness, num_verts), 0, 2 * avg_r)
494
+
495
+ # compute all points
496
+ points = []
497
+ for i in range(num_verts):
498
+ x = ctr_x + radii[i] * np.cos(angle)
499
+ y = ctr_y + radii[i] * np.sin(angle)
500
+ points.append([x, y])
501
+ angle += angle_steps[i]
502
+
503
+ return np.array(points).astype(int)
504
+
505
+ def get_random_affine_2d(B, rot_min=-5.0, rot_max=5.0, tx_min=-0.1, tx_max=0.1, ty_min=-0.1, ty_max=0.1, sx_min=-0.05, sx_max=0.05, sy_min=-0.05, sy_max=0.05, shx_min=-0.05, shx_max=0.05, shy_min=-0.05, shy_max=0.05):
506
+ '''
507
+ Params:
508
+ rot_min: rotation amount min
509
+ rot_max: rotation amount max
510
+
511
+ tx_min: translation x min
512
+ tx_max: translation x max
513
+
514
+ ty_min: translation y min
515
+ ty_max: translation y max
516
+
517
+ sx_min: scaling x min
518
+ sx_max: scaling x max
519
+
520
+ sy_min: scaling y min
521
+ sy_max: scaling y max
522
+
523
+ shx_min: shear x min
524
+ shx_max: shear x max
525
+
526
+ shy_min: shear y min
527
+ shy_max: shear y max
528
+
529
+ Returns:
530
+ transformation matrix: (B, 3, 3)
531
+ '''
532
+ # rotation
533
+ if rot_max - rot_min != 0:
534
+ rot_amount = np.random.uniform(low=rot_min, high=rot_max, size=B)
535
+ rot_amount = np.pi/180.0*rot_amount
536
+ else:
537
+ rot_amount = rot_min
538
+ rotation = np.zeros((B, 3, 3)) # B, 3, 3
539
+ rotation[:, 2, 2] = 1
540
+ rotation[:, 0, 0] = np.cos(rot_amount)
541
+ rotation[:, 0, 1] = -np.sin(rot_amount)
542
+ rotation[:, 1, 0] = np.sin(rot_amount)
543
+ rotation[:, 1, 1] = np.cos(rot_amount)
544
+
545
+ # translation
546
+ translation = np.zeros((B, 3, 3)) # B, 3, 3
547
+ translation[:, [0,1,2], [0,1,2]] = 1
548
+ if tx_max - tx_min != 0:
549
+ trans_x = np.random.uniform(low=tx_min, high=tx_max, size=B)
550
+ translation[:, 0, 2] = trans_x
551
+ else:
552
+ translation[:, 0, 2] = tx_max
553
+ if ty_max - ty_min != 0:
554
+ trans_y = np.random.uniform(low=ty_min, high=ty_max, size=B)
555
+ translation[:, 1, 2] = trans_y
556
+ else:
557
+ translation[:, 1, 2] = ty_max
558
+
559
+ # scaling
560
+ scaling = np.zeros((B, 3, 3)) # B, 3, 3
561
+ scaling[:, [0,1,2], [0,1,2]] = 1
562
+ if sx_max - sx_min != 0:
563
+ scale_x = 1 + np.random.uniform(low=sx_min, high=sx_max, size=B)
564
+ scaling[:, 0, 0] = scale_x
565
+ else:
566
+ scaling[:, 0, 0] = sx_max
567
+ if sy_max - sy_min != 0:
568
+ scale_y = 1 + np.random.uniform(low=sy_min, high=sy_max, size=B)
569
+ scaling[:, 1, 1] = scale_y
570
+ else:
571
+ scaling[:, 1, 1] = sy_max
572
+
573
+ # shear
574
+ shear = np.zeros((B, 3, 3)) # B, 3, 3
575
+ shear[:, [0,1,2], [0,1,2]] = 1
576
+ if shx_max - shx_min != 0:
577
+ shear_x = np.random.uniform(low=shx_min, high=shx_max, size=B)
578
+ shear[:, 0, 1] = shear_x
579
+ else:
580
+ shear[:, 0, 1] = shx_max
581
+ if shy_max - shy_min != 0:
582
+ shear_y = np.random.uniform(low=shy_min, high=shy_max, size=B)
583
+ shear[:, 1, 0] = shear_y
584
+ else:
585
+ shear[:, 1, 0] = shy_max
586
+
587
+ # compose all those
588
+ rt = np.einsum("ijk,ikl->ijl", rotation, translation)
589
+ ss = np.einsum("ijk,ikl->ijl", scaling, shear)
590
+ trans = np.einsum("ijk,ikl->ijl", rt, ss)
591
+
592
+ return trans
593
+
594
+ def pixels2camera(x,y,z,fx,fy,x0,y0):
595
+ # x and y are locations in pixel coordinates, z is a depth in meters
596
+ # they can be images or pointclouds
597
+ # fx, fy, x0, y0 are camera intrinsics
598
+ # returns xyz, sized B x N x 3
599
+
600
+ B = x.shape[0]
601
+
602
+ fx = torch.reshape(fx, [B,1])
603
+ fy = torch.reshape(fy, [B,1])
604
+ x0 = torch.reshape(x0, [B,1])
605
+ y0 = torch.reshape(y0, [B,1])
606
+
607
+ x = torch.reshape(x, [B,-1])
608
+ y = torch.reshape(y, [B,-1])
609
+ z = torch.reshape(z, [B,-1])
610
+
611
+ # unproject
612
+ x = (z/fx)*(x-x0)
613
+ y = (z/fy)*(y-y0)
614
+
615
+ xyz = torch.stack([x,y,z], dim=2)
616
+ # B x N x 3
617
+ return xyz
618
+
619
+ def camera2pixels(xyz, pix_T_cam):
620
+ # xyz is shaped B x H*W x 3
621
+ # returns xy, shaped B x H*W x 2
622
+
623
+ fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
624
+ x, y, z = torch.unbind(xyz, dim=-1)
625
+ B = list(z.shape)[0]
626
+
627
+ fx = torch.reshape(fx, [B,1])
628
+ fy = torch.reshape(fy, [B,1])
629
+ x0 = torch.reshape(x0, [B,1])
630
+ y0 = torch.reshape(y0, [B,1])
631
+ x = torch.reshape(x, [B,-1])
632
+ y = torch.reshape(y, [B,-1])
633
+ z = torch.reshape(z, [B,-1])
634
+
635
+ EPS = 1e-4
636
+ z = torch.clamp(z, min=EPS)
637
+ x = (x*fx)/z + x0
638
+ y = (y*fy)/z + y0
639
+ xy = torch.stack([x, y], dim=-1)
640
+ return xy
641
+
642
+ def depth2pointcloud(z, pix_T_cam):
643
+ B, C, H, W = list(z.shape)
644
+ device = z.device
645
+ y, x = utils.basic.meshgrid2d(B, H, W, device=device)
646
+ z = torch.reshape(z, [B, H, W])
647
+ fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
648
+ xyz = pixels2camera(x, y, z, fx, fy, x0, y0)
649
+ return xyz
650
+
651
+
652
+
653
+ def create_depth_image_single(xy, z, H, W, force_positive=True, max_val=100.0, serial=False, slices=20):
654
+ # turn the xy coordinates into image inds
655
+ xy = torch.round(xy).long()
656
+ depth = torch.zeros(H*W, dtype=torch.float32, device=xy.device)
657
+ depth[:] = max_val
658
+
659
+ # lidar reports a sphere of measurements
660
+ # only use the inds that are within the image bounds
661
+ # also, only use forward-pointing depths (z > 0)
662
+ valid_inds = (xy[:,0] <= W-1) & (xy[:,1] <= H-1) & (xy[:,0] >= 0) & (xy[:,1] >= 0) & (z[:] > 0)
663
+
664
+ # gather these up
665
+ xy = xy[valid_inds]
666
+ z = z[valid_inds]
667
+
668
+ inds = utils.basic.sub2ind(H, W, xy[:,1], xy[:,0]).long()
669
+ if not serial:
670
+ depth[inds] = z
671
+ else:
672
+ if False:
673
+ for (index, replacement) in zip(inds, z):
674
+ if depth[index] > replacement:
675
+ depth[index] = replacement
676
+ # ok my other idea is:
677
+ # sort the depths by distance
678
+ # create N depth maps
679
+ # merge them back-to-front
680
+
681
+ # actually maybe you don't even need the separate maps
682
+
683
+ sort_inds = torch.argsort(z, descending=True)
684
+ xy = xy[sort_inds]
685
+ z = z[sort_inds]
686
+ N = len(sort_inds)
687
+ def split(a, n):
688
+ k, m = divmod(len(a), n)
689
+ return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
690
+
691
+ slice_inds = split(range(N), slices)
692
+ for si in slice_inds:
693
+ mini_z = z[si]
694
+ mini_xy = xy[si]
695
+ inds = utils.basic.sub2ind(H, W, mini_xy[:,1], mini_xy[:,0]).long()
696
+ depth[inds] = mini_z
697
+ # cool; this is rougly as fast as the parallel, and as accurate as the serial
698
+
699
+ if False:
700
+ print('inds', inds.shape)
701
+ unique, inverse, counts = torch.unique(inds, return_inverse=True, return_counts=True)
702
+ print('unique', unique.shape)
703
+
704
+ perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
705
+ inverse, perm = inverse.flip([0]), perm.flip([0])
706
+ perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm)
707
+
708
+ # new_inds = inds[inverse_inds]
709
+ # depth[new_inds] = z[unique_inds]
710
+
711
+ depth[unique] = z[perm]
712
+
713
+ # now for the duplicates...
714
+
715
+ dup = counts > 1
716
+ dup_unique = unique[dup]
717
+ print('dup_unique', dup_unique.shape)
718
+ depth[dup_unique] = 0.5
719
+
720
+ if force_positive:
721
+ # valid = (depth > 0.0).float()
722
+ depth[torch.where(depth == 0.0)] = max_val
723
+ # else:
724
+ # valid = torch.ones_like(depth)
725
+
726
+ valid = (depth > 0.0).float() * (depth < max_val).float()
727
+
728
+ depth = torch.reshape(depth, [1, H, W])
729
+ valid = torch.reshape(valid, [1, H, W])
730
+ return depth, valid
731
+
732
+ def create_depth_image(pix_T_cam, xyz_cam, H, W, offset_amount=0, max_val=100.0, serial=False, slices=20):
733
+ B, N, D = list(xyz_cam.shape)
734
+ assert(D==3)
735
+ B2, E, F = list(pix_T_cam.shape)
736
+ assert(B==B2)
737
+ assert(E==4)
738
+ assert(F==4)
739
+ xy = apply_pix_T_cam(pix_T_cam, xyz_cam)
740
+ z = xyz_cam[:,:,2]
741
+
742
+ depth = torch.zeros(B, 1, H, W, dtype=torch.float32, device=xyz_cam.device)
743
+ valid = torch.zeros(B, 1, H, W, dtype=torch.float32, device=xyz_cam.device)
744
+ for b in list(range(B)):
745
+ xy_b, z_b = xy[b], z[b]
746
+ ind = z_b > 0
747
+ xy_b = xy_b[ind]
748
+ z_b = z_b[ind]
749
+ depth_b, valid_b = create_depth_image_single(xy_b, z_b, H, W, max_val=max_val, serial=serial, slices=slices)
750
+ if offset_amount:
751
+ depth_b = depth_b.reshape(-1)
752
+ valid_b = valid_b.reshape(-1)
753
+
754
+ for off_x in range(offset_amount):
755
+ for off_y in range(offset_amount):
756
+ for sign in [-1,1]:
757
+ offset = np.array([sign*off_x,sign*off_y]).astype(np.float32)
758
+ offset = torch.from_numpy(offset).reshape(1, 2).to(xyz_cam.device)
759
+ # offsets.append(offset)
760
+ depth_, valid_ = create_depth_image_single(xy_b + offset, z_b, H, W, max_val=max_val)
761
+ depth_ = depth_.reshape(-1)
762
+ valid_ = valid_.reshape(-1)
763
+ # at invalid locations, use this new value
764
+ depth_b[valid_b==0] = depth_[valid_b==0]
765
+ valid_b[valid_b==0] = valid_[valid_b==0]
766
+
767
+ depth_b = depth_b.reshape(1, H, W)
768
+ valid_b = valid_b.reshape(1, H, W)
769
+ depth[b] = depth_b
770
+ valid[b] = valid_b
771
+ return depth, valid
utils/improc.py ADDED
@@ -0,0 +1,1645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import utils.basic
4
+ import utils.py
5
+ from sklearn.decomposition import PCA
6
+ from matplotlib import cm
7
+ import matplotlib.pyplot as plt
8
+ import cv2
9
+ import torch.nn.functional as F
10
+ import torchvision
11
+ EPS = 1e-6
12
+
13
+ from skimage.color import (
14
+ rgb2lab, rgb2yuv, rgb2ycbcr, lab2rgb, yuv2rgb, ycbcr2rgb,
15
+ rgb2hsv, hsv2rgb, rgb2xyz, xyz2rgb, rgb2hed, hed2rgb)
16
+
17
+ def _convert(input_, type_):
18
+ return {
19
+ 'float': input_.float(),
20
+ 'double': input_.double(),
21
+ }.get(type_, input_)
22
+
23
+ def _generic_transform_sk_3d(transform, in_type='', out_type=''):
24
+ def apply_transform_individual(input_):
25
+ device = input_.device
26
+ input_ = input_.cpu()
27
+ input_ = _convert(input_, in_type)
28
+
29
+ input_ = input_.permute(1, 2, 0).detach().numpy()
30
+ transformed = transform(input_)
31
+ output = torch.from_numpy(transformed).float().permute(2, 0, 1)
32
+ output = _convert(output, out_type)
33
+ return output.to(device)
34
+
35
+ def apply_transform(input_):
36
+ to_stack = []
37
+ for image in input_:
38
+ to_stack.append(apply_transform_individual(image))
39
+ return torch.stack(to_stack)
40
+ return apply_transform
41
+
42
+ hsv_to_rgb = _generic_transform_sk_3d(hsv2rgb)
43
+
44
+ def preprocess_color_tf(x):
45
+ import tensorflow as tf
46
+ return tf.cast(x,tf.float32) * 1./255 - 0.5
47
+
48
+ def preprocess_color(x):
49
+ if isinstance(x, np.ndarray):
50
+ return x.astype(np.float32) * 1./255 - 0.5
51
+ else:
52
+ return x.float() * 1./255 - 0.5
53
+
54
+ def pca_embed(emb, keep, valid=None):
55
+ ## emb -- [S,H/2,W/2,C]
56
+ ## keep is the number of principal components to keep
57
+ ## Helper function for reduce_emb.
58
+ emb = emb + EPS
59
+ #emb is B x C x H x W
60
+ emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C
61
+
62
+ if valid:
63
+ valid = valid.cpu().detach().numpy().reshape((H*W))
64
+
65
+ emb_reduced = list()
66
+
67
+ B, H, W, C = np.shape(emb)
68
+ for img in emb:
69
+ if np.isnan(img).any():
70
+ emb_reduced.append(np.zeros([H, W, keep]))
71
+ continue
72
+
73
+ pixels_kd = np.reshape(img, (H*W, C))
74
+
75
+ if valid:
76
+ pixels_kd_pca = pixels_kd[valid]
77
+ else:
78
+ pixels_kd_pca = pixels_kd
79
+
80
+ P = PCA(keep)
81
+ P.fit(pixels_kd_pca)
82
+
83
+ if valid:
84
+ pixels3d = P.transform(pixels_kd)*valid
85
+ else:
86
+ pixels3d = P.transform(pixels_kd)
87
+
88
+ out_img = np.reshape(pixels3d, [H,W,keep]).astype(np.float32)
89
+ if np.isnan(out_img).any():
90
+ emb_reduced.append(np.zeros([H, W, keep]))
91
+ continue
92
+
93
+ emb_reduced.append(out_img)
94
+
95
+ emb_reduced = np.stack(emb_reduced, axis=0).astype(np.float32)
96
+
97
+ return torch.from_numpy(emb_reduced).permute(0, 3, 1, 2)
98
+
99
+ def pca_embed_together(emb, keep):
100
+ ## emb -- [S,H/2,W/2,C]
101
+ ## keep is the number of principal components to keep
102
+ ## Helper function for reduce_emb.
103
+ emb = emb + EPS
104
+ #emb is B x C x H x W
105
+ emb = emb.permute(0, 2, 3, 1).cpu().detach().float().numpy() #this is B x H x W x C
106
+
107
+ B, H, W, C = np.shape(emb)
108
+ if np.isnan(emb).any():
109
+ return torch.zeros(B, keep, H, W)
110
+
111
+ pixelskd = np.reshape(emb, (B*H*W, C))
112
+ P = PCA(keep)
113
+ P.fit(pixelskd)
114
+ pixels3d = P.transform(pixelskd)
115
+ out_img = np.reshape(pixels3d, [B,H,W,keep]).astype(np.float32)
116
+
117
+ if np.isnan(out_img).any():
118
+ return torch.zeros(B, keep, H, W)
119
+
120
+ return torch.from_numpy(out_img).permute(0, 3, 1, 2)
121
+
122
+ def reduce_emb(emb, valid=None, inbound=None, together=False):
123
+ ## emb -- [S,C,H/2,W/2], inbound -- [S,1,H/2,W/2]
124
+ ## Reduce number of chans to 3 with PCA. For vis.
125
+ # S,H,W,C = emb.shape.as_list()
126
+ S, C, H, W = list(emb.size())
127
+ keep = 4
128
+
129
+ if together:
130
+ reduced_emb = pca_embed_together(emb, keep)
131
+ else:
132
+ reduced_emb = pca_embed(emb, keep, valid) #not im
133
+
134
+ reduced_emb = reduced_emb[:,1:]
135
+ reduced_emb = utils.basic.normalize(reduced_emb) - 0.5
136
+ if inbound is not None:
137
+ emb_inbound = emb*inbound
138
+ else:
139
+ emb_inbound = None
140
+
141
+ return reduced_emb, emb_inbound
142
+
143
+ def get_feat_pca(feat, valid=None):
144
+ B, C, D, W = list(feat.size())
145
+ # feat is B x C x D x W. If 3D input, average it through Height dimension before passing into this function.
146
+
147
+ pca, _ = reduce_emb(feat, valid=valid,inbound=None, together=True)
148
+ # pca is B x 3 x W x D
149
+ return pca
150
+
151
+ def gif_and_tile(ims, just_gif=False):
152
+ S = len(ims)
153
+ # each im is B x H x W x C
154
+ # i want a gif in the left, and the tiled frames on the right
155
+ # for the gif tool, this means making a B x S x H x W tensor
156
+ # where the leftmost part is sequential and the rest is tiled
157
+ gif = torch.stack(ims, dim=1)
158
+ if just_gif:
159
+ return gif
160
+ til = torch.cat(ims, dim=2)
161
+ til = til.unsqueeze(dim=1).repeat(1, S, 1, 1, 1)
162
+ im = torch.cat([gif, til], dim=3)
163
+ return im
164
+
165
+ def back2color(i, blacken_zeros=False):
166
+ if blacken_zeros:
167
+ const = torch.tensor([-0.5])
168
+ i = torch.where(i==0.0, const.cuda() if i.is_cuda else const, i)
169
+ return back2color(i)
170
+ else:
171
+ return ((i+0.5)*255).type(torch.ByteTensor)
172
+
173
+ def xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=False):
174
+ # xy is B x N x 2, containing float x and y coordinates of N things
175
+ # grid_xs and grid_ys are B x N x Y x X
176
+
177
+ B, N, Y, X = list(grid_xs.shape)
178
+
179
+ mu_x = xy[:,:,0].clone()
180
+ mu_y = xy[:,:,1].clone()
181
+
182
+ x_valid = (mu_x>-0.5) & (mu_x<float(X+0.5))
183
+ y_valid = (mu_y>-0.5) & (mu_y<float(Y+0.5))
184
+ not_valid = ~(x_valid & y_valid)
185
+
186
+ mu_x[not_valid] = -10000
187
+ mu_y[not_valid] = -10000
188
+
189
+ mu_x = mu_x.reshape(B, N, 1, 1).repeat(1, 1, Y, X)
190
+ mu_y = mu_y.reshape(B, N, 1, 1).repeat(1, 1, Y, X)
191
+
192
+ sigma_sq = sigma*sigma
193
+ # sigma_sq = (sigma*sigma).reshape(B, N, 1, 1)
194
+ sq_diff_x = (grid_xs - mu_x)**2
195
+ sq_diff_y = (grid_ys - mu_y)**2
196
+
197
+ term1 = 1./2.*np.pi*sigma_sq
198
+ term2 = torch.exp(-(sq_diff_x+sq_diff_y)/(2.*sigma_sq))
199
+ gauss = term1*term2
200
+
201
+ if norm:
202
+ # normalize so each gaussian peaks at 1
203
+ gauss_ = gauss.reshape(B*N, Y, X)
204
+ gauss_ = utils.basic.normalize(gauss_)
205
+ gauss = gauss_.reshape(B, N, Y, X)
206
+
207
+ return gauss
208
+
209
+ def xy2heatmaps(xy, Y, X, sigma=30.0, norm=True):
210
+ # xy is B x N x 2
211
+
212
+ B, N, D = list(xy.shape)
213
+ assert(D==2)
214
+
215
+ device = xy.device
216
+
217
+ grid_y, grid_x = utils.basic.meshgrid2d(B, Y, X, device=device)
218
+ # grid_x and grid_y are B x Y x X
219
+ grid_xs = grid_x.unsqueeze(1).repeat(1, N, 1, 1)
220
+ grid_ys = grid_y.unsqueeze(1).repeat(1, N, 1, 1)
221
+ heat = xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=norm)
222
+ return heat
223
+
224
+ def draw_circles_at_xy(xy, Y, X, sigma=12.5, round=False):
225
+ B, N, D = list(xy.shape)
226
+ assert(D==2)
227
+ prior = xy2heatmaps(xy, Y, X, sigma=sigma)
228
+ # prior is B x N x Y x X
229
+ if round:
230
+ prior = (prior > 0.5).float()
231
+ return prior
232
+
233
+ def seq2color(im, norm=True, colormap='coolwarm'):
234
+ B, S, H, W = list(im.shape)
235
+ # S is sequential
236
+
237
+ # prep a mask of the valid pixels, so we can blacken the invalids later
238
+ mask = torch.max(im, dim=1, keepdim=True)[0]
239
+
240
+ # turn the S dim into an explicit sequence
241
+ coeffs = np.linspace(1.0, float(S), S).astype(np.float32)/float(S)
242
+
243
+ # # increase the spacing from the center
244
+ # coeffs[:int(S/2)] -= 2.0
245
+ # coeffs[int(S/2)+1:] += 2.0
246
+
247
+ coeffs = torch.from_numpy(coeffs).float().cuda()
248
+ coeffs = coeffs.reshape(1, S, 1, 1).repeat(B, 1, H, W)
249
+ # scale each channel by the right coeff
250
+ im = im * coeffs
251
+ # now im is in [1/S, 1], except for the invalid parts which are 0
252
+ # keep the highest valid coeff at each pixel
253
+ im = torch.max(im, dim=1, keepdim=True)[0]
254
+
255
+ out = []
256
+ for b in range(B):
257
+ im_ = im[b]
258
+ # move channels out to last dim_
259
+ im_ = im_.detach().cpu().numpy()
260
+ im_ = np.squeeze(im_)
261
+ # im_ is H x W
262
+ if colormap=='coolwarm':
263
+ im_ = cm.coolwarm(im_)[:, :, :3]
264
+ elif colormap=='PiYG':
265
+ im_ = cm.PiYG(im_)[:, :, :3]
266
+ elif colormap=='winter':
267
+ im_ = cm.winter(im_)[:, :, :3]
268
+ elif colormap=='spring':
269
+ im_ = cm.spring(im_)[:, :, :3]
270
+ elif colormap=='onediff':
271
+ im_ = np.reshape(im_, (-1))
272
+ im0_ = cm.spring(im_)[:, :3]
273
+ im1_ = cm.winter(im_)[:, :3]
274
+ im1_[im_==1/float(S)] = im0_[im_==1/float(S)]
275
+ im_ = np.reshape(im1_, (H, W, 3))
276
+ else:
277
+ assert(False) # invalid colormap
278
+ # move channels into dim 0
279
+ im_ = np.transpose(im_, [2, 0, 1])
280
+ im_ = torch.from_numpy(im_).float().cuda()
281
+ out.append(im_)
282
+ out = torch.stack(out, dim=0)
283
+
284
+ # blacken the invalid pixels, instead of using the 0-color
285
+ out = out*mask
286
+ # out = out*255.0
287
+
288
+ # put it in [-0.5, 0.5]
289
+ out = out - 0.5
290
+
291
+ return out
292
+
293
+ def colorize(d):
294
+ # this is actually just grayscale right now
295
+
296
+ if d.ndim==2:
297
+ d = d.unsqueeze(dim=0)
298
+ else:
299
+ assert(d.ndim==3)
300
+
301
+ # color_map = cm.get_cmap('plasma')
302
+ color_map = cm.get_cmap('inferno')
303
+ # S1, D = traj.shape
304
+
305
+ # print('d1', d.shape)
306
+ C,H,W = d.shape
307
+ assert(C==1)
308
+ d = d.reshape(-1)
309
+ d = d.detach().cpu().numpy()
310
+ # print('d2', d.shape)
311
+ color = np.array(color_map(d)) * 255 # rgba
312
+ # print('color1', color.shape)
313
+ color = np.reshape(color[:,:3], [H*W, 3])
314
+ # print('color2', color.shape)
315
+ color = torch.from_numpy(color).permute(1,0).reshape(3,H,W)
316
+ # # gather
317
+ # cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray')
318
+ # if cmap=='RdBu' or cmap=='RdYlGn':
319
+ # colors = cm(np.arange(256))[:, :3]
320
+ # else:
321
+ # colors = cm.colors
322
+ # colors = np.array(colors).astype(np.float32)
323
+ # colors = np.reshape(colors, [-1, 3])
324
+ # colors = tf.constant(colors, dtype=tf.float32)
325
+
326
+ # value = tf.gather(colors, indices)
327
+ # colorize(value, normalize=True, vmin=None, vmax=None, cmap=None, vals=255)
328
+
329
+ # copy to the three chans
330
+ # d = d.repeat(3, 1, 1)
331
+ return color
332
+
333
+
334
+ def oned2inferno(d, norm=True, do_colorize=False):
335
+ # convert a 1chan input to a 3chan image output
336
+
337
+ # if it's just B x H x W, add a C dim
338
+ if d.ndim==3:
339
+ d = d.unsqueeze(dim=1)
340
+ # d should be B x C x H x W, where C=1
341
+ B, C, H, W = list(d.shape)
342
+ assert(C==1)
343
+
344
+ if norm:
345
+ d = utils.basic.normalize(d)
346
+
347
+ if do_colorize:
348
+ rgb = torch.zeros(B, 3, H, W)
349
+ for b in list(range(B)):
350
+ rgb[b] = colorize(d[b])
351
+ else:
352
+ rgb = d.repeat(1, 3, 1, 1)*255.0
353
+ # rgb = (255.0*rgb).type(torch.ByteTensor)
354
+ rgb = rgb.type(torch.ByteTensor)
355
+
356
+ # rgb = tf.cast(255.0*rgb, tf.uint8)
357
+ # rgb = tf.reshape(rgb, [-1, hyp.H, hyp.W, 3])
358
+ # rgb = tf.expand_dims(rgb, axis=0)
359
+ return rgb
360
+
361
+ def oned2gray(d, norm=True):
362
+ # convert a 1chan input to a 3chan image output
363
+
364
+ # if it's just B x H x W, add a C dim
365
+ if d.ndim==3:
366
+ d = d.unsqueeze(dim=1)
367
+ # d should be B x C x H x W, where C=1
368
+ B, C, H, W = list(d.shape)
369
+ assert(C==1)
370
+
371
+ if norm:
372
+ d = utils.basic.normalize(d)
373
+
374
+ rgb = d.repeat(1,3,1,1)
375
+ rgb = (255.0*rgb).type(torch.ByteTensor)
376
+ return rgb
377
+
378
+
379
+ def draw_frame_id_on_vis(vis, frame_id, scale=0.5, left=5, top=20, shadow=True):
380
+
381
+ rgb = vis.detach().cpu().numpy()[0]
382
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
383
+ rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
384
+ color = (255, 255, 255)
385
+ # print('putting frame id', frame_id)
386
+
387
+ frame_str = utils.basic.strnum(frame_id)
388
+
389
+ text_color_bg = (0,0,0)
390
+ font = cv2.FONT_HERSHEY_SIMPLEX
391
+ text_size, _ = cv2.getTextSize(frame_str, font, scale, 1)
392
+ text_w, text_h = text_size
393
+ if shadow:
394
+ cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1)
395
+
396
+ cv2.putText(
397
+ rgb,
398
+ frame_str,
399
+ (left, top), # from left, from top
400
+ font,
401
+ scale, # font scale (float)
402
+ color,
403
+ 1) # font thickness (int)
404
+ rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
405
+ vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
406
+ return vis
407
+
408
+ def draw_frame_str_on_vis(vis, frame_str, scale=0.5, left=5, top=40, shadow=True):
409
+
410
+ rgb = vis.detach().cpu().numpy()[0]
411
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
412
+ rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
413
+ color = (255, 255, 255)
414
+
415
+ text_color_bg = (0,0,0)
416
+ font = cv2.FONT_HERSHEY_SIMPLEX
417
+ text_size, _ = cv2.getTextSize(frame_str, font, scale, 1)
418
+ text_w, text_h = text_size
419
+ if shadow:
420
+ cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1)
421
+
422
+ cv2.putText(
423
+ rgb,
424
+ frame_str,
425
+ (left, top), # from left, from top
426
+ font,
427
+ scale, # font scale (float)
428
+ color,
429
+ 1) # font thickness (int)
430
+ rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
431
+ vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
432
+ return vis
433
+
434
+ COLORMAP_FILE = "./utils/bremm.png"
435
+ class ColorMap2d:
436
+ def __init__(self, filename=None):
437
+ self._colormap_file = filename or COLORMAP_FILE
438
+ self._img = plt.imread(self._colormap_file)
439
+
440
+ self._height = self._img.shape[0]
441
+ self._width = self._img.shape[1]
442
+
443
+ def __call__(self, X):
444
+ assert len(X.shape) == 2
445
+ output = np.zeros((X.shape[0], 3))
446
+ for i in range(X.shape[0]):
447
+ x, y = X[i, :]
448
+ xp = int((self._width-1) * x)
449
+ yp = int((self._height-1) * y)
450
+ xp = np.clip(xp, 0, self._width-1)
451
+ yp = np.clip(yp, 0, self._height-1)
452
+ output[i, :] = self._img[yp, xp]
453
+ return output
454
+
455
+ def get_n_colors(N, sequential=False):
456
+ label_colors = []
457
+ for ii in range(N):
458
+ if sequential:
459
+ rgb = cm.winter(ii/(N-1))
460
+ rgb = (np.array(rgb) * 255).astype(np.uint8)[:3]
461
+ else:
462
+ # rgb = np.zeros(3)
463
+ # while np.sum(rgb) < 128: # ensure min brightness
464
+ # rgb = np.random.randint(0,256,3)
465
+ rgb = cm.gist_rainbow(ii/(N-1))
466
+ rgb = (np.array(rgb) * 255).astype(np.uint8)[:3]
467
+
468
+ label_colors.append(rgb)
469
+ return label_colors
470
+
471
+ class Summ_writer(object):
472
+ def __init__(self, writer, global_step, log_freq=10, fps=8, scalar_freq=100, just_gif=False):
473
+ self.writer = writer
474
+ self.global_step = global_step
475
+ self.log_freq = log_freq
476
+ self.scalar_freq = scalar_freq
477
+ self.fps = fps
478
+ self.just_gif = just_gif
479
+ self.maxwidth = 10000
480
+ self.save_this = (self.global_step % self.log_freq == 0)
481
+ self.scalar_freq = max(scalar_freq,1)
482
+ self.save_scalar = (self.global_step % self.scalar_freq == 0)
483
+ if self.save_this:
484
+ self.save_scalar = True
485
+
486
+ def summ_gif(self, name, tensor, blacken_zeros=False):
487
+ # tensor should be in B x S x C x H x W
488
+
489
+ assert tensor.dtype in {torch.uint8,torch.float32}
490
+ shape = list(tensor.shape)
491
+
492
+ if tensor.dtype == torch.float32:
493
+ tensor = back2color(tensor, blacken_zeros=blacken_zeros)
494
+
495
+ video_to_write = tensor[0:1]
496
+
497
+ S = video_to_write.shape[1]
498
+ if S==1:
499
+ # video_to_write is 1 x 1 x C x H x W
500
+ self.writer.add_image(name, video_to_write[0,0], global_step=self.global_step)
501
+ else:
502
+ self.writer.add_video(name, video_to_write, fps=self.fps, global_step=self.global_step)
503
+
504
+ return video_to_write
505
+
506
+ def draw_boxlist2d_on_image(self, rgb, boxlist, scores=None, tids=None, linewidth=1):
507
+ B, C, H, W = list(rgb.shape)
508
+ assert(C==3)
509
+ B2, N, D = list(boxlist.shape)
510
+ assert(B2==B)
511
+ assert(D==4) # ymin, xmin, ymax, xmax
512
+
513
+ rgb = back2color(rgb)
514
+ if scores is None:
515
+ scores = torch.ones(B2, N).float()
516
+ if tids is None:
517
+ # tids = torch.arange(N).reshape(1,N).repeat(B2,1).long()
518
+ tids = torch.zeros(B2, N).long()
519
+ out = self.draw_boxlist2d_on_image_py(
520
+ rgb[0].cpu().detach().numpy(),
521
+ boxlist[0].cpu().detach().numpy(),
522
+ scores[0].cpu().detach().numpy(),
523
+ tids[0].cpu().detach().numpy(),
524
+ linewidth=linewidth)
525
+ out = torch.from_numpy(out).type(torch.ByteTensor).permute(2, 0, 1)
526
+ out = torch.unsqueeze(out, dim=0)
527
+ out = preprocess_color(out)
528
+ out = torch.reshape(out, [1, C, H, W])
529
+ return out
530
+
531
+ def draw_boxlist2d_on_image_py(self, rgb, boxlist, scorelist, tidlist, linewidth=1):
532
+ # all inputs are numpy tensors
533
+ # rgb is H x W x 3
534
+ # boxlist is N x 4
535
+ # scorelist is N
536
+ # tidlist is N
537
+
538
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
539
+ # rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
540
+ rgb = rgb.astype(np.uint8).copy()
541
+
542
+ H, W, C = rgb.shape
543
+ assert(C==3)
544
+ N, D = boxlist.shape
545
+ assert(D==4)
546
+ M = scorelist.shape[0]
547
+ assert(M==N)
548
+ O = tidlist.shape[0]
549
+ assert(M==O)
550
+
551
+ # color_map = cm.get_cmap('Accent')
552
+ # color_map = cm.get_cmap('Set3')
553
+ color_map = cm.get_cmap('tab20')
554
+ color_map = color_map.colors
555
+ # print('color_map', color_map)
556
+
557
+ # draw
558
+ for (box, score, tid) in zip(boxlist, scorelist, tidlist):
559
+ # box is 4
560
+ if not np.isclose(score, 0.0, atol=1e-3):
561
+ # ymin, xmin, ymax, xmax = box
562
+ xmin, ymin, xmax, ymax = box
563
+
564
+ color = color_map[tid]
565
+ color = np.array(color)*255.0
566
+ color = color.round()
567
+
568
+ if not np.isclose(score, 1.0, atol=1e-3):
569
+ cv2.putText(rgb,
570
+ # '%d (%.2f)' % (tidlist[ind], scorelist[ind]),
571
+ '%.2f' % (score),
572
+ (int(xmin), int(ymin)),
573
+ cv2.FONT_HERSHEY_SIMPLEX,
574
+ 0.5, # font size
575
+ color),
576
+
577
+ xmin = int(np.clip(xmin,0,W-1))
578
+ ymin = int(np.clip(ymin,0,H-1))
579
+ xmax = int(np.clip(xmax,0,W-1))
580
+ ymax = int(np.clip(ymax,0,H-1))
581
+
582
+ # print('xmin, xmax, ymin, ymax', xmin, xmax, ymin, ymax)
583
+
584
+ cv2.line(rgb, (xmin, ymin), (xmin, ymax), color, linewidth, cv2.LINE_4)
585
+ cv2.line(rgb, (xmin, ymin), (xmax, ymin), color, linewidth, cv2.LINE_4)
586
+ cv2.line(rgb, (xmax, ymin), (xmax, ymax), color, linewidth, cv2.LINE_4)
587
+ cv2.line(rgb, (xmax, ymax), (xmin, ymax), color, linewidth, cv2.LINE_4)
588
+
589
+ # rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
590
+ return rgb
591
+
592
+ def summ_boxlist2d(self, name, rgb, boxlist, scores=None, tids=None, frame_id=None, frame_str=None, only_return=False, linewidth=1):
593
+ B, C, H, W = list(rgb.shape)
594
+ boxlist_vis = self.draw_boxlist2d_on_image(rgb, boxlist, scores=scores, tids=tids, linewidth=linewidth)
595
+ return self.summ_rgb(name, boxlist_vis, frame_id=frame_id, frame_str=frame_str, only_return=only_return)
596
+
597
+ def summ_rgbs(self, name, ims, frame_ids=None, frame_strs=None, blacken_zeros=False, only_return=False):
598
+ if self.save_this:
599
+
600
+ ims = gif_and_tile(ims, just_gif=self.just_gif)
601
+ vis = ims
602
+
603
+ assert vis.dtype in {torch.uint8,torch.float32}
604
+
605
+ if vis.dtype == torch.float32:
606
+ vis = back2color(vis, blacken_zeros)
607
+
608
+ B, S, C, H, W = list(vis.shape)
609
+
610
+ if frame_ids is not None:
611
+ assert(len(frame_ids)==S)
612
+ for s in range(S):
613
+ vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
614
+
615
+ if frame_strs is not None:
616
+ assert(len(frame_strs)==S)
617
+ for s in range(S):
618
+ vis[:,s] = draw_frame_str_on_vis(vis[:,s], frame_strs[s])
619
+
620
+ if int(W) > self.maxwidth:
621
+ vis = vis[:,:,:,:self.maxwidth]
622
+
623
+ if only_return:
624
+ return vis
625
+ else:
626
+ return self.summ_gif(name, vis, blacken_zeros)
627
+
628
+ def summ_rgb(self, name, ims, blacken_zeros=False, frame_id=None, frame_str=None, only_return=False, halfres=False, shadow=True):
629
+ if self.save_this:
630
+ assert ims.dtype in {torch.uint8,torch.float32}
631
+
632
+ if ims.dtype == torch.float32:
633
+ ims = back2color(ims, blacken_zeros)
634
+
635
+ #ims is B x C x H x W
636
+ vis = ims[0:1] # just the first one
637
+ B, C, H, W = list(vis.shape)
638
+
639
+ if halfres:
640
+ vis = F.interpolate(vis, scale_factor=0.5)
641
+
642
+ if frame_id is not None:
643
+ vis = draw_frame_id_on_vis(vis, frame_id, shadow=shadow)
644
+
645
+ if frame_str is not None:
646
+ vis = draw_frame_str_on_vis(vis, frame_str, shadow=shadow)
647
+
648
+ if int(W) > self.maxwidth:
649
+ vis = vis[:,:,:,:self.maxwidth]
650
+
651
+ if only_return:
652
+ return vis
653
+ else:
654
+ return self.summ_gif(name, vis.unsqueeze(1), blacken_zeros)
655
+
656
+ def flow2color(self, flow, clip=0.0):
657
+ B, C, H, W = list(flow.size())
658
+ assert(C==2)
659
+ flow = flow[0:1].detach()
660
+
661
+ if False:
662
+ flow = flow[0].detach().cpu().permute(1,2,0).numpy() # H,W,2
663
+ if clip > 0:
664
+ clip_flow = clip
665
+ else:
666
+ clip_flow = None
667
+ im = utils.py.flow_to_image(flow, clip_flow=clip_flow, convert_to_bgr=True)
668
+ # im = utils.py.flow_to_image(flow, convert_to_bgr=True)
669
+ im = torch.from_numpy(im).permute(2,0,1).unsqueeze(0).byte() # 1,3,H,W
670
+ im = torch.flip(im, dims=[1]).clone() # BGR
671
+
672
+ # # i prefer black bkg
673
+ # white_pixels = (im == 255).all(dim=1, keepdim=True)
674
+ # im[white_pixels.expand(-1, 3, -1, -1)] = 0
675
+
676
+ return im
677
+
678
+ # flow_abs = torch.abs(flow)
679
+ # flow_mean = flow_abs.mean(dim=[1,2,3])
680
+ # flow_std = flow_abs.std(dim=[1,2,3])
681
+ if clip==0:
682
+ clip = torch.max(torch.abs(flow)).item()
683
+
684
+ # if clip:
685
+ flow = torch.clamp(flow, -clip, clip)/clip
686
+ # else:
687
+ # # # Apply some kind of normalization. Divide by the perceived maximum (mean + std*2)
688
+ # # flow_max = flow_mean + flow_std*2 + 1e-10
689
+ # # for b in range(B):
690
+ # # flow[b] = flow[b].clamp(-flow_max[b].item(), flow_max[b].item()) / flow_max[b].clamp(min=1)
691
+
692
+ # flow_max = torch.max(flow_abs[b])
693
+ # for b in range(B):
694
+ # flow[b] = flow[b].clamp(-flow_max.item(), flow_max.item()) / flow_max[b].clamp(min=1)
695
+
696
+
697
+ radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) #B x 1 x H x W
698
+ radius_clipped = torch.clamp(radius, 0.0, 1.0)
699
+
700
+ angle = torch.atan2(-flow[:, 1:2], -flow[:, 0:1]) / np.pi # B x 1 x H x W
701
+
702
+ hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0)
703
+ # hue = torch.mod(angle / (2 * np.pi) + 1.0, 1.0)
704
+
705
+ saturation = torch.ones_like(hue) * 0.75
706
+ value = radius_clipped
707
+ hsv = torch.cat([hue, saturation, value], dim=1) #B x 3 x H x W
708
+
709
+ #flow = tf.image.hsv_to_rgb(hsv)
710
+ flow = hsv_to_rgb(hsv)
711
+ flow = (flow*255.0).type(torch.ByteTensor)
712
+ # flow = torch.flip(flow, dims=[1]).clone() # BGR
713
+ return flow
714
+
715
+ def summ_flow(self, name, im, clip=0.0, only_return=False, frame_id=None, frame_str=None, shadow=True):
716
+ # flow is B x C x D x W
717
+ if self.save_this:
718
+ return self.summ_rgb(name, self.flow2color(im, clip=clip), only_return=only_return, frame_id=frame_id, frame_str=frame_str, shadow=shadow)
719
+ else:
720
+ return None
721
+
722
+ def summ_oneds(self, name, ims, frame_ids=None, frame_strs=None, bev=False, fro=False, logvis=False, reduce_max=False, max_val=0.0, norm=True, only_return=False, do_colorize=False):
723
+ if self.save_this:
724
+ if bev:
725
+ B, C, H, _, W = list(ims[0].shape)
726
+ if reduce_max:
727
+ ims = [torch.max(im, dim=3)[0] for im in ims]
728
+ else:
729
+ ims = [torch.mean(im, dim=3) for im in ims]
730
+ elif fro:
731
+ B, C, _, H, W = list(ims[0].shape)
732
+ if reduce_max:
733
+ ims = [torch.max(im, dim=2)[0] for im in ims]
734
+ else:
735
+ ims = [torch.mean(im, dim=2) for im in ims]
736
+
737
+
738
+ if len(ims) != 1: # sequence
739
+ im = gif_and_tile(ims, just_gif=self.just_gif)
740
+ else:
741
+ im = torch.stack(ims, dim=1) # single frame
742
+
743
+ B, S, C, H, W = list(im.shape)
744
+
745
+ if logvis and max_val:
746
+ max_val = np.log(max_val)
747
+ im = torch.log(torch.clamp(im, 0)+1.0)
748
+ im = torch.clamp(im, 0, max_val)
749
+ im = im/max_val
750
+ norm = False
751
+ elif max_val:
752
+ im = torch.clamp(im, 0, max_val)
753
+ im = im/max_val
754
+ norm = False
755
+
756
+ if norm:
757
+ # normalize before oned2inferno,
758
+ # so that the ranges are similar within B across S
759
+ im = utils.basic.normalize(im)
760
+
761
+ im = im.view(B*S, C, H, W)
762
+ vis = oned2inferno(im, norm=norm, do_colorize=do_colorize)
763
+ vis = vis.view(B, S, 3, H, W)
764
+
765
+ if frame_ids is not None:
766
+ assert(len(frame_ids)==S)
767
+ for s in range(S):
768
+ vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
769
+
770
+ if frame_strs is not None:
771
+ assert(len(frame_strs)==S)
772
+ for s in range(S):
773
+ vis[:,s] = draw_frame_str_on_vis(vis[:,s], frame_strs[s])
774
+
775
+ if W > self.maxwidth:
776
+ vis = vis[...,:self.maxwidth]
777
+
778
+ if only_return:
779
+ return vis
780
+ else:
781
+ self.summ_gif(name, vis)
782
+
783
+ def summ_oned(self, name, im, bev=False, fro=False, logvis=False, max_val=0, max_along_y=False, norm=True, frame_id=None, frame_str=None, only_return=False, shadow=True):
784
+ if self.save_this:
785
+
786
+ if bev:
787
+ B, C, H, _, W = list(im.shape)
788
+ if max_along_y:
789
+ im = torch.max(im, dim=3)[0]
790
+ else:
791
+ im = torch.mean(im, dim=3)
792
+ elif fro:
793
+ B, C, _, H, W = list(im.shape)
794
+ if max_along_y:
795
+ im = torch.max(im, dim=2)[0]
796
+ else:
797
+ im = torch.mean(im, dim=2)
798
+ else:
799
+ B, C, H, W = list(im.shape)
800
+
801
+ im = im[0:1] # just the first one
802
+ assert(C==1)
803
+
804
+ if logvis and max_val:
805
+ max_val = np.log(max_val)
806
+ im = torch.log(im)
807
+ im = torch.clamp(im, 0, max_val)
808
+ im = im/max_val
809
+ norm = False
810
+ elif max_val:
811
+ im = torch.clamp(im, 0, max_val)/max_val
812
+ norm = False
813
+
814
+ vis = oned2inferno(im, norm=norm)
815
+ if W > self.maxwidth:
816
+ vis = vis[...,:self.maxwidth]
817
+ return self.summ_rgb(name, vis, blacken_zeros=False, frame_id=frame_id, frame_str=frame_str, only_return=only_return, shadow=shadow)
818
+
819
+ def summ_4chan(self, name, im, norm=True, frame_id=None, frame_str=None, only_return=False):
820
+ if self.save_this:
821
+
822
+ B, C, H, W = list(im.shape)
823
+
824
+ im = im[0:1] # just the first one
825
+ assert(C==4)
826
+
827
+ # d = utils.basic.normalize(d)
828
+ im0 = im[:,0:1]
829
+ im1 = im[:,1:2]
830
+ im2 = im[:,2:3]
831
+ im3 = im[:,3:4]
832
+
833
+ im0 = utils.basic.normalize(im0).round()
834
+ im1 = utils.basic.normalize(im1).round()
835
+ im2 = utils.basic.normalize(im2).round()
836
+ im3 = utils.basic.normalize(im3).round()
837
+
838
+ # kp_vis = sw.summ_rgbs('tff/2_kp_s%d' % s, kp.unbind(1), only_return=True)
839
+ # kp_any = (torch.max(kp_vis, dim=2, keepdims=True)[0]).repeat(1, 1, 3, 1, 1)
840
+ # kp_vis[kp_any==0] = fcp_vis[kp_any==0]
841
+
842
+ # vis0 = oned2inferno(im0, norm=False)
843
+ # vis1 = oned2inferno(im1, norm=False)
844
+ # vis2 = oned2inferno(im2, norm=False)
845
+ # vis3 = oned2inferno(im3, norm=False)
846
+
847
+ # vis0 = self.summ_seg('', im0[:,0:1]*1, only_return=True, frame_id=frame_id, frame_str=frame_str, colormap='tab20')
848
+ # vis1 = self.summ_seg('', im1[:,0:1]*2, only_return=True, frame_id=frame_id, frame_str=frame_str, colormap='tab20')
849
+ # vis2 = self.summ_seg('', im2[:,0:1]*3, only_return=True, frame_id=frame_id, frame_str=frame_str, colormap='tab20')
850
+ # vis3 = self.summ_seg('', im3[:,0:1]*4, only_return=True, frame_id=frame_id, frame_str=frame_str, colormap='tab20')
851
+
852
+ vis0 = self.summ_seg('', im0[:,0]*1, only_return=True, colormap='tab20')
853
+ vis1 = self.summ_seg('', im1[:,0]*2, only_return=True, colormap='tab20')
854
+ vis2 = self.summ_seg('', im2[:,0]*3, only_return=True, colormap='tab20')
855
+ vis3 = self.summ_seg('', im3[:,0]*4, only_return=True, colormap='tab20')
856
+
857
+ # vis_any = (torch.max(vis2, dim=2, keepdims=True)[0]).repeat(1, 1, 3, 1, 1)
858
+ # vis3[vis_any==0] = fcp_vis[kp_any==0]
859
+
860
+ vis0_any = (torch.max(vis0, dim=1, keepdims=True)[0]).repeat(1, 3, 1, 1)
861
+ vis1_any = (torch.max(vis1, dim=1, keepdims=True)[0]).repeat(1, 3, 1, 1)
862
+ vis2_any = (torch.max(vis2, dim=1, keepdims=True)[0]).repeat(1, 3, 1, 1)
863
+ vis3_any = (torch.max(vis3, dim=1, keepdims=True)[0]).repeat(1, 3, 1, 1)
864
+ vis0[vis0_any==0] = vis1[vis0_any==0]
865
+ vis0[vis1_any==0] = vis2[vis1_any==0]
866
+ vis0[vis2_any==0] = vis3[vis2_any==0]
867
+
868
+ print('vis0', vis0.shape, vis0.device)
869
+
870
+ vis0 = vis0.cpu()
871
+
872
+ # vis = oned2inferno(im, norm=norm)
873
+ # if W > self.maxwidth:
874
+ # vis = vis[...,:self.maxwidth]
875
+ return self.summ_rgb(name, vis0, blacken_zeros=False, frame_id=frame_id, frame_str=frame_str, only_return=only_return)
876
+
877
+
878
+ def summ_feats(self, name, feats, valids=None, pca=True, fro=False, only_return=False, frame_ids=None, frame_strs=None):
879
+ if self.save_this:
880
+ if valids is not None:
881
+ valids = torch.stack(valids, dim=1)
882
+
883
+ feats = torch.stack(feats, dim=1)
884
+ # feats leads with B x S x C
885
+
886
+ if feats.ndim==6:
887
+
888
+ # feats is B x S x C x D x H x W
889
+ if fro:
890
+ reduce_dim = 3
891
+ else:
892
+ reduce_dim = 4
893
+
894
+ if valids is None:
895
+ feats = torch.mean(feats, dim=reduce_dim)
896
+ else:
897
+ valids = valids.repeat(1, 1, feats.size()[2], 1, 1, 1)
898
+ feats = utils.basic.reduce_masked_mean(feats, valids, dim=reduce_dim)
899
+
900
+ B, S, C, D, W = list(feats.size())
901
+
902
+ if not pca:
903
+ # feats leads with B x S x C
904
+ feats = torch.mean(torch.abs(feats), dim=2, keepdims=True)
905
+ # feats leads with B x S x 1
906
+ feats = torch.unbind(feats, dim=1)
907
+ return self.summ_oneds(name=name, ims=feats, norm=True, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
908
+
909
+ else:
910
+ __p = lambda x: utils.basic.pack_seqdim(x, B)
911
+ __u = lambda x: utils.basic.unpack_seqdim(x, B)
912
+
913
+ feats_ = __p(feats)
914
+
915
+ if valids is None:
916
+ feats_pca_ = get_feat_pca(feats_)
917
+ else:
918
+ valids_ = __p(valids)
919
+ feats_pca_ = get_feat_pca(feats_, valids)
920
+
921
+ feats_pca = __u(feats_pca_)
922
+
923
+ return self.summ_rgbs(name=name, ims=torch.unbind(feats_pca, dim=1), only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
924
+
925
+ def summ_feat(self, name, feat, valid=None, pca=True, only_return=False, bev=False, fro=False, frame_id=None, frame_str=None):
926
+ if self.save_this:
927
+ if feat.ndim==5: # B x C x D x H x W
928
+
929
+ if bev:
930
+ reduce_axis = 3
931
+ elif fro:
932
+ reduce_axis = 2
933
+ else:
934
+ # default to bev
935
+ reduce_axis = 3
936
+
937
+ if valid is None:
938
+ feat = torch.mean(feat, dim=reduce_axis)
939
+ else:
940
+ valid = valid.repeat(1, feat.size()[1], 1, 1, 1)
941
+ feat = utils.basic.reduce_masked_mean(feat, valid, dim=reduce_axis)
942
+
943
+ B, C, D, W = list(feat.shape)
944
+
945
+ if not pca:
946
+ feat = torch.mean(torch.abs(feat), dim=1, keepdims=True)
947
+ # feat is B x 1 x D x W
948
+ return self.summ_oned(name=name, im=feat, norm=True, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
949
+ else:
950
+ feat_pca = get_feat_pca(feat, valid)
951
+ return self.summ_rgb(name, feat_pca, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
952
+
953
+ def summ_scalar(self, name, value):
954
+ if (not (isinstance(value, int) or isinstance(value, float) or isinstance(value, np.float32))) and ('torch' in value.type()):
955
+ value = value.detach().cpu().numpy()
956
+ if not np.isnan(value):
957
+ if (self.log_freq == 1):
958
+ self.writer.add_scalar(name, value, global_step=self.global_step)
959
+ elif self.save_this or self.save_scalar:
960
+ self.writer.add_scalar(name, value, global_step=self.global_step)
961
+
962
+ def summ_seg(self, name, seg, only_return=False, frame_id=None, frame_str=None, colormap='tab20', label_colors=None):
963
+ if not self.save_this:
964
+ return
965
+
966
+ B,H,W = seg.shape
967
+
968
+ if label_colors is None:
969
+ custom_label_colors = False
970
+ # label_colors = get_n_colors(int(torch.max(seg).item()), sequential=True)
971
+ label_colors = cm.get_cmap(colormap).colors
972
+ label_colors = [[int(i*255) for i in l] for l in label_colors]
973
+ else:
974
+ custom_label_colors = True
975
+ # label_colors = matplotlib.cm.get_cmap(colormap).colors
976
+ # label_colors = [[int(i*255) for i in l] for l in label_colors]
977
+ # print('label_colors', label_colors)
978
+
979
+ # label_colors = [
980
+ # (0, 0, 0), # None
981
+ # (70, 70, 70), # Buildings
982
+ # (190, 153, 153), # Fences
983
+ # (72, 0, 90), # Other
984
+ # (220, 20, 60), # Pedestrians
985
+ # (153, 153, 153), # Poles
986
+ # (157, 234, 50), # RoadLines
987
+ # (128, 64, 128), # Roads
988
+ # (244, 35, 232), # Sidewalks
989
+ # (107, 142, 35), # Vegetation
990
+ # (0, 0, 255), # Vehicles
991
+ # (102, 102, 156), # Walls
992
+ # (220, 220, 0) # TrafficSigns
993
+ # ]
994
+
995
+ r = torch.zeros_like(seg,dtype=torch.uint8)
996
+ g = torch.zeros_like(seg,dtype=torch.uint8)
997
+ b = torch.zeros_like(seg,dtype=torch.uint8)
998
+
999
+ for label in range(0,len(label_colors)):
1000
+ if (not custom_label_colors):# and (N > 20):
1001
+ label_ = label % 20
1002
+ else:
1003
+ label_ = label
1004
+
1005
+ idx = (seg == label)
1006
+ r[idx] = label_colors[label_][0]
1007
+ g[idx] = label_colors[label_][1]
1008
+ b[idx] = label_colors[label_][2]
1009
+
1010
+ rgb = torch.stack([r,g,b],axis=1)
1011
+ return self.summ_rgb(name,rgb,only_return=only_return, frame_id=frame_id, frame_str=frame_str)
1012
+
1013
+ def summ_traj2ds_on_rgbs(self, name, trajs, rgbs, visibs=None, valids=None, frame_ids=None, frame_strs=None, only_return=False, show_dots=True, cmap='coolwarm', vals=None, linewidth=1, max_show=1024):
1014
+ # trajs is B, S, N, 2
1015
+ # rgbs is B, S, C, H, W
1016
+ B, S, C, H, W = rgbs.shape
1017
+ B, S2, N, D = trajs.shape
1018
+ assert(S==S2)
1019
+
1020
+
1021
+ rgbs = rgbs[0] # S, C, H, W
1022
+ trajs = trajs[0] # S, N, 2
1023
+ if valids is None:
1024
+ valids = torch.ones_like(trajs[:,:,0]) # S, N
1025
+ else:
1026
+ valids = valids[0]
1027
+
1028
+ if visibs is None:
1029
+ visibs = torch.ones_like(trajs[:,:,0]) # S, N
1030
+ else:
1031
+ visibs = visibs[0]
1032
+
1033
+ if vals is not None:
1034
+ vals = vals[0] # N
1035
+ # print('vals', vals.shape)
1036
+
1037
+ if N > max_show:
1038
+ inds = np.random.choice(N, max_show)
1039
+ trajs = trajs[:,inds]
1040
+ valids = valids[:,inds]
1041
+ visibs = visibs[:,inds]
1042
+ if vals is not None:
1043
+ vals = vals[inds]
1044
+ N = trajs.shape[1]
1045
+
1046
+ trajs = trajs.clamp(-16, W+16)
1047
+
1048
+ rgbs_color = []
1049
+ for rgb in rgbs:
1050
+ rgb = back2color(rgb).detach().cpu().numpy()
1051
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
1052
+ rgbs_color.append(rgb) # each element 3 x H x W
1053
+
1054
+ for i in range(min(N, max_show)):
1055
+ if cmap=='onediff' and i==0:
1056
+ cmap_ = 'spring'
1057
+ elif cmap=='onediff':
1058
+ cmap_ = 'winter'
1059
+ else:
1060
+ cmap_ = cmap
1061
+ traj = trajs[:,i].long().detach().cpu().numpy() # S, 2
1062
+ valid = valids[:,i].long().detach().cpu().numpy() # S
1063
+
1064
+ # print('traj', traj.shape)
1065
+ # print('valid', valid.shape)
1066
+
1067
+ if vals is not None:
1068
+ # val = vals[:,i].float().detach().cpu().numpy() # []
1069
+ val = vals[i].float().detach().cpu().numpy() # []
1070
+ # print('val', val.shape)
1071
+ else:
1072
+ val = None
1073
+
1074
+ for t in range(S):
1075
+ if valid[t]:
1076
+ rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj[:t+1], S=S, show_dots=show_dots, cmap=cmap_, val=val, linewidth=linewidth)
1077
+
1078
+ for i in range(min(N, max_show)):
1079
+ if cmap=='onediff' and i==0:
1080
+ cmap_ = 'spring'
1081
+ elif cmap=='onediff':
1082
+ cmap_ = 'winter'
1083
+ else:
1084
+ cmap_ = cmap
1085
+ traj = trajs[:,i] # S,2
1086
+ vis = visibs[:,i].round() # S
1087
+ valid = valids[:,i] # S
1088
+ rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
1089
+
1090
+ rgbs = []
1091
+ for rgb in rgbs_color:
1092
+ rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
1093
+ rgbs.append(preprocess_color(rgb))
1094
+
1095
+ return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
1096
+
1097
+ def summ_traj2ds_on_rgbs2(self, name, trajs, visibles, rgbs, valids=None, frame_ids=None, frame_strs=None, only_return=False, show_dots=True, cmap=None, linewidth=1, max_show=1024):
1098
+ # trajs is B, S, N, 2
1099
+ # rgbs is B, S, C, H, W
1100
+ B, S, C, H, W = rgbs.shape
1101
+ B, S2, N, D = trajs.shape
1102
+ assert(S==S2)
1103
+
1104
+ rgbs = rgbs[0] # S, C, H, W
1105
+ trajs = trajs[0] # S, N, 2
1106
+ visibles = visibles[0] # S, N
1107
+ if valids is None:
1108
+ valids = torch.ones_like(trajs[:,:,0]) # S, N
1109
+ else:
1110
+ valids = valids[0]
1111
+
1112
+ rgbs_color = []
1113
+ for rgb in rgbs:
1114
+ rgb = back2color(rgb).detach().cpu().numpy()
1115
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
1116
+ rgbs_color.append(rgb) # each element 3 x H x W
1117
+
1118
+ trajs = trajs.long().detach().cpu().numpy() # S, N, 2
1119
+ visibles = visibles.float().detach().cpu().numpy() # S, N
1120
+ valids = valids.long().detach().cpu().numpy() # S, N
1121
+
1122
+ for i in range(min(N, max_show)):
1123
+ if cmap=='onediff' and i==0:
1124
+ cmap_ = 'spring'
1125
+ elif cmap=='onediff':
1126
+ cmap_ = 'winter'
1127
+ else:
1128
+ cmap_ = cmap
1129
+ traj = trajs[:,i] # S,2
1130
+ vis = visibles[:,i] # S
1131
+ valid = valids[:,i] # S
1132
+ rgbs_color = self.draw_traj_on_images_py(rgbs_color, traj, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
1133
+
1134
+ for i in range(min(N, max_show)):
1135
+ if cmap=='onediff' and i==0:
1136
+ cmap_ = 'spring'
1137
+ elif cmap=='onediff':
1138
+ cmap_ = 'winter'
1139
+ else:
1140
+ cmap_ = cmap
1141
+ traj = trajs[:,i] # S,2
1142
+ vis = visibles[:,i] # S
1143
+ valid = valids[:,i] # S
1144
+ rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=None, linewidth=linewidth)
1145
+
1146
+ rgbs = []
1147
+ for rgb in rgbs_color:
1148
+ rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
1149
+ rgbs.append(preprocess_color(rgb))
1150
+
1151
+ return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
1152
+
1153
+ def summ_traj2ds_on_rgb(self, name, trajs, rgb, valids=None, show_dots=True, show_lines=True, frame_id=None, frame_str=None, only_return=False, cmap='coolwarm', linewidth=1, max_show=1024):
1154
+ # trajs is B, S, N, 2
1155
+ # rgb is B, C, H, W
1156
+ B, C, H, W = rgb.shape
1157
+ B, S, N, D = trajs.shape
1158
+
1159
+ rgb = rgb[0] # S, C, H, W
1160
+ trajs = trajs[0] # S, N, 2
1161
+
1162
+ if valids is None:
1163
+ valids = torch.ones_like(trajs[:,:,0])
1164
+ else:
1165
+ valids = valids[0]
1166
+
1167
+ rgb_color = back2color(rgb).detach().cpu().numpy()
1168
+ rgb_color = np.transpose(rgb_color, [1, 2, 0]) # put channels last
1169
+
1170
+ # using maxdist will dampen the colors for short motions
1171
+ # norms = torch.sqrt(1e-4 + torch.sum((trajs[-1] - trajs[0])**2, dim=1)) # N
1172
+ # maxdist = torch.quantile(norms, 0.95).detach().cpu().numpy()
1173
+ maxdist = None
1174
+ trajs = trajs.long().detach().cpu().numpy() # S, N, 2
1175
+ valids = valids.long().detach().cpu().numpy() # S, N
1176
+
1177
+ if N > max_show:
1178
+ inds = np.random.choice(N, max_show)
1179
+ trajs = trajs[:,inds]
1180
+ valids = valids[:,inds]
1181
+ N = trajs.shape[1]
1182
+
1183
+ for i in range(min(N, max_show)):
1184
+ if cmap=='onediff' and i==0:
1185
+ cmap_ = 'spring'
1186
+ elif cmap=='onediff':
1187
+ cmap_ = 'winter'
1188
+ else:
1189
+ cmap_ = cmap
1190
+ traj = trajs[:,i] # S, 2
1191
+ valid = valids[:,i] # S
1192
+ if valid[0]==1:
1193
+ traj = traj[valid>0]
1194
+ rgb_color = self.draw_traj_on_image_py(
1195
+ rgb_color, traj, S=S, show_dots=show_dots, show_lines=show_lines, cmap=cmap_, maxdist=maxdist, linewidth=linewidth)
1196
+
1197
+ rgb_color = torch.from_numpy(rgb_color).permute(2, 0, 1).unsqueeze(0)
1198
+ rgb = preprocess_color(rgb_color)
1199
+ return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
1200
+
1201
+ def draw_traj_on_image_py(self, rgb, traj, S=50, linewidth=1, show_dots=False, show_lines=True, cmap='coolwarm', val=None, maxdist=None):
1202
+ # all inputs are numpy tensors
1203
+ # rgb is 3 x H x W
1204
+ # traj is S x 2
1205
+
1206
+ H, W, C = rgb.shape
1207
+ assert(C==3)
1208
+
1209
+ rgb = rgb.astype(np.uint8).copy()
1210
+
1211
+ S1, D = traj.shape
1212
+ assert(D==2)
1213
+
1214
+ color_map = cm.get_cmap(cmap)
1215
+ S1, D = traj.shape
1216
+
1217
+ for s in range(S1):
1218
+ if val is not None:
1219
+ color = np.array(color_map(val)[:3]) * 255 # rgb
1220
+ else:
1221
+ if maxdist is not None:
1222
+ val = (np.sqrt(np.sum((traj[s]-traj[0])**2))/maxdist).clip(0,1)
1223
+ color = np.array(color_map(val)[:3]) * 255 # rgb
1224
+ else:
1225
+ color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb
1226
+
1227
+ if show_lines and s<(S1-1):
1228
+ cv2.line(rgb,
1229
+ (int(traj[s,0]), int(traj[s,1])),
1230
+ (int(traj[s+1,0]), int(traj[s+1,1])),
1231
+ color,
1232
+ linewidth,
1233
+ cv2.LINE_AA)
1234
+ if show_dots:
1235
+ cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, -1)
1236
+
1237
+ # if maxdist is not None:
1238
+ # val = (np.sqrt(np.sum((traj[-1]-traj[0])**2))/maxdist).clip(0,1)
1239
+ # color = np.array(color_map(val)[:3]) * 255 # rgb
1240
+ # else:
1241
+ # # draw the endpoint of traj, using the next color (which may be the last color)
1242
+ # color = np.array(color_map((S1-1)/max(1,float(S-2)))[:3]) * 255 # rgb
1243
+
1244
+ # # emphasize endpoint
1245
+ # cv2.circle(rgb, (traj[-1,0], traj[-1,1]), linewidth*2, color, -1)
1246
+
1247
+ return rgb
1248
+
1249
+
1250
+ def draw_traj_on_images_py(self, rgbs, traj, S=50, linewidth=1, show_dots=False, cmap='coolwarm', maxdist=None):
1251
+ # all inputs are numpy tensors
1252
+ # rgbs is a list of H,W,3
1253
+ # traj is S,2
1254
+ H, W, C = rgbs[0].shape
1255
+ assert(C==3)
1256
+
1257
+ rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
1258
+
1259
+ S1, D = traj.shape
1260
+ assert(D==2)
1261
+
1262
+ x = int(np.clip(traj[0,0], 0, W-1))
1263
+ y = int(np.clip(traj[0,1], 0, H-1))
1264
+ color = rgbs[0][y,x]
1265
+ color = (int(color[0]),int(color[1]),int(color[2]))
1266
+ for s in range(S):
1267
+ # bak_color = np.array(color_map(1.0)[:3]) * 255 # rgb
1268
+ # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth*4, bak_color, -1)
1269
+ cv2.polylines(rgbs[s],
1270
+ [traj[:s+1]],
1271
+ False,
1272
+ color,
1273
+ linewidth,
1274
+ cv2.LINE_AA)
1275
+ return rgbs
1276
+
1277
+ def draw_circs_on_image_py(self, rgb, xy, colors=None, linewidth=10, radius=3, show_dots=False, maxdist=None):
1278
+ # all inputs are numpy tensors
1279
+ # rgbs is a list of 3,H,W
1280
+ # xy is N,2
1281
+ H, W, C = rgb.shape
1282
+ assert(C==3)
1283
+
1284
+ rgb = rgb.astype(np.uint8).copy()
1285
+
1286
+ N, D = xy.shape
1287
+ assert(D==2)
1288
+
1289
+
1290
+ xy = xy.astype(np.float32)
1291
+ xy[:,0] = np.clip(xy[:,0], 0, W-1)
1292
+ xy[:,1] = np.clip(xy[:,1], 0, H-1)
1293
+ xy = xy.astype(np.int32)
1294
+
1295
+
1296
+
1297
+ if colors is None:
1298
+ colors = get_n_colors(N)
1299
+
1300
+ for n in range(N):
1301
+ color = colors[n]
1302
+ # print('color', color)
1303
+ # color = (color[0]*255).astype(np.uint8)
1304
+ color = (int(color[0]),int(color[1]),int(color[2]))
1305
+
1306
+ # x = int(np.clip(xy[0,0], 0, W-1))
1307
+ # y = int(np.clip(xy[0,1], 0, H-1))
1308
+ # color_ = rgbs[0][y,x]
1309
+ # color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
1310
+ # color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
1311
+
1312
+ cv2.circle(rgb, (int(xy[n,0]), int(xy[n,1])), linewidth, color, 3)
1313
+ # vis_color = int(np.squeeze(vis[s])*255)
1314
+ # vis_color = (vis_color,vis_color,vis_color)
1315
+ # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth+1, vis_color, -1)
1316
+ return rgb
1317
+
1318
+ def draw_circ_on_images_py(self, rgbs, traj, vis, S=50, linewidth=1, show_dots=False, cmap=None, maxdist=None):
1319
+ # all inputs are numpy tensors
1320
+ # rgbs is a list of 3,H,W
1321
+ # traj is S,2
1322
+ H, W, C = rgbs[0].shape
1323
+ assert(C==3)
1324
+
1325
+ rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
1326
+
1327
+ S1, D = traj.shape
1328
+ assert(D==2)
1329
+
1330
+ if cmap is None:
1331
+ bremm = ColorMap2d()
1332
+ traj_ = traj[0:1].astype(np.float32)
1333
+ traj_[:,0] /= float(W)
1334
+ traj_[:,1] /= float(H)
1335
+ color = bremm(traj_)
1336
+ # print('color', color)
1337
+ color = (color[0]*255).astype(np.uint8)
1338
+ color = (int(color[0]),int(color[1]),int(color[2]))
1339
+
1340
+ for s in range(S):
1341
+ if cmap is not None:
1342
+ color_map = cm.get_cmap(cmap)
1343
+ # color = np.array(color_map(s/(S-1))[:3]) * 255 # rgb
1344
+ color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb
1345
+ # color = color.astype(np.uint8)
1346
+ # color = (color[0], color[1], color[2])
1347
+ # print('color', color)
1348
+ # import ipdb; ipdb.set_trace()
1349
+
1350
+ cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+2, color, -1)
1351
+ vis_color = int(np.squeeze(vis[s])*255)
1352
+ vis_color = (vis_color,vis_color,vis_color)
1353
+ cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, vis_color, -1)
1354
+
1355
+ return rgbs
1356
+
1357
+ def summ_traj_as_crops(self, name, trajs_e, rgbs, frame_id=None, frame_str=None, only_return=False, show_circ=True, trajs_g=None, valids_g=None, is_g=False, anchor_ind=None, ara=None, anchors=None, frame_ids=None, frame_strs=None):
1358
+ B, S, N, D = trajs_e.shape
1359
+ assert(N==1)
1360
+ assert(D==2)
1361
+
1362
+ rgbs = back2color(rgbs).detach().cpu().byte().numpy()
1363
+
1364
+ rgbs_vis = []
1365
+ n = 0
1366
+ pad_amount = 128
1367
+ trajs_e_py = trajs_e[0].detach().cpu().numpy()
1368
+ # trajs_e_py = np.clip(trajs_e_py, min=pad_amount/2, max=pad_amoun
1369
+ trajs_e_py = trajs_e_py + pad_amount
1370
+
1371
+ if trajs_g is not None:
1372
+ trajs_g_py = trajs_g[0].detach().cpu().numpy()
1373
+ trajs_g_py = trajs_g_py + pad_amount
1374
+
1375
+ if valids_g is not None:
1376
+ valids_g_py = valids_g[0].detach().cpu().numpy()
1377
+ else:
1378
+ valids_g_py = np.ones_like(trajs_g_py[:,:,:,0])
1379
+
1380
+
1381
+ for s in range(S):
1382
+ rgb = rgbs[0,s]
1383
+ # print('orig rgb', rgb.shape)
1384
+ rgb = np.transpose(rgb,(1,2,0)) # H, W, 3
1385
+
1386
+ rgb = np.pad(rgb, ((pad_amount,pad_amount),(pad_amount,pad_amount),(0,0)))
1387
+ # print('pad rgb', rgb.shape)
1388
+ H, W, C = rgb.shape
1389
+
1390
+ if trajs_g is not None:
1391
+ xy_g = trajs_g_py[s,n]
1392
+ xy_g[0] = np.clip(xy_g[0], pad_amount, W-pad_amount)
1393
+ xy_g[1] = np.clip(xy_g[1], pad_amount, H-pad_amount)
1394
+ if valids_g_py[s,n] > 0:
1395
+ rgb = self.draw_circs_on_image_py(rgb, xy_g.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3)
1396
+
1397
+ xy_e = trajs_e_py[s,n]
1398
+ xy_e[0] = np.clip(xy_e[0], pad_amount, W-pad_amount)
1399
+ xy_e[1] = np.clip(xy_e[1], pad_amount, H-pad_amount)
1400
+
1401
+ if show_circ:
1402
+
1403
+ # if (anchors is not None) and (s==anchor_ind):
1404
+ # rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(0,0,0)], linewidth=8, radius=12)
1405
+
1406
+ if (anchors is not None) and (s in anchors):
1407
+ rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(255,255,255)], linewidth=4, radius=8)
1408
+
1409
+ if is_g:
1410
+ rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3)
1411
+ else:
1412
+ rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(255,0,255)], linewidth=2, radius=3)
1413
+
1414
+
1415
+
1416
+ xmin = int(xy_e[0])-pad_amount//2
1417
+ xmax = int(xy_e[0])+pad_amount//2
1418
+ ymin = int(xy_e[1])-pad_amount//2
1419
+ ymax = int(xy_e[1])+pad_amount//2
1420
+
1421
+ rgb_ = rgb[ymin:ymax, xmin:xmax]
1422
+
1423
+ H_, W_ = rgb_.shape[:2]
1424
+ # if np.any(rgb_.shape==0):
1425
+ # input()
1426
+ if H_==0 or W_==0:
1427
+ import ipdb; ipdb.set_trace()
1428
+
1429
+ if (ara is not None) and (s in ara):
1430
+ # green border
1431
+ rgb_[0,:,0] = 0
1432
+ rgb_[0,:,1] = 255
1433
+ rgb_[0,:,2] = 0
1434
+
1435
+ rgb_[-1,:,0] = 0
1436
+ rgb_[-1,:,1] = 255
1437
+ rgb_[-1,:,2] = 0
1438
+
1439
+ rgb_[:,0,0] = 0
1440
+ rgb_[:,0,1] = 255
1441
+ rgb_[:,0,2] = 0
1442
+
1443
+ rgb_[:,-1,0] = 0
1444
+ rgb_[:,-1,1] = 255
1445
+ rgb_[:,-1,2] = 0
1446
+
1447
+ if (anchor_ind is not None) and (s==anchor_ind):
1448
+ # inner green border
1449
+ pad = 4
1450
+
1451
+ rgb_[:pad,:,0] = 0
1452
+ rgb_[:pad,:,1] = 255
1453
+ rgb_[:pad,:,2] = 0
1454
+
1455
+ rgb_[-pad:,:,0] = 0
1456
+ rgb_[-pad:,:,1] = 255
1457
+ rgb_[-pad:,:,2] = 0
1458
+
1459
+ rgb_[:,:pad,0] = 0
1460
+ rgb_[:,:pad,1] = 255
1461
+ rgb_[:,:pad,2] = 0
1462
+
1463
+ rgb_[:,-pad:,0] = 0
1464
+ rgb_[:,-pad:,1] = 255
1465
+ rgb_[:,-pad:,2] = 0
1466
+
1467
+ rgb_ = rgb_.transpose(2,0,1)
1468
+ rgb_ = torch.from_numpy(rgb_)
1469
+
1470
+ if frame_ids is not None:
1471
+ # if s==anchor_ind:
1472
+ # frame_ids[s] = frame_ids[s] + '
1473
+ rgb_ = draw_frame_id_on_vis(rgb_.unsqueeze(0), frame_ids[s]).squeeze(0)
1474
+
1475
+ if s==anchor_ind:
1476
+ rgb_ = draw_frame_str_on_vis(rgb_.unsqueeze(0), '(A)').squeeze(0)
1477
+
1478
+ if frame_strs is not None:
1479
+ rgb_ = draw_frame_str_on_vis(rgb_.unsqueeze(0), frame_strs[s]).squeeze(0)
1480
+
1481
+ rgbs_vis.append(rgb_)
1482
+
1483
+ # nrow = int(np.sqrt(S)*(16.0/9)/2.0)
1484
+ nrow = int(np.sqrt(S)*1.5)
1485
+ grid_img = torchvision.utils.make_grid(torch.stack(rgbs_vis, dim=0), nrow=nrow).unsqueeze(0)
1486
+ # print('grid_img', grid_img.shape)
1487
+ return self.summ_rgb(name, grid_img.byte(), frame_id=frame_id, frame_str=frame_str, only_return=only_return)
1488
+
1489
+ def summ_pts_on_rgb(self, name, trajs, rgb, visibs=None, valids=None, frame_id=None, frame_str=None, only_return=False, show_dots=True, colors=None, cmap='coolwarm', linewidth=1, max_show=1024, already_sorted=False):
1490
+ # trajs is B, S, N, 2
1491
+ # rgbs is B, S, C, H, W
1492
+ B, C, H, W = rgb.shape
1493
+ B, S, N, D = trajs.shape
1494
+
1495
+ rgb = rgb[0] # C, H, W
1496
+ trajs = trajs[0] # S, N, 2
1497
+ if valids is None:
1498
+ valids = torch.ones_like(trajs[:,:,0]) # S, N
1499
+ else:
1500
+ valids = valids[0]
1501
+ if visibs is None:
1502
+ visibs = torch.ones_like(trajs[:,:,0]) # S, N
1503
+ else:
1504
+ visibs = visibs[0]
1505
+
1506
+ trajs = trajs.clamp(-16, W+16)
1507
+
1508
+ if N > max_show:
1509
+ inds = np.random.choice(N, max_show)
1510
+ trajs = trajs[:,inds]
1511
+ valids = valids[:,inds]
1512
+ visibs = visibs[:,inds]
1513
+ N = trajs.shape[1]
1514
+
1515
+ if not already_sorted:
1516
+ inds = torch.argsort(torch.mean(trajs[:,:,1], dim=0))
1517
+ trajs = trajs[:,inds]
1518
+ valids = valids[:,inds]
1519
+ visibs = visibs[:,inds]
1520
+
1521
+ rgb = back2color(rgb).detach().cpu().numpy()
1522
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
1523
+
1524
+ trajs = trajs.long().detach().cpu().numpy() # S, N, 2
1525
+ valids = valids.long().detach().cpu().numpy() # S, N
1526
+ visibs = visibs.long().detach().cpu().numpy() # S, N
1527
+
1528
+ rgb = rgb.astype(np.uint8).copy()
1529
+
1530
+ for i in range(min(N, max_show)):
1531
+ if cmap=='onediff' and i==0:
1532
+ cmap_ = 'spring'
1533
+ elif cmap=='onediff':
1534
+ cmap_ = 'winter'
1535
+ else:
1536
+ cmap_ = cmap
1537
+ traj = trajs[:,i] # S,2
1538
+ valid = valids[:,i] # S
1539
+ visib = visibs[:,i] # S
1540
+
1541
+ if colors is None:
1542
+ ii = i/(1e-4+N-1.0)
1543
+ color_map = cm.get_cmap(cmap)
1544
+ color = np.array(color_map(ii)[:3]) * 255 # rgb
1545
+ else:
1546
+ color = np.array(colors[i]).astype(np.int64)
1547
+ color = (int(color[0]),int(color[1]),int(color[2]))
1548
+
1549
+ for s in range(S):
1550
+ if valid[s]:
1551
+ if visib[s]:
1552
+ thickness = -1
1553
+ else:
1554
+ thickness = 2
1555
+ cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, thickness)
1556
+ rgb = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0)
1557
+ rgb = preprocess_color(rgb)
1558
+ return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
1559
+
1560
+ def summ_pts_on_rgbs(self, name, trajs, rgbs, visibs=None, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap='coolwarm', colors=None, linewidth=1, max_show=1024, frame_strs=None):
1561
+ # trajs is B, S, N, 2
1562
+ # rgbs is B, S, C, H, W
1563
+ B, S, C, H, W = rgbs.shape
1564
+ B, S2, N, D = trajs.shape
1565
+ assert(S==S2)
1566
+
1567
+ rgbs = rgbs[0] # S, C, H, W
1568
+ trajs = trajs[0] # S, N, 2
1569
+ if valids is None:
1570
+ valids = torch.ones_like(trajs[:,:,0]) # S, N
1571
+ else:
1572
+ valids = valids[0]
1573
+ if visibs is None:
1574
+ visibs = torch.ones_like(trajs[:,:,0]) # S, N
1575
+ else:
1576
+ visibs = visibs[0]
1577
+
1578
+ if N > max_show:
1579
+ inds = np.random.choice(N, max_show)
1580
+ trajs = trajs[:,inds]
1581
+ valids = valids[:,inds]
1582
+ visibs = visibs[:,inds]
1583
+ N = trajs.shape[1]
1584
+ inds = torch.argsort(torch.mean(trajs[:,:,1], dim=0))
1585
+ trajs = trajs[:,inds]
1586
+ valids = valids[:,inds]
1587
+ visibs = visibs[:,inds]
1588
+
1589
+ rgbs_color = []
1590
+ for rgb in rgbs:
1591
+ rgb = back2color(rgb).detach().cpu().numpy()
1592
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
1593
+ rgbs_color.append(rgb) # each element 3 x H x W
1594
+
1595
+ trajs = trajs.long().detach().cpu().numpy() # S, N, 2
1596
+ valids = valids.long().detach().cpu().numpy() # S, N
1597
+ visibs = visibs.long().detach().cpu().numpy() # S, N
1598
+
1599
+ rgbs_color = [rgb.astype(np.uint8).copy() for rgb in rgbs_color]
1600
+
1601
+ for i in range(min(N, max_show)):
1602
+ traj = trajs[:,i] # S,2
1603
+ valid = valids[:,i] # S
1604
+ visib = visibs[:,i] # S
1605
+
1606
+ if colors is None:
1607
+ ii = i/(1e-4+N-1.0)
1608
+ color_map = cm.get_cmap(cmap)
1609
+ color = np.array(color_map(ii)[:3]) * 255 # rgb
1610
+ else:
1611
+ color = np.array(colors[i]).astype(np.int64)
1612
+ color = (int(color[0]),int(color[1]),int(color[2]))
1613
+
1614
+ for s in range(S):
1615
+ if valid[s]:
1616
+ if visib[s]:
1617
+ thickness = -1
1618
+ else:
1619
+ thickness = 2
1620
+ cv2.circle(rgbs_color[s], (int(traj[s,0]), int(traj[s,1])), int(linewidth), color, thickness)
1621
+ rgbs = []
1622
+ for rgb in rgbs_color:
1623
+ rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
1624
+ rgbs.append(preprocess_color(rgb))
1625
+
1626
+ return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
1627
+
1628
+
1629
+ def erode2d(im, times=1):
1630
+ device = im.device
1631
+ weights2d = torch.ones(1, 1, 3, 3, dtype=im.dtype, device=device)
1632
+ for time in range(times):
1633
+ im = 1.0 - F.conv2d(1.0 - im, weights2d, padding=1).clamp(0, 1)
1634
+ return im
1635
+
1636
+ def dilate2d(im, times=1):
1637
+ device = im.device
1638
+ assert(times>0)
1639
+ dilation_kernel_size = times*2 + 1
1640
+ padding_size = dilation_kernel_size // 2
1641
+ dilation_kernel = torch.ones((1, 1, dilation_kernel_size, dilation_kernel_size), device=device)
1642
+ im = F.conv2d(im, dilation_kernel, padding=padding_size, groups=im.shape[1]).clamp(0,1)
1643
+ return im
1644
+
1645
+
utils/loss.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from typing import List
5
+ import utils.basic
6
+
7
+
8
+ def sequence_loss(
9
+ flow_preds,
10
+ flow_gt,
11
+ valids,
12
+ vis=None,
13
+ gamma=0.8,
14
+ use_huber_loss=False,
15
+ loss_only_for_visible=False,
16
+ ):
17
+ """Loss function defined over sequence of flow predictions"""
18
+ total_flow_loss = 0.0
19
+ for j in range(len(flow_gt)):
20
+ B, S, N, D = flow_gt[j].shape
21
+ B, S2, N = valids[j].shape
22
+ assert S == S2
23
+ n_predictions = len(flow_preds[j])
24
+ flow_loss = 0.0
25
+ for i in range(n_predictions):
26
+ i_weight = gamma ** (n_predictions - i - 1)
27
+ flow_pred = flow_preds[j][i]
28
+ if use_huber_loss:
29
+ i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0)
30
+ else:
31
+ i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
32
+ i_loss = torch.mean(i_loss, dim=3) # B, S, N
33
+ valid_ = valids[j].clone()
34
+ if loss_only_for_visible:
35
+ valid_ = valid_ * vis[j]
36
+ flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss, valid_)
37
+ flow_loss = flow_loss / n_predictions
38
+ total_flow_loss += flow_loss
39
+ return total_flow_loss / len(flow_gt)
40
+
41
+ def sequence_loss_dense(
42
+ flow_preds,
43
+ flow_gt,
44
+ valids,
45
+ vis=None,
46
+ gamma=0.8,
47
+ use_huber_loss=False,
48
+ loss_only_for_visible=False,
49
+ ):
50
+ """Loss function defined over sequence of flow predictions"""
51
+ total_flow_loss = 0.0
52
+ for j in range(len(flow_gt)):
53
+ # print('flow_gt[j]', flow_gt[j].shape)
54
+ B, S, D, H, W = flow_gt[j].shape
55
+ B, S2, _, H, W = valids[j].shape
56
+ assert S == S2
57
+ n_predictions = len(flow_preds[j])
58
+ flow_loss = 0.0
59
+ # import ipdb; ipdb.set_trace()
60
+ for i in range(n_predictions):
61
+ # print('flow_e[j][i]', flow_preds[j][i].shape)
62
+ i_weight = gamma ** (n_predictions - i - 1)
63
+ flow_pred = flow_preds[j][i] # B,S,2,H,W
64
+ if use_huber_loss:
65
+ i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0) # B,S,2,H,W
66
+ else:
67
+ i_loss = (flow_pred - flow_gt[j]).abs() # B,S,2,H,W
68
+ i_loss_ = torch.mean(i_loss, dim=2) # B,S,H,W
69
+ valid_ = valids[j].reshape(B,S,H,W)
70
+ # print(' (%d,%d) i_loss_' % (i,j), i_loss_.shape)
71
+ # print(' (%d,%d) valid_' % (i,j), valid_.shape)
72
+ if loss_only_for_visible:
73
+ valid_ = valid_ * vis[j].reshape(B,-1,H,W) # usually B,S,H,W, but maybe B,1,H,W
74
+ flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss_, valid_, broadcast=True)
75
+ # import ipdb; ipdb.set_trace()
76
+ flow_loss = flow_loss / n_predictions
77
+ total_flow_loss += flow_loss
78
+ return total_flow_loss / len(flow_gt)
79
+
80
+
81
+ def huber_loss(x, y, delta=1.0):
82
+ """Calculate element-wise Huber loss between x and y"""
83
+ diff = x - y
84
+ abs_diff = diff.abs()
85
+ flag = (abs_diff <= delta).float()
86
+ return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta)
87
+
88
+
89
+ def sequence_BCE_loss(vis_preds, vis_gts, valids=None, use_logits=False):
90
+ total_bce_loss = 0.0
91
+ # all_vis_preds = [torch.stack(vp) for vp in vis_preds]
92
+ # all_vis_preds = torch.stack(all_vis_preds)
93
+ # utils.basic.print_stats('all_vis_preds', all_vis_preds)
94
+ for j in range(len(vis_preds)):
95
+ n_predictions = len(vis_preds[j])
96
+ bce_loss = 0.0
97
+ for i in range(n_predictions):
98
+ # utils.basic.print_stats('vis_preds[%d][%d]' % (j,i), vis_preds[j][i])
99
+ # utils.basic.print_stats('vis_gts[%d]' % (i), vis_gts[i])
100
+ if use_logits:
101
+ loss = F.binary_cross_entropy_with_logits(vis_preds[j][i], vis_gts[j], reduction='none')
102
+ else:
103
+ loss = F.binary_cross_entropy(vis_preds[j][i], vis_gts[j], reduction='none')
104
+ if valids is None:
105
+ bce_loss += loss.mean()
106
+ else:
107
+ bce_loss += (loss * valids[j]).mean()
108
+ bce_loss = bce_loss / n_predictions
109
+ total_bce_loss += bce_loss
110
+ return total_bce_loss / len(vis_preds)
111
+
112
+
113
+ # def sequence_BCE_loss_dense(vis_preds, vis_gts):
114
+ # total_bce_loss = 0.0
115
+ # for j in range(len(vis_preds)):
116
+ # n_predictions = len(vis_preds[j])
117
+ # bce_loss = 0.0
118
+ # for i in range(n_predictions):
119
+ # vis_e = vis_preds[j][i]
120
+ # vis_g = vis_gts[j]
121
+ # print('vis_e', vis_e.shape, 'vis_g', vis_g.shape)
122
+ # vis_loss = F.binary_cross_entropy(vis_e, vis_g)
123
+ # bce_loss += vis_loss
124
+ # bce_loss = bce_loss / n_predictions
125
+ # total_bce_loss += bce_loss
126
+ # return total_bce_loss / len(vis_preds)
127
+
128
+
129
+ def sequence_prob_loss(
130
+ tracks: torch.Tensor,
131
+ confidence: torch.Tensor,
132
+ target_points: torch.Tensor,
133
+ visibility: torch.Tensor,
134
+ expected_dist_thresh: float = 12.0,
135
+ use_logits=False,
136
+ ):
137
+ """Loss for classifying if a point is within pixel threshold of its target."""
138
+ # Points with an error larger than 12 pixels are likely to be useless; marking
139
+ # them as occluded will actually improve Jaccard metrics and give
140
+ # qualitatively better results.
141
+ total_logprob_loss = 0.0
142
+ for j in range(len(tracks)):
143
+ n_predictions = len(tracks[j])
144
+ logprob_loss = 0.0
145
+ for i in range(n_predictions):
146
+ err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=-1)
147
+ valid = (err <= expected_dist_thresh**2).float()
148
+ if use_logits:
149
+ loss = F.binary_cross_entropy_with_logits(confidence[j][i], valid, reduction="none")
150
+ else:
151
+ loss = F.binary_cross_entropy(confidence[j][i], valid, reduction="none")
152
+ loss *= visibility[j]
153
+ loss = torch.mean(loss, dim=[1, 2])
154
+ logprob_loss += loss
155
+ logprob_loss = logprob_loss / n_predictions
156
+ total_logprob_loss += logprob_loss
157
+ return total_logprob_loss / len(tracks)
158
+
159
+ def sequence_prob_loss_dense(
160
+ tracks: torch.Tensor,
161
+ confidence: torch.Tensor,
162
+ target_points: torch.Tensor,
163
+ visibility: torch.Tensor,
164
+ expected_dist_thresh: float = 12.0,
165
+ use_logits=False,
166
+ ):
167
+ """Loss for classifying if a point is within pixel threshold of its target."""
168
+ # Points with an error larger than 12 pixels are likely to be useless; marking
169
+ # them as occluded will actually improve Jaccard metrics and give
170
+ # qualitatively better results.
171
+
172
+ # all_confidence = [torch.stack(vp) for vp in confidence]
173
+ # all_confidence = torch.stack(all_confidence)
174
+ # utils.basic.print_stats('all_confidence', all_confidence)
175
+
176
+ total_logprob_loss = 0.0
177
+ for j in range(len(tracks)):
178
+ n_predictions = len(tracks[j])
179
+ logprob_loss = 0.0
180
+ for i in range(n_predictions):
181
+ # print('trajs_e', tracks[j][i].shape)
182
+ # print('trajs_g', target_points[j].shape)
183
+ err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=2)
184
+ positive = (err <= expected_dist_thresh**2).float()
185
+ # print('conf', confidence[j][i].shape, 'positive', positive.shape)
186
+ if use_logits:
187
+ loss = F.binary_cross_entropy_with_logits(confidence[j][i].squeeze(2), positive, reduction="none")
188
+ else:
189
+ loss = F.binary_cross_entropy(confidence[j][i].squeeze(2), positive, reduction="none")
190
+ loss *= visibility[j].squeeze(2) # B,S,H,W
191
+ loss = torch.mean(loss, dim=[1,2,3])
192
+ logprob_loss += loss
193
+ logprob_loss = logprob_loss / n_predictions
194
+ total_logprob_loss += logprob_loss
195
+ return total_logprob_loss / len(tracks)
196
+
197
+
198
+ def masked_mean(data, mask, dim):
199
+ if mask is None:
200
+ return data.mean(dim=dim, keepdim=True)
201
+ mask = mask.float()
202
+ mask_sum = torch.sum(mask, dim=dim, keepdim=True)
203
+ mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
204
+ mask_sum, min=1.0
205
+ )
206
+ return mask_mean
207
+
208
+
209
+ def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
210
+ if mask is None:
211
+ return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
212
+ mask = mask.float()
213
+ mask_sum = torch.sum(mask, dim=dim, keepdim=True)
214
+ mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
215
+ mask_sum, min=1.0
216
+ )
217
+ mask_var = torch.sum(
218
+ mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
219
+ ) / torch.clamp(mask_sum, min=1.0)
220
+ return mask_mean.squeeze(dim), mask_var.squeeze(dim)
utils/metric.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import cv2
4
+
5
+ def _seg2bmap(seg, width=None, height=None):
6
+ """
7
+ From a segmentation, compute a binary boundary map with 1 pixel wide
8
+ boundaries. The boundary pixels are offset by 1/2 pixel towards the
9
+ origin from the actual segment boundary.
10
+ Arguments:
11
+ seg : Segments labeled from 1..k.
12
+ width : Width of desired bmap <= seg.shape[1]
13
+ height : Height of desired bmap <= seg.shape[0]
14
+ Returns:
15
+ bmap (ndarray): Binary boundary map.
16
+ David Martin <[email protected]>
17
+ January 2003
18
+ """
19
+
20
+ seg = seg.astype(bool)
21
+ seg[seg > 0] = 1
22
+
23
+ assert np.atleast_3d(seg).shape[2] == 1
24
+
25
+ width = seg.shape[1] if width is None else width
26
+ height = seg.shape[0] if height is None else height
27
+
28
+ h, w = seg.shape[:2]
29
+
30
+ ar1 = float(width) / float(height)
31
+ ar2 = float(w) / float(h)
32
+
33
+ assert not (width > w | height > h | abs(ar1 - ar2) >
34
+ 0.01), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
35
+
36
+ e = np.zeros_like(seg)
37
+ s = np.zeros_like(seg)
38
+ se = np.zeros_like(seg)
39
+
40
+ e[:, :-1] = seg[:, 1:]
41
+ s[:-1, :] = seg[1:, :]
42
+ se[:-1, :-1] = seg[1:, 1:]
43
+
44
+ b = seg ^ e | seg ^ s | seg ^ se
45
+ b[-1, :] = seg[-1, :] ^ e[-1, :]
46
+ b[:, -1] = seg[:, -1] ^ s[:, -1]
47
+ b[-1, -1] = 0
48
+
49
+ if w == width and h == height:
50
+ bmap = b
51
+ else:
52
+ bmap = np.zeros((height, width))
53
+ for x in range(w):
54
+ for y in range(h):
55
+ if b[y, x]:
56
+ j = 1 + math.floor((y - 1) + height / h)
57
+ i = 1 + math.floor((x - 1) + width / h)
58
+ bmap[j, i] = 1
59
+
60
+ return bmap
61
+
62
+ # https://github.com/davisvideochallenge/davis2017-evaluation/blob/master/davis2017/metrics.py#L6
63
+ def db_eval_iou(annotation, segmentation, void_pixels=None):
64
+ """ Compute region similarity as the Jaccard Index.
65
+ Arguments:
66
+ annotation (ndarray): binary annotation map.
67
+ segmentation (ndarray): binary segmentation map.
68
+ void_pixels (ndarray): optional mask with void pixels
69
+
70
+ Return:
71
+ jaccard (float): region similarity
72
+ """
73
+ assert annotation.shape == segmentation.shape, \
74
+ f'Annotation({annotation.shape}) and segmentation:{segmentation.shape} dimensions do not match.'
75
+ annotation = annotation.astype(bool)
76
+ segmentation = segmentation.astype(bool)
77
+
78
+ if void_pixels is not None:
79
+ assert annotation.shape == void_pixels.shape, \
80
+ f'Annotation({annotation.shape}) and void pixels:{void_pixels.shape} dimensions do not match.'
81
+ void_pixels = void_pixels.astype(bool)
82
+ else:
83
+ void_pixels = np.zeros_like(segmentation)
84
+
85
+ # Intersection between all sets
86
+ inters = np.sum((segmentation & annotation) & np.logical_not(void_pixels), axis=(-2, -1))
87
+ union = np.sum((segmentation | annotation) & np.logical_not(void_pixels), axis=(-2, -1))
88
+
89
+ j = inters / union
90
+ if j.ndim == 0:
91
+ j = 0 if np.isclose(union, 0) else j
92
+ else:
93
+ j[np.isclose(union, 0)] = 0
94
+ return j
95
+
96
+
97
+ def db_eval_boundary(annotation, segmentation, void_pixels=None, bound_th=0.008):
98
+ assert annotation.shape == segmentation.shape
99
+ if void_pixels is not None:
100
+ assert annotation.shape == void_pixels.shape
101
+ if annotation.ndim == 3:
102
+ n_frames = annotation.shape[0]
103
+ f_res = np.zeros(n_frames)
104
+ for frame_id in range(n_frames):
105
+ void_pixels_frame = None if void_pixels is None else void_pixels[frame_id, :, :, ]
106
+ f_res[frame_id] = f_measure(segmentation[frame_id, :, :, ], annotation[frame_id, :, :], void_pixels_frame, bound_th=bound_th)
107
+ elif annotation.ndim == 2:
108
+ f_res = f_measure(segmentation, annotation, void_pixels, bound_th=bound_th)
109
+ else:
110
+ raise ValueError(f'db_eval_boundary does not support tensors with {annotation.ndim} dimensions')
111
+ return f_res
112
+
113
+
114
+ def f_measure(foreground_mask, gt_mask, void_pixels=None, bound_th=0.008):
115
+ """
116
+ Compute mean,recall and decay from per-frame evaluation.
117
+ Calculates precision/recall for boundaries between foreground_mask and
118
+ gt_mask using morphological operators to speed it up.
119
+
120
+ Arguments:
121
+ foreground_mask (ndarray): binary segmentation image.
122
+ gt_mask (ndarray): binary annotated image.
123
+ void_pixels (ndarray): optional mask with void pixels
124
+
125
+ Returns:
126
+ F (float): boundaries F-measure
127
+ """
128
+ assert np.atleast_3d(foreground_mask).shape[2] == 1
129
+ if void_pixels is not None:
130
+ void_pixels = void_pixels.astype(bool)
131
+ else:
132
+ void_pixels = np.zeros_like(foreground_mask).astype(bool)
133
+
134
+ bound_pix = bound_th if bound_th >= 1 else \
135
+ np.ceil(bound_th * np.linalg.norm(foreground_mask.shape))
136
+
137
+ # Get the pixel boundaries of both masks
138
+ fg_boundary = _seg2bmap(foreground_mask * np.logical_not(void_pixels))
139
+ gt_boundary = _seg2bmap(gt_mask * np.logical_not(void_pixels))
140
+
141
+ from skimage.morphology import disk
142
+
143
+ # fg_dil = binary_dilation(fg_boundary, disk(bound_pix))
144
+ fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))
145
+ # gt_dil = binary_dilation(gt_boundary, disk(bound_pix))
146
+ gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))
147
+
148
+ # Get the intersection
149
+ gt_match = gt_boundary * fg_dil
150
+ fg_match = fg_boundary * gt_dil
151
+
152
+ # Area of the intersection
153
+ n_fg = np.sum(fg_boundary)
154
+ n_gt = np.sum(gt_boundary)
155
+
156
+ # % Compute precision and recall
157
+ if n_fg == 0 and n_gt > 0:
158
+ precision = 1
159
+ recall = 0
160
+ elif n_fg > 0 and n_gt == 0:
161
+ precision = 0
162
+ recall = 1
163
+ elif n_fg == 0 and n_gt == 0:
164
+ precision = 1
165
+ recall = 1
166
+ else:
167
+ precision = np.sum(fg_match) / float(n_fg)
168
+ recall = np.sum(gt_match) / float(n_gt)
169
+
170
+ # Compute F measure
171
+ if precision + recall == 0:
172
+ F = 0
173
+ else:
174
+ F = 2 * precision * recall / (precision + recall)
175
+
176
+ return F
utils/misc.py ADDED
@@ -0,0 +1,1062 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import math
4
+ import torch.nn.functional as F
5
+ import utils.basic
6
+ from typing import Tuple, Union
7
+
8
+ def standardize_test_data(rgbs, trajs, visibs, valids, S_cap=600, only_first=False, seq_len=None):
9
+ trajs = trajs.astype(np.float32) # S,N,2
10
+ visibs = visibs.astype(np.float32) # S,N
11
+ valids = valids.astype(np.float32) # S,N
12
+
13
+ visval_ok = np.sum(valids*visibs, axis=0) > 1
14
+ trajs = trajs[:,visval_ok]
15
+ visibs = visibs[:,visval_ok]
16
+ valids = valids[:,visval_ok]
17
+
18
+ # fill in missing data
19
+ N = trajs.shape[1]
20
+ for ni in range(N):
21
+ trajs[:,ni] = utils.misc.data_replace_with_nearest(trajs[:,ni], valids[:,ni])
22
+
23
+ # use cap or seq_len
24
+ if seq_len is not None:
25
+ S = min(len(rgbs), seq_len)
26
+ else:
27
+ S = len(rgbs)
28
+ S = min(S, S_cap)
29
+
30
+ if only_first:
31
+ # we'll find the best frame to start on
32
+ best_count = 0
33
+ best_ind = 0
34
+
35
+ for si in range(0,len(rgbs)-64):
36
+ # try this slice
37
+ visibs_ = visibs[si:min(si+S,len(rgbs)+1)] # S,N
38
+ valids_ = valids[si:min(si+S,len(rgbs)+1)] # S,N
39
+ visval_ok0 = (visibs_[0]*valids_[0]) > 0 # N
40
+ visval_okA = np.sum(visibs_*valids_, axis=0) > 1 # N
41
+ all_ok = visval_ok0 & visval_okA
42
+ # print('- slicing %d to %d; sum(ok) %d' % (si, min(si+S,len(rgbs)+1), np.sum(all_ok)))
43
+ if np.sum(all_ok) > best_count:
44
+ best_count = np.sum(all_ok)
45
+ best_ind = si
46
+ si = best_ind
47
+ rgbs = rgbs[si:si+S]
48
+ trajs = trajs[si:si+S]
49
+ visibs = visibs[si:si+S]
50
+ valids = valids[si:si+S]
51
+ vis_ok0 = visibs[0] > 0 # N
52
+ trajs = trajs[:,vis_ok0]
53
+ visibs = visibs[:,vis_ok0]
54
+ valids = valids[:,vis_ok0]
55
+ # print('- best_count', best_count, 'best_ind', best_ind)
56
+
57
+ if seq_len is not None:
58
+ rgbs = rgbs[:seq_len]
59
+ trajs = trajs[:seq_len]
60
+ valids = valids[:seq_len]
61
+
62
+ # req two timesteps valid (after seqlen trim)
63
+ visval_ok = np.sum(visibs*valids, axis=0) > 1
64
+ trajs = trajs[:,visval_ok]
65
+ valids = valids[:,visval_ok]
66
+ visibs = visibs[:,visval_ok]
67
+
68
+ return rgbs, trajs, visibs, valids
69
+
70
+
71
+
72
+ def get_2d_sincos_pos_embed(
73
+ embed_dim: int, grid_size: Union[int, Tuple[int, int]]
74
+ ) -> torch.Tensor:
75
+ """
76
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
77
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
78
+ Args:
79
+ - embed_dim: The embedding dimension.
80
+ - grid_size: The grid size.
81
+ Returns:
82
+ - pos_embed: The generated 2D positional embedding.
83
+ """
84
+ if isinstance(grid_size, tuple):
85
+ grid_size_h, grid_size_w = grid_size
86
+ else:
87
+ grid_size_h = grid_size_w = grid_size
88
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
89
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
90
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
91
+ grid = torch.stack(grid, dim=0)
92
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
93
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
94
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
95
+
96
+
97
+ def get_2d_sincos_pos_embed_from_grid(
98
+ embed_dim: int, grid: torch.Tensor
99
+ ) -> torch.Tensor:
100
+ """
101
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
102
+
103
+ Args:
104
+ - embed_dim: The embedding dimension.
105
+ - grid: The grid to generate the embedding from.
106
+
107
+ Returns:
108
+ - emb: The generated 2D positional embedding.
109
+ """
110
+ assert embed_dim % 2 == 0
111
+
112
+ # use half of dimensions to encode grid_h
113
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
114
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
115
+
116
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
117
+ return emb
118
+
119
+
120
+ def get_1d_sincos_pos_embed_from_grid(
121
+ embed_dim: int, pos: torch.Tensor
122
+ ) -> torch.Tensor:
123
+ """
124
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
125
+
126
+ Args:
127
+ - embed_dim: The embedding dimension.
128
+ - pos: The position to generate the embedding from.
129
+
130
+ Returns:
131
+ - emb: The generated 1D positional embedding.
132
+ """
133
+ assert embed_dim % 2 == 0
134
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
135
+ omega /= embed_dim / 2.0
136
+ omega = 1.0 / 10000**omega # (D/2,)
137
+
138
+ pos = pos.reshape(-1) # (M,)
139
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
140
+
141
+ emb_sin = torch.sin(out) # (M, D/2)
142
+ emb_cos = torch.cos(out) # (M, D/2)
143
+
144
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
145
+ return emb[None].float()
146
+
147
+
148
+ def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
149
+ """
150
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
151
+
152
+ Args:
153
+ - xy: The coordinates to generate the embedding from.
154
+ - C: The size of the embedding.
155
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
156
+
157
+ Returns:
158
+ - pe: The generated 2D positional embedding.
159
+ """
160
+ B, N, D = xy.shape
161
+ assert D == 2
162
+
163
+ x = xy[:, :, 0:1]
164
+ y = xy[:, :, 1:2]
165
+ div_term = (
166
+ torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
167
+ ).reshape(1, 1, int(C / 2))
168
+
169
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
170
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
171
+
172
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
173
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
174
+
175
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
176
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
177
+
178
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
179
+ if cat_coords:
180
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
181
+ return pe
182
+
183
+ # from datasets.dataset import mask2bbox
184
+
185
+ # from pips2
186
+ def posemb_sincos_2d_xy(xy, C, temperature=10000, dtype=torch.float32, cat_coords=False):
187
+ device = xy.device
188
+ dtype = xy.dtype
189
+ B, S, D = xy.shape
190
+ assert(D==2)
191
+ x = xy[:,:,0]
192
+ y = xy[:,:,1]
193
+ assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
194
+ omega = torch.arange(C // 4, device=device) / (C // 4 - 1)
195
+ omega = 1. / (temperature ** omega)
196
+
197
+ y = y.flatten()[:, None] * omega[None, :]
198
+ x = x.flatten()[:, None] * omega[None, :]
199
+ pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
200
+ pe = pe.reshape(B,S,C).type(dtype)
201
+ if cat_coords:
202
+ pe = torch.cat([pe, xy], dim=2) # B,N,C+2
203
+ return pe
204
+
205
+
206
+ # # prevent circular imports
207
+ # def mask2bbox(mask):
208
+ # if mask.ndim == 3:
209
+ # mask = mask[..., 0]
210
+ # ys, xs = np.where(mask > 0.4)
211
+ # if ys.size == 0 or xs.size==0:
212
+ # return np.array((0, 0, 0, 0), dtype=int)
213
+ # lt = np.array([np.min(xs), np.min(ys)])
214
+ # rb = np.array([np.max(xs), np.max(ys)]) + 1
215
+ # return np.concatenate([lt, rb])
216
+
217
+ # def get_stark_2d_embedding(H, W, C=64, device='cuda:0', temperature=10000, normalize=True):
218
+ # scale = 2*math.pi
219
+ # mask = torch.ones((1,H,W), dtype=torch.float32, device=device)
220
+ # y_embed = mask.cumsum(1, dtype=torch.float32) # cumulative sum along axis 1 (h axis) --> (b, h, w)
221
+ # x_embed = mask.cumsum(2, dtype=torch.float32) # cumulative sum along axis 2 (w axis) --> (b, h, w)
222
+ # if normalize:
223
+ # eps = 1e-6
224
+ # y_embed = y_embed / (y_embed[:, -1:, :] + eps) * scale # 2pi * (y / sigma(y))
225
+ # x_embed = x_embed / (x_embed[:, :, -1:] + eps) * scale # 2pi * (x / sigma(x))
226
+
227
+ # dim_t = torch.arange(C, dtype=torch.float32, device=device) # (0,1,2,...,d/2)
228
+ # dim_t = temperature ** (2 * (dim_t // 2) / C)
229
+
230
+ # pos_x = x_embed[:, :, :, None] / dim_t # (b,h,w,d/2)
231
+ # pos_y = y_embed[:, :, :, None] / dim_t # (b,h,w,d/2)
232
+ # pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) # (b,h,w,d/2)
233
+ # pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) # (b,h,w,d/2)
234
+ # pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) # (b,h,w,d)
235
+ # return pos
236
+
237
+ # def get_1d_embedding(x, C, cat_coords=False):
238
+ # B, N, D = x.shape
239
+ # assert(D==1)
240
+
241
+ # div_term = (torch.arange(0, C, 2, device=x.device, dtype=torch.float32) * (10000.0 / C)).reshape(1, 1, int(C/2))
242
+
243
+ # pe_x = torch.zeros(B, N, C, device=x.device, dtype=torch.float32)
244
+
245
+ # pe_x[:, :, 0::2] = torch.sin(x * div_term)
246
+ # pe_x[:, :, 1::2] = torch.cos(x * div_term)
247
+
248
+ # if cat_coords:
249
+ # pe_x = torch.cat([pe, x], dim=2) # B,N,C*2+2
250
+ # return pe_x
251
+
252
+ # def posemb_sincos_2d(h, w, dim, temperature=10000, dtype=torch.float32, device='cuda:0'):
253
+
254
+ # y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
255
+ # assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
256
+ # omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
257
+ # omega = 1. / (temperature ** omega)
258
+
259
+ # y = y.flatten()[:, None] * omega[None, :]
260
+ # x = x.flatten()[:, None] * omega[None, :]
261
+ # pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) # B,C,H,W
262
+ # return pe.type(dtype)
263
+
264
+ def iou(bbox1, bbox2):
265
+ # bbox1, bbox2: [x1, y1, x2, y2]
266
+ x1, y1, x2, y2 = bbox1
267
+ x1_, y1_, x2_, y2_ = bbox2
268
+ inter_x1 = max(x1, x1_)
269
+ inter_y1 = max(y1, y1_)
270
+ inter_x2 = min(x2, x2_)
271
+ inter_y2 = min(y2, y2_)
272
+ inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
273
+ area1 = (x2 - x1) * (y2 - y1)
274
+ area2 = (x2_ - x1_) * (y2_ - y1_)
275
+ iou = inter_area / (area1 + area2 - inter_area)
276
+ return iou
277
+
278
+ # def get_2d_embedding(xy, C, cat_coords=False):
279
+ # B, N, D = xy.shape
280
+ # assert(D==2)
281
+
282
+ # x = xy[:,:,0:1]
283
+ # y = xy[:,:,1:2]
284
+ # div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (10000.0 / C)).reshape(1, 1, int(C/2))
285
+
286
+ # pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
287
+ # pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
288
+
289
+ # pe_x[:, :, 0::2] = torch.sin(x * div_term)
290
+ # pe_x[:, :, 1::2] = torch.cos(x * div_term)
291
+
292
+ # pe_y[:, :, 0::2] = torch.sin(y * div_term)
293
+ # pe_y[:, :, 1::2] = torch.cos(y * div_term)
294
+
295
+ # pe = torch.cat([pe_x, pe_y], dim=2) # B,N,C*2
296
+ # if cat_coords:
297
+ # pe = torch.cat([pe, xy], dim=2) # B,N,C*2+2
298
+ # return pe
299
+
300
+ # def get_3d_embedding(xyz, C, cat_coords=False):
301
+ # B, N, D = xyz.shape
302
+ # assert(D==3)
303
+
304
+ # x = xyz[:,:,0:1]
305
+ # y = xyz[:,:,1:2]
306
+ # z = xyz[:,:,2:3]
307
+ # div_term = (torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (10000.0 / C)).reshape(1, 1, int(C/2))
308
+
309
+ # pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
310
+ # pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
311
+ # pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
312
+
313
+ # pe_x[:, :, 0::2] = torch.sin(x * div_term)
314
+ # pe_x[:, :, 1::2] = torch.cos(x * div_term)
315
+
316
+ # pe_y[:, :, 0::2] = torch.sin(y * div_term)
317
+ # pe_y[:, :, 1::2] = torch.cos(y * div_term)
318
+
319
+ # pe_z[:, :, 0::2] = torch.sin(z * div_term)
320
+ # pe_z[:, :, 1::2] = torch.cos(z * div_term)
321
+
322
+ # pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3
323
+ # if cat_coords:
324
+ # pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3
325
+ # return pe
326
+
327
+ class SimplePool():
328
+ def __init__(self, pool_size, version='pt', min_size=1):
329
+ self.pool_size = pool_size
330
+ self.version = version
331
+ self.items = []
332
+ self.min_size = min_size
333
+
334
+ if not (version=='pt' or version=='np'):
335
+ print('version = %s; please choose pt or np')
336
+ assert(False) # please choose pt or np
337
+
338
+ def __len__(self):
339
+ return len(self.items)
340
+
341
+ def mean(self, min_size=None):
342
+ if min_size is None:
343
+ pool_size_thresh = self.min_size
344
+ elif min_size=='half':
345
+ pool_size_thresh = self.pool_size/2
346
+ else:
347
+ pool_size_thresh = min_size
348
+
349
+ if self.version=='np':
350
+ if len(self.items) >= pool_size_thresh:
351
+ return np.sum(self.items)/float(len(self.items))
352
+ else:
353
+ return np.nan
354
+ if self.version=='pt':
355
+ if len(self.items) >= pool_size_thresh:
356
+ return torch.sum(self.items)/float(len(self.items))
357
+ else:
358
+ return torch.from_numpy(np.nan)
359
+
360
+ def sample(self, with_replacement=True):
361
+ idx = np.random.randint(len(self.items))
362
+ if with_replacement:
363
+ return self.items[idx]
364
+ else:
365
+ return self.items.pop(idx)
366
+
367
+ def fetch(self, num=None):
368
+ if self.version=='pt':
369
+ item_array = torch.stack(self.items)
370
+ elif self.version=='np':
371
+ item_array = np.stack(self.items)
372
+ if num is not None:
373
+ # there better be some items
374
+ assert(len(self.items) >= num)
375
+
376
+ # if there are not that many elements just return however many there are
377
+ if len(self.items) < num:
378
+ return item_array
379
+ else:
380
+ idxs = np.random.randint(len(self.items), size=num)
381
+ return item_array[idxs]
382
+ else:
383
+ return item_array
384
+
385
+ def is_full(self):
386
+ full = len(self.items)==self.pool_size
387
+ return full
388
+
389
+ def empty(self):
390
+ self.items = []
391
+
392
+ def have_min_size(self):
393
+ return len(self.items) >= self.min_size
394
+
395
+
396
+ def update(self, items):
397
+ for item in items:
398
+ if len(self.items) < self.pool_size:
399
+ # the pool is not full, so let's add this in
400
+ self.items.append(item)
401
+ else:
402
+ # the pool is full
403
+ # pop from the front
404
+ self.items.pop(0)
405
+ # add to the back
406
+ self.items.append(item)
407
+ return self.items
408
+
409
+
410
+ class SimpleHeap():
411
+ def __init__(self, pool_size, version='pt'):
412
+ self.pool_size = pool_size
413
+ self.version = version
414
+ self.items = []
415
+ self.vals = []
416
+
417
+ if not (version=='pt' or version=='np'):
418
+ print('version = %s; please choose pt or np')
419
+ assert(False) # please choose pt or np
420
+
421
+ def __len__(self):
422
+ return len(self.items)
423
+
424
+ def sample(self, random=True, with_replacement=True, semirandom=False):
425
+ vals_arr = np.stack(self.vals)
426
+ if random:
427
+ ind = np.random.randint(len(self.items))
428
+ else:
429
+ if semirandom and len(vals_arr)>1:
430
+ # choose from the harder half
431
+ inds = np.argsort(vals_arr) # ascending
432
+ inds = inds[len(vals_arr)//2:]
433
+ ind = np.random.choice(inds)
434
+ else:
435
+ # find the most valuable element
436
+ ind = np.argmax(vals_arr)
437
+
438
+
439
+ if with_replacement:
440
+ return self.items[ind]
441
+ else:
442
+ item = self.items.pop(ind)
443
+ val = self.vals.pop(ind)
444
+ return item
445
+
446
+ def fetch(self, num=None):
447
+ if self.version=='pt':
448
+ item_array = torch.stack(self.items)
449
+ elif self.version=='np':
450
+ item_array = np.stack(self.items)
451
+ if num is not None:
452
+ # there better be some items
453
+ assert(len(self.items) >= num)
454
+
455
+ # if there are not that many elements just return however many there are
456
+ if len(self.items) < num:
457
+ return item_array
458
+ else:
459
+ idxs = np.random.randint(len(self.items), size=num)
460
+ return item_array[idxs]
461
+ else:
462
+ return item_array
463
+
464
+ def is_full(self):
465
+ full = len(self.items)==self.pool_size
466
+ return full
467
+
468
+ def empty(self):
469
+ self.items = []
470
+
471
+ def update(self, vals, items):
472
+ for val,item in zip(vals, items):
473
+ if len(self.items) < self.pool_size:
474
+ # the pool is not full, so let's add this in
475
+ self.items.append(item)
476
+ self.vals.append(val)
477
+ else:
478
+ # the pool is full
479
+ # find our least-valuable element
480
+ # and see if we should replace it
481
+ vals_arr = np.stack(self.vals)
482
+ ind = np.argmin(vals_arr)
483
+ if vals_arr[ind] < val:
484
+ # pop the min
485
+ self.items.pop(ind)
486
+ self.vals.pop(ind)
487
+
488
+ # add to the back
489
+ self.items.append(item)
490
+ self.vals.append(val)
491
+ return self.items
492
+
493
+
494
+ def farthest_point_sample(xyz, npoint, include_ends=False, deterministic=False):
495
+ """
496
+ Input:
497
+ xyz: pointcloud data, [B, N, C], where C is probably 3
498
+ npoint: number of samples
499
+ Return:
500
+ inds: sampled pointcloud index, [B, npoint]
501
+ """
502
+ device = xyz.device
503
+ B, N, C = xyz.shape
504
+ xyz = xyz.float()
505
+ inds = torch.zeros(B, npoint, dtype=torch.long, device=device)
506
+ distance = torch.ones((B, N), dtype=torch.float32, device=device) * 1e10
507
+ if deterministic:
508
+ farthest = torch.randint(0, 1, (B,), dtype=torch.long, device=device)
509
+ else:
510
+ farthest = torch.randint(0, N, (B,), dtype=torch.long, device=device)
511
+ batch_indices = torch.arange(B, dtype=torch.long, device=device)
512
+ for i in range(npoint):
513
+ if include_ends:
514
+ if i==0:
515
+ farthest = 0
516
+ elif i==1:
517
+ farthest = N-1
518
+ inds[:, i] = farthest
519
+ centroid = xyz[batch_indices, farthest, :].view(B, 1, C)
520
+ dist = torch.sum((xyz - centroid) ** 2, -1)
521
+ mask = dist < distance
522
+ distance[mask] = dist[mask]
523
+ farthest = torch.max(distance, -1)[1]
524
+
525
+ if npoint > N:
526
+ # if we need more samples, make them random
527
+ distance += torch.randn_like(distance)
528
+ return inds
529
+
530
+ def farthest_point_sample_py(xyz, npoint, deterministic=False):
531
+ N,C = xyz.shape
532
+ inds = np.zeros(npoint, dtype=np.int32)
533
+ distance = np.ones(N) * 1e10
534
+ if deterministic:
535
+ farthest = 0
536
+ else:
537
+ farthest = np.random.randint(0, N, dtype=np.int32)
538
+ for i in range(npoint):
539
+ inds[i] = farthest
540
+ centroid = xyz[farthest, :].reshape(1,C)
541
+ dist = np.sum((xyz - centroid) ** 2, -1)
542
+ mask = dist < distance
543
+ distance[mask] = dist[mask]
544
+ farthest = np.argmax(distance, -1)
545
+ if npoint > N:
546
+ # if we need more samples, make them random
547
+ distance += np.random.randn(*distance.shape)
548
+ return inds
549
+
550
+ def balanced_ce_loss(pred, gt, pos_weight=0.5, valid=None, dim=None, return_both=False, use_halfmask=False, H=64, W=64):
551
+ # # pred and gt are the same shape
552
+ # for (a,b) in zip(pred.size(), gt.size()):
553
+ # if not a==b:
554
+ # print('mismatch: pred, gt', pred.shape, gt.shape)
555
+ # assert(a==b) # some shape mismatch!
556
+
557
+ pred = pred.reshape(-1)
558
+ gt = gt.reshape(-1)
559
+ device = pred.device
560
+
561
+ if valid is not None:
562
+ valid = valid.reshape(-1)
563
+ for (a,b) in zip(pred.size(), valid.size()):
564
+ assert(a==b) # some shape mismatch!
565
+ else:
566
+ valid = torch.ones_like(gt)
567
+
568
+ pos = (gt > 0.95).float()
569
+ if use_halfmask:
570
+ pos_wide = (gt >= 0.5).float()
571
+ halfmask = (gt == 0.5).float()
572
+ else:
573
+ pos_wide = pos
574
+
575
+ neg = (gt < 0.05).float()
576
+
577
+ label = pos_wide*2.0 - 1.0
578
+ a = -label * pred
579
+ b = F.relu(a)
580
+ loss = b + torch.log(torch.exp(-b)+torch.exp(a-b))
581
+
582
+ if torch.sum(pos*valid)>0:
583
+ pos_loss = loss[(pos*valid) > 0].mean()
584
+ else:
585
+ pos_loss = torch.tensor(0.0, requires_grad=True, device=device)
586
+
587
+ if torch.sum(neg*valid)>0:
588
+ neg_loss = loss[(neg*valid) > 0].mean()
589
+ else:
590
+ neg_loss = torch.tensor(0.0, requires_grad=True, device=device)
591
+ balanced_loss = pos_weight*pos_loss + (1-pos_weight)*neg_loss
592
+ return balanced_loss
593
+
594
+ # pos_loss = utils.basic.reduce_masked_mean(loss, pos*valid, dim=dim)
595
+ # neg_loss = utils.basic.reduce_masked_mean(loss, neg*valid, dim=dim)
596
+
597
+ if use_halfmask:
598
+ # here we will find the pixels which are already leaning positive,
599
+ # and encourage them to be more positive
600
+ B = loss.shape[0]
601
+ loss_ = loss.reshape(B,-1)
602
+ mask_ = halfmask.reshape(B,-1) * valid.reshape(B,-1)
603
+
604
+ # to avoid the issue where spikes become spikier,
605
+ # we will only apply this loss on batch els where we predicted zero positives
606
+ pred_sig_ = torch.sigmoid(pred).reshape(B,-1)
607
+ no_pred_ = torch.max(pred_sig_.round(), axis=1)[0] < 1 # B
608
+ # and only on batch els where we have negatives available
609
+ have_neg_ = torch.sum(neg, dim=1)>0 # B
610
+
611
+ loss_ = loss_[no_pred_ & have_neg_] # N,H*W
612
+ mask_ = mask_[no_pred_ & have_neg_] # N,H*W
613
+ N = loss_.shape[0]
614
+
615
+ if N > 0:
616
+ # we want:
617
+ # in the neg pixels,
618
+ # set them to the max loss of the pos pixels,
619
+ # so that they do not contribute to the min
620
+ loss__ = loss_.reshape(-1)
621
+ mask__ = mask_.reshape(-1)
622
+ if torch.sum(mask__)>0:
623
+ # print('loss_', loss_.shape, 'mask_', mask_.shape, 'loss__', loss__.shape, 'mask__', mask__.shape)
624
+ mloss__ = loss__.detach()
625
+ mloss__[mask__==0] = torch.max(loss__[mask__==1])
626
+ mloss_ = mloss__.reshape(N,H*W)
627
+
628
+ # now, in each batch el, take a tiny region around the argmin, so we can boost this region
629
+ minloss_mask_ = torch.zeros_like(mloss_).scatter(1,mloss_.argmin(1,True),value=1)
630
+ minloss_mask_ = utils.improc.dilate2d(minloss_mask_.view(N,1,H,W), times=3).reshape(N,H*W)
631
+
632
+ loss__ = loss_.reshape(-1)
633
+ minloss_mask__ = minloss_mask_.reshape(-1)
634
+ half_loss = loss__[minloss_mask__>0].mean()
635
+
636
+ # print('N', N, 'half_loss', half_loss)
637
+ pos_loss = pos_loss + half_loss
638
+
639
+ # if False:
640
+ # min_pos = 8
641
+
642
+ # # only apply the loss when we have some negatives available,
643
+ # # otherwise it's a whole "ignore" frame, which may mean
644
+ # # we are unsure if the target is even there
645
+ # if torch.sum(mask__==0) > 0: # negatives available
646
+ # # only apply the loss when the halfmask is larger area than
647
+ # # min_pos (the number of pixels we want to boost),
648
+ # # so that indexing will work
649
+ # if torch.all(torch.sum(mask_==1, dim=1) >= min_pos): # topk indexing will work
650
+ # # in the pixels we will not use,
651
+ # # set them to the max of the pixels we may use,
652
+ # # so that they do not contribute to the min
653
+ # loss__[mask__==0] = torch.max(loss__[mask__==1])
654
+ # loss_ = loss__.reshape(B,-1)
655
+
656
+ # half_loss = torch.mean(torch.topk(loss_, min_pos, dim=1, largest=False)[0], dim=1) # B
657
+
658
+ # have_neg = (torch.sum(neg, dim=1)>0).float() # B
659
+ # pos_loss = pos_loss + half_loss*have_neg
660
+
661
+
662
+
663
+
664
+
665
+
666
+ # half_loss = []
667
+ # for b in range(B):
668
+ # loss_b = loss_[b]
669
+ # mask_b = mask_[b]
670
+ # if torch.sum(mask_b):
671
+ # inds = torch.nonzero(mask_b).reshape(-1)
672
+ # half_loss.append(torch.min(loss_b[inds]))
673
+ # if len(half_loss):
674
+ # # # half_loss_ = half_loss.reshape(-1)
675
+ # # half_loss = torch.min(half_loss, dim=1)[0] # B
676
+ # # half_loss = torch.mean(torch.topk(half_loss, 4, dim=1, largest=False)[0], dim=1) # B
677
+ # pos_loss = pos_loss + torch.stack(half_loss).mean()
678
+
679
+ if return_both:
680
+ return pos_loss, neg_loss
681
+ balanced_loss = pos_weight*pos_loss + (1-pos_weight)*neg_loss
682
+
683
+ return balanced_loss
684
+
685
+ def dice_loss(pred, gt):
686
+ # gt has ignores at 0.5
687
+ # pred and gt are the same shape
688
+ for (a,b) in zip(pred.size(), gt.size()):
689
+ assert(a==b) # some shape mismatch!
690
+
691
+ prob = pred.sigmoid()
692
+
693
+ # flatten everything except batch
694
+ prob = prob.flatten(1)
695
+ gt = gt.flatten(1)
696
+
697
+ pos = (gt > 0.95).float()
698
+ neg = (gt < 0.05).float()
699
+ valid = (pos+neg).float().clamp(0,1)
700
+
701
+ numerator = 2 * (prob * pos * valid).sum(1)
702
+ denominator = (prob*valid).sum(1) + (pos*valid).sum(1)
703
+ loss = 1 - (numerator + 1) / (denominator + 1)
704
+ return loss
705
+
706
+ def sigmoid_focal_loss(pred, gt, alpha=0.25, gamma=2):#, use_halfmask=False):
707
+ # gt has ignores at 0.5
708
+ # pred and gt are the same shape
709
+ for (a,b) in zip(pred.size(), gt.size()):
710
+ assert(a==b) # some shape mismatch!
711
+
712
+ # flatten everything except batch
713
+ pred = pred.flatten(1)
714
+ gt = gt.flatten(1)
715
+
716
+ pos = (gt > 0.95).float()
717
+ neg = (gt < 0.05).float()
718
+ # if use_halfmask:
719
+ # pos_wide = (gt >= 0.5).float()
720
+ # halfmask = (gt == 0.5).float()
721
+ # else:
722
+ # pos_wide = pos
723
+ valid = (pos+neg).float().clamp(0,1)
724
+
725
+ prob = pred.sigmoid()
726
+ ce_loss = F.binary_cross_entropy_with_logits(pred, pos, reduction="none")
727
+ p_t = prob * pos + (1 - prob) * (1 - pos)
728
+ loss = ce_loss * ((1 - p_t) ** gamma)
729
+
730
+ if alpha >= 0:
731
+ alpha_t = alpha * pos + (1 - alpha) * (1 - pos)
732
+ loss = alpha_t * loss
733
+
734
+ loss = (loss*valid).sum(1) / (1 + valid.sum(1))
735
+ return loss
736
+
737
+ # def dice_loss(inputs, targets, normalizer=1):
738
+ # inputs = inputs.sigmoid()
739
+ # inputs = inputs.flatten(1)
740
+ # numerator = 2 * (inputs * targets).sum(1)
741
+ # denominator = inputs.sum(-1) + targets.sum(-1)
742
+ # loss = 1 - (numerator + 1) / (denominator + 1)
743
+ # return loss.sum() / normalizer
744
+
745
+ # def sigmoid_focal_loss(inputs, targets, normalizer=1, alpha=0.25, gamma=2):
746
+ # prob = inputs.sigmoid()
747
+ # ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
748
+ # p_t = prob * targets + (1 - prob) * (1 - targets)
749
+ # loss = ce_loss * ((1 - p_t) ** gamma)
750
+
751
+ # if alpha >= 0:
752
+ # alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
753
+ # loss = alpha_t * loss
754
+
755
+ # return loss.mean(1).sum() / normalizer
756
+
757
+ def data_replace_with_nearest(xys, valids):
758
+ # replace invalid xys with nearby ones
759
+ invalid_idx = np.where(valids==0)[0]
760
+ valid_idx = np.where(valids==1)[0]
761
+ for idx in invalid_idx:
762
+ nearest = valid_idx[np.argmin(np.abs(valid_idx - idx))]
763
+ xys[idx] = xys[nearest]
764
+ return xys
765
+
766
+ def data_get_traj_from_masks(masks):
767
+ if masks.ndim==4:
768
+ masks = masks[...,0]
769
+ S, H, W = masks.shape
770
+ masks = (masks > 0.1).astype(np.float32)
771
+ fills = np.zeros((S))
772
+ xy_means = np.zeros((S,2))
773
+ xy_rands = np.zeros((S,2))
774
+ valids = np.zeros((S))
775
+ for si, mask in enumerate(masks):
776
+ if np.sum(mask) > 0:
777
+ ys, xs = np.where(mask)
778
+ inds = np.random.permutation(len(xs))
779
+ xs, ys = xs[inds], ys[inds]
780
+ x0, x1 = np.min(xs), np.max(xs)+1
781
+ y0, y1 = np.min(ys), np.max(ys)+1
782
+ # if (x1-x0)>0 and (y1-y0)>0:
783
+ xy_means[si] = np.array([xs.mean(), ys.mean()])
784
+ xy_rands[si] = np.array([xs[0], ys[0]])
785
+ valids[si] = 1
786
+ crop = mask[y0:y1, x0:x1]
787
+ fill = np.mean(crop)
788
+ fills[si] = fill
789
+ # print('fills', fills)
790
+ return xy_means, xy_rands, valids, fills
791
+
792
+ def data_zoom(zoom, xys, visibs, rgbs, valids=None, masks=None, masks2=None, masks3=None, masks4=None):
793
+ S, H, W, C = rgbs.shape
794
+ S,N,D = xys.shape
795
+
796
+ _, H, W, C = rgbs.shape
797
+ assert(C==3)
798
+ crop_W = int(W//zoom)
799
+ crop_H = int(H//zoom)
800
+
801
+ if np.random.rand() < 0.25: # follow-crop
802
+ # start with xy traj
803
+ # smooth_xys = xys.copy()
804
+ smooth_xys = xys[:,np.random.randint(N)].reshape(S,1,2)
805
+ # make it inbounds
806
+ smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2])
807
+
808
+ # smooth it out, to remove info about the traj, and simulate camera motion
809
+ for _ in range(S*3):
810
+ for ii in range(S):
811
+ if ii==0:
812
+ smooth_xys[ii] = (smooth_xys[ii] + smooth_xys[ii+1])/2.0
813
+ elif ii==S-1:
814
+ smooth_xys[ii] = (smooth_xys[ii-1] + smooth_xys[ii])/2.0
815
+ else:
816
+ smooth_xys[ii] = (smooth_xys[ii-1] + smooth_xys[ii] + smooth_xys[ii+1])/3.0
817
+ else: # static (no-hint) crop
818
+ # zero-vel on random available coordinate
819
+
820
+ if valids is not None:
821
+ visval = visibs*valids # S,N
822
+ visval = np.sum(visval, axis=1) # S
823
+ else:
824
+ visval = np.sum(visibs, axis=1) # S
825
+
826
+ anchor_inds = np.nonzero(visval >= np.mean(visval))[0]
827
+ ind = anchor_inds[np.random.randint(len(anchor_inds))]
828
+ # print('ind', ind)
829
+ smooth_xys = xys[ind:ind+1].repeat(S,axis=0)
830
+ smooth_xys = smooth_xys.mean(axis=1, keepdims=True)
831
+ # xmid = np.random.randint(crop_W//2, W-crop_W//2)
832
+ # ymid = np.random.randint(crop_H//2, H-crop_H//2)
833
+ # smooth_xys = np.stack([xmid, ymid], axis=-1).reshape(1,1,2).repeat(S, axis=0) # S,1,2
834
+ smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2])
835
+
836
+ if np.random.rand() < 0.5:
837
+ # add a random alternate trajectory, to help push us off center
838
+ alt_xys = np.random.randint(-crop_H//8, crop_H//8, (S,1,2))
839
+ for _ in range(4): # smooth out
840
+ for ii in range(S):
841
+ if ii==0:
842
+ alt_xys[ii] = (alt_xys[ii] + alt_xys[ii+1])/2.0
843
+ elif ii==S-1:
844
+ alt_xys[ii] = (alt_xys[ii-1] + alt_xys[ii])/2.0
845
+ else:
846
+ alt_xys[ii] = (alt_xys[ii-1] + alt_xys[ii] + alt_xys[ii+1])/3.0
847
+ smooth_xys = smooth_xys + alt_xys
848
+
849
+ smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2])
850
+
851
+ rgbs_crop = []
852
+ if masks is not None:
853
+ masks_crop = []
854
+ if masks2 is not None:
855
+ masks2_crop = []
856
+ if masks3 is not None:
857
+ masks3_crop = []
858
+ if masks4 is not None:
859
+ masks4_crop = []
860
+
861
+ offsets = []
862
+ for si in range(S):
863
+ xy_mid = smooth_xys[si].squeeze(0).round().astype(np.int32) # 2
864
+
865
+ xmid, ymid = xy_mid[0], xy_mid[1]
866
+
867
+ x0, x1 = np.clip(xmid-crop_W//2, 0, W), np.clip(xmid+crop_W//2, 0, W)
868
+ y0, y1 = np.clip(ymid-crop_H//2, 0, H), np.clip(ymid+crop_H//2, 0, H)
869
+ offset = np.array([x0, y0]).reshape(1,2)
870
+
871
+ rgbs_crop.append(rgbs[si,y0:y1,x0:x1])
872
+ if masks is not None:
873
+ masks_crop.append(masks[si,y0:y1,x0:x1])
874
+ if masks2 is not None:
875
+ masks2_crop.append(masks2[si,y0:y1,x0:x1])
876
+ if masks3 is not None:
877
+ masks3_crop.append(masks3[si,y0:y1,x0:x1])
878
+ if masks4 is not None:
879
+ masks4_crop.append(masks4[si,y0:y1,x0:x1])
880
+ xys[si] -= offset
881
+
882
+ offsets.append(offset)
883
+
884
+ rgbs = np.stack(rgbs_crop, axis=0)
885
+ if masks is not None:
886
+ masks = np.stack(masks_crop, axis=0)
887
+ if masks2 is not None:
888
+ masks2 = np.stack(masks2_crop, axis=0)
889
+ if masks3 is not None:
890
+ masks3 = np.stack(masks3_crop, axis=0)
891
+ if masks4 is not None:
892
+ masks4 = np.stack(masks4_crop, axis=0)
893
+
894
+ # update visibility annotations
895
+ for si in range(S):
896
+ oob_inds = np.logical_or(
897
+ np.logical_or(xys[si,:,0] < 0, xys[si,:,0] > crop_W-1),
898
+ np.logical_or(xys[si,:,1] < 0, xys[si,:,1] > crop_H-1))
899
+ visibs[si,oob_inds] = 0
900
+
901
+ # if masks4 is not None:
902
+ # return xys, visibs, valids, rgbs, masks, masks2, masks3, masks4
903
+ # if masks3 is not None:
904
+ # return xys, visibs, valids, rgbs, masks, masks2, masks3
905
+ # if masks2 is not None:
906
+ # return xys, visibs, valids, rgbs, masks, masks2
907
+ # if masks is not None:
908
+ # return xys, visibs, valids, rgbs, masks
909
+ # else:
910
+ # return xys, visibs, valids, rgbs
911
+ if valids is not None:
912
+ return xys, visibs, rgbs, valids
913
+ else:
914
+ return xys, visibs, rgbs
915
+
916
+
917
+ def data_zoom_bbox(zoom, bboxes, visibs, rgbs):#, valids=None):
918
+ S, H, W, C = rgbs.shape
919
+
920
+ _, H, W, C = rgbs.shape
921
+ assert(C==3)
922
+ crop_W = int(W//zoom)
923
+ crop_H = int(H//zoom)
924
+
925
+ xys = bboxes[:,0:2]*0.5 + bboxes[:,2:4]*0.5
926
+
927
+ if np.random.rand() < 0.25: # follow-crop
928
+ # start with xy traj
929
+ smooth_xys = xys.copy()
930
+ # make it inbounds
931
+ smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2])
932
+ # smooth it out, to remove info about the traj, and simulate camera motion
933
+ for _ in range(S*3):
934
+ for ii in range(1,S-1):
935
+ smooth_xys[ii] = (smooth_xys[ii-1] + smooth_xys[ii] + smooth_xys[ii+1])/3.0
936
+ else: # static (no-hint) crop
937
+ # zero-vel on random available coordinate
938
+ anchor_inds = np.nonzero(visibs.reshape(-1)>0.5)[0]
939
+ ind = anchor_inds[np.random.randint(len(anchor_inds))]
940
+ smooth_xys = xys[ind:ind+1].repeat(S,axis=0)
941
+ # xmid = np.random.randint(crop_W//2, W-crop_W//2)
942
+ # ymid = np.random.randint(crop_H//2, H-crop_H//2)
943
+ # smooth_xys = np.stack([xmid, ymid], axis=-1).reshape(1,1,2).repeat(S, axis=0) # S,1,2
944
+ smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2])
945
+ # print('xys', xys)
946
+ # print('smooth_xys', smooth_xys)
947
+
948
+ if np.random.rand() < 0.5:
949
+ # add a random alternate trajectory, to help push us off center
950
+ alt_xys = np.random.randint(-crop_H//8, crop_H//8, (S,2))
951
+ for _ in range(3):
952
+ for ii in range(1,S-1):
953
+ alt_xys[ii] = (alt_xys[ii-1] + alt_xys[ii] + alt_xys[ii+1])/3.0
954
+ smooth_xys = smooth_xys + alt_xys
955
+
956
+ smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2])
957
+
958
+ rgbs_crop = []
959
+
960
+ offsets = []
961
+ for si in range(S):
962
+ xy_mid = smooth_xys[si].round().astype(np.int32)
963
+ xmid, ymid = xy_mid[0], xy_mid[1]
964
+
965
+ x0, x1 = np.clip(xmid-crop_W//2, 0, W), np.clip(xmid+crop_W//2, 0, W)
966
+ y0, y1 = np.clip(ymid-crop_H//2, 0, H), np.clip(ymid+crop_H//2, 0, H)
967
+ offset = np.array([x0, y0]).reshape(2)
968
+
969
+ rgbs_crop.append(rgbs[si,y0:y1,x0:x1])
970
+ xys[si] -= offset
971
+ bboxes[si,0:2] -= offset
972
+ bboxes[si,2:4] -= offset
973
+
974
+ offsets.append(offset)
975
+
976
+ rgbs = np.stack(rgbs_crop, axis=0)
977
+
978
+ # update visibility annotations
979
+ for si in range(S):
980
+ # avoid 1px edge
981
+ oob_inds = np.logical_or(
982
+ np.logical_or(xys[si,0] < 1, xys[si,0] > W-2),
983
+ np.logical_or(xys[si,1] < 1, xys[si,1] > H-2))
984
+ visibs[si,oob_inds] = 0
985
+
986
+ # clamp to image bounds
987
+ xys0 = np.minimum(np.maximum(bboxes[:,0:2], np.zeros((2,), dtype=int)), np.array([W, H]) - 1) # S,2
988
+ xys1 = np.minimum(np.maximum(bboxes[:,2:4], np.zeros((2,), dtype=int)), np.array([W, H]) - 1) # S,2
989
+ bboxes = np.concatenate([xys0, xys1], axis=1)
990
+ return bboxes, visibs, rgbs
991
+
992
+
993
+ def data_pad_if_necessary(rgbs, masks, masks2=None):
994
+ S,H,W,C = rgbs.shape
995
+
996
+ mask_areas = (masks > 0).reshape(S,-1).sum(axis=1)
997
+ mask_areas_norm = mask_areas / np.max(mask_areas)
998
+ visibs = mask_areas_norm
999
+
1000
+ bboxes = np.stack([mask2bbox(mask) for mask in masks])
1001
+ whs = bboxes[:,2:4] - bboxes[:,0:2]
1002
+ whs = whs[visibs > 0.5]
1003
+ # print('mean wh', np.mean(whs[:,0]), np.mean(whs[:,1]))
1004
+ if np.mean(whs[:,0]) >= W/2:
1005
+ # print('padding w')
1006
+ pad = ((0,0),(0,0),(W//4,W//4),(0,0))
1007
+ rgbs = np.pad(rgbs, pad, mode="constant")
1008
+ masks = np.pad(masks, pad[:3], mode="constant")
1009
+ if masks2 is not None:
1010
+ masks2 = np.pad(masks2, pad[:3], mode="constant")
1011
+ # print('rgbs', rgbs.shape)
1012
+ # print('masks', masks.shape)
1013
+ if np.mean(whs[:,1]) >= H/2:
1014
+ # print('padding h')
1015
+ pad = ((0,0),(H//4,H//4),(0,0),(0,0))
1016
+ rgbs = np.pad(rgbs, pad, mode="constant")
1017
+ masks = np.pad(masks, pad[:3], mode="constant")
1018
+ if masks2 is not None:
1019
+ masks2 = np.pad(masks2, pad[:3], mode="constant", constant_values=0.5)
1020
+
1021
+ if masks2 is not None:
1022
+ return rgbs, masks, masks2
1023
+ return rgbs, masks
1024
+
1025
+ def data_pad_if_necessary_b(rgbs, bboxes, visibs):
1026
+ S,H,W,C = rgbs.shape
1027
+ whs = bboxes[:,2:4] - bboxes[:,0:2]
1028
+ whs = whs[visibs > 0.5]
1029
+ if np.mean(whs[:,0]) >= W/2:
1030
+ pad = ((0,0),(0,0),(W//4,W//4),(0,0))
1031
+ rgbs = np.pad(rgbs, pad, mode="constant")
1032
+ bboxes[:,0] += W//4
1033
+ bboxes[:,2] += W//4
1034
+ if np.mean(whs[:,1]) >= H/2:
1035
+ pad = ((0,0),(H//4,H//4),(0,0),(0,0))
1036
+ rgbs = np.pad(rgbs, pad, mode="constant")
1037
+ bboxes[:,1] += H//4
1038
+ bboxes[:,3] += H//4
1039
+ return rgbs, bboxes
1040
+
1041
+ def posenc(x, min_deg, max_deg):
1042
+ """Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
1043
+ Instead of computing [sin(x), cos(x)], we use the trig identity
1044
+ cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
1045
+ Args:
1046
+ x: torch.Tensor, variables to be encoded. Note that x should be in [-pi, pi].
1047
+ min_deg: int, the minimum (inclusive) degree of the encoding.
1048
+ max_deg: int, the maximum (exclusive) degree of the encoding.
1049
+ legacy_posenc_order: bool, keep the same ordering as the original tf code.
1050
+ Returns:
1051
+ encoded: torch.Tensor, encoded variables.
1052
+ """
1053
+ if min_deg == max_deg:
1054
+ return x
1055
+ scales = torch.tensor(
1056
+ [2**i for i in range(min_deg, max_deg)], dtype=x.dtype, device=x.device
1057
+ )
1058
+
1059
+ xb = (x[..., None, :] * scales[:, None]).reshape(list(x.shape[:-1]) + [-1])
1060
+ four_feat = torch.sin(torch.cat([xb, xb + 0.5 * torch.pi], dim=-1))
1061
+ return torch.cat([x] + [four_feat], dim=-1)
1062
+
utils/py.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob, math
2
+ import numpy as np
3
+ # from scipy import misc
4
+ # from scipy import linalg
5
+ from PIL import Image
6
+ import io
7
+ import matplotlib.pyplot as plt
8
+ EPS = 1e-6
9
+
10
+
11
+ XMIN = -64.0 # right (neg is left)
12
+ XMAX = 64.0 # right
13
+ YMIN = -64.0 # down (neg is up)
14
+ YMAX = 64.0 # down
15
+ ZMIN = -64.0 # forward
16
+ ZMAX = 64.0 # forward
17
+
18
+ def print_stats(name, tensor):
19
+ tensor = tensor.astype(np.float32)
20
+ print('%s min = %.2f, mean = %.2f, max = %.2f' % (name, np.min(tensor), np.mean(tensor), np.max(tensor)), tensor.shape)
21
+
22
+ def reduce_masked_mean(x, mask, axis=None, keepdims=False):
23
+ # x and mask are the same shape
24
+ # returns shape-1
25
+ # axis can be a list of axes
26
+ prod = x*mask
27
+ numer = np.sum(prod, axis=axis, keepdims=keepdims)
28
+ denom = EPS+np.sum(mask, axis=axis, keepdims=keepdims)
29
+ mean = numer/denom
30
+ return mean
31
+
32
+ def reduce_masked_sum(x, mask, axis=None, keepdims=False):
33
+ # x and mask are the same shape
34
+ # returns shape-1
35
+ # axis can be a list of axes
36
+ prod = x*mask
37
+ numer = np.sum(prod, axis=axis, keepdims=keepdims)
38
+ return numer
39
+
40
+ def reduce_masked_median(x, mask, keep_batch=False):
41
+ # x and mask are the same shape
42
+ # returns shape-1
43
+ # axis can be a list of axes
44
+
45
+ if not (x.shape == mask.shape):
46
+ print('reduce_masked_median: these shapes should match:', x.shape, mask.shape)
47
+ assert(False)
48
+ # assert(x.shape == mask.shape)
49
+
50
+ B = list(x.shape)[0]
51
+
52
+ if keep_batch:
53
+ x = np.reshape(x, [B, -1])
54
+ mask = np.reshape(mask, [B, -1])
55
+ meds = np.zeros([B], np.float32)
56
+ for b in list(range(B)):
57
+ xb = x[b]
58
+ mb = mask[b]
59
+ if np.sum(mb) > 0:
60
+ xb = xb[mb > 0]
61
+ meds[b] = np.median(xb)
62
+ else:
63
+ meds[b] = np.nan
64
+ return meds
65
+ else:
66
+ x = np.reshape(x, [-1])
67
+ mask = np.reshape(mask, [-1])
68
+ if np.sum(mask) > 0:
69
+ x = x[mask > 0]
70
+ med = np.median(x)
71
+ else:
72
+ med = np.nan
73
+ med = np.array([med], np.float32)
74
+ return med
75
+
76
+ def get_nFiles(path):
77
+ return len(glob.glob(path))
78
+
79
+ def get_file_list(path):
80
+ return glob.glob(path)
81
+
82
+ def rotm2eul(R):
83
+ # R is 3x3
84
+ sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0])
85
+ if sy > 1e-6: # singular
86
+ x = math.atan2(R[2,1] , R[2,2])
87
+ y = math.atan2(-R[2,0], sy)
88
+ z = math.atan2(R[1,0], R[0,0])
89
+ else:
90
+ x = math.atan2(-R[1,2], R[1,1])
91
+ y = math.atan2(-R[2,0], sy)
92
+ z = 0
93
+ return x, y, z
94
+
95
+ def rad2deg(rad):
96
+ return rad*180.0/np.pi
97
+
98
+ def deg2rad(deg):
99
+ return deg/180.0*np.pi
100
+
101
+ def eul2rotm(rx, ry, rz):
102
+ # copy of matlab, but order of inputs is different
103
+ # R = [ cy*cz sy*sx*cz-sz*cx sy*cx*cz+sz*sx
104
+ # cy*sz sy*sx*sz+cz*cx sy*cx*sz-cz*sx
105
+ # -sy cy*sx cy*cx]
106
+ sinz = np.sin(rz)
107
+ siny = np.sin(ry)
108
+ sinx = np.sin(rx)
109
+ cosz = np.cos(rz)
110
+ cosy = np.cos(ry)
111
+ cosx = np.cos(rx)
112
+ r11 = cosy*cosz
113
+ r12 = sinx*siny*cosz - cosx*sinz
114
+ r13 = cosx*siny*cosz + sinx*sinz
115
+ r21 = cosy*sinz
116
+ r22 = sinx*siny*sinz + cosx*cosz
117
+ r23 = cosx*siny*sinz - sinx*cosz
118
+ r31 = -siny
119
+ r32 = sinx*cosy
120
+ r33 = cosx*cosy
121
+ r1 = np.stack([r11,r12,r13],axis=-1)
122
+ r2 = np.stack([r21,r22,r23],axis=-1)
123
+ r3 = np.stack([r31,r32,r33],axis=-1)
124
+ r = np.stack([r1,r2,r3],axis=0)
125
+ return r
126
+
127
+ def wrap2pi(rad_angle):
128
+ # puts the angle into the range [-pi, pi]
129
+ return np.arctan2(np.sin(rad_angle), np.cos(rad_angle))
130
+
131
+ def rot2view(rx,ry,rz,x,y,z):
132
+ # takes rot angles and 3d position as input
133
+ # returns viewpoint angles as output
134
+ # (all in radians)
135
+ # it will perform strangely if z <= 0
136
+ az = wrap2pi(ry - (-np.arctan2(z, x) - 1.5*np.pi))
137
+ el = -wrap2pi(rx - (-np.arctan2(z, y) - 1.5*np.pi))
138
+ th = -rz
139
+ return az, el, th
140
+
141
+ def invAxB(a,b):
142
+ """
143
+ Compute the relative 3d transformation between a and b.
144
+
145
+ Input:
146
+ a -- first pose (homogeneous 4x4 matrix)
147
+ b -- second pose (homogeneous 4x4 matrix)
148
+
149
+ Output:
150
+ Relative 3d transformation from a to b.
151
+ """
152
+ return np.dot(np.linalg.inv(a),b)
153
+
154
+ def merge_rt(r, t):
155
+ # r is 3 x 3
156
+ # t is 3 or maybe 3 x 1
157
+ t = np.reshape(t, [3, 1])
158
+ rt = np.concatenate((r,t), axis=1)
159
+ # rt is 3 x 4
160
+ br = np.reshape(np.array([0,0,0,1], np.float32), [1, 4])
161
+ # br is 1 x 4
162
+ rt = np.concatenate((rt, br), axis=0)
163
+ # rt is 4 x 4
164
+ return rt
165
+
166
+ def split_rt(rt):
167
+ r = rt[:3,:3]
168
+ t = rt[:3,3]
169
+ r = np.reshape(r, [3, 3])
170
+ t = np.reshape(t, [3, 1])
171
+ return r, t
172
+
173
+ def split_intrinsics(K):
174
+ # K is 3 x 4 or 4 x 4
175
+ fx = K[0,0]
176
+ fy = K[1,1]
177
+ x0 = K[0,2]
178
+ y0 = K[1,2]
179
+ return fx, fy, x0, y0
180
+
181
+ def merge_intrinsics(fx, fy, x0, y0):
182
+ # inputs are shaped []
183
+ K = np.eye(4)
184
+ K[0,0] = fx
185
+ K[1,1] = fy
186
+ K[0,2] = x0
187
+ K[1,2] = y0
188
+ # K is shaped 4 x 4
189
+ return K
190
+
191
+ def scale_intrinsics(K, sx, sy):
192
+ fx, fy, x0, y0 = split_intrinsics(K)
193
+ fx *= sx
194
+ fy *= sy
195
+ x0 *= sx
196
+ y0 *= sy
197
+ return merge_intrinsics(fx, fy, x0, y0)
198
+
199
+ # def meshgrid(H, W):
200
+ # x = np.linspace(0, W-1, W)
201
+ # y = np.linspace(0, H-1, H)
202
+ # xv, yv = np.meshgrid(x, y)
203
+ # return xv, yv
204
+
205
+ def compute_distance(transform):
206
+ """
207
+ Compute the distance of the translational component of a 4x4 homogeneous matrix.
208
+ """
209
+ return numpy.linalg.norm(transform[0:3,3])
210
+
211
+ def radian_l1_dist(e, g):
212
+ # if our angles are in [0, 360] we can follow this stack overflow answer:
213
+ # https://gamedev.stackexchange.com/questions/4467/comparing-angles-and-working-out-the-difference
214
+ # wrap2pi brings the angles to [-180, 180]; adding pi puts them in [0, 360]
215
+ e = wrap2pi(e)+np.pi
216
+ g = wrap2pi(g)+np.pi
217
+ l = np.abs(np.pi - np.abs(np.abs(e-g) - np.pi))
218
+ return l
219
+
220
+ def apply_pix_T_cam(pix_T_cam, xyz):
221
+ fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
222
+ # xyz is shaped B x H*W x 3
223
+ # returns xy, shaped B x H*W x 2
224
+ N, C = xyz.shape
225
+ x, y, z = np.split(xyz, 3, axis=-1)
226
+ EPS = 1e-4
227
+ z = np.clip(z, EPS, None)
228
+ x = (x*fx)/(z)+x0
229
+ y = (y*fy)/(z)+y0
230
+ xy = np.concatenate([x, y], axis=-1)
231
+ return xy
232
+
233
+ def apply_4x4(RT, XYZ):
234
+ # RT is 4 x 4
235
+ # XYZ is N x 3
236
+
237
+ # put into homogeneous coords
238
+ X, Y, Z = np.split(XYZ, 3, axis=1)
239
+ ones = np.ones_like(X)
240
+ XYZ1 = np.concatenate([X, Y, Z, ones], axis=1)
241
+ # XYZ1 is N x 4
242
+
243
+ XYZ1_t = np.transpose(XYZ1)
244
+ # this is 4 x N
245
+
246
+ XYZ2_t = np.dot(RT, XYZ1_t)
247
+ # this is 4 x N
248
+
249
+ XYZ2 = np.transpose(XYZ2_t)
250
+ # this is N x 4
251
+
252
+ XYZ2 = XYZ2[:,:3]
253
+ # this is N x 3
254
+
255
+ return XYZ2
256
+
257
+ def Ref2Mem(xyz, Z, Y, X):
258
+ # xyz is N x 3, in ref coordinates
259
+ # transforms ref coordinates into mem coordinates
260
+ N, C = xyz.shape
261
+ assert(C==3)
262
+ mem_T_ref = get_mem_T_ref(Z, Y, X)
263
+ xyz = apply_4x4(mem_T_ref, xyz)
264
+ return xyz
265
+
266
+ # def Mem2Ref(xyz_mem, MH, MW, MD):
267
+ # # xyz is B x N x 3, in mem coordinates
268
+ # # transforms mem coordinates into ref coordinates
269
+ # B, N, C = xyz_mem.get_shape().as_list()
270
+ # ref_T_mem = get_ref_T_mem(B, MH, MW, MD)
271
+ # xyz_ref = utils_geom.apply_4x4(ref_T_mem, xyz_mem)
272
+ # return xyz_ref
273
+
274
+ def get_mem_T_ref(Z, Y, X):
275
+ # sometimes we want the mat itself
276
+ # note this is not a rigid transform
277
+
278
+ # for interpretability, let's construct this in two steps...
279
+
280
+ # translation
281
+ center_T_ref = np.eye(4, dtype=np.float32)
282
+ center_T_ref[0,3] = -XMIN
283
+ center_T_ref[1,3] = -YMIN
284
+ center_T_ref[2,3] = -ZMIN
285
+
286
+ VOX_SIZE_X = (XMAX-XMIN)/float(X)
287
+ VOX_SIZE_Y = (YMAX-YMIN)/float(Y)
288
+ VOX_SIZE_Z = (ZMAX-ZMIN)/float(Z)
289
+
290
+ # scaling
291
+ mem_T_center = np.eye(4, dtype=np.float32)
292
+ mem_T_center[0,0] = 1./VOX_SIZE_X
293
+ mem_T_center[1,1] = 1./VOX_SIZE_Y
294
+ mem_T_center[2,2] = 1./VOX_SIZE_Z
295
+
296
+ mem_T_ref = np.dot(mem_T_center, center_T_ref)
297
+ return mem_T_ref
298
+
299
+ def safe_inverse(a):
300
+ r, t = split_rt(a)
301
+ t = np.reshape(t, [3, 1])
302
+ r_transpose = r.T
303
+ inv = np.concatenate([r_transpose, -np.matmul(r_transpose, t)], 1)
304
+ bottom_row = a[3:4, :] # this is [0, 0, 0, 1]
305
+ inv = np.concatenate([inv, bottom_row], 0)
306
+ return inv
307
+
308
+ def get_ref_T_mem(Z, Y, X):
309
+ mem_T_ref = get_mem_T_ref(X, Y, X)
310
+ # note safe_inverse is inapplicable here,
311
+ # since the transform is nonrigid
312
+ ref_T_mem = np.linalg.inv(mem_T_ref)
313
+ return ref_T_mem
314
+
315
+ def voxelize_xyz(xyz_ref, Z, Y, X):
316
+ # xyz_ref is N x 3
317
+ xyz_mem = Ref2Mem(xyz_ref, Z, Y, X)
318
+ # this is N x 3
319
+ voxels = get_occupancy(xyz_mem, Z, Y, X)
320
+ voxels = np.reshape(voxels, [Z, Y, X, 1])
321
+ return voxels
322
+
323
+ def get_inbounds(xyz, Z, Y, X, already_mem=False):
324
+ # xyz is H*W x 3
325
+
326
+ if not already_mem:
327
+ xyz = Ref2Mem(xyz, Z, Y, X)
328
+
329
+ x_valid = np.logical_and(
330
+ np.greater_equal(xyz[:,0], -0.5),
331
+ np.less(xyz[:,0], float(X)-0.5))
332
+ y_valid = np.logical_and(
333
+ np.greater_equal(xyz[:,1], -0.5),
334
+ np.less(xyz[:,1], float(Y)-0.5))
335
+ z_valid = np.logical_and(
336
+ np.greater_equal(xyz[:,2], -0.5),
337
+ np.less(xyz[:,2], float(Z)-0.5))
338
+ inbounds = np.logical_and(np.logical_and(x_valid, y_valid), z_valid)
339
+ return inbounds
340
+
341
+ def sub2ind3d_zyx(depth, height, width, d, h, w):
342
+ # same as sub2ind3d, but inputs in zyx order
343
+ # when gathering/scattering with these inds, the tensor should be Z x Y x X
344
+ return d*height*width + h*width + w
345
+
346
+ def sub2ind3d_yxz(height, width, depth, h, w, d):
347
+ return h*width*depth + w*depth + d
348
+
349
+ # def ind2sub(height, width, ind):
350
+ # # int input
351
+ # y = int(ind / height)
352
+ # x = ind % height
353
+ # return y, x
354
+
355
+ def get_occupancy(xyz_mem, Z, Y, X):
356
+ # xyz_mem is N x 3
357
+ # we want to fill a voxel tensor with 1's at these inds
358
+
359
+ inbounds = get_inbounds(xyz_mem, Z, Y, X, already_mem=True)
360
+ inds = np.where(inbounds)
361
+
362
+ xyz_mem = np.reshape(xyz_mem[inds], [-1, 3])
363
+ # xyz_mem is N x 3
364
+
365
+ # this is more accurate than a cast/floor, but runs into issues when Y==0
366
+ xyz_mem = np.round(xyz_mem).astype(np.int32)
367
+ x = xyz_mem[:,0]
368
+ y = xyz_mem[:,1]
369
+ z = xyz_mem[:,2]
370
+
371
+ voxels = np.zeros([Z, Y, X], np.float32)
372
+ voxels[z, y, x] = 1.0
373
+
374
+ return voxels
375
+
376
+ def pixels2camera(x,y,z,fx,fy,x0,y0):
377
+ # x and y are locations in pixel coordinates, z is a depth image in meters
378
+ # their shapes are H x W
379
+ # fx, fy, x0, y0 are scalar camera intrinsics
380
+ # returns xyz, sized [B,H*W,3]
381
+
382
+ H, W = z.shape
383
+
384
+ fx = np.reshape(fx, [1,1])
385
+ fy = np.reshape(fy, [1,1])
386
+ x0 = np.reshape(x0, [1,1])
387
+ y0 = np.reshape(y0, [1,1])
388
+
389
+ # unproject
390
+ x = ((z+EPS)/fx)*(x-x0)
391
+ y = ((z+EPS)/fy)*(y-y0)
392
+
393
+ x = np.reshape(x, [-1])
394
+ y = np.reshape(y, [-1])
395
+ z = np.reshape(z, [-1])
396
+ xyz = np.stack([x,y,z], axis=1)
397
+ return xyz
398
+
399
+ def depth2pointcloud(z, pix_T_cam):
400
+ H = z.shape[0]
401
+ W = z.shape[1]
402
+ y, x = meshgrid2d(H, W)
403
+ z = np.reshape(z, [H, W])
404
+
405
+ fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
406
+ xyz = pixels2camera(x, y, z, fx, fy, x0, y0)
407
+ return xyz
408
+
409
+ def meshgrid2d(Y, X):
410
+ grid_y = np.linspace(0.0, Y-1, Y)
411
+ grid_y = np.reshape(grid_y, [Y, 1])
412
+ grid_y = np.tile(grid_y, [1, X])
413
+
414
+ grid_x = np.linspace(0.0, X-1, X)
415
+ grid_x = np.reshape(grid_x, [1, X])
416
+ grid_x = np.tile(grid_x, [Y, 1])
417
+
418
+ # outputs are Y x X
419
+ return grid_y, grid_x
420
+
421
+ def gridcloud3d(Y, X, Z):
422
+ x_ = np.linspace(0, X-1, X)
423
+ y_ = np.linspace(0, Y-1, Y)
424
+ z_ = np.linspace(0, Z-1, Z)
425
+ y, x, z = np.meshgrid(y_, x_, z_, indexing='ij')
426
+ x = np.reshape(x, [-1])
427
+ y = np.reshape(y, [-1])
428
+ z = np.reshape(z, [-1])
429
+ xyz = np.stack([x,y,z], axis=1).astype(np.float32)
430
+ return xyz
431
+
432
+ def gridcloud2d(Y, X):
433
+ x_ = np.linspace(0, X-1, X)
434
+ y_ = np.linspace(0, Y-1, Y)
435
+ y, x = np.meshgrid(y_, x_, indexing='ij')
436
+ x = np.reshape(x, [-1])
437
+ y = np.reshape(y, [-1])
438
+ xy = np.stack([x,y], axis=1).astype(np.float32)
439
+ return xy
440
+
441
+ def normalize(im):
442
+ im = im - np.min(im)
443
+ im = im / np.max(im)
444
+ return im
445
+
446
+ def wrap2pi(rad_angle):
447
+ # rad_angle can be any shape
448
+ # puts the angle into the range [-pi, pi]
449
+ return np.arctan2(np.sin(rad_angle), np.cos(rad_angle))
450
+
451
+ def convert_occ_to_height(occ):
452
+ Z, Y, X, C = occ.shape
453
+ assert(C==1)
454
+
455
+ height = np.linspace(float(Y), 1.0, Y)
456
+ height = np.reshape(height, [1, Y, 1, 1])
457
+ height = np.max(occ*height, axis=1)/float(Y)
458
+ height = np.reshape(height, [Z, X, C])
459
+ return height
460
+
461
+ def create_depth_image(xy, Z, H, W):
462
+
463
+ # turn the xy coordinates into image inds
464
+ xy = np.round(xy)
465
+
466
+ # lidar reports a sphere of measurements
467
+ # only use the inds that are within the image bounds
468
+ # also, only use forward-pointing depths (Z > 0)
469
+ valid = (xy[:,0] < W-1) & (xy[:,1] < H-1) & (xy[:,0] >= 0) & (xy[:,1] >= 0) & (Z[:] > 0)
470
+
471
+ # gather these up
472
+ xy = xy[valid]
473
+ Z = Z[valid]
474
+
475
+ inds = sub2ind(H,W,xy[:,1],xy[:,0])
476
+ depth = np.zeros((H*W), np.float32)
477
+
478
+ for (index, replacement) in zip(inds, Z):
479
+ depth[index] = replacement
480
+ depth[np.where(depth == 0.0)] = 70.0
481
+ depth = np.reshape(depth, [H, W])
482
+
483
+ return depth
484
+
485
+ def vis_depth(depth, maxdepth=80.0, log_vis=True):
486
+ depth[depth<=0.0] = maxdepth
487
+ if log_vis:
488
+ depth = np.log(depth)
489
+ depth = np.clip(depth, 0, np.log(maxdepth))
490
+ else:
491
+ depth = np.clip(depth, 0, maxdepth)
492
+ depth = (depth*255.0).astype(np.uint8)
493
+ return depth
494
+
495
+ def preprocess_color(x):
496
+ return x.astype(np.float32) * 1./255 - 0.5
497
+
498
+ def convert_box_to_ref_T_obj(boxes):
499
+ shape = boxes.shape
500
+ boxes = boxes.reshape(-1,9)
501
+ rots = [eul2rotm(rx,ry,rz)
502
+ for rx,ry,rz in boxes[:,6:]]
503
+ rots = np.stack(rots,axis=0)
504
+ trans = boxes[:,:3]
505
+ ref_T_objs = [merge_rt(rot,tran)
506
+ for rot,tran in zip(rots,trans)]
507
+ ref_T_objs = np.stack(ref_T_objs,axis=0)
508
+ ref_T_objs = ref_T_objs.reshape(shape[:-1]+(4,4))
509
+ ref_T_objs = ref_T_objs.astype(np.float32)
510
+ return ref_T_objs
511
+
512
+ def get_rot_from_delta(delta, yaw_only=False):
513
+ dx = delta[:,0]
514
+ dy = delta[:,1]
515
+ dz = delta[:,2]
516
+
517
+ bot_hyp = np.sqrt(dz**2 + dx**2)
518
+ # top_hyp = np.sqrt(bot_hyp**2 + dy**2)
519
+
520
+ pitch = -np.arctan2(dy, bot_hyp)
521
+ yaw = np.arctan2(dz, dx)
522
+
523
+ if yaw_only:
524
+ rot = [eul2rotm(0,y,0) for y in yaw]
525
+ else:
526
+ rot = [eul2rotm(0,y,p) for (p,y) in zip(pitch,yaw)]
527
+
528
+ rot = np.stack(rot)
529
+ # rot is B x 3 x 3
530
+ return rot
531
+
532
+ def im2col(im, psize):
533
+ n_channels = 1 if len(im.shape) == 2 else im.shape[0]
534
+ (n_channels, rows, cols) = (1,) * (3 - len(im.shape)) + im.shape
535
+
536
+ im_pad = np.zeros((n_channels,
537
+ int(math.ceil(1.0 * rows / psize) * psize),
538
+ int(math.ceil(1.0 * cols / psize) * psize)))
539
+ im_pad[:, 0:rows, 0:cols] = im
540
+
541
+ final = np.zeros((im_pad.shape[1], im_pad.shape[2], n_channels,
542
+ psize, psize))
543
+ for c in np.arange(n_channels):
544
+ for x in np.arange(psize):
545
+ for y in np.arange(psize):
546
+ im_shift = np.vstack(
547
+ (im_pad[c, x:], im_pad[c, :x]))
548
+ im_shift = np.column_stack(
549
+ (im_shift[:, y:], im_shift[:, :y]))
550
+ final[x::psize, y::psize, c] = np.swapaxes(
551
+ im_shift.reshape(int(im_pad.shape[1] / psize), psize,
552
+ int(im_pad.shape[2] / psize), psize), 1, 2)
553
+
554
+ return np.squeeze(final[0:rows - psize + 1, 0:cols - psize + 1])
555
+
556
+ def filter_discontinuities(depth, filter_size=9, thresh=10):
557
+ H, W = list(depth.shape)
558
+
559
+ # Ensure that filter sizes are okay
560
+ assert filter_size % 2 == 1, "Can only use odd filter sizes."
561
+
562
+ # Compute discontinuities
563
+ offset = int((filter_size - 1) / 2)
564
+ patches = 1.0 * im2col(depth, filter_size)
565
+ mids = patches[:, :, offset, offset]
566
+ mins = np.min(patches, axis=(2, 3))
567
+ maxes = np.max(patches, axis=(2, 3))
568
+
569
+ discont = np.maximum(np.abs(mins - mids),
570
+ np.abs(maxes - mids))
571
+ mark = discont > thresh
572
+
573
+ # Account for offsets
574
+ final_mark = np.zeros((H, W), dtype=np.uint16)
575
+ final_mark[offset:offset + mark.shape[0],
576
+ offset:offset + mark.shape[1]] = mark
577
+
578
+ return depth * (1 - final_mark)
579
+
580
+ def argmax2d(tensor):
581
+ Y, X = list(tensor.shape)
582
+ # flatten the Tensor along the height and width axes
583
+ flat_tensor = tensor.reshape(-1)
584
+ # argmax of the flat tensor
585
+ argmax = np.argmax(flat_tensor)
586
+
587
+ # convert the indices into 2d coordinates
588
+ argmax_y = argmax // X # row
589
+ argmax_x = argmax % X # col
590
+
591
+ return argmax_y, argmax_x
592
+
593
+ def plot_traj_3d(traj):
594
+ # traj is S x 3
595
+
596
+ # print('traj', traj.shape)
597
+ S, C = list(traj.shape)
598
+ assert(C==3)
599
+
600
+ fig = plt.figure()
601
+ ax = fig.add_subplot(111, projection='3d')
602
+
603
+ colors = [plt.cm.RdYlBu(i) for i in np.linspace(0,1,S)]
604
+ # print('colors', colors)
605
+
606
+ xs = traj[:,0]
607
+ ys = -traj[:,1]
608
+ zs = traj[:,2]
609
+
610
+ ax.scatter(xs, zs, ys, s=30, c=colors, marker='o', alpha=1.0, edgecolors=(0,0,0))#, color=color_map[n])
611
+
612
+ ax.set_xlabel('X')
613
+ ax.set_ylabel('Z')
614
+ ax.set_zlabel('Y')
615
+
616
+ ax.set_xlim(0,1)
617
+ ax.set_ylim(0,1) # this is really Z
618
+ ax.set_zlim(-1,0) # this is really Y
619
+
620
+ buf = io.BytesIO()
621
+ plt.savefig(buf, format='png')
622
+ buf.seek(0)
623
+ image = np.array(Image.open(buf)) # H x W x 4
624
+ image = image[:,:,:3]
625
+
626
+ plt.close()
627
+ return image
628
+
629
+ def camera2pixels(xyz, pix_T_cam):
630
+ # xyz is shaped N x 3
631
+ # returns xy, shaped N x 2
632
+
633
+ fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
634
+ x, y, z = xyz[:,0], xyz[:,1], xyz[:,2]
635
+
636
+ EPS = 1e-4
637
+ z = np.clip(z, EPS, None)
638
+ x = (x*fx)/z + x0
639
+ y = (y*fy)/z + y0
640
+ xy = np.stack([x, y], axis=-1)
641
+ return xy
642
+
643
+ def make_colorwheel():
644
+ """
645
+ Generates a color wheel for optical flow visualization as presented in:
646
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
647
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
648
+
649
+ Code follows the original C++ source code of Daniel Scharstein.
650
+ Code follows the the Matlab source code of Deqing Sun.
651
+
652
+ Returns:
653
+ np.ndarray: Color wheel
654
+ """
655
+
656
+ RY = 15
657
+ YG = 6
658
+ GC = 4
659
+ CB = 11
660
+ BM = 13
661
+ MR = 6
662
+
663
+ ncols = RY + YG + GC + CB + BM + MR
664
+ colorwheel = np.zeros((ncols, 3))
665
+ col = 0
666
+
667
+ # RY
668
+ colorwheel[0:RY, 0] = 255
669
+ colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
670
+ col = col+RY
671
+ # YG
672
+ colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
673
+ colorwheel[col:col+YG, 1] = 255
674
+ col = col+YG
675
+ # GC
676
+ colorwheel[col:col+GC, 1] = 255
677
+ colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
678
+ col = col+GC
679
+ # CB
680
+ colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
681
+ colorwheel[col:col+CB, 2] = 255
682
+ col = col+CB
683
+ # BM
684
+ colorwheel[col:col+BM, 2] = 255
685
+ colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
686
+ col = col+BM
687
+ # MR
688
+ colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
689
+ colorwheel[col:col+MR, 0] = 255
690
+ return colorwheel
691
+
692
+ def flow_uv_to_colors(u, v, convert_to_bgr=False):
693
+ """
694
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
695
+
696
+ According to the C++ source code of Daniel Scharstein
697
+ According to the Matlab source code of Deqing Sun
698
+
699
+ Args:
700
+ u (np.ndarray): Input horizontal flow of shape [H,W]
701
+ v (np.ndarray): Input vertical flow of shape [H,W]
702
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
703
+
704
+ Returns:
705
+ np.ndarray: Flow visualization image of shape [H,W,3]
706
+ """
707
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
708
+ colorwheel = make_colorwheel() # shape [55x3]
709
+ ncols = colorwheel.shape[0]
710
+ rad = np.sqrt(np.square(u) + np.square(v))
711
+ a = np.arctan2(-v, -u)/np.pi
712
+ fk = (a+1) / 2*(ncols-1)
713
+ k0 = np.floor(fk).astype(np.int32)
714
+ k1 = k0 + 1
715
+ k1[k1 == ncols] = 0
716
+ f = fk - k0
717
+ for i in range(colorwheel.shape[1]):
718
+ tmp = colorwheel[:,i]
719
+ col0 = tmp[k0] / 255.0
720
+ col1 = tmp[k1] / 255.0
721
+ col = (1-f)*col0 + f*col1
722
+ idx = (rad <= 1)
723
+ col[idx] = 1 - rad[idx] * (1-col[idx])
724
+ col[~idx] = col[~idx] * 0.75 # out of range
725
+ # Note the 2-i => BGR instead of RGB
726
+ ch_idx = 2-i if convert_to_bgr else i
727
+ flow_image[:,:,ch_idx] = np.floor(255 * col)
728
+ return flow_image
729
+
730
+ def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
731
+ """
732
+ Expects a two dimensional flow image of shape.
733
+
734
+ Args:
735
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
736
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
737
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
738
+
739
+ Returns:
740
+ np.ndarray: Flow visualization image of shape [H,W,3]
741
+ """
742
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
743
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
744
+ if clip_flow is not None:
745
+ flow_uv = np.clip(flow_uv, -clip_flow, clip_flow) / clip_flow
746
+ # flow_uv = np.clamp(flow, -clip, clip)/clip
747
+
748
+ u = flow_uv[:,:,0]
749
+ v = flow_uv[:,:,1]
750
+ rad = np.sqrt(np.square(u) + np.square(v))
751
+ rad_max = np.max(rad)
752
+ epsilon = 1e-5
753
+ u = u / (rad_max + epsilon)
754
+ v = v / (rad_max + epsilon)
755
+ return flow_uv_to_colors(u, v, convert_to_bgr)
utils/samp.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import utils.basic
3
+ import torch.nn.functional as F
4
+
5
+ def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
6
+ r"""Sample a tensor using bilinear interpolation
7
+
8
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
9
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
10
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
11
+ convention.
12
+
13
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
14
+ :math:`B` is the batch size, :math:`C` is the number of channels,
15
+ :math:`H` is the height of the image, and :math:`W` is the width of the
16
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
17
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
18
+
19
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
20
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
21
+ that in this case the order of the components is slightly different
22
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
23
+
24
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
25
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
26
+ left-most image pixel :math:`W-1` to the center of the right-most
27
+ pixel.
28
+
29
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
30
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
31
+ the left-most pixel :math:`W` to the right edge of the right-most
32
+ pixel.
33
+
34
+ Similar conventions apply to the :math:`y` for the range
35
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
36
+ :math:`[0,T-1]` and :math:`[0,T]`.
37
+
38
+ Args:
39
+ input (Tensor): batch of input images.
40
+ coords (Tensor): batch of coordinates.
41
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
42
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
43
+
44
+ Returns:
45
+ Tensor: sampled points.
46
+ """
47
+
48
+ sizes = input.shape[2:]
49
+
50
+ assert len(sizes) in [2, 3]
51
+
52
+ if len(sizes) == 3:
53
+ # t x y -> x y t to match dimensions T H W in grid_sample
54
+ coords = coords[..., [1, 2, 0]]
55
+
56
+ if align_corners:
57
+ coords = coords * torch.tensor(
58
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
59
+ )
60
+ else:
61
+ coords = coords * torch.tensor(
62
+ [2 / size for size in reversed(sizes)], device=coords.device
63
+ )
64
+
65
+ coords -= 1
66
+
67
+ return F.grid_sample(
68
+ input, coords, align_corners=align_corners, padding_mode=padding_mode
69
+ )
70
+
71
+
72
+ def sample_features4d(input, coords):
73
+ r"""Sample spatial features
74
+
75
+ `sample_features4d(input, coords)` samples the spatial features
76
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
77
+
78
+ The field is sampled at coordinates :attr:`coords` using bilinear
79
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
80
+ 3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
81
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
82
+
83
+ The output tensor has one feature per point, and has shape :math:`(B,
84
+ R, C)`.
85
+
86
+ Args:
87
+ input (Tensor): spatial features.
88
+ coords (Tensor): points.
89
+
90
+ Returns:
91
+ Tensor: sampled features.
92
+ """
93
+
94
+ B, _, _, _ = input.shape
95
+
96
+ # B R 2 -> B R 1 2
97
+ coords = coords.unsqueeze(2)
98
+
99
+ # B C R 1
100
+ feats = bilinear_sampler(input, coords)
101
+
102
+ return feats.permute(0, 2, 1, 3).view(
103
+ B, -1, feats.shape[1] * feats.shape[3]
104
+ ) # B C R 1 -> B R C
105
+
106
+
107
+ def sample_features5d(input, coords):
108
+ r"""Sample spatio-temporal features
109
+
110
+ `sample_features5d(input, coords)` works in the same way as
111
+ :func:`sample_features4d` but for spatio-temporal features and points:
112
+ :attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
113
+ a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
114
+ x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
115
+
116
+ Args:
117
+ input (Tensor): spatio-temporal features.
118
+ coords (Tensor): spatio-temporal points.
119
+
120
+ Returns:
121
+ Tensor: sampled features.
122
+ """
123
+
124
+ B, T, _, _, _ = input.shape
125
+
126
+ # B T C H W -> B C T H W
127
+ input = input.permute(0, 2, 1, 3, 4)
128
+
129
+ # B R1 R2 3 -> B R1 R2 1 3
130
+ coords = coords.unsqueeze(3)
131
+
132
+ # B C R1 R2 1
133
+ feats = bilinear_sampler(input, coords)
134
+
135
+ return feats.permute(0, 2, 3, 1, 4).view(
136
+ B, feats.shape[2], feats.shape[3], feats.shape[1]
137
+ ) # B C R1 R2 1 -> B R1 R2 C
138
+
139
+
140
+ def bilinear_sample2d(im, x, y, return_inbounds=False):
141
+ # x and y are each B, N
142
+ # output is B, C, N
143
+ B, C, H, W = list(im.shape)
144
+ N = list(x.shape)[1]
145
+
146
+ x = x.float()
147
+ y = y.float()
148
+ H_f = torch.tensor(H, dtype=torch.float32)
149
+ W_f = torch.tensor(W, dtype=torch.float32)
150
+
151
+ # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
152
+
153
+ max_y = (H_f - 1).int()
154
+ max_x = (W_f - 1).int()
155
+
156
+ x0 = torch.floor(x).int()
157
+ x1 = x0 + 1
158
+ y0 = torch.floor(y).int()
159
+ y1 = y0 + 1
160
+
161
+ x0_clip = torch.clamp(x0, 0, max_x)
162
+ x1_clip = torch.clamp(x1, 0, max_x)
163
+ y0_clip = torch.clamp(y0, 0, max_y)
164
+ y1_clip = torch.clamp(y1, 0, max_y)
165
+ dim2 = W
166
+ dim1 = W * H
167
+
168
+ base = torch.arange(0, B, dtype=torch.int64, device=x.device)*dim1
169
+ base = torch.reshape(base, [B, 1]).repeat([1, N])
170
+
171
+ base_y0 = base + y0_clip * dim2
172
+ base_y1 = base + y1_clip * dim2
173
+
174
+ idx_y0_x0 = base_y0 + x0_clip
175
+ idx_y0_x1 = base_y0 + x1_clip
176
+ idx_y1_x0 = base_y1 + x0_clip
177
+ idx_y1_x1 = base_y1 + x1_clip
178
+
179
+ # use the indices to lookup pixels in the flat image
180
+ # im is B x C x H x W
181
+ # move C out to last dim
182
+ im_flat = (im.permute(0, 2, 3, 1)).reshape(B*H*W, C)
183
+ i_y0_x0 = im_flat[idx_y0_x0.long()]
184
+ i_y0_x1 = im_flat[idx_y0_x1.long()]
185
+ i_y1_x0 = im_flat[idx_y1_x0.long()]
186
+ i_y1_x1 = im_flat[idx_y1_x1.long()]
187
+
188
+ # Finally calculate interpolated values.
189
+ x0_f = x0.float()
190
+ x1_f = x1.float()
191
+ y0_f = y0.float()
192
+ y1_f = y1.float()
193
+
194
+ w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
195
+ w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
196
+ w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
197
+ w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
198
+
199
+ output = w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + \
200
+ w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
201
+ # output is B*N x C
202
+ output = output.view(B, -1, C)
203
+ output = output.permute(0, 2, 1)
204
+ # output is B x C x N
205
+
206
+ if return_inbounds:
207
+ x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
208
+ y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
209
+ inbounds = (x_valid & y_valid).float()
210
+ inbounds = inbounds.reshape(B, N) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
211
+ return output, inbounds
212
+
213
+ return output # B, C, N
utils/saveload.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import os
3
+ import torch
4
+
5
+ def save(ckpt_dir, module, optimizer, scheduler, global_step, keep_latest=2, model_name='model'):
6
+ pathlib.Path(ckpt_dir).mkdir(exist_ok=True, parents=True)
7
+ prev_ckpts = list(pathlib.Path(ckpt_dir).glob('%s-*pth' % model_name))
8
+ prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
9
+ if len(prev_ckpts) > keep_latest-1:
10
+ for f in prev_ckpts[keep_latest-1:]:
11
+ f.unlink()
12
+ save_path = '%s/%s-%09d.pth' % (ckpt_dir, model_name, global_step)
13
+ save_dict = {
14
+ "model": module.state_dict(),
15
+ "optimizer": optimizer.state_dict(),
16
+ "global_step": global_step,
17
+ }
18
+ if scheduler is not None:
19
+ save_dict['scheduler'] = scheduler.state_dict()
20
+ print(f"saving {save_path}")
21
+ torch.save(save_dict, save_path)
22
+ return False
23
+
24
+ def load(fabric, ckpt_dir, model, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model', ignore_load=None, strict=True, verbose=True, weights_only=False):
25
+ if verbose:
26
+ print('reading ckpt from %s' % ckpt_dir)
27
+ if not os.path.exists(ckpt_dir):
28
+ print('...there is no full checkpoint in %s' % ckpt_dir)
29
+ print('-- note this function no longer appends "saved_checkpoints/" before the ckpt_dir --')
30
+ assert(False)
31
+ else:
32
+ prev_ckpts = list(pathlib.Path(ckpt_dir).glob('%s-*pth' % model_name))
33
+ prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
34
+ if len(prev_ckpts):
35
+ path = prev_ckpts[0]
36
+ # e.g., './checkpoints/2Ai4_5e-4_base18_1539/model-000050000.pth'
37
+ step = int(str(path).split('-')[-1].split('.')[0])
38
+ if verbose:
39
+ print('...found checkpoint %s; (parsed step %d from path)' % (path, step))
40
+ if fabric is not None:
41
+ checkpoint = fabric.load(path)
42
+ else:
43
+ checkpoint = torch.load(path, weights_only=weights_only)
44
+ if optimizer is not None:
45
+ optimizer.load_state_dict(checkpoint['optimizer'])
46
+ if scheduler is not None:
47
+ scheduler.load_state_dict(checkpoint['scheduler'])
48
+ assert ignore_load is None # not ready yet
49
+ if 'model' in checkpoint:
50
+ state_dict = checkpoint['model']
51
+ else:
52
+ state_dict = checkpoint
53
+ model.load_state_dict(state_dict, strict=strict)
54
+ else:
55
+ print('...there is no full checkpoint here!')
56
+ return step
57
+
58
+
59
+
utils/test.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ def prep_frame_for_dino(img, scale_size=[192]):
7
+ """
8
+ read a single frame & preprocess
9
+ """
10
+ ori_h, ori_w, _ = img.shape
11
+ if len(scale_size) == 1:
12
+ if(ori_h > ori_w):
13
+ tw = scale_size[0]
14
+ th = (tw * ori_h) / ori_w
15
+ th = int((th // 64) * 64)
16
+ else:
17
+ th = scale_size[0]
18
+ tw = (th * ori_w) / ori_h
19
+ tw = int((tw // 64) * 64)
20
+ else:
21
+ th, tw = scale_size
22
+ img = cv2.resize(img, (tw, th))
23
+ img = img.astype(np.float32)
24
+ img = img / 255.0
25
+ img = img[:, :, ::-1]
26
+ img = np.transpose(img.copy(), (2, 0, 1))
27
+ img = torch.from_numpy(img).float()
28
+
29
+ def color_normalize(x, mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]):
30
+ for t, m, s in zip(x, mean, std):
31
+ t.sub_(m)
32
+ t.div_(s)
33
+ return x
34
+
35
+ img = color_normalize(img)
36
+ return img, ori_h, ori_w
37
+
38
+ def get_feats_from_dino(model, frame):
39
+ # batch version of the other func
40
+ B = frame.shape[0]
41
+ patch_size = model.patch_embed.patch_size
42
+ h, w = int(frame.shape[2] / patch_size), int(frame.shape[3] / patch_size)
43
+ out = model.get_intermediate_layers(frame.cuda(), n=1)[0] # B, 1+h*w, dim
44
+ dim = out.shape[-1]
45
+ out = out[:, 1:, :] # discard the [CLS] token
46
+ outmap = out.permute(0, 2, 1).reshape(B, dim, h, w)
47
+ return out, outmap, h, w
48
+
49
+ def restrict_neighborhood(h, w):
50
+ size_mask_neighborhood = 12
51
+ # We restrict the set of source nodes considered to a spatial neighborhood of the query node (i.e. ``local attention'')
52
+ mask = torch.zeros(h, w, h, w)
53
+ for i in range(h):
54
+ for j in range(w):
55
+ for p in range(2 * size_mask_neighborhood + 1):
56
+ for q in range(2 * size_mask_neighborhood + 1):
57
+ if i - size_mask_neighborhood + p < 0 or i - size_mask_neighborhood + p >= h:
58
+ continue
59
+ if j - size_mask_neighborhood + q < 0 or j - size_mask_neighborhood + q >= w:
60
+ continue
61
+ mask[i, j, i - size_mask_neighborhood + p, j - size_mask_neighborhood + q] = 1
62
+
63
+ mask = mask.reshape(h * w, h * w)
64
+ return mask.cuda(non_blocking=True)
65
+
66
+ def label_propagation(h, w, feat_tar, list_frame_feats, list_segs, mask_neighborhood=None):
67
+ ncontext = len(list_frame_feats)
68
+ feat_sources = torch.stack(list_frame_feats) # nmb_context x dim x h*w
69
+
70
+ feat_tar = F.normalize(feat_tar, dim=1, p=2)
71
+ feat_sources = F.normalize(feat_sources, dim=1, p=2)
72
+
73
+ # print('feat_tar', feat_tar.shape)
74
+ # print('feat_sources', feat_sources.shape)
75
+
76
+ feat_tar = feat_tar.unsqueeze(0).repeat(ncontext, 1, 1)
77
+ aff = torch.exp(torch.bmm(feat_tar, feat_sources) / 0.1)
78
+
79
+ size_mask_neighborhood = 12
80
+ if size_mask_neighborhood > 0:
81
+ if mask_neighborhood is None:
82
+ mask_neighborhood = restrict_neighborhood(h, w)
83
+ mask_neighborhood = mask_neighborhood.unsqueeze(0).repeat(ncontext, 1, 1)
84
+ aff *= mask_neighborhood
85
+
86
+ aff = aff.transpose(2, 1).reshape(-1, h*w) # nmb_context*h*w (source: keys) x h*w (tar: queries)
87
+ topk = 5
88
+ tk_val, _ = torch.topk(aff, dim=0, k=topk)
89
+ tk_val_min, _ = torch.min(tk_val, dim=0)
90
+ aff[aff < tk_val_min] = 0
91
+
92
+ aff = aff / torch.sum(aff, keepdim=True, axis=0)
93
+
94
+ list_segs = [s.cuda() for s in list_segs]
95
+ segs = torch.cat(list_segs)
96
+ nmb_context, C, h, w = segs.shape
97
+ segs = segs.reshape(nmb_context, C, -1).transpose(2, 1).reshape(-1, C).T # C x nmb_context*h*w
98
+ seg_tar = torch.mm(segs, aff)
99
+ seg_tar = seg_tar.reshape(1, C, h, w)
100
+
101
+ return seg_tar, mask_neighborhood
102
+
103
+ def norm_mask(mask):
104
+ c, h, w = mask.size()
105
+ for cnt in range(c):
106
+ mask_cnt = mask[cnt,:,:]
107
+ if(mask_cnt.max() > 0):
108
+ mask_cnt = (mask_cnt - mask_cnt.min())
109
+ mask_cnt = mask_cnt/mask_cnt.max()
110
+ mask[cnt,:,:] = mask_cnt
111
+ return mask
112
+
113
+
114
+ def get_dino_output(dino, rgbs, trajs_g, vis_g):
115
+ B, S, C, H, W = rgbs.shape
116
+
117
+ B1, S1, N, D = trajs_g.shape
118
+ assert(B1==B)
119
+ assert(S1==S)
120
+ assert(D==2)
121
+
122
+ assert(B==1)
123
+ xy0 = trajs_g[:,0] # B, N, 2
124
+
125
+ # The queue stores the n preceeding frames
126
+ import queue
127
+ import copy
128
+ n_last_frames = 7
129
+ que = queue.Queue(n_last_frames)
130
+
131
+ # run dino
132
+ prep_rgbs = []
133
+ for s in range(S):
134
+ prep_rgb, ori_h, ori_w = prep_frame_for_dino(rgbs[0, s].permute(1,2,0).detach().cpu().numpy(), scale_size=[H])
135
+ prep_rgbs.append(prep_rgb)
136
+ prep_rgbs = torch.stack(prep_rgbs, dim=0) # S, 3, H, W
137
+ with torch.no_grad():
138
+ bs = 8
139
+ idx = 0
140
+ featmaps = []
141
+ while idx < S:
142
+ end_id = min(S, idx+bs)
143
+ _, featmaps_cur, h, w = get_feats_from_dino(dino, prep_rgbs[idx:end_id]) # S, C, h, w
144
+ idx = end_id
145
+ featmaps.append(featmaps_cur)
146
+ featmaps = torch.cat(featmaps, dim=0)
147
+ C = featmaps.shape[1]
148
+ featmaps = featmaps.unsqueeze(0) # 1, S, C, h, w
149
+ # featmaps = F.normalize(featmaps, dim=2, p=2)
150
+
151
+ xy0 = trajs_g[:, 0, :] # B, N, 2
152
+ patch_size = dino.patch_embed.patch_size
153
+ first_seg = torch.zeros((1, N, H//patch_size, W//patch_size))
154
+ for n in range(N):
155
+ first_seg[0, n, (xy0[0, n, 1]/patch_size).long(), (xy0[0, n, 0]/patch_size).long()] = 1
156
+
157
+ frame1_feat = featmaps[0, 0].reshape(C, h*w) # dim x h*w
158
+ mask_neighborhood = None
159
+ accs = []
160
+ trajs_e = torch.zeros_like(trajs_g)
161
+ trajs_e[0,0] = trajs_g[0,0]
162
+ for cnt in range(1, S):
163
+ used_frame_feats = [frame1_feat] + [pair[0] for pair in list(que.queue)]
164
+ used_segs = [first_seg] + [pair[1] for pair in list(que.queue)]
165
+
166
+ feat_tar = featmaps[0, cnt].reshape(C, h*w)
167
+
168
+ frame_tar_avg, mask_neighborhood = label_propagation(h, w, feat_tar.T, used_frame_feats, used_segs, mask_neighborhood)
169
+
170
+ # pop out oldest frame if neccessary
171
+ if que.qsize() == n_last_frames:
172
+ que.get()
173
+ # push current results into queue
174
+ seg = copy.deepcopy(frame_tar_avg)
175
+ que.put([feat_tar, seg])
176
+
177
+ # upsampling & argmax
178
+ frame_tar_avg = F.interpolate(frame_tar_avg, scale_factor=patch_size, mode='bilinear', align_corners=False, recompute_scale_factor=False)[0]
179
+ frame_tar_avg = norm_mask(frame_tar_avg)
180
+ _, frame_tar_seg = torch.max(frame_tar_avg, dim=0)
181
+
182
+ for n in range(N):
183
+ vis = vis_g[0,cnt,n]
184
+ if len(torch.nonzero(frame_tar_avg[n])) > 0:
185
+ # weighted average
186
+ nz = torch.nonzero(frame_tar_avg[n])
187
+ coord_e = torch.sum(frame_tar_avg[n][nz[:,0], nz[:,1]].reshape(-1,1) * nz.float(), 0) / frame_tar_avg[n][nz[:,0], nz[:,1]].sum() # 2
188
+ coord_e = coord_e[[1,0]]
189
+ else:
190
+ # stay where it was
191
+ coord_e = trajs_e[0,cnt-1,n]
192
+
193
+ trajs_e[0, cnt, n] = coord_e
194
+ return trajs_e
utils/visualizer.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import numpy as np
9
+ import imageio
10
+ import torch
11
+
12
+ from matplotlib import cm
13
+ import torch.nn.functional as F
14
+ import torchvision.transforms as transforms
15
+ import matplotlib.pyplot as plt
16
+ from PIL import Image, ImageDraw
17
+
18
+
19
+ def read_video_from_path(path):
20
+ try:
21
+ reader = imageio.get_reader(path)
22
+ except Exception as e:
23
+ print("Error opening video file: ", e)
24
+ return None
25
+ frames = []
26
+ for i, im in enumerate(reader):
27
+ frames.append(np.array(im))
28
+ return np.stack(frames)
29
+
30
+
31
+ def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True, color_alpha=None):
32
+ # Create a draw object
33
+ draw = ImageDraw.Draw(rgb)
34
+ # Calculate the bounding box of the circle
35
+ left_up_point = (coord[0] - radius, coord[1] - radius)
36
+ right_down_point = (coord[0] + radius, coord[1] + radius)
37
+ # Draw the circle
38
+ color = tuple(list(color) + [color_alpha if color_alpha is not None else 255])
39
+
40
+ draw.ellipse(
41
+ [left_up_point, right_down_point],
42
+ fill=tuple(color) if visible else None,
43
+ outline=tuple(color),
44
+ )
45
+ return rgb
46
+
47
+
48
+ def draw_line(rgb, coord_y, coord_x, color, linewidth):
49
+ draw = ImageDraw.Draw(rgb)
50
+ draw.line(
51
+ (coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
52
+ fill=tuple(color),
53
+ width=linewidth,
54
+ )
55
+ return rgb
56
+
57
+
58
+ def add_weighted(rgb, alpha, original, beta, gamma):
59
+ return (rgb * alpha + original * beta + gamma).astype("uint8")
60
+
61
+
62
+ class Visualizer:
63
+ def __init__(
64
+ self,
65
+ save_dir: str = "./results",
66
+ grayscale: bool = False,
67
+ pad_value: int = 0,
68
+ fps: int = 10,
69
+ mode: str = "rainbow", # 'cool', 'optical_flow'
70
+ linewidth: int = 2,
71
+ show_first_frame: int = 10,
72
+ tracks_leave_trace: int = 0, # -1 for infinite
73
+ ):
74
+ self.mode = mode
75
+ self.save_dir = save_dir
76
+ if mode == "rainbow":
77
+ self.color_map = cm.get_cmap("gist_rainbow")
78
+ elif mode == "cool":
79
+ self.color_map = cm.get_cmap(mode)
80
+ self.show_first_frame = show_first_frame
81
+ self.grayscale = grayscale
82
+ self.tracks_leave_trace = tracks_leave_trace
83
+ self.pad_value = pad_value
84
+ self.linewidth = linewidth
85
+ self.fps = fps
86
+
87
+ def visualize(
88
+ self,
89
+ video: torch.Tensor, # (B,T,C,H,W)
90
+ tracks: torch.Tensor, # (B,T,N,2)
91
+ visibility: torch.Tensor = None, # (B, T, N, 1) bool
92
+ gt_tracks: torch.Tensor = None, # (B,T,N,2)
93
+ segm_mask: torch.Tensor = None, # (B,1,H,W)
94
+ filename: str = "video",
95
+ writer=None, # tensorboard Summary Writer, used for visualization during training
96
+ step: int = 0,
97
+ query_frame=0,
98
+ save_video: bool = True,
99
+ compensate_for_camera_motion: bool = False,
100
+ opacity: float = 1.0,
101
+ ):
102
+ if compensate_for_camera_motion:
103
+ assert segm_mask is not None
104
+ if segm_mask is not None:
105
+ coords = tracks[0, query_frame].round().long()
106
+ segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
107
+ if (segm_mask <= 0).all():
108
+ segm_mask = None
109
+
110
+ video = F.pad(
111
+ video,
112
+ (self.pad_value, self.pad_value, self.pad_value, self.pad_value),
113
+ "constant",
114
+ 255,
115
+ )
116
+ color_alpha = int(opacity * 255)
117
+ tracks = tracks + self.pad_value
118
+
119
+ if self.grayscale:
120
+ transform = transforms.Grayscale()
121
+ video = transform(video)
122
+ video = video.repeat(1, 1, 3, 1, 1)
123
+
124
+ res_video = self.draw_tracks_on_video(
125
+ video=video,
126
+ tracks=tracks,
127
+ visibility=visibility,
128
+ segm_mask=segm_mask,
129
+ gt_tracks=gt_tracks,
130
+ query_frame=query_frame,
131
+ compensate_for_camera_motion=compensate_for_camera_motion,
132
+ color_alpha=color_alpha,
133
+ )
134
+ if save_video:
135
+ self.save_video(res_video, filename=filename, writer=writer, step=step)
136
+ return res_video
137
+
138
+ def save_video(self, video, filename, writer=None, step=0):
139
+ if writer is not None:
140
+ writer.add_video(
141
+ filename,
142
+ video.to(torch.uint8),
143
+ global_step=step,
144
+ fps=self.fps,
145
+ )
146
+ else:
147
+ os.makedirs(self.save_dir, exist_ok=True)
148
+ wide_list = list(video.unbind(1))
149
+ wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
150
+
151
+ # Prepare the video file path
152
+ save_path = os.path.join(self.save_dir, f"{filename}.mp4")
153
+
154
+ # Create a writer object
155
+ video_writer = imageio.get_writer(save_path, fps=self.fps)
156
+
157
+ # Write frames to the video file
158
+ for frame in wide_list[2:-1]:
159
+ video_writer.append_data(frame)
160
+
161
+ video_writer.close()
162
+
163
+ print(f"Video saved to {save_path}")
164
+
165
+ def draw_tracks_on_video(
166
+ self,
167
+ video: torch.Tensor,
168
+ tracks: torch.Tensor,
169
+ visibility: torch.Tensor = None,
170
+ segm_mask: torch.Tensor = None,
171
+ gt_tracks=None,
172
+ query_frame=0,
173
+ compensate_for_camera_motion=False,
174
+ color_alpha: int = 255,
175
+ ):
176
+ B, T, C, H, W = video.shape
177
+ _, _, N, D = tracks.shape
178
+
179
+ assert D == 2
180
+ assert C == 3
181
+ video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
182
+ tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
183
+ if gt_tracks is not None:
184
+ gt_tracks = gt_tracks[0].detach().cpu().numpy()
185
+
186
+ res_video = []
187
+
188
+ # process input video
189
+ for rgb in video:
190
+ res_video.append(rgb.copy())
191
+ vector_colors = np.zeros((T, N, 3))
192
+
193
+ if self.mode == "optical_flow":
194
+ import flow_vis
195
+
196
+ vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
197
+ elif segm_mask is None:
198
+ if self.mode == "rainbow":
199
+ y_min, y_max = (
200
+ tracks[query_frame, :, 1].min(),
201
+ tracks[query_frame, :, 1].max(),
202
+ )
203
+ norm = plt.Normalize(y_min, y_max)
204
+ for n in range(N):
205
+ if isinstance(query_frame, torch.Tensor):
206
+ query_frame_ = query_frame[n]
207
+ else:
208
+ query_frame_ = query_frame
209
+ color = self.color_map(norm(tracks[query_frame_, n, 1]))
210
+ color = np.array(color[:3])[None] * 255
211
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
212
+ else:
213
+ # color changes with time
214
+ for t in range(T):
215
+ color = np.array(self.color_map(t / T)[:3])[None] * 255
216
+ vector_colors[t] = np.repeat(color, N, axis=0)
217
+ else:
218
+ if self.mode == "rainbow":
219
+ vector_colors[:, segm_mask <= 0, :] = 255
220
+
221
+ y_min, y_max = (
222
+ tracks[0, segm_mask > 0, 1].min(),
223
+ tracks[0, segm_mask > 0, 1].max(),
224
+ )
225
+ norm = plt.Normalize(y_min, y_max)
226
+ for n in range(N):
227
+ if segm_mask[n] > 0:
228
+ color = self.color_map(norm(tracks[0, n, 1]))
229
+ color = np.array(color[:3])[None] * 255
230
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
231
+
232
+ else:
233
+ # color changes with segm class
234
+ segm_mask = segm_mask.cpu()
235
+ color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
236
+ color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
237
+ color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
238
+ vector_colors = np.repeat(color[None], T, axis=0)
239
+
240
+ # draw tracks
241
+ if self.tracks_leave_trace != 0:
242
+ for t in range(query_frame + 1, T):
243
+ first_ind = (
244
+ max(0, t - self.tracks_leave_trace)
245
+ if self.tracks_leave_trace >= 0
246
+ else 0
247
+ )
248
+ curr_tracks = tracks[first_ind : t + 1]
249
+ curr_colors = vector_colors[first_ind : t + 1]
250
+ if compensate_for_camera_motion:
251
+ diff = (
252
+ tracks[first_ind : t + 1, segm_mask <= 0]
253
+ - tracks[t : t + 1, segm_mask <= 0]
254
+ ).mean(1)[:, None]
255
+
256
+ curr_tracks = curr_tracks - diff
257
+ curr_tracks = curr_tracks[:, segm_mask > 0]
258
+ curr_colors = curr_colors[:, segm_mask > 0]
259
+
260
+ res_video[t] = self._draw_pred_tracks(
261
+ res_video[t],
262
+ curr_tracks,
263
+ curr_colors,
264
+ )
265
+ if gt_tracks is not None:
266
+ res_video[t] = self._draw_gt_tracks(
267
+ res_video[t], gt_tracks[first_ind : t + 1]
268
+ )
269
+
270
+ # draw points
271
+ for t in range(T):
272
+ img = Image.fromarray(np.uint8(res_video[t]))
273
+ for i in range(N):
274
+ coord = (tracks[t, i, 0], tracks[t, i, 1])
275
+ visibile = True
276
+ if visibility is not None:
277
+ visibile = visibility[0, t, i]
278
+ if coord[0] != 0 and coord[1] != 0:
279
+ if not compensate_for_camera_motion or (
280
+ compensate_for_camera_motion and segm_mask[i] > 0
281
+ ):
282
+ img = draw_circle(
283
+ img,
284
+ coord=coord,
285
+ radius=int(self.linewidth * 2),
286
+ color=vector_colors[t, i].astype(int),
287
+ visible=visibile,
288
+ color_alpha=color_alpha,
289
+ )
290
+ res_video[t] = np.array(img)
291
+
292
+ # construct the final rgb sequence
293
+ if self.show_first_frame > 0:
294
+ res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
295
+ return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
296
+
297
+ def _draw_pred_tracks(
298
+ self,
299
+ rgb: np.ndarray, # H x W x 3
300
+ tracks: np.ndarray, # T x 2
301
+ vector_colors: np.ndarray,
302
+ alpha: float = 0.5,
303
+ ):
304
+ T, N, _ = tracks.shape
305
+ rgb = Image.fromarray(np.uint8(rgb))
306
+ for s in range(T - 1):
307
+ vector_color = vector_colors[s]
308
+ original = rgb.copy()
309
+ alpha = (s / T) ** 2
310
+ for i in range(N):
311
+ coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
312
+ coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
313
+ if coord_y[0] != 0 and coord_y[1] != 0:
314
+ rgb = draw_line(
315
+ rgb,
316
+ coord_y,
317
+ coord_x,
318
+ vector_color[i].astype(int),
319
+ self.linewidth,
320
+ )
321
+ if self.tracks_leave_trace > 0:
322
+ rgb = Image.fromarray(
323
+ np.uint8(
324
+ add_weighted(
325
+ np.array(rgb), alpha, np.array(original), 1 - alpha, 0
326
+ )
327
+ )
328
+ )
329
+ rgb = np.array(rgb)
330
+ return rgb
331
+
332
+ def _draw_gt_tracks(
333
+ self,
334
+ rgb: np.ndarray, # H x W x 3,
335
+ gt_tracks: np.ndarray, # T x 2
336
+ ):
337
+ T, N, _ = gt_tracks.shape
338
+ color = np.array((211, 0, 0))
339
+ rgb = Image.fromarray(np.uint8(rgb))
340
+ for t in range(T):
341
+ for i in range(N):
342
+ gt_tracks = gt_tracks[t][i]
343
+ # draw a red cross
344
+ if gt_tracks[0] > 0 and gt_tracks[1] > 0:
345
+ length = self.linewidth * 3
346
+ coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
347
+ coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
348
+ rgb = draw_line(
349
+ rgb,
350
+ coord_y,
351
+ coord_x,
352
+ color,
353
+ self.linewidth,
354
+ )
355
+ coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
356
+ coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
357
+ rgb = draw_line(
358
+ rgb,
359
+ coord_y,
360
+ coord_x,
361
+ color,
362
+ self.linewidth,
363
+ )
364
+ rgb = np.array(rgb)
365
+ return rgb