diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..53aee21b9a286bf9d1904e9527c3f0b047f4f0ea
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,36 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.mp4 filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100755
index 0000000000000000000000000000000000000000..044e909664b324837cf74183fd91eee1a88829a5
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,69 @@
+# ignore the multi media
+checkpoints
+**/checkpoints/
+**/temp/
+temp
+assets_dev
+assets/example0/results
+assets/example0/snowboard.npz
+assets/example1/results
+assets/davis_eval
+assets/*/results
+*gradio*
+#
+models/monoD/zoeDepth/ckpts/*
+models/monoD/depth_anything/ckpts/*
+vis_results
+dist_encrypted
+# remove the dependencies
+deps
+
+# filter the __pycache__ files
+__pycache__/
+/**/**/__pycache__
+/**/__pycache__
+
+outputs
+scripts/lauch_exp/config
+scripts/lauch_exp/submit_job.log
+scripts/lauch_exp/hydra_output
+scripts/lauch_wulan
+scripts/custom_video
+# ignore the visualizer
+viser
+viser_result
+benchmark/results
+benchmark
+
+ossutil_output
+
+prev_version
+spat_ceres
+wandb
+*.log
+seg_target.py
+
+eval_davis.py
+eval_multiple_gpu.py
+eval_pose_scan.py
+eval_single_gpu.py
+
+infer_cam.py
+infer_stream.py
+
+*.egg-info/
+**/*.egg-info
+
+eval_kinectics.py
+models/SpaTrackV2/datasets
+
+scripts
+config/fix_2d.yaml
+
+models/SpaTrackV2/datasets
+scripts/
+
+models/**/build
+models/**/dist
+
+temp_local
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a2fb57fa4b98c295b80fda2f321a9fe88d21fcf7
--- /dev/null
+++ b/README.md
@@ -0,0 +1,14 @@
+---
+title: SpatialTrackerV2
+emoji: ⚡️
+colorFrom: yellow
+colorTo: red
+sdk: gradio
+sdk_version: 5.31.0
+app_file: app.py
+pinned: false
+license: mit
+short_description: Official Space for SpatialTrackerV2
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/_viz/viz_template.html b/_viz/viz_template.html
new file mode 100644
index 0000000000000000000000000000000000000000..218d0df73f3689e66413c53671b784e5bb89fea7
--- /dev/null
+++ b/_viz/viz_template.html
@@ -0,0 +1,1778 @@
+
+
+
+
+
+ 3D Point Cloud Visualizer
+
+
+
+
+
+
+
+
+
+
+
Initializing...
+
+
+
+
+
+
+
Frame 0 / 0
+
+
+
+
+
+
+
+
+ ☰
+ Visualization Settings
+
+
+
+
+
+
+
+
+
Camera
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Interactive Viewer of 3D Tracking
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0b661a61d8c91782ee0892a78f9324463c115d4
--- /dev/null
+++ b/app.py
@@ -0,0 +1,1118 @@
+import gradio as gr
+import os
+import json
+import numpy as np
+import cv2
+import base64
+import time
+import tempfile
+import shutil
+import glob
+import threading
+import subprocess
+import struct
+import zlib
+from pathlib import Path
+from einops import rearrange
+from typing import List, Tuple, Union
+try:
+ import spaces
+except ImportError:
+ # Fallback for local development
+ def spaces(func):
+ return func
+import torch
+import logging
+from concurrent.futures import ThreadPoolExecutor
+import atexit
+import uuid
+
+# Configure logging
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+# Import custom modules with error handling
+try:
+ from app_3rd.sam_utils.inference import SamPredictor, get_sam_predictor, run_inference
+ from app_3rd.spatrack_utils.infer_track import get_tracker_predictor, run_tracker, get_points_on_a_grid
+except ImportError as e:
+ logger.error(f"Failed to import custom modules: {e}")
+ raise
+
+# Constants
+MAX_FRAMES = 80
+COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
+MARKERS = [1, 5] # Cross for negative, Star for positive
+MARKER_SIZE = 8
+
+# Thread pool for delayed deletion
+thread_pool_executor = ThreadPoolExecutor(max_workers=2)
+
+def delete_later(path: Union[str, os.PathLike], delay: int = 600):
+ """Delete file or directory after specified delay (default 10 minutes)"""
+ def _delete():
+ try:
+ if os.path.isfile(path):
+ os.remove(path)
+ elif os.path.isdir(path):
+ shutil.rmtree(path)
+ except Exception as e:
+ logger.warning(f"Failed to delete {path}: {e}")
+
+ def _wait_and_delete():
+ time.sleep(delay)
+ _delete()
+
+ thread_pool_executor.submit(_wait_and_delete)
+ atexit.register(_delete)
+
+def create_user_temp_dir():
+ """Create a unique temporary directory for each user session"""
+ session_id = str(uuid.uuid4())[:8] # Short unique ID
+ temp_dir = os.path.join("temp_local", f"session_{session_id}")
+ os.makedirs(temp_dir, exist_ok=True)
+
+ # Schedule deletion after 10 minutes
+ delete_later(temp_dir, delay=600)
+
+ return temp_dir
+
+from huggingface_hub import hf_hub_download
+# init the model
+os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", "spatrack_front.pth") #, force_download=True)
+
+if os.environ.get("VGGT_DIR", None) is not None:
+ from models.vggt.vggt.models.vggt_moe import VGGT_MoE
+ from models.vggt.vggt.utils.load_fn import preprocess_image
+ vggt_model = VGGT_MoE()
+ vggt_model.load_state_dict(torch.load(os.environ.get("VGGT_DIR")), strict=False)
+ vggt_model.eval()
+ vggt_model = vggt_model.to("cuda")
+
+# Global model initialization
+print("🚀 Initializing local models...")
+tracker_model, _ = get_tracker_predictor(".", vo_points=756)
+predictor = get_sam_predictor()
+print("✅ Models loaded successfully!")
+
+gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
+
+@spaces.GPU
+def gpu_run_inference(predictor_arg, image, points, boxes):
+ """GPU-accelerated SAM inference"""
+ if predictor_arg is None:
+ print("Initializing SAM predictor inside GPU function...")
+ predictor_arg = get_sam_predictor(predictor=predictor)
+
+ # Ensure predictor is on GPU
+ try:
+ if hasattr(predictor_arg, 'model'):
+ predictor_arg.model = predictor_arg.model.cuda()
+ elif hasattr(predictor_arg, 'sam'):
+ predictor_arg.sam = predictor_arg.sam.cuda()
+ elif hasattr(predictor_arg, 'to'):
+ predictor_arg = predictor_arg.to('cuda')
+
+ if hasattr(image, 'cuda'):
+ image = image.cuda()
+
+ except Exception as e:
+ print(f"Warning: Could not move predictor to GPU: {e}")
+
+ return run_inference(predictor_arg, image, points, boxes)
+
+@spaces.GPU
+def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, grid_size, vo_points, fps, mode="offline"):
+ """GPU-accelerated tracking"""
+ import torchvision.transforms as T
+ import decord
+
+ if tracker_model_arg is None or tracker_viser_arg is None:
+ print("Initializing tracker models inside GPU function...")
+ out_dir = os.path.join(temp_dir, "results")
+ os.makedirs(out_dir, exist_ok=True)
+ tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points, tracker_model=tracker_model)
+
+ # Setup paths
+ video_path = os.path.join(temp_dir, f"{video_name}.mp4")
+ mask_path = os.path.join(temp_dir, f"{video_name}.png")
+ out_dir = os.path.join(temp_dir, "results")
+ os.makedirs(out_dir, exist_ok=True)
+
+ # Load video using decord
+ video_reader = decord.VideoReader(video_path)
+ video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2)
+
+ # Resize to ensure minimum side is 336
+ h, w = video_tensor.shape[2:]
+ scale = max(224 / h, 224 / w)
+ if scale < 1:
+ new_h, new_w = int(h * scale), int(w * scale)
+ video_tensor = T.Resize((new_h, new_w))(video_tensor)
+ video_tensor = video_tensor[::fps].float()[:MAX_FRAMES]
+
+ # Move to GPU
+ video_tensor = video_tensor.cuda()
+ print(f"Video tensor shape: {video_tensor.shape}, device: {video_tensor.device}")
+
+ depth_tensor = None
+ intrs = None
+ extrs = None
+ data_npz_load = {}
+
+ # run vggt
+ if os.environ.get("VGGT_DIR", None) is not None:
+ # process the image tensor
+ video_tensor = preprocess_image(video_tensor)[None]
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+ # Predict attributes including cameras, depth maps, and point maps.
+ predictions = vggt_model(video_tensor.cuda()/255)
+ extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
+ depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
+
+ depth_tensor = depth_map.squeeze().cpu().numpy()
+ extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
+ extrs = extrinsic.squeeze().cpu().numpy()
+ intrs = intrinsic.squeeze().cpu().numpy()
+ video_tensor = video_tensor.squeeze()
+ #NOTE: 20% of the depth is not reliable
+ # threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
+ unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
+
+ # Load and process mask
+ if os.path.exists(mask_path):
+ mask = cv2.imread(mask_path)
+ mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
+ mask = mask.sum(axis=-1)>0
+ else:
+ mask = np.ones_like(video_tensor[0,0].cpu().numpy())>0
+ grid_size = 10
+
+ # Get frame dimensions and create grid points
+ frame_H, frame_W = video_tensor.shape[2:]
+ grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cuda")
+
+ # Sample mask values at grid points and filter
+ if os.path.exists(mask_path):
+ grid_pts_int = grid_pts[0].long()
+ mask_values = mask[grid_pts_int.cpu()[...,1], grid_pts_int.cpu()[...,0]]
+ grid_pts = grid_pts[:, mask_values]
+
+ query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].cpu().numpy()
+ print(f"Query points shape: {query_xyt.shape}")
+
+ # Run model inference
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ (
+ c2w_traj, intrs, point_map, conf_depth,
+ track3d_pred, track2d_pred, vis_pred, conf_pred, video
+ ) = tracker_model_arg.forward(video_tensor, depth=depth_tensor,
+ intrs=intrs, extrs=extrs,
+ queries=query_xyt,
+ fps=1, full_point=False, iters_track=4,
+ query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
+ support_frame=len(video_tensor)-1, replace_ratio=0.2)
+
+ # Resize results to avoid large I/O
+ max_size = 224
+ h, w = video.shape[2:]
+ scale = min(max_size / h, max_size / w)
+ if scale < 1:
+ new_h, new_w = int(h * scale), int(w * scale)
+ video = T.Resize((new_h, new_w))(video)
+ video_tensor = T.Resize((new_h, new_w))(video_tensor)
+ point_map = T.Resize((new_h, new_w))(point_map)
+ track2d_pred[...,:2] = track2d_pred[...,:2] * scale
+ intrs[:,:2,:] = intrs[:,:2,:] * scale
+ conf_depth = T.Resize((new_h, new_w))(conf_depth)
+
+ # Visualize tracks
+ tracker_viser_arg.visualize(video=video[None],
+ tracks=track2d_pred[None][...,:2],
+ visibility=vis_pred[None],filename="test")
+
+ # Save in tapip3d format
+ data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
+ data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
+ data_npz_load["intrinsics"] = intrs.cpu().numpy()
+ data_npz_load["depths"] = point_map[:,2,...].cpu().numpy()
+ data_npz_load["video"] = (video_tensor).cpu().numpy()/255
+ data_npz_load["visibs"] = vis_pred.cpu().numpy()
+ data_npz_load["confs"] = conf_pred.cpu().numpy()
+ data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
+ np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
+
+ return None
+
+def compress_and_write(filename, header, blob):
+ header_bytes = json.dumps(header).encode("utf-8")
+ header_len = struct.pack(" T H W C") * 255).astype(np.uint8)
+ rgb_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_AREA)
+ for frame in rgb_video])
+
+ depth_video = data["depths"].astype(np.float32)
+ if "confs_depth" in data.keys():
+ confs = (data["confs_depth"].astype(np.float32) > 0.5).astype(np.float32)
+ depth_video = depth_video * confs
+ depth_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_NEAREST)
+ for frame in depth_video])
+
+ scale_x = fixed_size[0] / W
+ scale_y = fixed_size[1] / H
+ intrinsics = intrinsics.copy()
+ intrinsics[:, 0, :] *= scale_x
+ intrinsics[:, 1, :] *= scale_y
+
+ min_depth = float(depth_video.min()) * 0.8
+ max_depth = float(depth_video.max()) * 1.5
+
+ depth_normalized = (depth_video - min_depth) / (max_depth - min_depth)
+ depth_int = (depth_normalized * ((1 << 16) - 1)).astype(np.uint16)
+
+ depths_rgb = np.zeros((T, fixed_size[1], fixed_size[0], 3), dtype=np.uint8)
+ depths_rgb[:, :, :, 0] = (depth_int & 0xFF).astype(np.uint8)
+ depths_rgb[:, :, :, 1] = ((depth_int >> 8) & 0xFF).astype(np.uint8)
+
+ first_frame_inv = np.linalg.inv(extrinsics[0])
+ normalized_extrinsics = np.array([first_frame_inv @ ext for ext in extrinsics])
+
+ normalized_trajs = np.zeros_like(trajs)
+ for t in range(T):
+ homogeneous_trajs = np.concatenate([trajs[t], np.ones((trajs.shape[1], 1))], axis=1)
+ transformed_trajs = (first_frame_inv @ homogeneous_trajs.T).T
+ normalized_trajs[t] = transformed_trajs[:, :3]
+
+ arrays = {
+ "rgb_video": rgb_video,
+ "depths_rgb": depths_rgb,
+ "intrinsics": intrinsics,
+ "extrinsics": normalized_extrinsics,
+ "inv_extrinsics": np.linalg.inv(normalized_extrinsics),
+ "trajectories": normalized_trajs.astype(np.float32),
+ "cameraZ": 0.0
+ }
+
+ header = {}
+ blob_parts = []
+ offset = 0
+ for key, arr in arrays.items():
+ arr = np.ascontiguousarray(arr)
+ arr_bytes = arr.tobytes()
+ header[key] = {
+ "dtype": str(arr.dtype),
+ "shape": arr.shape,
+ "offset": offset,
+ "length": len(arr_bytes)
+ }
+ blob_parts.append(arr_bytes)
+ offset += len(arr_bytes)
+
+ raw_blob = b"".join(blob_parts)
+ compressed_blob = zlib.compress(raw_blob, level=9)
+
+ header["meta"] = {
+ "depthRange": [min_depth, max_depth],
+ "totalFrames": int(T),
+ "resolution": fixed_size,
+ "baseFrameRate": fps,
+ "numTrajectoryPoints": normalized_trajs.shape[1],
+ "fov": float(fov_y),
+ "fov_x": float(fov_x),
+ "original_aspect_ratio": float(original_aspect_ratio),
+ "fixed_aspect_ratio": float(fixed_size[0]/fixed_size[1])
+ }
+
+ compress_and_write('./_viz/data.bin', header, compressed_blob)
+ with open('./_viz/data.bin', "rb") as f:
+ encoded_blob = base64.b64encode(f.read()).decode("ascii")
+ os.unlink('./_viz/data.bin')
+
+ random_path = f'./_viz/_{time.time()}.html'
+ with open('./_viz/viz_template.html') as f:
+ html_template = f.read()
+ html_out = html_template.replace(
+ "",
+ f"\n"
+ )
+ with open(random_path,'w') as f:
+ f.write(html_out)
+
+ return random_path
+
+def numpy_to_base64(arr):
+ """Convert numpy array to base64 string"""
+ return base64.b64encode(arr.tobytes()).decode('utf-8')
+
+def base64_to_numpy(b64_str, shape, dtype):
+ """Convert base64 string back to numpy array"""
+ return np.frombuffer(base64.b64decode(b64_str), dtype=dtype).reshape(shape)
+
+def get_video_name(video_path):
+ """Extract video name without extension"""
+ return os.path.splitext(os.path.basename(video_path))[0]
+
+def extract_first_frame(video_path):
+ """Extract first frame from video file"""
+ try:
+ cap = cv2.VideoCapture(video_path)
+ ret, frame = cap.read()
+ cap.release()
+
+ if ret:
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ return frame_rgb
+ else:
+ return None
+ except Exception as e:
+ print(f"Error extracting first frame: {e}")
+ return None
+
+def handle_video_upload(video):
+ """Handle video upload and extract first frame"""
+ if video is None:
+ return (None, None, [],
+ gr.update(value=50),
+ gr.update(value=756),
+ gr.update(value=3))
+
+ # Create user-specific temporary directory
+ user_temp_dir = create_user_temp_dir()
+
+ # Get original video name and copy to temp directory
+ if isinstance(video, str):
+ video_name = get_video_name(video)
+ video_path = os.path.join(user_temp_dir, f"{video_name}.mp4")
+ shutil.copy(video, video_path)
+ else:
+ video_name = get_video_name(video.name)
+ video_path = os.path.join(user_temp_dir, f"{video_name}.mp4")
+ with open(video_path, 'wb') as f:
+ f.write(video.read())
+
+ print(f"📁 Video saved to: {video_path}")
+
+ # Extract first frame
+ frame = extract_first_frame(video_path)
+ if frame is None:
+ return (None, None, [],
+ gr.update(value=50),
+ gr.update(value=756),
+ gr.update(value=3))
+
+ # Resize frame to have minimum side length of 336
+ h, w = frame.shape[:2]
+ scale = 336 / min(h, w)
+ new_h, new_w = int(h * scale)//2*2, int(w * scale)//2*2
+ frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
+
+ # Store frame data with temp directory info
+ frame_data = {
+ 'data': numpy_to_base64(frame),
+ 'shape': frame.shape,
+ 'dtype': str(frame.dtype),
+ 'temp_dir': user_temp_dir,
+ 'video_name': video_name,
+ 'video_path': video_path
+ }
+
+ # Get video-specific settings
+ print(f"🎬 Video path: '{video}' -> Video name: '{video_name}'")
+ grid_size_val, vo_points_val, fps_val = get_video_settings(video_name)
+ print(f"🎬 Video settings for '{video_name}': grid_size={grid_size_val}, vo_points={vo_points_val}, fps={fps_val}")
+
+ return (json.dumps(frame_data), frame, [],
+ gr.update(value=grid_size_val),
+ gr.update(value=vo_points_val),
+ gr.update(value=fps_val))
+
+def save_masks(o_masks, video_name, temp_dir):
+ """Save binary masks to files in user-specific temp directory"""
+ o_files = []
+ for mask, _ in o_masks:
+ o_mask = np.uint8(mask.squeeze() * 255)
+ o_file = os.path.join(temp_dir, f"{video_name}.png")
+ cv2.imwrite(o_file, o_mask)
+ o_files.append(o_file)
+ return o_files
+
+def select_point(original_img: str, sel_pix: list, point_type: str, evt: gr.SelectData):
+ """Handle point selection for SAM"""
+ if original_img is None:
+ return None, []
+
+ try:
+ # Convert stored image data back to numpy array
+ frame_data = json.loads(original_img)
+ original_img_array = base64_to_numpy(frame_data['data'], frame_data['shape'], frame_data['dtype'])
+ temp_dir = frame_data.get('temp_dir', 'temp_local')
+ video_name = frame_data.get('video_name', 'video')
+
+ # Create a display image for visualization
+ display_img = original_img_array.copy()
+ new_sel_pix = sel_pix.copy() if sel_pix else []
+ new_sel_pix.append((evt.index, 1 if point_type == 'positive_point' else 0))
+
+ print(f"🎯 Running SAM inference for point: {evt.index}, type: {point_type}")
+ # Run SAM inference
+ o_masks = gpu_run_inference(None, original_img_array, new_sel_pix, [])
+
+ # Draw points on display image
+ for point, label in new_sel_pix:
+ cv2.drawMarker(display_img, point, COLORS[label], markerType=MARKERS[label], markerSize=MARKER_SIZE, thickness=2)
+
+ # Draw mask overlay on display image
+ if o_masks:
+ mask = o_masks[0][0]
+ overlay = display_img.copy()
+ overlay[mask.squeeze()!=0] = [20, 60, 200] # Light blue
+ display_img = cv2.addWeighted(overlay, 0.6, display_img, 0.4, 0)
+
+ # Save mask for tracking
+ save_masks(o_masks, video_name, temp_dir)
+ print(f"✅ Mask saved for video: {video_name}")
+
+ return display_img, new_sel_pix
+
+ except Exception as e:
+ print(f"❌ Error in select_point: {e}")
+ return None, []
+
+def reset_points(original_img: str, sel_pix):
+ """Reset all points and clear the mask"""
+ if original_img is None:
+ return None, []
+
+ try:
+ # Convert stored image data back to numpy array
+ frame_data = json.loads(original_img)
+ original_img_array = base64_to_numpy(frame_data['data'], frame_data['shape'], frame_data['dtype'])
+ temp_dir = frame_data.get('temp_dir', 'temp_local')
+
+ # Create a display image (just the original image)
+ display_img = original_img_array.copy()
+
+ # Clear all points
+ new_sel_pix = []
+
+ # Clear any existing masks
+ for mask_file in glob.glob(os.path.join(temp_dir, "*.png")):
+ try:
+ os.remove(mask_file)
+ except Exception as e:
+ logger.warning(f"Failed to remove mask file {mask_file}: {e}")
+
+ print("🔄 Points and masks reset")
+ return display_img, new_sel_pix
+
+ except Exception as e:
+ print(f"❌ Error in reset_points: {e}")
+ return None, []
+
+def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
+ """Launch visualization with user-specific temp directory"""
+ if original_image_state is None:
+ return None, None, None
+
+ try:
+ # Get user's temp directory from stored frame data
+ frame_data = json.loads(original_image_state)
+ temp_dir = frame_data.get('temp_dir', 'temp_local')
+ video_name = frame_data.get('video_name', 'video')
+
+ print(f"🚀 Starting tracking for video: {video_name}")
+ print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}")
+
+ # Check for mask files
+ mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
+ video_files = glob.glob(os.path.join(temp_dir, "*.mp4"))
+
+ if not video_files:
+ print("❌ No video file found")
+ return "❌ Error: No video file found", None, None
+
+ video_path = video_files[0]
+ mask_path = mask_files[0] if mask_files else None
+
+ # Run tracker
+ print("🎯 Running tracker...")
+ out_dir = os.path.join(temp_dir, "results")
+ os.makedirs(out_dir, exist_ok=True)
+
+ gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=mode)
+
+ # Process results
+ npz_path = os.path.join(out_dir, "result.npz")
+ track2d_video = os.path.join(out_dir, "test_pred_track.mp4")
+
+ if os.path.exists(npz_path):
+ print("📊 Processing 3D visualization...")
+ html_path = process_point_cloud_data(npz_path)
+
+ # Schedule deletion of generated files
+ delete_later(html_path, delay=600)
+ if os.path.exists(track2d_video):
+ delete_later(track2d_video, delay=600)
+ delete_later(npz_path, delay=600)
+
+ # Create iframe HTML
+ iframe_html = f"""
+
+
+
+ """
+
+ print("✅ Tracking completed successfully!")
+ return iframe_html, track2d_video if os.path.exists(track2d_video) else None, html_path
+ else:
+ print("❌ Tracking failed - no results generated")
+ return "❌ Error: Tracking failed to generate results", None, None
+
+ except Exception as e:
+ print(f"❌ Error in launch_viz: {e}")
+ return f"❌ Error: {str(e)}", None, None
+
+def clear_all():
+ """Clear all buffers and temporary files"""
+ return (None, None, [],
+ gr.update(value=50),
+ gr.update(value=756),
+ gr.update(value=3))
+
+def clear_all_with_download():
+ """Clear all buffers including both download components"""
+ return (None, None, [],
+ gr.update(value=50),
+ gr.update(value=756),
+ gr.update(value=3),
+ None, # tracking_video_download
+ None) # HTML download component
+
+def get_video_settings(video_name):
+ """Get video-specific settings based on video name"""
+ video_settings = {
+ "running": (50, 512, 2),
+ "backpack": (40, 600, 2),
+ "kitchen": (60, 800, 3),
+ "pillow": (35, 500, 2),
+ "handwave": (35, 500, 8),
+ "hockey": (45, 700, 2),
+ "drifting": (35, 1000, 6),
+ "basketball": (45, 1500, 5),
+ "ken_block_0": (45, 700, 2),
+ "ego_kc1": (45, 500, 4),
+ "vertical_place": (45, 500, 3),
+ "ego_teaser": (45, 1200, 10),
+ "robot_unitree": (45, 500, 4),
+ "robot_3": (35, 400, 5),
+ "teleop2": (45, 256, 7),
+ "pusht": (45, 256, 10),
+ "cinema_0": (45, 356, 5),
+ "cinema_1": (45, 756, 3),
+ "robot1": (45, 600, 2),
+ "robot2": (45, 600, 2),
+ "protein": (45, 600, 2),
+ "kitchen_egocentric": (45, 600, 2),
+ }
+
+ return video_settings.get(video_name, (50, 756, 3))
+
+# Create the Gradio interface
+print("🎨 Creating Gradio interface...")
+
+with gr.Blocks(
+ theme=gr.themes.Soft(),
+ title="🎯 [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2)",
+ css="""
+ .gradio-container {
+ max-width: 1200px !important;
+ margin: auto !important;
+ }
+ .gr-button {
+ margin: 5px;
+ }
+ .gr-form {
+ background: white;
+ border-radius: 10px;
+ padding: 20px;
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
+ }
+ /* 移除 gr.Group 的默认灰色背景 */
+ .gr-form {
+ background: transparent !important;
+ border: none !important;
+ box-shadow: none !important;
+ padding: 0 !important;
+ }
+ /* 固定3D可视化器尺寸 */
+ #viz_container {
+ height: 650px !important;
+ min-height: 650px !important;
+ max-height: 650px !important;
+ width: 100% !important;
+ margin: 0 !important;
+ padding: 0 !important;
+ overflow: hidden !important;
+ }
+ #viz_container > div {
+ height: 650px !important;
+ min-height: 650px !important;
+ max-height: 650px !important;
+ width: 100% !important;
+ margin: 0 !important;
+ padding: 0 !important;
+ box-sizing: border-box !important;
+ }
+ #viz_container iframe {
+ height: 650px !important;
+ min-height: 650px !important;
+ max-height: 650px !important;
+ width: 100% !important;
+ border: none !important;
+ display: block !important;
+ margin: 0 !important;
+ padding: 0 !important;
+ box-sizing: border-box !important;
+ }
+ /* 固定视频上传组件高度 */
+ .gr-video {
+ height: 300px !important;
+ min-height: 300px !important;
+ max-height: 300px !important;
+ }
+ .gr-video video {
+ height: 260px !important;
+ max-height: 260px !important;
+ object-fit: contain !important;
+ background: #f8f9fa;
+ }
+ .gr-video .gr-video-player {
+ height: 260px !important;
+ max-height: 260px !important;
+ }
+ /* 强力移除examples的灰色背景 - 使用更通用的选择器 */
+ .horizontal-examples,
+ .horizontal-examples > *,
+ .horizontal-examples * {
+ background: transparent !important;
+ background-color: transparent !important;
+ border: none !important;
+ }
+
+ /* Examples组件水平滚动样式 */
+ .horizontal-examples [data-testid="examples"] {
+ background: transparent !important;
+ background-color: transparent !important;
+ }
+
+ .horizontal-examples [data-testid="examples"] > div {
+ background: transparent !important;
+ background-color: transparent !important;
+ overflow-x: auto !important;
+ overflow-y: hidden !important;
+ scrollbar-width: thin;
+ scrollbar-color: #667eea transparent;
+ padding: 0 !important;
+ margin-top: 10px;
+ border: none !important;
+ }
+
+ .horizontal-examples [data-testid="examples"] table {
+ display: flex !important;
+ flex-wrap: nowrap !important;
+ min-width: max-content !important;
+ gap: 15px !important;
+ padding: 10px 0;
+ background: transparent !important;
+ border: none !important;
+ }
+
+ .horizontal-examples [data-testid="examples"] tbody {
+ display: flex !important;
+ flex-direction: row !important;
+ flex-wrap: nowrap !important;
+ gap: 15px !important;
+ background: transparent !important;
+ }
+
+ .horizontal-examples [data-testid="examples"] tr {
+ display: flex !important;
+ flex-direction: column !important;
+ min-width: 160px !important;
+ max-width: 160px !important;
+ margin: 0 !important;
+ background: white !important;
+ border-radius: 12px;
+ box-shadow: 0 3px 12px rgba(0,0,0,0.12);
+ transition: all 0.3s ease;
+ cursor: pointer;
+ overflow: hidden;
+ border: none !important;
+ }
+
+ .horizontal-examples [data-testid="examples"] tr:hover {
+ transform: translateY(-4px);
+ box-shadow: 0 8px 20px rgba(102, 126, 234, 0.25);
+ }
+
+ .horizontal-examples [data-testid="examples"] td {
+ text-align: center !important;
+ padding: 0 !important;
+ border: none !important;
+ background: transparent !important;
+ }
+
+ .horizontal-examples [data-testid="examples"] td:first-child {
+ padding: 0 !important;
+ background: transparent !important;
+ }
+
+ .horizontal-examples [data-testid="examples"] video {
+ border-radius: 8px 8px 0 0 !important;
+ width: 100% !important;
+ height: 90px !important;
+ object-fit: cover !important;
+ background: #f8f9fa !important;
+ }
+
+ .horizontal-examples [data-testid="examples"] td:last-child {
+ font-size: 11px !important;
+ font-weight: 600 !important;
+ color: #333 !important;
+ padding: 8px 12px !important;
+ background: linear-gradient(135deg, #f8f9ff 0%, #e6f3ff 100%) !important;
+ border-radius: 0 0 8px 8px;
+ }
+
+ /* 滚动条样式 */
+ .horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar {
+ height: 8px;
+ }
+ .horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-track {
+ background: transparent;
+ border-radius: 4px;
+ }
+ .horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-thumb {
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
+ border-radius: 4px;
+ }
+ .horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-thumb:hover {
+ background: linear-gradient(135deg, #5a6fd8 0%, #6a4190 100%);
+ }
+ """
+) as demo:
+
+ # Add prominent main title
+
+ gr.Markdown("""
+ # ✨ SpatialTrackerV2
+
+ Welcome to [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2)! This interface allows you to track any pixels in 3D using our model.
+
+ **⚡ Quick Start:** Upload video → Click "Start Tracking Now!"
+
+ **🔬 Advanced Usage with SAM:**
+ 1. Upload a video file or select from examples below
+ 2. Expand "Manual Point Selection" to click on specific objects for SAM-guided tracking
+ 3. Adjust tracking parameters for optimal performance
+ 4. Click "Start Tracking Now!" to begin 3D tracking with SAM guidance
+
+ """)
+
+ # Status indicator
+ gr.Markdown("**Status:** 🟢 Local Processing Mode")
+
+ # Main content area - video upload left, 3D visualization right
+ with gr.Row():
+ with gr.Column(scale=1):
+ # Video upload section
+ gr.Markdown("### 📂 Select Video")
+
+ # Define video_input here so it can be referenced in examples
+ video_input = gr.Video(
+ label="Upload Video or Select Example",
+ format="mp4",
+ height=250 # Matched height with 3D viz
+ )
+
+
+ # Traditional examples but with horizontal scroll styling
+ gr.Markdown("🎨**Examples:** (scroll horizontally to see all videos)")
+ with gr.Row(elem_classes=["horizontal-examples"]):
+ # Horizontal video examples with slider
+ # gr.HTML("")
+ gr.Examples(
+ examples=[
+ ["./examples/robot1.mp4"],
+ ["./examples/robot2.mp4"],
+ ["./examples/protein.mp4"],
+ ["./examples/kitchen_egocentric.mp4"],
+ ["./examples/hockey.mp4"],
+ ["./examples/running.mp4"],
+ ["./examples/robot_3.mp4"],
+ ["./examples/backpack.mp4"],
+ ["./examples/kitchen.mp4"],
+ ["./examples/pillow.mp4"],
+ ["./examples/handwave.mp4"],
+ ["./examples/drifting.mp4"],
+ ["./examples/basketball.mp4"],
+ ["./examples/ken_block_0.mp4"],
+ ["./examples/ego_kc1.mp4"],
+ ["./examples/vertical_place.mp4"],
+ ["./examples/ego_teaser.mp4"],
+ ["./examples/robot_unitree.mp4"],
+ ["./examples/teleop2.mp4"],
+ ["./examples/pusht.mp4"],
+ ["./examples/cinema_0.mp4"],
+ ["./examples/cinema_1.mp4"],
+ ],
+ inputs=[video_input],
+ outputs=[video_input],
+ fn=None,
+ cache_examples=False,
+ label="",
+ examples_per_page=6 # Show 6 examples per page so they can wrap to multiple rows
+ )
+
+ with gr.Column(scale=2):
+ # 3D Visualization - wider and taller to match left side
+ with gr.Group():
+ gr.Markdown("### 🌐 3D Trajectory Visualization")
+ viz_html = gr.HTML(
+ label="3D Trajectory Visualization",
+ value="""
+
+
🌐
+
+ 3D Trajectory Visualization
+
+
+ Track any pixels in 3D space with camera motion
+
+
+
+ ⚡ Powered by SpatialTracker V2
+
+
+
+ """,
+ elem_id="viz_container"
+ )
+
+ # Start button section - below video area
+ with gr.Row():
+ with gr.Column(scale=3):
+ launch_btn = gr.Button("🚀 Start Tracking Now!", variant="primary", size="lg")
+ with gr.Column(scale=1):
+ clear_all_btn = gr.Button("🗑️ Clear All", variant="secondary", size="sm")
+
+ # Tracking parameters section
+ with gr.Row():
+ gr.Markdown("### ⚙️ Tracking Parameters")
+ with gr.Row():
+ grid_size = gr.Slider(
+ minimum=10, maximum=100, step=10, value=50,
+ label="Grid Size", info="Tracking detail level"
+ )
+ vo_points = gr.Slider(
+ minimum=100, maximum=2000, step=50, value=756,
+ label="VO Points", info="Motion accuracy"
+ )
+ fps = gr.Slider(
+ minimum=1, maximum=20, step=1, value=3,
+ label="FPS", info="Processing speed"
+ )
+
+ # Advanced Point Selection with SAM - Collapsed by default
+ with gr.Row():
+ gr.Markdown("### 🎯 Advanced: Manual Point Selection with SAM")
+ with gr.Accordion("🔬 SAM Point Selection Controls", open=False):
+ gr.HTML("""
+
+
+ - Click on target objects in the image for SAM-guided segmentation
+ - Positive points: include these areas | Negative points: exclude these areas
+ - Get more accurate 3D tracking results with SAM's powerful segmentation
+
+
+ """)
+
+ with gr.Row():
+ with gr.Column():
+ interactive_frame = gr.Image(
+ label="Click to select tracking points with SAM guidance",
+ type="numpy",
+ interactive=True,
+ height=300
+ )
+
+ with gr.Row():
+ point_type = gr.Radio(
+ choices=["positive_point", "negative_point"],
+ value="positive_point",
+ label="Point Type",
+ info="Positive: track these areas | Negative: avoid these areas"
+ )
+
+ with gr.Row():
+ reset_points_btn = gr.Button("🔄 Reset Points", variant="secondary", size="sm")
+
+ # Downloads section - hidden but still functional for local processing
+ with gr.Row(visible=False):
+ with gr.Column(scale=1):
+ tracking_video_download = gr.File(
+ label="📹 Download 2D Tracking Video",
+ interactive=False,
+ visible=False
+ )
+ with gr.Column(scale=1):
+ html_download = gr.File(
+ label="📄 Download 3D Visualization HTML",
+ interactive=False,
+ visible=False
+ )
+
+ # GitHub Star Section
+ gr.HTML("""
+
+ """)
+
+ # Acknowledgments Section
+ gr.HTML("""
+
+
+
+ 📚 Acknowledgments
+
+
+ Our 3D visualizer is adapted from TAPIP3D. We thank the authors for their excellent work and contribution to the computer vision community!
+
+
+ 📚 Visit TAPIP3D Repository
+
+
+
+ """)
+
+ # Footer
+ gr.HTML("""
+
+
+ Powered by SpatialTracker V2 | Built with ❤️ for the Computer Vision Community
+
+
+ """)
+
+ # Hidden state variables
+ original_image_state = gr.State(None)
+ selected_points = gr.State([])
+
+ # Event handlers
+ video_input.change(
+ fn=handle_video_upload,
+ inputs=[video_input],
+ outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
+ )
+
+ interactive_frame.select(
+ fn=select_point,
+ inputs=[original_image_state, selected_points, point_type],
+ outputs=[interactive_frame, selected_points]
+ )
+
+ reset_points_btn.click(
+ fn=reset_points,
+ inputs=[original_image_state, selected_points],
+ outputs=[interactive_frame, selected_points]
+ )
+
+ clear_all_btn.click(
+ fn=clear_all_with_download,
+ outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, tracking_video_download, html_download]
+ )
+
+ launch_btn.click(
+ fn=launch_viz,
+ inputs=[grid_size, vo_points, fps, original_image_state],
+ outputs=[viz_html, tracking_video_download, html_download]
+ )
+
+# Launch the interface
+if __name__ == "__main__":
+ print("🌟 Launching SpatialTracker V2 Local Version...")
+ print("🔗 Running in Local Processing Mode")
+
+ demo.launch(
+ server_name="0.0.0.0",
+ server_port=7860,
+ share=True,
+ debug=True,
+ show_error=True
+ )
\ No newline at end of file
diff --git a/app_3rd/README.md b/app_3rd/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f4f73be4ade195219e78a23077cb2e10249e144b
--- /dev/null
+++ b/app_3rd/README.md
@@ -0,0 +1,12 @@
+# 🌟 SpatialTrackerV2 Integrated with SAM 🌟
+SAM receives a point prompt and generates a mask for the target object, facilitating easy interaction to obtain the object's 3D trajectories with SpaTrack2.
+
+## Installation
+```
+
+python -m pip install git+https://github.com/facebookresearch/segment-anything.git
+cd app_3rd/sam_utils
+mkdir checkpoints
+cd checkpoints
+wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
+```
\ No newline at end of file
diff --git a/app_3rd/sam_utils/hf_sam_predictor.py b/app_3rd/sam_utils/hf_sam_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..5008aa230ebf8bc04adb4befe55b83f2f24f7f13
--- /dev/null
+++ b/app_3rd/sam_utils/hf_sam_predictor.py
@@ -0,0 +1,129 @@
+import gc
+import numpy as np
+import torch
+from typing import Optional, Tuple, List, Union
+import warnings
+import cv2
+try:
+ from transformers import SamModel, SamProcessor
+ from huggingface_hub import hf_hub_download
+ HF_AVAILABLE = True
+except ImportError:
+ HF_AVAILABLE = False
+ warnings.warn("transformers or huggingface_hub not available. HF SAM models will not work.")
+
+# Hugging Face model mapping
+HF_MODELS = {
+ 'vit_b': 'facebook/sam-vit-base',
+ 'vit_l': 'facebook/sam-vit-large',
+ 'vit_h': 'facebook/sam-vit-huge'
+}
+
+class HFSamPredictor:
+ """
+ Hugging Face version of SamPredictor that wraps the transformers SAM models.
+ This class provides the same interface as the original SamPredictor for seamless integration.
+ """
+
+ def __init__(self, model: SamModel, processor: SamProcessor, device: Optional[str] = None):
+ """
+ Initialize the HF SAM predictor.
+
+ Args:
+ model: The SAM model from transformers
+ processor: The SAM processor from transformers
+ device: Device to run the model on ('cuda', 'cpu', etc.)
+ """
+ self.model = model
+ self.processor = processor
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
+ self.model.to(self.device)
+ self.model.eval()
+
+ # Store the current image and its features
+ self.original_size = None
+ self.input_size = None
+ self.features = None
+ self.image = None
+
+ @classmethod
+ def from_pretrained(cls, model_name: str, device: Optional[str] = None) -> 'HFSamPredictor':
+ """
+ Load a SAM model from Hugging Face Hub.
+
+ Args:
+ model_name: Model name from HF_MODELS or direct HF model path
+ device: Device to load the model on
+
+ Returns:
+ HFSamPredictor instance
+ """
+ if not HF_AVAILABLE:
+ raise ImportError("transformers and huggingface_hub are required for HF SAM models")
+
+ # Map model type to HF model name if needed
+ if model_name in HF_MODELS:
+ model_name = HF_MODELS[model_name]
+
+ print(f"Loading SAM model from Hugging Face: {model_name}")
+
+ # Load model and processor
+ model = SamModel.from_pretrained(model_name)
+ processor = SamProcessor.from_pretrained(model_name)
+ return cls(model, processor, device)
+
+ def preprocess(self, image: np.ndarray,
+ input_points: List[List[float]], input_labels: List[int]) -> None:
+ """
+ Set the image for prediction. This preprocesses the image and extracts features.
+
+ Args:
+ image: Input image as numpy array (H, W, C) in RGB format
+ """
+ if image.dtype != np.uint8:
+ image = (image * 255).astype(np.uint8)
+
+ self.image = image
+ self.original_size = image.shape[:2]
+
+ # Use dummy point to ensure processor returns original_sizes & reshaped_input_sizes
+ inputs = self.processor(
+ images=image,
+ input_points=input_points,
+ input_labels=input_labels,
+ return_tensors="pt"
+ )
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
+
+ self.input_size = inputs['pixel_values'].shape[-2:]
+ self.features = inputs
+ return inputs
+
+
+def get_hf_sam_predictor(model_type: str = 'vit_h', device: Optional[str] = None,
+ image: Optional[np.ndarray] = None) -> HFSamPredictor:
+ """
+ Get a Hugging Face SAM predictor with the same interface as the original get_sam_predictor.
+
+ Args:
+ model_type: Model type ('vit_b', 'vit_l', 'vit_h')
+ device: Device to run the model on
+ image: Optional image to set immediately
+
+ Returns:
+ HFSamPredictor instance
+ """
+ if not HF_AVAILABLE:
+ raise ImportError("transformers and huggingface_hub are required for HF SAM models")
+
+ if device is None:
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ # Load the predictor
+ predictor = HFSamPredictor.from_pretrained(model_type, device)
+
+ # Set image if provided
+ if image is not None:
+ predictor.set_image(image)
+
+ return predictor
\ No newline at end of file
diff --git a/app_3rd/sam_utils/inference.py b/app_3rd/sam_utils/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c5ecbc2e98f9569e8396b6c293a7c619b9c7016
--- /dev/null
+++ b/app_3rd/sam_utils/inference.py
@@ -0,0 +1,123 @@
+import gc
+
+import numpy as np
+import torch
+from segment_anything import SamPredictor, sam_model_registry
+
+# Try to import HF SAM support
+try:
+ from app_3rd.sam_utils.hf_sam_predictor import get_hf_sam_predictor, HFSamPredictor
+ HF_AVAILABLE = True
+except ImportError:
+ HF_AVAILABLE = False
+
+models = {
+ 'vit_b': 'app_3rd/sam_utils/checkpoints/sam_vit_b_01ec64.pth',
+ 'vit_l': 'app_3rd/sam_utils/checkpoints/sam_vit_l_0b3195.pth',
+ 'vit_h': 'app_3rd/sam_utils/checkpoints/sam_vit_h_4b8939.pth'
+}
+
+
+def get_sam_predictor(model_type='vit_b', device=None, image=None, use_hf=True, predictor=None):
+ """
+ Get SAM predictor with option to use HuggingFace version
+
+ Args:
+ model_type: Model type ('vit_b', 'vit_l', 'vit_h')
+ device: Device to run on
+ image: Optional image to set immediately
+ use_hf: Whether to use HuggingFace SAM instead of original SAM
+ """
+ if predictor is not None:
+ return predictor
+ if use_hf:
+ if not HF_AVAILABLE:
+ raise ImportError("HuggingFace SAM not available. Install transformers and huggingface_hub.")
+ return get_hf_sam_predictor(model_type, device, image)
+
+ # Original SAM logic
+ if device is None and torch.cuda.is_available():
+ device = 'cuda'
+ elif device is None:
+ device = 'cpu'
+ # sam model
+ sam = sam_model_registry[model_type](checkpoint=models[model_type])
+ sam = sam.to(device)
+
+ predictor = SamPredictor(sam)
+ if image is not None:
+ predictor.set_image(image)
+ return predictor
+
+
+def run_inference(predictor, input_x, selected_points, multi_object: bool = False):
+ """
+ Run inference with either original SAM or HF SAM predictor
+
+ Args:
+ predictor: SamPredictor or HFSamPredictor instance
+ input_x: Input image
+ selected_points: List of (point, label) tuples
+ multi_object: Whether to handle multiple objects
+ """
+ if len(selected_points) == 0:
+ return []
+
+ # Check if using HF SAM
+ if isinstance(predictor, HFSamPredictor):
+ return _run_hf_inference(predictor, input_x, selected_points, multi_object)
+ else:
+ return _run_original_inference(predictor, input_x, selected_points, multi_object)
+
+
+def _run_original_inference(predictor: SamPredictor, input_x, selected_points, multi_object: bool = False):
+ """Run inference with original SAM"""
+ points = torch.Tensor(
+ [p for p, _ in selected_points]
+ ).to(predictor.device).unsqueeze(1)
+
+ labels = torch.Tensor(
+ [int(l) for _, l in selected_points]
+ ).to(predictor.device).unsqueeze(1)
+
+ transformed_points = predictor.transform.apply_coords_torch(
+ points, input_x.shape[:2])
+
+ masks, scores, logits = predictor.predict_torch(
+ point_coords=transformed_points[:,0][None],
+ point_labels=labels[:,0][None],
+ multimask_output=False,
+ )
+ masks = masks[0].cpu().numpy() # N 1 H W N is the number of points
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return [(masks, 'final_mask')]
+
+
+def _run_hf_inference(predictor: HFSamPredictor, input_x, selected_points, multi_object: bool = False):
+ """Run inference with HF SAM"""
+ # Prepare points and labels for HF SAM
+ select_pts = [[list(p) for p, _ in selected_points]]
+ select_lbls = [[int(l) for _, l in selected_points]]
+
+ # Preprocess inputs
+ inputs = predictor.preprocess(input_x, select_pts, select_lbls)
+
+ # Run inference
+ with torch.no_grad():
+ outputs = predictor.model(**inputs)
+
+ # Post-process masks
+ masks = predictor.processor.image_processor.post_process_masks(
+ outputs.pred_masks.cpu(),
+ inputs["original_sizes"].cpu(),
+ inputs["reshaped_input_sizes"].cpu(),
+ )
+ masks = masks[0][:,:1,...].cpu().numpy()
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return [(masks, 'final_mask')]
\ No newline at end of file
diff --git a/app_3rd/spatrack_utils/infer_track.py b/app_3rd/spatrack_utils/infer_track.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852b3d23cf9e5e1d79340ef6945ccc9652802f1
--- /dev/null
+++ b/app_3rd/spatrack_utils/infer_track.py
@@ -0,0 +1,194 @@
+from models.SpaTrackV2.models.predictor import Predictor
+import yaml
+import easydict
+import os
+import numpy as np
+import cv2
+import torch
+import torchvision.transforms as T
+from PIL import Image
+import io
+import moviepy.editor as mp
+from models.SpaTrackV2.utils.visualizer import Visualizer
+import tqdm
+from models.SpaTrackV2.models.utils import get_points_on_a_grid
+import glob
+from rich import print
+import argparse
+import decord
+from huggingface_hub import hf_hub_download
+
+config = {
+ "ckpt_dir": "Yuxihenry/SpatialTrackerCkpts", # HuggingFace repo ID
+ "cfg_dir": "config/magic_infer_moge.yaml",
+}
+
+def get_tracker_predictor(output_dir: str, vo_points: int = 756, tracker_model=None):
+ """
+ Initialize and return the tracker predictor and visualizer
+ Args:
+ output_dir: Directory to save visualization results
+ vo_points: Number of points for visual odometry
+ Returns:
+ Tuple of (tracker_predictor, visualizer)
+ """
+ viz = True
+ os.makedirs(output_dir, exist_ok=True)
+
+ with open(config["cfg_dir"], "r") as f:
+ cfg = yaml.load(f, Loader=yaml.FullLoader)
+ cfg = easydict.EasyDict(cfg)
+ cfg.out_dir = output_dir
+ cfg.model.track_num = vo_points
+
+ # Check if it's a local path or HuggingFace repo
+ if tracker_model is not None:
+ model = tracker_model
+ model.spatrack.track_num = vo_points
+ else:
+ if os.path.exists(config["ckpt_dir"]):
+ # Local file
+ model = Predictor.from_pretrained(config["ckpt_dir"], model_cfg=cfg["model"])
+ else:
+ # HuggingFace repo - download the model
+ print(f"Downloading model from HuggingFace: {config['ckpt_dir']}")
+ checkpoint_path = hf_hub_download(
+ repo_id=config["ckpt_dir"],
+ repo_type="model",
+ filename="SpaTrack3_offline.pth"
+ )
+ model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"])
+ model.eval()
+ model.to("cuda")
+
+ viser = Visualizer(save_dir=cfg.out_dir, grayscale=True,
+ fps=10, pad_value=0, tracks_leave_trace=5)
+
+ return model, viser
+
+def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3):
+ """
+ Run tracking on a video sequence
+ Args:
+ model: Tracker predictor instance
+ viser: Visualizer instance
+ temp_dir: Directory containing temporary files
+ video_name: Name of the video file (without extension)
+ grid_size: Size of the tracking grid
+ vo_points: Number of points for visual odometry
+ fps: Frames per second for visualization
+ """
+ # Setup paths
+ video_path = os.path.join(temp_dir, f"{video_name}.mp4")
+ mask_path = os.path.join(temp_dir, f"{video_name}.png")
+ out_dir = os.path.join(temp_dir, "results")
+ os.makedirs(out_dir, exist_ok=True)
+
+ # Load video using decord
+ video_reader = decord.VideoReader(video_path)
+ video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2) # Convert to tensor and permute to (N, C, H, W)
+
+ # resize make sure the shortest side is 336
+ h, w = video_tensor.shape[2:]
+ scale = max(336 / h, 336 / w)
+ if scale < 1:
+ new_h, new_w = int(h * scale), int(w * scale)
+ video_tensor = T.Resize((new_h, new_w))(video_tensor)
+ video_tensor = video_tensor[::fps].float()
+ depth_tensor = None
+ intrs = None
+ extrs = None
+ data_npz_load = {}
+
+ # Load and process mask
+ if os.path.exists(mask_path):
+ mask = cv2.imread(mask_path)
+ mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
+ mask = mask.sum(axis=-1)>0
+ else:
+ mask = np.ones_like(video_tensor[0,0].numpy())>0
+
+ # Get frame dimensions and create grid points
+ frame_H, frame_W = video_tensor.shape[2:]
+ grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
+
+ # Sample mask values at grid points and filter out points where mask=0
+ if os.path.exists(mask_path):
+ grid_pts_int = grid_pts[0].long()
+ mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
+ grid_pts = grid_pts[:, mask_values]
+
+ query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
+
+ # run vggt
+ if os.environ.get("VGGT_DIR", None) is not None:
+ vggt_model = VGGT()
+ vggt_model.load_state_dict(torch.load(VGGT_DIR))
+ vggt_model.eval()
+ vggt_model = vggt_model.to("cuda")
+ # process the image tensor
+ video_tensor = preprocess_image(video_tensor)[None]
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+ # Predict attributes including cameras, depth maps, and point maps.
+ aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_tensor.cuda()/255)
+ pose_enc = vggt_model.camera_head(aggregated_tokens_list)[-1]
+ # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, video_tensor.shape[-2:])
+ # Predict Depth Maps
+ depth_map, depth_conf = vggt_model.depth_head(aggregated_tokens_list, video_tensor.cuda()/255, ps_idx)
+ # clear the cache
+ del vggt_model, aggregated_tokens_list, ps_idx, pose_enc
+ torch.cuda.empty_cache()
+ depth_tensor = depth_map.squeeze().cpu().numpy()
+ extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
+ extrs[:, :3, :4] = extrinsic.squeeze().cpu().numpy()
+ intrs = intrinsic.squeeze().cpu().numpy()
+ video_tensor = video_tensor.squeeze()
+ #NOTE: 20% of the depth is not reliable
+ # threshold = depth_conf.squeeze().view(-1).quantile(0.5)
+ unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
+
+ # Run model inference
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ (
+ c2w_traj, intrs, point_map, conf_depth,
+ track3d_pred, track2d_pred, vis_pred, conf_pred, video
+ ) = model.forward(video_tensor, depth=depth_tensor,
+ intrs=intrs, extrs=extrs,
+ queries=query_xyt,
+ fps=1, full_point=False, iters_track=4,
+ query_no_BA=True, fixed_cam=False, stage=1,
+ support_frame=len(video_tensor)-1, replace_ratio=0.2)
+
+ # Resize results to avoid too large I/O Burden
+ max_size = 336
+ h, w = video.shape[2:]
+ scale = min(max_size / h, max_size / w)
+ if scale < 1:
+ new_h, new_w = int(h * scale), int(w * scale)
+ video = T.Resize((new_h, new_w))(video)
+ video_tensor = T.Resize((new_h, new_w))(video_tensor)
+ point_map = T.Resize((new_h, new_w))(point_map)
+ track2d_pred[...,:2] = track2d_pred[...,:2] * scale
+ intrs[:,:2,:] = intrs[:,:2,:] * scale
+ if depth_tensor is not None:
+ depth_tensor = T.Resize((new_h, new_w))(depth_tensor)
+ conf_depth = T.Resize((new_h, new_w))(conf_depth)
+
+ # Visualize tracks
+ viser.visualize(video=video[None],
+ tracks=track2d_pred[None][...,:2],
+ visibility=vis_pred[None],filename="test")
+
+ # Save in tapip3d format
+ data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
+ data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
+ data_npz_load["intrinsics"] = intrs.cpu().numpy()
+ data_npz_load["depths"] = point_map[:,2,...].cpu().numpy()
+ data_npz_load["video"] = (video_tensor).cpu().numpy()/255
+ data_npz_load["visibs"] = vis_pred.cpu().numpy()
+ data_npz_load["confs"] = conf_pred.cpu().numpy()
+ data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
+ np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
+
+ print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]")
\ No newline at end of file
diff --git a/config/__init__.py b/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/config/magic_infer_moge.yaml b/config/magic_infer_moge.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b6634c0b1f72cd9d24ab4ed22841ce9b5c71b826
--- /dev/null
+++ b/config/magic_infer_moge.yaml
@@ -0,0 +1,48 @@
+seed: 0
+# config the hydra logger, only in hydra `$` can be decoded as cite
+data: ./assets/room
+vis_track: false
+hydra:
+ run:
+ dir: .
+ output_subdir: null
+ job_logging: {}
+ hydra_logging: {}
+mixed_precision: bf16
+visdom:
+ viz_ip: "localhost"
+ port: 6666
+relax_load: false
+res_all: 336
+# config the ckpt path
+# ckpts: "/mnt/bn/xyxdata/home/codes/my_projs/SpaTrack2/checkpoints/new_base.pth"
+ckpts: "Yuxihenry/SpatialTracker_Files"
+batch_size: 1
+input:
+ type: image
+fps: 1
+model_wind_size: 32
+model:
+ backbone_cfg:
+ ckpt_dir: "checkpoints/model.pt"
+ chunk_size: 24 # downsample factor for patchified features
+ ckpt_fwd: true
+ ft_cfg:
+ mode: "fix"
+ paras_name: []
+ resolution: 336
+ max_len: 512
+ Track_cfg:
+ base_ckpt: "checkpoints/scaled_offline.pth"
+ base:
+ stride: 4
+ corr_radius: 3
+ window_len: 60
+ stablizer: True
+ mode: "online"
+ s_wind: 200
+ overlap: 4
+ track_num: 0
+
+dist_train:
+ num_nodes: 1
\ No newline at end of file
diff --git a/examples/backpack.mp4 b/examples/backpack.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..907644929bcfa0485684069f4090aef73a663f48
--- /dev/null
+++ b/examples/backpack.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b5ac6b2285ffb48e3a740e419e38c781df9c963589a5fd894e5b4e13dd6a8b8
+size 1208738
diff --git a/examples/ball.mp4 b/examples/ball.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..58abc8e1dd7070dad7a7e649037425cfd35bd5a8
--- /dev/null
+++ b/examples/ball.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:31f6e3bf875a85284b376c05170b4c08b546b7d5e95106848b1e3818a9d0db91
+size 3030268
diff --git a/examples/basketball.mp4 b/examples/basketball.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..416b16aaab4c02cb5e804a726a5688ff9853bd4d
--- /dev/null
+++ b/examples/basketball.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0df3b429d5fd64c298f2d79b2d818a4044e7341a71d70b957f60b24e313c3760
+size 2487837
diff --git a/examples/biker.mp4 b/examples/biker.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..3ccc385b273999a46e8fcaeab2bd4a519ee65562
--- /dev/null
+++ b/examples/biker.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fba880c24bdb8fa3b84b1b491d52f2c1f426fb09e34c3013603e5a549cf3b22b
+size 249196
diff --git a/examples/cinema_0.mp4 b/examples/cinema_0.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..2ba38c3c44ad330020d0c473351f59b6a4549cbb
--- /dev/null
+++ b/examples/cinema_0.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a68a5643c14f61c05d48e25a98ddf5cf0344d3ffcda08ad4a0adc989d49d7a9c
+size 1774022
diff --git a/examples/cinema_1.mp4 b/examples/cinema_1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..21842ed76be0e96f4c2d9e86c3d07cd7f7a4c371
--- /dev/null
+++ b/examples/cinema_1.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:99624e2d0fb2e9f994e46aefb904e884de37a6d78e7f6b6670e286eaa397e515
+size 2370749
diff --git a/examples/drifting.mp4 b/examples/drifting.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..d173677a65b49940ae8db51831fba2ba94f74d1f
--- /dev/null
+++ b/examples/drifting.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4f3937871117d3cc5d7da3ef31d1edf5626fc8372126b73590f75f05713fe97c
+size 4695804
diff --git a/examples/ego_kc1.mp4 b/examples/ego_kc1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a7a4b0fafa5c5a863b8f1bf9e43e0df2e55095cf
--- /dev/null
+++ b/examples/ego_kc1.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:22fe64e458e329e8b3c3e20b3725ffd85c3a2e725fd03909cf883d3fd02c80b3
+size 1365980
diff --git a/examples/ego_teaser.mp4 b/examples/ego_teaser.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..6e9453e8944b9ddce4d17bb471c5b2302fc8b912
--- /dev/null
+++ b/examples/ego_teaser.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8780b291b48046b1c7dea90712c1c3f59d60c03216df1c489f6f03e8d61fae5c
+size 7365665
diff --git a/examples/handwave.mp4 b/examples/handwave.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a2b56b27488b32cb075bc77c7c6aff636f820b9e
--- /dev/null
+++ b/examples/handwave.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e6dde7cf4ffa7c66b6861bb5abdedc49dfc4b5b4dd9dd46ee8415dd4953935b6
+size 2099369
diff --git a/examples/hockey.mp4 b/examples/hockey.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..c5c78106eb0634f44dd2b0b5d80ff48250072dcd
--- /dev/null
+++ b/examples/hockey.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8c3be095777b442dc401e7d1f489b749611ffade3563a01e4e3d1e511311bd86
+size 1795810
diff --git a/examples/ken_block_0.mp4 b/examples/ken_block_0.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..1504473a1f3f63cae0cb705c24a2bbb1f3f4802c
--- /dev/null
+++ b/examples/ken_block_0.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7b788faeb4d3206fa604d622a05268f1321ad6a229178fe12319d20c9438deb1
+size 196343
diff --git a/examples/kiss.mp4 b/examples/kiss.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..52e31f18c993b8abfc60a63e228a095fd22d0654
--- /dev/null
+++ b/examples/kiss.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f78fffc5108d95d4e5837d7607226f3dd9796615ea3481f2629c69ccd2ccb12f
+size 1073570
diff --git a/examples/kitchen.mp4 b/examples/kitchen.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..c6faa07535345c47fbd20a0cb1842e2ffcc8bfc8
--- /dev/null
+++ b/examples/kitchen.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3120e942a9b3d7b300928e43113b000fb5ccc209012a2c560ec26b8a04c2d5f9
+size 543970
diff --git a/examples/kitchen_egocentric.mp4 b/examples/kitchen_egocentric.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..e0e78bdad5b3423f10cd77087bde0aedab710867
--- /dev/null
+++ b/examples/kitchen_egocentric.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5468ab10d8d39b68b51fa616adc3d099dab7543e38dd221a0a7a20a2401824a2
+size 2176685
diff --git a/examples/pillow.mp4 b/examples/pillow.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..4bb67dd98c42a50cb734ef7ecbdea3e13c59af01
--- /dev/null
+++ b/examples/pillow.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8f05818f586d7b0796fcd4714ea4be489c93701598cadc86ce7973fc24655fee
+size 1407147
diff --git a/examples/protein.mp4 b/examples/protein.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..ff5c9adce98fedb1b62726b3004a774b1afb28a1
--- /dev/null
+++ b/examples/protein.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b2dc9cfceb0984b61ebc62fda4c826135ebe916c8c966a8123dcc3315d43b73f
+size 2002300
diff --git a/examples/pusht.mp4 b/examples/pusht.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..24d1e91e8adb7cd3dbebc1540e422c7c12b406ac
--- /dev/null
+++ b/examples/pusht.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:996d1923e36811a1069e4d6b5e8c0338d9068c0870ea09c4c04e13e9fbcd207a
+size 5256495
diff --git a/examples/robot1.mp4 b/examples/robot1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..71eb477af3b2f47cde219452c94039f832b53895
--- /dev/null
+++ b/examples/robot1.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a3b9e4449572129fdd96a751938e211241cdd86bcc56ffd33bfd23fc4d6e9c0
+size 1178671
diff --git a/examples/robot2.mp4 b/examples/robot2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..dcf8d0a2439c90a5c5d748db85d2fe6183565ba3
--- /dev/null
+++ b/examples/robot2.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:188b2d8824ce345c86a603bff210639a6158d72cf6119cc1d3f79d409ac68bb3
+size 867261
diff --git a/examples/robot_3.mp4 b/examples/robot_3.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..f60276ef0d411a7d8c9f94ccb3b5669d0315a9a9
--- /dev/null
+++ b/examples/robot_3.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:784a0f9c36a316d0da5745075dbc8cefd9ce60c25b067d3d80a1d52830df8a37
+size 1153015
diff --git a/examples/robot_unitree.mp4 b/examples/robot_unitree.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..6dbbfd885b4ef3ff5580578ec38fd6deab6a27c7
--- /dev/null
+++ b/examples/robot_unitree.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:99bc274f7613a665c6135085fe01691ebfaa9033101319071f37c550ab21d1ea
+size 1964268
diff --git a/examples/running.mp4 b/examples/running.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..632eaeb9e446fb3d76924f4ac62960c89287c905
--- /dev/null
+++ b/examples/running.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9ceb96b287fefb1c090dcd2f5db7634f808d2079413500beeb7b33023dfae51b
+size 7307897
diff --git a/examples/teleop2.mp4 b/examples/teleop2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..3b79c83909de3b30540e8ff7276fd23e565368bd
--- /dev/null
+++ b/examples/teleop2.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:59ea006a18227da8cf5db1fa50cd48e71ec7eb66fef48ea2158c325088bd9fee
+size 1077267
diff --git a/examples/vertical_place.mp4 b/examples/vertical_place.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a824311aad89b904245ac1cf721e5a67076adb67
--- /dev/null
+++ b/examples/vertical_place.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6c8061ae449f986113c2ecb17aefc2c13f737aecbcd41d6c057c88e6d41ac3ee
+size 719810
diff --git a/models/SpaTrackV2/models/SpaTrack.py b/models/SpaTrackV2/models/SpaTrack.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c4b832401ef866601de3be9b652034c7572b926
--- /dev/null
+++ b/models/SpaTrackV2/models/SpaTrack.py
@@ -0,0 +1,759 @@
+#python
+"""
+SpaTrackerV2, which is an unified model to estimate 'intrinsic',
+'video depth', 'extrinsic' and '3D Tracking' from casual video frames.
+
+Contact: DM yuxixiao@zju.edu.cn
+"""
+
+import os
+import numpy as np
+from typing import Literal, Union, List, Tuple, Dict
+import cv2
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+# from depth anything v2
+from huggingface_hub import PyTorchModelHubMixin # used for model hub
+from einops import rearrange
+from models.monoD.depth_anything_v2.dpt import DepthAnythingV2
+from models.moge.model.v1 import MoGeModel
+import copy
+from functools import partial
+from models.SpaTrackV2.models.tracker3D.TrackRefiner import TrackRefiner3D
+import kornia
+from models.SpaTrackV2.utils.model_utils import sample_features5d
+import utils3d
+from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
+from models.SpaTrackV2.models.utils import pose_enc2mat, matrix_to_quaternion, get_track_points, normalize_rgb
+import random
+
+class SpaTrack2(nn.Module, PyTorchModelHubMixin):
+ def __init__(
+ self,
+ loggers: list, # include [ viz, logger_tf, logger]
+ backbone_cfg,
+ Track_cfg=None,
+ chunk_size=24,
+ ckpt_fwd: bool = False,
+ ft_cfg=None,
+ resolution=518,
+ max_len=600, # the maximum video length we can preprocess,
+ track_num=768,
+ ):
+
+ self.chunk_size = chunk_size
+ self.max_len = max_len
+ self.resolution = resolution
+ # config the T-Lora Dinov2
+ #NOTE: initial the base model
+ base_cfg = copy.deepcopy(backbone_cfg)
+ backbone_ckpt_dir = base_cfg.pop('ckpt_dir', None)
+
+ super(SpaTrack2, self).__init__()
+ if os.path.exists(backbone_ckpt_dir)==False:
+ base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl')
+ else:
+ checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True)
+ base_model = MoGeModel(**checkpoint["model_config"])
+ base_model.load_state_dict(checkpoint['model'])
+ # avoid the base_model is a member of SpaTrack2
+ object.__setattr__(self, 'base_model', base_model)
+
+ # Tracker model
+ self.Track3D = TrackRefiner3D(Track_cfg)
+ track_base_ckpt_dir = Track_cfg.base_ckpt
+ if os.path.exists(track_base_ckpt_dir):
+ track_pretrain = torch.load(track_base_ckpt_dir)
+ self.Track3D.load_state_dict(track_pretrain, strict=False)
+
+ # wrap the function of make lora trainable
+ self.make_paras_trainable = partial(self.make_paras_trainable,
+ mode=ft_cfg.mode,
+ paras_name=ft_cfg.paras_name)
+ self.track_num = track_num
+
+ def make_paras_trainable(self, mode: str = 'fix', paras_name: List[str] = []):
+ # gradient required for the lora_experts and gate
+ for name, param in self.named_parameters():
+ if any(x in name for x in paras_name):
+ if mode == 'fix':
+ param.requires_grad = False
+ else:
+ param.requires_grad = True
+ else:
+ if mode == 'fix':
+ param.requires_grad = True
+ else:
+ param.requires_grad = False
+ total_params = sum(p.numel() for p in self.parameters())
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
+ print(f"Total parameters: {total_params}")
+ print(f"Trainable parameters: {trainable_params/total_params*100:.2f}%")
+
+ def ProcVid(self,
+ x: torch.Tensor):
+ """
+ split the video into several overlapped windows.
+
+ args:
+ x: the input video frames. [B, T, C, H, W]
+ outputs:
+ patch_size: the patch size of the video features
+ raises:
+ ValueError: if the input video is longer than `max_len`.
+
+ """
+ # normalize the input images
+ num_types = x.dtype
+ x = normalize_rgb(x, input_size=self.resolution)
+ x = x.to(num_types)
+ # get the video features
+ B, T, C, H, W = x.size()
+ if T > self.max_len:
+ raise ValueError(f"the video length should no more than {self.max_len}.")
+ # get the video features
+ patch_h, patch_w = H // 14, W // 14
+ patch_size = (patch_h, patch_w)
+ # resize and get the video features
+ x = x.view(B * T, C, H, W)
+ # operate the temporal encoding
+ return patch_size, x
+
+ def forward_stream(
+ self,
+ video: torch.Tensor,
+ queries: torch.Tensor = None,
+ T_org: int = None,
+ depth: torch.Tensor|np.ndarray|str=None,
+ unc_metric_in: torch.Tensor|np.ndarray|str=None,
+ intrs: torch.Tensor|np.ndarray|str=None,
+ extrs: torch.Tensor|np.ndarray|str=None,
+ queries_3d: torch.Tensor = None,
+ window_len: int = 16,
+ overlap_len: int = 4,
+ full_point: bool = False,
+ track2d_gt: torch.Tensor = None,
+ fixed_cam: bool = False,
+ query_no_BA: bool = False,
+ stage: int = 0,
+ support_frame: int = 0,
+ replace_ratio: float = 0.6,
+ annots_train: Dict = None,
+ iters_track=4,
+ **kwargs,
+ ):
+ # step 1 allocate the query points on the grid
+ T, C, H, W = video.shape
+
+ if annots_train is not None:
+ vis_gt = annots_train["vis"]
+ _, _, N = vis_gt.shape
+ number_visible = vis_gt.sum(dim=1)
+ ratio_rand = torch.rand(1, N, device=vis_gt.device)
+ first_positive_inds = get_nth_visible_time_index(vis_gt, (number_visible*ratio_rand).long().clamp(min=1, max=T))
+ assert (torch.gather(vis_gt, 1, first_positive_inds[:, None, :].repeat(1, T, 1)) < 0).sum() == 0
+
+ first_positive_inds = first_positive_inds.long()
+ gather = torch.gather(
+ annots_train["traj_3d"][...,:2], 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2)
+ )
+ xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
+ queries = torch.cat([first_positive_inds[:, :, None], xys], dim=-1)[0].cpu().numpy()
+
+
+ # Unfold video into segments of window_len with overlap_len
+ step_slide = window_len - overlap_len
+ if T < window_len:
+ video_unf = video.unsqueeze(0)
+ if depth is not None:
+ depth_unf = depth.unsqueeze(0)
+ else:
+ depth_unf = None
+ if unc_metric_in is not None:
+ unc_metric_unf = unc_metric_in.unsqueeze(0)
+ else:
+ unc_metric_unf = None
+ if intrs is not None:
+ intrs_unf = intrs.unsqueeze(0)
+ else:
+ intrs_unf = None
+ if extrs is not None:
+ extrs_unf = extrs.unsqueeze(0)
+ else:
+ extrs_unf = None
+ else:
+ video_unf = video.unfold(0, window_len, step_slide).permute(0, 4, 1, 2, 3) # [B, S, C, H, W]
+ if depth is not None:
+ depth_unf = depth.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
+ intrs_unf = intrs.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
+ else:
+ depth_unf = None
+ intrs_unf = None
+ if extrs is not None:
+ extrs_unf = extrs.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
+ else:
+ extrs_unf = None
+ if unc_metric_in is not None:
+ unc_metric_unf = unc_metric_in.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
+ else:
+ unc_metric_unf = None
+
+ # parallel
+ # Get number of segments
+ B = video_unf.shape[0]
+ #TODO: Process each segment in parallel using torch.nn.DataParallel
+ c2w_traj = torch.eye(4, 4)[None].repeat(T, 1, 1)
+ intrs_out = torch.eye(3, 3)[None].repeat(T, 1, 1)
+ point_map = torch.zeros(T, 3, H, W).cuda()
+ unc_metric = torch.zeros(T, H, W).cuda()
+ # set the queries
+ N, _ = queries.shape
+ track3d_pred = torch.zeros(T, N, 6).cuda()
+ track2d_pred = torch.zeros(T, N, 3).cuda()
+ vis_pred = torch.zeros(T, N, 1).cuda()
+ conf_pred = torch.zeros(T, N, 1).cuda()
+ dyn_preds = torch.zeros(T, N, 1).cuda()
+ # sort the queries by time
+ sorted_indices = np.argsort(queries[...,0])
+ sorted_inv_indices = np.argsort(sorted_indices)
+ sort_query = queries[sorted_indices]
+ sort_query = torch.from_numpy(sort_query).cuda()
+ if queries_3d is not None:
+ sort_query_3d = queries_3d[sorted_indices]
+ sort_query_3d = torch.from_numpy(sort_query_3d).cuda()
+
+ queries_len = 0
+ overlap_d = None
+ cache = None
+ loss = 0.0
+
+ for i in range(B):
+ segment = video_unf[i:i+1].cuda()
+ # Forward pass through model
+ # detect the key points for each frames
+
+ queries_new_mask = (sort_query[...,0] < i * step_slide + window_len) * (sort_query[...,0] >= (i * step_slide + overlap_len if i > 0 else 0))
+ if queries_3d is not None:
+ queries_new_3d = sort_query_3d[queries_new_mask]
+ queries_new_3d = queries_new_3d.float()
+ else:
+ queries_new_3d = None
+ queries_new = sort_query[queries_new_mask.bool()]
+ queries_new = queries_new.float()
+ if i > 0:
+ overlap2d = track2d_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
+ overlapvis = vis_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
+ overlapconf = conf_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
+ overlap_query = (overlapvis * overlapconf).max(dim=0)[1][None, ...]
+ overlap_xy = torch.gather(overlap2d, 0, overlap_query.repeat(1,1,2))
+ overlap_d = torch.gather(overlap2d, 0, overlap_query.repeat(1,1,3))[...,2].detach()
+ overlap_query = torch.cat([overlap_query[...,:1], overlap_xy], dim=-1)[0]
+ queries_new[...,0] -= i*step_slide
+ queries_new = torch.cat([overlap_query, queries_new], dim=0).detach()
+
+ if annots_train is None:
+ annots = {}
+ else:
+ annots = copy.deepcopy(annots_train)
+ annots["traj_3d"] = annots["traj_3d"][:, i*step_slide:i*step_slide+window_len, sorted_indices,:][...,:len(queries_new),:]
+ annots["vis"] = annots["vis"][:, i*step_slide:i*step_slide+window_len, sorted_indices][...,:len(queries_new)]
+ annots["poses_gt"] = annots["poses_gt"][:, i*step_slide:i*step_slide+window_len]
+ annots["depth_gt"] = annots["depth_gt"][:, i*step_slide:i*step_slide+window_len]
+ annots["intrs"] = annots["intrs"][:, i*step_slide:i*step_slide+window_len]
+ annots["traj_mat"] = annots["traj_mat"][:,i*step_slide:i*step_slide+window_len]
+
+ if depth is not None:
+ annots["depth_gt"] = depth_unf[i:i+1].to(segment.device).to(segment.dtype)
+ if unc_metric_in is not None:
+ annots["unc_metric"] = unc_metric_unf[i:i+1].to(segment.device).to(segment.dtype)
+ if intrs is not None:
+ intr_seg = intrs_unf[i:i+1].to(segment.device).to(segment.dtype)[0].clone()
+ focal = (intr_seg[:,0,0] / segment.shape[-1] + intr_seg[:,1,1]/segment.shape[-2]) / 2
+ pose_fake = torch.zeros(1, 8).to(depth.device).to(depth.dtype).repeat(segment.shape[1], 1)
+ pose_fake[:, -1] = focal
+ pose_fake[:,3]=1
+ annots["intrs_gt"] = intr_seg
+ if extrs is not None:
+ extrs_unf_norm = extrs_unf[i:i+1][0].clone()
+ extrs_unf_norm = torch.inverse(extrs_unf_norm[:1,...]) @ extrs_unf[i:i+1][0]
+ rot_vec = matrix_to_quaternion(extrs_unf_norm[:,:3,:3])
+ annots["poses_gt"] = torch.zeros(1, rot_vec.shape[0], 7).to(segment.device).to(segment.dtype)
+ annots["poses_gt"][:, :, 3:7] = rot_vec.to(segment.device).to(segment.dtype)[None]
+ annots["poses_gt"][:, :, :3] = extrs_unf_norm[:,:3,3].to(segment.device).to(segment.dtype)[None]
+ annots["use_extr"] = True
+
+ kwargs.update({"stage": stage})
+
+ #TODO: DEBUG
+ out = self.forward(segment, pts_q=queries_new,
+ pts_q_3d=queries_new_3d, overlap_d=overlap_d,
+ full_point=full_point,
+ fixed_cam=fixed_cam, query_no_BA=query_no_BA,
+ support_frame=segment.shape[1]-1,
+ cache=cache, replace_ratio=replace_ratio,
+ iters_track=iters_track,
+ **kwargs, annots=annots)
+ if self.training:
+ loss += out["loss"].squeeze()
+ # from models.SpaTrackV2.utils.visualizer import Visualizer
+ # vis_track = Visualizer(grayscale=False,
+ # fps=10, pad_value=50, tracks_leave_trace=0)
+ # vis_track.visualize(video=segment,
+ # tracks=out["traj_est"][...,:2],
+ # visibility=out["vis_est"],
+ # save_video=True)
+ # # visualize 4d
+ # import os, json
+ # import os.path as osp
+ # viser4d_dir = os.path.join("viser_4d_results")
+ # os.makedirs(viser4d_dir, exist_ok=True)
+ # depth_est = annots["depth_gt"][0]
+ # unc_metric = out["unc_metric"]
+ # mask = (unc_metric > 0.5).squeeze(1)
+ # # pose_est = out["poses_pred"].squeeze(0)
+ # pose_est = annots["traj_mat"][0]
+ # rgb_tracks = out["rgb_tracks"].squeeze(0)
+ # intrinsics = out["intrs"].squeeze(0)
+ # for i_k in range(out["depth"].shape[0]):
+ # img_i = out["imgs_raw"][0][i_k].permute(1, 2, 0).cpu().numpy()
+ # img_i = cv2.cvtColor(img_i, cv2.COLOR_BGR2RGB)
+ # cv2.imwrite(osp.join(viser4d_dir, f'frame_{i_k:04d}.png'), img_i)
+ # if stage == 1:
+ # depth = depth_est[i_k].squeeze().cpu().numpy()
+ # np.save(osp.join(viser4d_dir, f'frame_{i_k:04d}.npy'), depth)
+ # else:
+ # point_map_vis = out["points_map"][i_k].cpu().numpy()
+ # np.save(osp.join(viser4d_dir, f'point_{i_k:04d}.npy'), point_map_vis)
+ # np.save(os.path.join(viser4d_dir, f'intrinsics.npy'), intrinsics.cpu().numpy())
+ # np.save(os.path.join(viser4d_dir, f'extrinsics.npy'), pose_est.cpu().numpy())
+ # np.save(os.path.join(viser4d_dir, f'conf.npy'), mask.float().cpu().numpy())
+ # np.save(os.path.join(viser4d_dir, f'colored_track3d.npy'), rgb_tracks.cpu().numpy())
+
+ queries_len = len(queries_new)
+ # update the track3d and track2d
+ left_len = len(track3d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :])
+ track3d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["rgb_tracks"][0,:left_len,:queries_len,:]
+ track2d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["traj_est"][0,:left_len,:queries_len,:3]
+ vis_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["vis_est"][0,:left_len,:queries_len,None]
+ conf_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["conf_pred"][0,:left_len,:queries_len,None]
+ dyn_preds[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["dyn_preds"][0,:left_len,:queries_len,None]
+
+ # process the output for each segment
+ seg_c2w = out["poses_pred"][0]
+ seg_intrs = out["intrs"][0]
+ seg_point_map = out["points_map"]
+ seg_conf_depth = out["unc_metric"]
+
+ # cache management
+ cache = out["cache"]
+ for k in cache.keys():
+ if "_pyramid" in k:
+ for j in range(len(cache[k])):
+ if len(cache[k][j].shape) == 5:
+ cache[k][j] = cache[k][j][:,:,:,:queries_len,:]
+ elif len(cache[k][j].shape) == 4:
+ cache[k][j] = cache[k][j][:,:1,:queries_len,:]
+ elif "_pred_cache" in k:
+ cache[k] = cache[k][-overlap_len:,:queries_len,:]
+ else:
+ cache[k] = cache[k][-overlap_len:]
+
+ # update the results
+ idx_glob = i * step_slide
+ # refine part
+ # mask_update = sort_query[..., 0] < i * step_slide + window_len
+ # sort_query_pick = sort_query[mask_update]
+ intrs_out[idx_glob:idx_glob+window_len] = seg_intrs
+ point_map[idx_glob:idx_glob+window_len] = seg_point_map
+ unc_metric[idx_glob:idx_glob+window_len] = seg_conf_depth
+ # update the camera poses
+
+ # if using the ground truth pose
+ # if extrs_unf is not None:
+ # c2w_traj[idx_glob:idx_glob+window_len] = extrs_unf[i:i+1][0].to(c2w_traj.device).to(c2w_traj.dtype)
+ # else:
+ prev_c2w = c2w_traj[idx_glob:idx_glob+window_len][:1]
+ c2w_traj[idx_glob:idx_glob+window_len] = prev_c2w@seg_c2w.to(c2w_traj.device).to(c2w_traj.dtype)
+
+ track2d_pred = track2d_pred[:T_org,sorted_inv_indices,:]
+ track3d_pred = track3d_pred[:T_org,sorted_inv_indices,:]
+ vis_pred = vis_pred[:T_org,sorted_inv_indices,:]
+ conf_pred = conf_pred[:T_org,sorted_inv_indices,:]
+ dyn_preds = dyn_preds[:T_org,sorted_inv_indices,:]
+ unc_metric = unc_metric[:T_org,:]
+ point_map = point_map[:T_org,:]
+ intrs_out = intrs_out[:T_org,:]
+ c2w_traj = c2w_traj[:T_org,:]
+ if self.training:
+ ret = {
+ "loss": loss,
+ "depth_loss": 0.0,
+ "ab_loss": 0.0,
+ "vis_loss": out["vis_loss"],
+ "track_loss": out["track_loss"],
+ "conf_loss": out["conf_loss"],
+ "dyn_loss": out["dyn_loss"],
+ "sync_loss": out["sync_loss"],
+ "poses_pred": c2w_traj[None],
+ "intrs": intrs_out[None],
+ "points_map": point_map,
+ "track3d_pred": track3d_pred[None],
+ "rgb_tracks": track3d_pred[None],
+ "track2d_pred": track2d_pred[None],
+ "traj_est": track2d_pred[None],
+ "vis_est": vis_pred[None], "conf_pred": conf_pred[None],
+ "dyn_preds": dyn_preds[None],
+ "imgs_raw": video[None],
+ "unc_metric": unc_metric,
+ }
+
+ return ret
+ else:
+ return c2w_traj, intrs_out, point_map, unc_metric, track3d_pred, track2d_pred, vis_pred, conf_pred
+ def forward(self,
+ x: torch.Tensor,
+ annots: Dict = {},
+ pts_q: torch.Tensor = None,
+ pts_q_3d: torch.Tensor = None,
+ overlap_d: torch.Tensor = None,
+ full_point = False,
+ fixed_cam = False,
+ support_frame = 0,
+ query_no_BA = False,
+ cache = None,
+ replace_ratio = 0.6,
+ iters_track=4,
+ **kwargs):
+ """
+ forward the video camera model, which predict (
+ `intr` `camera poses` `video depth`
+ )
+
+ args:
+ x: the input video frames. [B, T, C, H, W]
+ annots: the annotations for video frames.
+ {
+ "poses_gt": the pose encoding for the video frames. [B, T, 7]
+ "depth_gt": the ground truth depth for the video frames. [B, T, 1, H, W],
+ "metric": bool, whether to calculate the metric for the video frames.
+ }
+ """
+ self.support_frame = support_frame
+
+ #TODO: to adjust a little bit
+ track_loss=ab_loss=vis_loss=track_loss=conf_loss=dyn_loss=0.0
+ B, T, _, H, W = x.shape
+ imgs_raw = x.clone()
+ # get the video split and features for each segment
+ patch_size, x_resize = self.ProcVid(x)
+ x_resize = rearrange(x_resize, "(b t) c h w -> b t c h w", b=B)
+ H_resize, W_resize = x_resize.shape[-2:]
+
+ prec_fx = W / W_resize
+ prec_fy = H / H_resize
+ # get patch size
+ P_H, P_W = patch_size
+
+ # get the depth, pointmap and mask
+ #TODO: Release DepthAnything Version
+ points_map_gt = None
+ with torch.no_grad():
+ if_gt_depth = (("depth_gt" in annots.keys())) and (kwargs.get('stage', 0)==1 or kwargs.get('stage', 0)==3)
+ if if_gt_depth==False:
+ if cache is not None:
+ T_cache = cache["points_map"].shape[0]
+ T_new = T - T_cache
+ x_resize_new = x_resize[:, T_cache:]
+ else:
+ T_new = T
+ x_resize_new = x_resize
+ # infer with chunk
+ chunk_size = self.chunk_size
+ metric_depth = []
+ intrs = []
+ unc_metric = []
+ mask = []
+ points_map = []
+ normals = []
+ normals_mask = []
+ for i in range(0, B*T_new, chunk_size):
+ output = self.base_model.infer(x_resize_new.view(B*T_new, -1, H_resize, W_resize)[i:i+chunk_size])
+ metric_depth.append(output['depth'])
+ intrs.append(output['intrinsics'])
+ unc_metric.append(output['mask_prob'])
+ mask.append(output['mask'])
+ points_map.append(output['points'])
+ normals_i, normals_mask_i = utils3d.torch.points_to_normals(output['points'], mask=output['mask'])
+ normals.append(normals_i)
+ normals_mask.append(normals_mask_i)
+
+ metric_depth = torch.cat(metric_depth, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
+ intrs = torch.cat(intrs, dim=0).view(B, T_new, 3, 3).to(x.dtype)
+ intrs[:,:,0,:] *= W_resize
+ intrs[:,:,1,:] *= H_resize
+ # points_map = torch.cat(points_map, dim=0)
+ mask = torch.cat(mask, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
+ # cat the normals
+ normals = torch.cat(normals, dim=0)
+ normals_mask = torch.cat(normals_mask, dim=0)
+
+ metric_depth = metric_depth.clone()
+ metric_depth[metric_depth == torch.inf] = 0
+ _depths = metric_depth[metric_depth > 0].reshape(-1)
+ q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values
+ q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values
+ iqr = q75 - q25
+ upper_bound = (q75 + 0.8*iqr).clamp(min=1e-6, max=10*q25)
+ _depth_roi = torch.tensor(
+ [1e-1, upper_bound.item()],
+ dtype=metric_depth.dtype,
+ device=metric_depth.device
+ )
+ mask_roi = (metric_depth > _depth_roi[0]) & (metric_depth < _depth_roi[1])
+ mask = mask * mask_roi
+ mask = mask * (~(utils3d.torch.depth_edge(metric_depth, rtol=0.03, mask=mask.bool()))) * normals_mask[:,None,...]
+ points_map = depth_to_points_colmap(metric_depth.squeeze(1), intrs.view(B*T_new, 3, 3))
+ unc_metric = torch.cat(unc_metric, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
+ unc_metric *= mask
+ if full_point:
+ unc_metric = (~(utils3d.torch.depth_edge(metric_depth, rtol=0.1, mask=torch.ones_like(metric_depth).bool()))).float() * (metric_depth != 0)
+ if cache is not None:
+ assert B==1, "only support batch size 1 right now."
+ unc_metric = torch.cat([cache["unc_metric"], unc_metric], dim=0)
+ intrs = torch.cat([cache["intrs"][None], intrs], dim=1)
+ points_map = torch.cat([cache["points_map"].permute(0,2,3,1), points_map], dim=0)
+ metric_depth = torch.cat([cache["metric_depth"], metric_depth], dim=0)
+
+ if "poses_gt" in annots.keys():
+ intrs, c2w_traj_gt = pose_enc2mat(annots["poses_gt"],
+ H_resize, W_resize, self.resolution)
+ else:
+ c2w_traj_gt = None
+
+ if "intrs_gt" in annots.keys():
+ intrs = annots["intrs_gt"].view(B, T, 3, 3)
+ fx_factor = W_resize / W
+ fy_factor = H_resize / H
+ intrs[:,:,0,:] *= fx_factor
+ intrs[:,:,1,:] *= fy_factor
+
+ if "depth_gt" in annots.keys():
+
+ metric_depth_gt = annots['depth_gt'].view(B*T, 1, H, W)
+ metric_depth_gt = F.interpolate(metric_depth_gt,
+ size=(H_resize, W_resize), mode='nearest')
+
+ _depths = metric_depth_gt[metric_depth_gt > 0].reshape(-1)
+ q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values
+ q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values
+ iqr = q75 - q25
+ upper_bound = (q75 + 0.8*iqr).clamp(min=1e-6, max=10*q25)
+ _depth_roi = torch.tensor(
+ [1e-1, upper_bound.item()],
+ dtype=metric_depth_gt.dtype,
+ device=metric_depth_gt.device
+ )
+ mask_roi = (metric_depth_gt > _depth_roi[0]) & (metric_depth_gt < _depth_roi[1])
+ # if (upper_bound > 200).any():
+ # import pdb; pdb.set_trace()
+ if (kwargs.get('stage', 0) == 2):
+ unc_metric = ((metric_depth_gt > 0)*(mask_roi) * (unc_metric > 0.5)).float()
+ metric_depth_gt[metric_depth_gt > 10*q25] = 0
+ else:
+ unc_metric = ((metric_depth_gt > 0)*(mask_roi)).float()
+ unc_metric *= (~(utils3d.torch.depth_edge(metric_depth_gt, rtol=0.03, mask=mask_roi.bool()))).float()
+ # filter the sky
+ metric_depth_gt[metric_depth_gt > 10*q25] = 0
+ if "unc_metric" in annots.keys():
+ unc_metric_ = F.interpolate(annots["unc_metric"].permute(1,0,2,3),
+ size=(H_resize, W_resize), mode='nearest')
+ unc_metric = unc_metric * unc_metric_
+ if if_gt_depth:
+ points_map = depth_to_points_colmap(metric_depth_gt.squeeze(1), intrs.view(B*T, 3, 3))
+ metric_depth = metric_depth_gt
+ points_map_gt = points_map
+ else:
+ points_map_gt = depth_to_points_colmap(metric_depth_gt.squeeze(1), intrs.view(B*T, 3, 3))
+
+ # track the 3d points
+ ret_track = None
+ regular_track = True
+ dyn_preds, final_tracks = None, None
+
+ if "use_extr" in annots.keys():
+ init_pose = True
+ else:
+ init_pose = False
+ # set the custom vid and valid only
+ custom_vid = annots.get("custom_vid", False)
+ valid_only = annots.get("data_dir", [""])[0] == "stereo4d"
+ if self.training:
+ if (annots["vis"].sum() > 0) and (kwargs.get('stage', 0)==1 or kwargs.get('stage', 0)==3):
+ traj3d = annots['traj_3d']
+ if (kwargs.get('stage', 0)==1) and (annots.get("custom_vid", False)==False):
+ support_pts_q = get_track_points(H_resize, W_resize,
+ T, x.device, query_size=self.track_num // 2,
+ support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental")[None]
+ else:
+ support_pts_q = get_track_points(H_resize, W_resize,
+ T, x.device, query_size=random.randint(32, 256),
+ support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental")[None]
+ if pts_q is not None:
+ pts_q = pts_q[None,None]
+ ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
+ metric_depth,
+ unc_metric.detach(), points_map, pts_q,
+ intrs=intrs.clone(), cache=cache,
+ prec_fx=prec_fx, prec_fy=prec_fy, overlap_d=overlap_d,
+ vis_gt=annots['vis'], traj3d_gt=traj3d, iters=iters_track,
+ cam_gt=c2w_traj_gt, support_pts_q=support_pts_q, custom_vid=custom_vid,
+ init_pose=init_pose, fixed_cam=fixed_cam, stage=kwargs.get('stage', 0),
+ points_map_gt=points_map_gt, valid_only=valid_only, replace_ratio=replace_ratio)
+ else:
+ ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
+ metric_depth,
+ unc_metric.detach(), points_map, traj3d[..., :2],
+ intrs=intrs.clone(), cache=cache,
+ prec_fx=prec_fx, prec_fy=prec_fy, overlap_d=overlap_d,
+ vis_gt=annots['vis'], traj3d_gt=traj3d, iters=iters_track,
+ cam_gt=c2w_traj_gt, support_pts_q=support_pts_q, custom_vid=custom_vid,
+ init_pose=init_pose, fixed_cam=fixed_cam, stage=kwargs.get('stage', 0),
+ points_map_gt=points_map_gt, valid_only=valid_only, replace_ratio=replace_ratio)
+ regular_track = False
+
+
+ if regular_track:
+ if pts_q is None:
+ pts_q = get_track_points(H_resize, W_resize,
+ T, x.device, query_size=self.track_num,
+ support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental" if self.training else "incremental")[None]
+ support_pts_q = None
+ else:
+ pts_q = pts_q[None,None]
+ # resize the query points
+ pts_q[...,1] *= W_resize / W
+ pts_q[...,2] *= H_resize / H
+
+ if pts_q_3d is not None:
+ pts_q_3d = pts_q_3d[None,None]
+ # resize the query points
+ pts_q_3d[...,1] *= W_resize / W
+ pts_q_3d[...,2] *= H_resize / H
+ else:
+ # adjust the query with uncertainty
+ if (full_point==False) and (overlap_d is None):
+ pts_q_unc = sample_features5d(unc_metric[None], pts_q).squeeze()
+ pts_q = pts_q[:,:,pts_q_unc>0.5,:]
+ if (pts_q_unc<0.5).sum() > 0:
+ # pad the query points
+ pad_num = pts_q_unc.shape[0] - pts_q.shape[2]
+ # pick the random indices
+ indices = torch.randint(0, pts_q.shape[2], (pad_num,), device=pts_q.device)
+ pad_pts = indices
+ pts_q = torch.cat([pts_q, pts_q[:,:,pad_pts,:]], dim=-2)
+
+ support_pts_q = get_track_points(H_resize, W_resize,
+ T, x.device, query_size=self.track_num,
+ support_frame=self.support_frame,
+ unc_metric=unc_metric, mode="mixed")[None]
+
+ points_map[points_map>1e3] = 0
+ points_map = depth_to_points_colmap(metric_depth.squeeze(1), intrs.view(B*T, 3, 3))
+ ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
+ metric_depth,
+ unc_metric.detach(), points_map, pts_q,
+ pts_q_3d=pts_q_3d, intrs=intrs.clone(),cache=cache,
+ overlap_d=overlap_d, cam_gt=c2w_traj_gt if kwargs.get('stage', 0)==1 else None,
+ prec_fx=prec_fx, prec_fy=prec_fy, support_pts_q=support_pts_q, custom_vid=custom_vid, valid_only=valid_only,
+ fixed_cam=fixed_cam, query_no_BA=query_no_BA, init_pose=init_pose, iters=iters_track,
+ stage=kwargs.get('stage', 0), points_map_gt=points_map_gt, replace_ratio=replace_ratio)
+ intrs = intrs_org
+ points_map = point_map_org_refined
+ c2w_traj = ret_track["cam_pred"]
+
+ if ret_track is not None:
+ if ret_track["loss"] is not None:
+ track_loss, conf_loss, dyn_loss, vis_loss, point_map_loss, scale_loss, shift_loss, sync_loss= ret_track["loss"]
+
+ # update the cache
+ cache.update({"metric_depth": metric_depth, "unc_metric": unc_metric, "points_map": points_map, "intrs": intrs[0]})
+ # output
+ depth = F.interpolate(metric_depth,
+ size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
+ points_map = F.interpolate(points_map,
+ size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
+ unc_metric = F.interpolate(unc_metric,
+ size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
+
+ if self.training:
+
+ loss = track_loss + conf_loss + dyn_loss + sync_loss + vis_loss + point_map_loss + (scale_loss + shift_loss)*50
+ ret = {"loss": loss,
+ "depth_loss": point_map_loss,
+ "ab_loss": (scale_loss + shift_loss)*50,
+ "vis_loss": vis_loss, "track_loss": track_loss,
+ "poses_pred": c2w_traj, "dyn_preds": dyn_preds, "traj_est": final_tracks, "conf_loss": conf_loss,
+ "imgs_raw": imgs_raw, "rgb_tracks": rgb_tracks, "vis_est": ret_track['vis_pred'],
+ "depth": depth, "points_map": points_map, "unc_metric": unc_metric, "intrs": intrs, "dyn_loss": dyn_loss,
+ "sync_loss": sync_loss, "conf_pred": ret_track['conf_pred'], "cache": cache,
+ }
+
+ else:
+
+ if ret_track is not None:
+ traj_est = ret_track['preds']
+ traj_est[..., 0] *= W / W_resize
+ traj_est[..., 1] *= H / H_resize
+ vis_est = ret_track['vis_pred']
+ else:
+ traj_est = torch.zeros(B, self.track_num // 2, 3).to(x.device)
+ vis_est = torch.zeros(B, self.track_num // 2).to(x.device)
+
+ if intrs is not None:
+ intrs[..., 0, :] *= W / W_resize
+ intrs[..., 1, :] *= H / H_resize
+ ret = {"poses_pred": c2w_traj, "dyn_preds": dyn_preds,
+ "depth": depth, "traj_est": traj_est, "vis_est": vis_est, "imgs_raw": imgs_raw,
+ "rgb_tracks": rgb_tracks, "intrs": intrs, "unc_metric": unc_metric, "points_map": points_map,
+ "conf_pred": ret_track['conf_pred'], "cache": cache,
+ }
+
+ return ret
+
+
+
+
+# three stages of training
+
+# stage 1:
+# gt depth and intrinsics synthetic (includes Dynamic Replica, Kubric, Pointodyssey, Vkitti, TartanAir and Indoor() ) Motion Patern (tapvid3d)
+# Tracking and Pose as well -> based on gt depth and intrinsics
+# (Finished) -> (megasam + base model) vs. tapip3d. (use depth from megasam or pose, which keep the same setting as tapip3d.)
+
+# stage 2: fixed 3D tracking
+# Joint depth refiner
+# input depth from whatever + rgb -> temporal module + scale and shift token -> coarse alignment -> scale and shift
+# estimate the 3D tracks -> 3D tracks combine with pointmap -> update for pointmap (iteratively) -> residual map B T 3 H W
+# ongoing two days
+
+# stage 3: train multi windows by propagation
+# 4 frames overlapped -> train on 64 -> fozen image encoder and finetuning the transformer (learnable parameters pretty small)
+
+# types of scenarioes:
+# 1. auto driving (waymo open dataset)
+# 2. robot
+# 3. internet ego video
+
+
+
+# Iterative Transformer -- Solver -- General Neural MegaSAM + Tracks
+# Update Variables:
+# 1. 3D tracks B T N 3 xyz.
+# 2. 2D tracks B T N 2 x y.
+# 3. Dynamic Mask B T H W.
+# 4. Camera Pose B T 4 4.
+# 5. Video Depth.
+
+# (RGB, RGBD, RGBD+Pose) x (Static, Dynamic)
+# Campatiablity by product.
\ No newline at end of file
diff --git a/models/SpaTrackV2/models/__init__.py b/models/SpaTrackV2/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/SpaTrackV2/models/blocks.py b/models/SpaTrackV2/models/blocks.py
new file mode 100755
index 0000000000000000000000000000000000000000..1b7de1b75d3af76907e5989494cb38b9ae6cd295
--- /dev/null
+++ b/models/SpaTrackV2/models/blocks.py
@@ -0,0 +1,519 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.cuda.amp import autocast
+from einops import rearrange
+import collections
+from functools import partial
+from itertools import repeat
+import torchvision.models as tvm
+from torch.utils.checkpoint import checkpoint
+from models.monoD.depth_anything.dpt import DPTHeadEnc, DPTHead
+from typing import Union, Tuple
+from torch import Tensor
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ return val if exists(val) else d
+
+
+to_2tuple = _ntuple(2)
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+class Mlp(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.0,
+ use_conv=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+class Attention(nn.Module):
+ def __init__(self, query_dim, context_dim=None,
+ num_heads=8, dim_head=48, qkv_bias=False, flash=False):
+ super().__init__()
+ inner_dim = self.inner_dim = dim_head * num_heads
+ context_dim = default(context_dim, query_dim)
+ self.scale = dim_head**-0.5
+ self.heads = num_heads
+ self.flash = flash
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
+ self.to_out = nn.Linear(inner_dim, query_dim)
+
+ def forward(self, x, context=None, attn_bias=None):
+ B, N1, _ = x.shape
+ C = self.inner_dim
+ h = self.heads
+ q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
+ context = default(context, x)
+ k, v = self.to_kv(context).chunk(2, dim=-1)
+
+ N2 = context.shape[1]
+ k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
+ v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
+
+ with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16):
+ if self.flash==False:
+ sim = (q @ k.transpose(-2, -1)) * self.scale
+ if attn_bias is not None:
+ sim = sim + attn_bias
+ if sim.abs().max()>1e2:
+ import pdb; pdb.set_trace()
+ attn = sim.softmax(dim=-1)
+ x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
+ else:
+ input_args = [x.contiguous() for x in [q, k, v]]
+ x = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).reshape(B,N1,-1) # type: ignore
+
+ if self.to_out.bias.dtype != x.dtype:
+ x = x.to(self.to_out.bias.dtype)
+
+ return self.to_out(x)
+
+
+class VGG19(nn.Module):
+ def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
+ super().__init__()
+ self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
+ self.amp = amp
+ self.amp_dtype = amp_dtype
+
+ def forward(self, x, **kwargs):
+ with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
+ feats = {}
+ scale = 1
+ for layer in self.layers:
+ if isinstance(layer, nn.MaxPool2d):
+ feats[scale] = x
+ scale = scale*2
+ x = layer(x)
+ return feats
+
+class CNNandDinov2(nn.Module):
+ def __init__(self, cnn_kwargs = None, amp = True, amp_dtype = torch.float16):
+ super().__init__()
+ # in case the Internet connection is not stable, please load the DINOv2 locally
+ self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
+ 'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
+
+ state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
+ self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
+
+
+ cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
+ self.cnn = VGG19(**cnn_kwargs)
+ self.amp = amp
+ self.amp_dtype = amp_dtype
+ if self.amp:
+ dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
+ self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
+
+
+ def train(self, mode: bool = True):
+ return self.cnn.train(mode)
+
+ def forward(self, x, upsample = False):
+ B,C,H,W = x.shape
+ feature_pyramid = self.cnn(x)
+
+ if not upsample:
+ with torch.no_grad():
+ if self.dinov2_vitl14[0].device != x.device:
+ self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
+ dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
+ features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
+ del dinov2_features_16
+ feature_pyramid[16] = features_16
+ return feature_pyramid
+
+class Dinov2(nn.Module):
+ def __init__(self, amp = True, amp_dtype = torch.float16):
+ super().__init__()
+ # in case the Internet connection is not stable, please load the DINOv2 locally
+ self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
+ 'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
+
+ state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
+ self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
+
+ self.amp = amp
+ self.amp_dtype = amp_dtype
+ if self.amp:
+ self.dinov2_vitl14 = self.dinov2_vitl14.to(self.amp_dtype)
+
+ def forward(self, x, upsample = False):
+ B,C,H,W = x.shape
+ mean_ = torch.tensor([0.485, 0.456, 0.406],
+ device=x.device).view(1, 3, 1, 1)
+ std_ = torch.tensor([0.229, 0.224, 0.225],
+ device=x.device).view(1, 3, 1, 1)
+ x = (x+1)/2
+ x = (x - mean_)/std_
+ h_re, w_re = 560, 560
+ x_resize = F.interpolate(x, size=(h_re, w_re),
+ mode='bilinear', align_corners=True)
+ if not upsample:
+ with torch.no_grad():
+ dinov2_features_16 = self.dinov2_vitl14.forward_features(x_resize.to(self.amp_dtype))
+ features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,h_re//14, w_re//14)
+ del dinov2_features_16
+ features_16 = F.interpolate(features_16, size=(H//8, W//8), mode="bilinear", align_corners=True)
+ return features_16
+
+class AttnBlock(nn.Module):
+ """
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
+ """
+
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0,
+ flash=False, ckpt_fwd=False, debug=False, **block_kwargs):
+ super().__init__()
+ self.debug=debug
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.flash=flash
+
+ self.attn = Attention(
+ hidden_size, num_heads=num_heads, qkv_bias=True, flash=flash,
+ **block_kwargs
+ )
+ self.ls = LayerScale(hidden_size, init_values=0.005)
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(
+ in_features=hidden_size,
+ hidden_features=mlp_hidden_dim,
+ act_layer=approx_gelu,
+ )
+ self.ckpt_fwd = ckpt_fwd
+ def forward(self, x):
+ if self.debug:
+ print(x.max(), x.min(), x.mean())
+ if self.ckpt_fwd:
+ x = x + checkpoint(self.attn, self.norm1(x), use_reentrant=False)
+ else:
+ x = x + self.attn(self.norm1(x))
+
+ x = x + self.ls(self.mlp(self.norm2(x)))
+ return x
+
+class CrossAttnBlock(nn.Module):
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, head_dim=48,
+ flash=False, ckpt_fwd=False, **block_kwargs):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.norm_context = nn.LayerNorm(hidden_size)
+
+ self.cross_attn = Attention(
+ hidden_size, context_dim=context_dim, dim_head=head_dim,
+ num_heads=num_heads, qkv_bias=True, **block_kwargs, flash=flash,
+ )
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(
+ in_features=hidden_size,
+ hidden_features=mlp_hidden_dim,
+ act_layer=approx_gelu,
+ drop=0,
+ )
+ self.ckpt_fwd = ckpt_fwd
+ def forward(self, x, context):
+ if self.ckpt_fwd:
+ with autocast():
+ x = x + checkpoint(self.cross_attn,
+ self.norm1(x), self.norm_context(context), use_reentrant=False)
+ else:
+ with autocast():
+ x = x + self.cross_attn(
+ self.norm1(x), self.norm_context(context)
+ )
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+def bilinear_sampler(img, coords, mode="bilinear", mask=False):
+ """Wrapper for grid_sample, uses pixel coordinates"""
+ H, W = img.shape[-2:]
+ xgrid, ygrid = coords.split([1, 1], dim=-1)
+ # go to 0,1 then 0,2 then -1,1
+ xgrid = 2 * xgrid / (W - 1) - 1
+ ygrid = 2 * ygrid / (H - 1) - 1
+
+ grid = torch.cat([xgrid, ygrid], dim=-1)
+ img = F.grid_sample(img, grid, align_corners=True, mode=mode)
+
+ if mask:
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
+ return img, mask.float()
+
+ return img
+
+
+class CorrBlock:
+ def __init__(self, fmaps, num_levels=4, radius=4, depths_dnG=None):
+ B, S, C, H_prev, W_prev = fmaps.shape
+ self.S, self.C, self.H, self.W = S, C, H_prev, W_prev
+
+ self.num_levels = num_levels
+ self.radius = radius
+ self.fmaps_pyramid = []
+ self.depth_pyramid = []
+ self.fmaps_pyramid.append(fmaps)
+ if depths_dnG is not None:
+ self.depth_pyramid.append(depths_dnG)
+ for i in range(self.num_levels - 1):
+ if depths_dnG is not None:
+ depths_dnG_ = depths_dnG.reshape(B * S, 1, H_prev, W_prev)
+ depths_dnG_ = F.avg_pool2d(depths_dnG_, 2, stride=2)
+ _, _, H, W = depths_dnG_.shape
+ depths_dnG = depths_dnG_.reshape(B, S, 1, H, W)
+ self.depth_pyramid.append(depths_dnG)
+ fmaps_ = fmaps.reshape(B * S, C, H_prev, W_prev)
+ fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
+ _, _, H, W = fmaps_.shape
+ fmaps = fmaps_.reshape(B, S, C, H, W)
+ H_prev = H
+ W_prev = W
+ self.fmaps_pyramid.append(fmaps)
+
+ def sample(self, coords):
+ r = self.radius
+ B, S, N, D = coords.shape
+ assert D == 2
+
+ H, W = self.H, self.W
+ out_pyramid = []
+ for i in range(self.num_levels):
+ corrs = self.corrs_pyramid[i] # B, S, N, H, W
+ _, _, _, H, W = corrs.shape
+
+ dx = torch.linspace(-r, r, 2 * r + 1)
+ dy = torch.linspace(-r, r, 2 * r + 1)
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
+ coords.device
+ )
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
+ coords_lvl = centroid_lvl + delta_lvl
+ corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
+ corrs = corrs.view(B, S, N, -1)
+ out_pyramid.append(corrs)
+
+ out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
+ return out.contiguous().float()
+
+ def corr(self, targets):
+ B, S, N, C = targets.shape
+ assert C == self.C
+ assert S == self.S
+
+ fmap1 = targets
+
+ self.corrs_pyramid = []
+ for fmaps in self.fmaps_pyramid:
+ _, _, _, H, W = fmaps.shape
+ fmap2s = fmaps.view(B, S, C, H * W)
+ corrs = torch.matmul(fmap1, fmap2s)
+ corrs = corrs.view(B, S, N, H, W)
+ corrs = corrs / torch.sqrt(torch.tensor(C).float())
+ self.corrs_pyramid.append(corrs)
+
+ def corr_sample(self, targets, coords, coords_dp=None):
+ B, S, N, C = targets.shape
+ r = self.radius
+ Dim_c = (2*r+1)**2
+ assert C == self.C
+ assert S == self.S
+
+ out_pyramid = []
+ out_pyramid_dp = []
+ for i in range(self.num_levels):
+ dx = torch.linspace(-r, r, 2 * r + 1)
+ dy = torch.linspace(-r, r, 2 * r + 1)
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
+ coords.device
+ )
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
+ coords_lvl = centroid_lvl + delta_lvl
+ fmaps = self.fmaps_pyramid[i]
+ _, _, _, H, W = fmaps.shape
+ fmap2s = fmaps.view(B*S, C, H, W)
+ if len(self.depth_pyramid)>0:
+ depths_dnG_i = self.depth_pyramid[i]
+ depths_dnG_i = depths_dnG_i.view(B*S, 1, H, W)
+ dnG_sample = bilinear_sampler(depths_dnG_i, coords_lvl.view(B*S,1,N*Dim_c,2))
+ dp_corrs = (dnG_sample.view(B*S,N,-1) - coords_dp[0]).abs()/coords_dp[0]
+ out_pyramid_dp.append(dp_corrs)
+ fmap2s_sample = bilinear_sampler(fmap2s, coords_lvl.view(B*S,1,N*Dim_c,2))
+ fmap2s_sample = fmap2s_sample.permute(0, 3, 1, 2) # B*S, N*Dim_c, C, -1
+ corrs = torch.matmul(targets.reshape(B*S*N, 1, -1), fmap2s_sample.reshape(B*S*N, Dim_c, -1).permute(0, 2, 1))
+ corrs = corrs / torch.sqrt(torch.tensor(C).float())
+ corrs = corrs.view(B, S, N, -1)
+ out_pyramid.append(corrs)
+
+ out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
+ if len(self.depth_pyramid)>0:
+ out_dp = torch.cat(out_pyramid_dp, dim=-1)
+ self.fcorrD = out_dp.contiguous().float()
+ else:
+ self.fcorrD = torch.zeros_like(out).contiguous().float()
+ return out.contiguous().float()
+
+
+class EUpdateFormer(nn.Module):
+ """
+ Transformer model that updates track estimates.
+ """
+
+ def __init__(
+ self,
+ space_depth=12,
+ time_depth=12,
+ input_dim=320,
+ hidden_size=384,
+ num_heads=8,
+ output_dim=130,
+ mlp_ratio=4.0,
+ vq_depth=3,
+ add_space_attn=True,
+ add_time_attn=True,
+ flash=True
+ ):
+ super().__init__()
+ self.out_channels = 2
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.add_space_attn = add_space_attn
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
+ self.flash = flash
+ self.flow_head = nn.Sequential(
+ nn.Linear(hidden_size, output_dim, bias=True),
+ nn.ReLU(inplace=True),
+ nn.Linear(output_dim, output_dim, bias=True),
+ nn.ReLU(inplace=True),
+ nn.Linear(output_dim, output_dim, bias=True)
+ )
+ self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ cfg = xLSTMBlockStackConfig(
+ mlstm_block=mLSTMBlockConfig(
+ mlstm=mLSTMLayerConfig(
+ conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
+ )
+ ),
+ slstm_block=sLSTMBlockConfig(
+ slstm=sLSTMLayerConfig(
+ backend="cuda",
+ num_heads=4,
+ conv1d_kernel_size=4,
+ bias_init="powerlaw_blockdependent",
+ ),
+ feedforward=FeedForwardConfig(proj_factor=1.3, act_fn="gelu"),
+ ),
+ context_length=50,
+ num_blocks=7,
+ embedding_dim=384,
+ slstm_at=[1],
+
+ )
+ self.xlstm_fwd = xLSTMBlockStack(cfg)
+ self.xlstm_bwd = xLSTMBlockStack(cfg)
+
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ def forward(self,
+ input_tensor,
+ track_mask=None):
+ """ Updating with Transformer
+
+ Args:
+ input_tensor: B, N, T, C
+ arap_embed: B, N, T, C
+ """
+ B, N, T, C = input_tensor.shape
+ x = self.input_transform(input_tensor)
+
+ track_mask = track_mask.permute(0,2,1,3).float()
+ fwd_x = x*track_mask
+ bwd_x = x.flip(2)*track_mask.flip(2)
+ feat_fwd = self.xlstm_fwd(self.norm(fwd_x.view(B*N, T, -1)))
+ feat_bwd = self.xlstm_bwd(self.norm(bwd_x.view(B*N, T, -1)))
+ feat = (feat_bwd.flip(1) + feat_fwd).view(B, N, T, -1)
+
+ flow = self.flow_head(feat)
+
+ return flow[..., :2], flow[..., 2:]
+
diff --git a/models/SpaTrackV2/models/camera_transform.py b/models/SpaTrackV2/models/camera_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba5377bb774a9bf8fa776e3e439ab0a6be9a5d39
--- /dev/null
+++ b/models/SpaTrackV2/models/camera_transform.py
@@ -0,0 +1,248 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+# Adapted from https://github.com/amyxlase/relpose-plus-plus
+
+import torch
+import numpy as np
+import math
+
+
+
+
+def bbox_xyxy_to_xywh(xyxy):
+ wh = xyxy[2:] - xyxy[:2]
+ xywh = np.concatenate([xyxy[:2], wh])
+ return xywh
+
+
+def adjust_camera_to_bbox_crop_(fl, pp, image_size_wh: torch.Tensor, clamp_bbox_xywh: torch.Tensor):
+ focal_length_px, principal_point_px = _convert_ndc_to_pixels(fl, pp, image_size_wh)
+
+ principal_point_px_cropped = principal_point_px - clamp_bbox_xywh[:2]
+
+ focal_length, principal_point_cropped = _convert_pixels_to_ndc(
+ focal_length_px, principal_point_px_cropped, clamp_bbox_xywh[2:]
+ )
+
+ return focal_length, principal_point_cropped
+
+
+def adjust_camera_to_image_scale_(fl, pp, original_size_wh: torch.Tensor, new_size_wh: torch.LongTensor):
+ focal_length_px, principal_point_px = _convert_ndc_to_pixels(fl, pp, original_size_wh)
+
+ # now scale and convert from pixels to NDC
+ image_size_wh_output = new_size_wh.float()
+ scale = (image_size_wh_output / original_size_wh).min(dim=-1, keepdim=True).values
+ focal_length_px_scaled = focal_length_px * scale
+ principal_point_px_scaled = principal_point_px * scale
+
+ focal_length_scaled, principal_point_scaled = _convert_pixels_to_ndc(
+ focal_length_px_scaled, principal_point_px_scaled, image_size_wh_output
+ )
+ return focal_length_scaled, principal_point_scaled
+
+
+def _convert_ndc_to_pixels(focal_length: torch.Tensor, principal_point: torch.Tensor, image_size_wh: torch.Tensor):
+ half_image_size = image_size_wh / 2
+ rescale = half_image_size.min()
+ principal_point_px = half_image_size - principal_point * rescale
+ focal_length_px = focal_length * rescale
+ return focal_length_px, principal_point_px
+
+
+def _convert_pixels_to_ndc(
+ focal_length_px: torch.Tensor, principal_point_px: torch.Tensor, image_size_wh: torch.Tensor
+):
+ half_image_size = image_size_wh / 2
+ rescale = half_image_size.min()
+ principal_point = (half_image_size - principal_point_px) / rescale
+ focal_length = focal_length_px / rescale
+ return focal_length, principal_point
+
+
+def normalize_cameras(
+ cameras, compute_optical=True, first_camera=True, normalize_trans=True, scale=1.0, points=None, max_norm=False,
+ pose_mode="C2W"
+):
+ """
+ Normalizes cameras such that
+ (1) the optical axes point to the origin and the average distance to the origin is 1
+ (2) the first camera is the origin
+ (3) the translation vector is normalized
+
+ TODO: some transforms overlap with others. no need to do so many transforms
+ Args:
+ cameras (List[camera]).
+ """
+ # Let distance from first camera to origin be unit
+ new_cameras = cameras.clone()
+ scale = 1.0
+
+ if compute_optical:
+ new_cameras, points = compute_optical_transform(new_cameras, points=points)
+ if first_camera:
+ new_cameras, points = first_camera_transform(new_cameras, points=points, pose_mode=pose_mode)
+ if normalize_trans:
+ new_cameras, points, scale = normalize_translation(new_cameras,
+ points=points, max_norm=max_norm)
+ return new_cameras, points, scale
+
+
+def compute_optical_transform(new_cameras, points=None):
+ """
+ adapted from https://github.com/amyxlase/relpose-plus-plus
+ """
+
+ new_transform = new_cameras.get_world_to_view_transform()
+ p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(new_cameras)
+ t = Translate(p_intersect)
+ scale = dist.squeeze()[0]
+
+ if points is not None:
+ points = t.inverse().transform_points(points)
+ points = points / scale
+
+ # Degenerate case
+ if scale == 0:
+ scale = torch.norm(new_cameras.T, dim=(0, 1))
+ scale = torch.sqrt(scale)
+ new_cameras.T = new_cameras.T / scale
+ else:
+ new_matrix = t.compose(new_transform).get_matrix()
+ new_cameras.R = new_matrix[:, :3, :3]
+ new_cameras.T = new_matrix[:, 3, :3] / scale
+
+ return new_cameras, points
+
+
+def compute_optical_axis_intersection(cameras):
+ centers = cameras.get_camera_center()
+ principal_points = cameras.principal_point
+
+ one_vec = torch.ones((len(cameras), 1))
+ optical_axis = torch.cat((principal_points, one_vec), -1)
+
+ pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True)
+
+ pp2 = pp[torch.arange(pp.shape[0]), torch.arange(pp.shape[0])]
+
+ directions = pp2 - centers
+ centers = centers.unsqueeze(0).unsqueeze(0)
+ directions = directions.unsqueeze(0).unsqueeze(0)
+
+ p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(p=centers, r=directions, mask=None)
+
+ p_intersect = p_intersect.squeeze().unsqueeze(0)
+ dist = (p_intersect - centers).norm(dim=-1)
+
+ return p_intersect, dist, p_line_intersect, pp2, r
+
+
+def intersect_skew_line_groups(p, r, mask):
+ # p, r both of shape (B, N, n_intersected_lines, 3)
+ # mask of shape (B, N, n_intersected_lines)
+ p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
+ _, p_line_intersect = _point_line_distance(p, r, p_intersect[..., None, :].expand_as(p))
+ intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(dim=-1)
+ return p_intersect, p_line_intersect, intersect_dist_squared, r
+
+
+def intersect_skew_lines_high_dim(p, r, mask=None):
+ # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
+ dim = p.shape[-1]
+ # make sure the heading vectors are l2-normed
+ if mask is None:
+ mask = torch.ones_like(p[..., 0])
+ r = torch.nn.functional.normalize(r, dim=-1)
+
+ eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
+ I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
+ sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
+ p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
+
+ if torch.any(torch.isnan(p_intersect)):
+ print(p_intersect)
+ raise ValueError(f"p_intersect is NaN")
+
+ return p_intersect, r
+
+
+def _point_line_distance(p1, r1, p2):
+ df = p2 - p1
+ proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
+ line_pt_nearest = p2 - proj_vector
+ d = (proj_vector).norm(dim=-1)
+ return d, line_pt_nearest
+
+
+def first_camera_transform(cameras, rotation_only=False,
+ points=None, pose_mode="C2W"):
+ """
+ Transform so that the first camera is the origin
+ """
+
+ new_cameras = cameras.clone()
+ # new_transform = new_cameras.get_world_to_view_transform()
+
+ R = cameras.R
+ T = cameras.T
+ Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B, 3, 4]
+ Tran_M = torch.cat([Tran_M,
+ torch.tensor([[[0, 0, 0, 1]]], device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)], dim=1)
+ if pose_mode == "C2W":
+ Tran_M_new = (Tran_M[:1,...].inverse())@Tran_M
+ elif pose_mode == "W2C":
+ Tran_M_new = Tran_M@(Tran_M[:1,...].inverse())
+
+ if False:
+ tR = Rotate(new_cameras.R[0].unsqueeze(0))
+ if rotation_only:
+ t = tR.inverse()
+ else:
+ tT = Translate(new_cameras.T[0].unsqueeze(0))
+ t = tR.compose(tT).inverse()
+
+ if points is not None:
+ points = t.inverse().transform_points(points)
+
+ if pose_mode == "C2W":
+ new_matrix = new_transform.compose(t).get_matrix()
+ else:
+ import ipdb; ipdb.set_trace()
+ new_matrix = t.compose(new_transform).get_matrix()
+
+ new_cameras.R = Tran_M_new[:, :3, :3]
+ new_cameras.T = Tran_M_new[:, :3, 3]
+
+ return new_cameras, points
+
+
+def normalize_translation(new_cameras, points=None, max_norm=False):
+ t_gt = new_cameras.T.clone()
+ t_gt = t_gt[1:, :]
+
+ if max_norm:
+ t_gt_norm = torch.norm(t_gt, dim=(-1))
+ t_gt_scale = t_gt_norm.max()
+ if t_gt_norm.max() < 0.001:
+ t_gt_scale = torch.ones_like(t_gt_scale)
+ t_gt_scale = t_gt_scale.clamp(min=0.01, max=1e5)
+ else:
+ t_gt_norm = torch.norm(t_gt, dim=(0, 1))
+ t_gt_scale = t_gt_norm / math.sqrt(len(t_gt))
+ t_gt_scale = t_gt_scale / 2
+ if t_gt_norm.max() < 0.001:
+ t_gt_scale = torch.ones_like(t_gt_scale)
+ t_gt_scale = t_gt_scale.clamp(min=0.01, max=1e5)
+
+ new_cameras.T = new_cameras.T / t_gt_scale
+
+ if points is not None:
+ points = points / t_gt_scale
+
+ return new_cameras, points, t_gt_scale
\ No newline at end of file
diff --git a/models/SpaTrackV2/models/depth_refiner/backbone.py b/models/SpaTrackV2/models/depth_refiner/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ccec449a4a2c9306ab84182a29ddac5e68c36b9
--- /dev/null
+++ b/models/SpaTrackV2/models/depth_refiner/backbone.py
@@ -0,0 +1,472 @@
+# ---------------------------------------------------------------
+# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
+#
+# This work is licensed under the NVIDIA Source Code License
+# ---------------------------------------------------------------
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+
+from timm.layers import DropPath, to_2tuple, trunc_normal_
+from timm.models import register_model
+from timm.models.vision_transformer import _cfg
+import math
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.dwconv = DWConv(hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x, H, W):
+ x = self.fc1(x)
+ x = self.dwconv(x, H, W)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
+ super().__init__()
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
+
+ self.dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.sr_ratio = sr_ratio
+ if sr_ratio > 1:
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
+ self.norm = nn.LayerNorm(dim)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x, H, W):
+ B, N, C = x.shape
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+
+ if self.sr_ratio > 1:
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
+ x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
+ x_ = self.norm(x_)
+ kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ else:
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ k, v = kv[0], kv[1]
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x, H, W):
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
+
+ return x
+
+
+class OverlapPatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+
+ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
+ self.num_patches = self.H * self.W
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
+ self.norm = nn.LayerNorm(embed_dim)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x):
+ x = self.proj(x)
+ _, _, H, W = x.shape
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+
+ return x, H, W
+
+
+
+
+class OverlapPatchEmbed43(nn.Module):
+ """ Image to Patch Embedding
+ """
+
+ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
+ self.num_patches = self.H * self.W
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
+ self.norm = nn.LayerNorm(embed_dim)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x):
+ if x.shape[1]==4:
+ x = self.proj_4c(x)
+ else:
+ x = self.proj(x)
+ _, _, H, W = x.shape
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+
+ return x, H, W
+
+class MixVisionTransformer(nn.Module):
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
+ num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
+ attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
+ depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
+ super().__init__()
+ self.num_classes = num_classes
+ self.depths = depths
+
+ # patch_embed 43
+ self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
+ embed_dim=embed_dims[0])
+ self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
+ embed_dim=embed_dims[1])
+ self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
+ embed_dim=embed_dims[2])
+ self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
+ embed_dim=embed_dims[3])
+
+ # transformer encoder
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+ cur = 0
+ self.block1 = nn.ModuleList([Block(
+ dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[0])
+ for i in range(depths[0])])
+ self.norm1 = norm_layer(embed_dims[0])
+
+ cur += depths[0]
+ self.block2 = nn.ModuleList([Block(
+ dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[1])
+ for i in range(depths[1])])
+ self.norm2 = norm_layer(embed_dims[1])
+
+ cur += depths[1]
+ self.block3 = nn.ModuleList([Block(
+ dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[2])
+ for i in range(depths[2])])
+ self.norm3 = norm_layer(embed_dims[2])
+
+ cur += depths[2]
+ self.block4 = nn.ModuleList([Block(
+ dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[3])
+ for i in range(depths[3])])
+ self.norm4 = norm_layer(embed_dims[3])
+
+ # classification head
+ # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
+
+ def reset_drop_path(self, drop_path_rate):
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
+ cur = 0
+ for i in range(self.depths[0]):
+ self.block1[i].drop_path.drop_prob = dpr[cur + i]
+
+ cur += self.depths[0]
+ for i in range(self.depths[1]):
+ self.block2[i].drop_path.drop_prob = dpr[cur + i]
+
+ cur += self.depths[1]
+ for i in range(self.depths[2]):
+ self.block3[i].drop_path.drop_prob = dpr[cur + i]
+
+ cur += self.depths[2]
+ for i in range(self.depths[3]):
+ self.block4[i].drop_path.drop_prob = dpr[cur + i]
+
+ def freeze_patch_emb(self):
+ self.patch_embed1.requires_grad = False
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ outs = []
+
+ # stage 1
+ x, H, W = self.patch_embed1(x)
+ for i, blk in enumerate(self.block1):
+ x = blk(x, H, W)
+ x = self.norm1(x)
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
+ outs.append(x)
+
+ # stage 2
+ x, H, W = self.patch_embed2(x)
+ for i, blk in enumerate(self.block2):
+ x = blk(x, H, W)
+ x = self.norm2(x)
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
+ outs.append(x)
+
+ # stage 3
+ x, H, W = self.patch_embed3(x)
+ for i, blk in enumerate(self.block3):
+ x = blk(x, H, W)
+ x = self.norm3(x)
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
+ outs.append(x)
+
+ # stage 4
+ x, H, W = self.patch_embed4(x)
+ for i, blk in enumerate(self.block4):
+ x = blk(x, H, W)
+ x = self.norm4(x)
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
+ outs.append(x)
+
+ return outs
+
+ def forward(self, x):
+ if x.dim() == 5:
+ x = x.reshape(x.shape[0]*x.shape[1],x.shape[2],x.shape[3],x.shape[4])
+ x = self.forward_features(x)
+ # x = self.head(x)
+
+ return x
+
+
+class DWConv(nn.Module):
+ def __init__(self, dim=768):
+ super(DWConv, self).__init__()
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
+
+ def forward(self, x, H, W):
+ B, N, C = x.shape
+ x = x.transpose(1, 2).view(B, C, H, W)
+ x = self.dwconv(x)
+ x = x.flatten(2).transpose(1, 2)
+
+ return x
+
+
+
+#@BACKBONES.register_module()
+class mit_b0(MixVisionTransformer):
+ def __init__(self, **kwargs):
+ super(mit_b0, self).__init__(
+ patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
+ drop_rate=0.0, drop_path_rate=0.1)
+
+
+#@BACKBONES.register_module()
+class mit_b1(MixVisionTransformer):
+ def __init__(self, **kwargs):
+ super(mit_b1, self).__init__(
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
+ drop_rate=0.0, drop_path_rate=0.1)
+
+
+#@BACKBONES.register_module()
+class mit_b2(MixVisionTransformer):
+ def __init__(self, **kwargs):
+ super(mit_b2, self).__init__(
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
+ drop_rate=0.0, drop_path_rate=0.1)
+
+
+#@BACKBONES.register_module()
+class mit_b3(MixVisionTransformer):
+ def __init__(self, **kwargs):
+ super(mit_b3, self).__init__(
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
+ drop_rate=0.0, drop_path_rate=0.1)
+
+
+#@BACKBONES.register_module()
+class mit_b4(MixVisionTransformer):
+ def __init__(self, **kwargs):
+ super(mit_b4, self).__init__(
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
+ drop_rate=0.0, drop_path_rate=0.1)
+
+
+#@BACKBONES.register_module()
+class mit_b5(MixVisionTransformer):
+ def __init__(self, **kwargs):
+ super(mit_b5, self).__init__(
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
+ drop_rate=0.0, drop_path_rate=0.1)
+
+
diff --git a/models/SpaTrackV2/models/depth_refiner/decode_head.py b/models/SpaTrackV2/models/depth_refiner/decode_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b733c367f0459a790c182633024e5d6d851bd1fe
--- /dev/null
+++ b/models/SpaTrackV2/models/depth_refiner/decode_head.py
@@ -0,0 +1,619 @@
+from abc import ABCMeta, abstractmethod
+
+import torch
+import torch.nn as nn
+
+# from mmcv.cnn import normal_init
+# from mmcv.runner import auto_fp16, force_fp32
+
+# from mmseg.core import build_pixel_sampler
+# from mmseg.ops import resize
+
+
+class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
+ """Base class for BaseDecodeHead.
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ num_classes (int): Number of classes.
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
+ conv_cfg (dict|None): Config of conv layers. Default: None.
+ norm_cfg (dict|None): Config of norm layers. Default: None.
+ act_cfg (dict): Config of activation layers.
+ Default: dict(type='ReLU')
+ in_index (int|Sequence[int]): Input feature index. Default: -1
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ Default: None.
+ loss_decode (dict): Config of decode loss.
+ Default: dict(type='CrossEntropyLoss').
+ ignore_index (int | None): The label index to be ignored. When using
+ masked BCE loss, ignore_index should be set to None. Default: 255
+ sampler (dict|None): The config of segmentation map sampler.
+ Default: None.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ """
+
+ def __init__(self,
+ in_channels,
+ channels,
+ *,
+ num_classes,
+ dropout_ratio=0.1,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ in_index=-1,
+ input_transform=None,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0),
+ decoder_params=None,
+ ignore_index=255,
+ sampler=None,
+ align_corners=False):
+ super(BaseDecodeHead, self).__init__()
+ self._init_inputs(in_channels, in_index, input_transform)
+ self.channels = channels
+ self.num_classes = num_classes
+ self.dropout_ratio = dropout_ratio
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.in_index = in_index
+ self.ignore_index = ignore_index
+ self.align_corners = align_corners
+
+ if sampler is not None:
+ self.sampler = build_pixel_sampler(sampler, context=self)
+ else:
+ self.sampler = None
+
+ self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
+ if dropout_ratio > 0:
+ self.dropout = nn.Dropout2d(dropout_ratio)
+ else:
+ self.dropout = None
+ self.fp16_enabled = False
+
+ def extra_repr(self):
+ """Extra repr."""
+ s = f'input_transform={self.input_transform}, ' \
+ f'ignore_index={self.ignore_index}, ' \
+ f'align_corners={self.align_corners}'
+ return s
+
+ def _init_inputs(self, in_channels, in_index, input_transform):
+ """Check and initialize input transforms.
+
+ The in_channels, in_index and input_transform must match.
+ Specifically, when input_transform is None, only single feature map
+ will be selected. So in_channels and in_index must be of type int.
+ When input_transform
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ in_index (int|Sequence[int]): Input feature index.
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ """
+
+ if input_transform is not None:
+ assert input_transform in ['resize_concat', 'multiple_select']
+ self.input_transform = input_transform
+ self.in_index = in_index
+ if input_transform is not None:
+ assert isinstance(in_channels, (list, tuple))
+ assert isinstance(in_index, (list, tuple))
+ assert len(in_channels) == len(in_index)
+ if input_transform == 'resize_concat':
+ self.in_channels = sum(in_channels)
+ else:
+ self.in_channels = in_channels
+ else:
+ assert isinstance(in_channels, int)
+ assert isinstance(in_index, int)
+ self.in_channels = in_channels
+
+ def init_weights(self):
+ """Initialize weights of classification layer."""
+ normal_init(self.conv_seg, mean=0, std=0.01)
+
+ def _transform_inputs(self, inputs):
+ """Transform inputs for decoder.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+
+ Returns:
+ Tensor: The transformed inputs
+ """
+
+ if self.input_transform == 'resize_concat':
+ inputs = [inputs[i] for i in self.in_index]
+ upsampled_inputs = [
+ resize(
+ input=x,
+ size=inputs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners) for x in inputs
+ ]
+ inputs = torch.cat(upsampled_inputs, dim=1)
+ elif self.input_transform == 'multiple_select':
+ inputs = [inputs[i] for i in self.in_index]
+ else:
+ inputs = inputs[self.in_index]
+
+ return inputs
+
+ # @auto_fp16()
+ @abstractmethod
+ def forward(self, inputs):
+ """Placeholder of forward function."""
+ pass
+
+ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+ train_cfg (dict): The training config.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ seg_logits = self.forward(inputs)
+ losses = self.losses(seg_logits, gt_semantic_seg)
+ return losses
+
+ def forward_test(self, inputs, img_metas, test_cfg):
+ """Forward function for testing.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ test_cfg (dict): The testing config.
+
+ Returns:
+ Tensor: Output segmentation map.
+ """
+ return self.forward(inputs)
+
+ def cls_seg(self, feat):
+ """Classify each pixel."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.conv_seg(feat)
+ return output
+
+
+class BaseDecodeHead_clips(nn.Module, metaclass=ABCMeta):
+ """Base class for BaseDecodeHead_clips.
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ num_classes (int): Number of classes.
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
+ conv_cfg (dict|None): Config of conv layers. Default: None.
+ norm_cfg (dict|None): Config of norm layers. Default: None.
+ act_cfg (dict): Config of activation layers.
+ Default: dict(type='ReLU')
+ in_index (int|Sequence[int]): Input feature index. Default: -1
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ Default: None.
+ loss_decode (dict): Config of decode loss.
+ Default: dict(type='CrossEntropyLoss').
+ ignore_index (int | None): The label index to be ignored. When using
+ masked BCE loss, ignore_index should be set to None. Default: 255
+ sampler (dict|None): The config of segmentation map sampler.
+ Default: None.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ """
+
+ def __init__(self,
+ in_channels,
+ channels,
+ *,
+ num_classes,
+ dropout_ratio=0.1,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ in_index=-1,
+ input_transform=None,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0),
+ decoder_params=None,
+ ignore_index=255,
+ sampler=None,
+ align_corners=False,
+ num_clips=5):
+ super(BaseDecodeHead_clips, self).__init__()
+ self._init_inputs(in_channels, in_index, input_transform)
+ self.channels = channels
+ self.num_classes = num_classes
+ self.dropout_ratio = dropout_ratio
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.in_index = in_index
+ self.ignore_index = ignore_index
+ self.align_corners = align_corners
+ self.num_clips=num_clips
+
+ if sampler is not None:
+ self.sampler = build_pixel_sampler(sampler, context=self)
+ else:
+ self.sampler = None
+
+ self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
+ if dropout_ratio > 0:
+ self.dropout = nn.Dropout2d(dropout_ratio)
+ else:
+ self.dropout = None
+ self.fp16_enabled = False
+
+ def extra_repr(self):
+ """Extra repr."""
+ s = f'input_transform={self.input_transform}, ' \
+ f'ignore_index={self.ignore_index}, ' \
+ f'align_corners={self.align_corners}'
+ return s
+
+ def _init_inputs(self, in_channels, in_index, input_transform):
+ """Check and initialize input transforms.
+
+ The in_channels, in_index and input_transform must match.
+ Specifically, when input_transform is None, only single feature map
+ will be selected. So in_channels and in_index must be of type int.
+ When input_transform
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ in_index (int|Sequence[int]): Input feature index.
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ """
+
+ if input_transform is not None:
+ assert input_transform in ['resize_concat', 'multiple_select']
+ self.input_transform = input_transform
+ self.in_index = in_index
+ if input_transform is not None:
+ assert isinstance(in_channels, (list, tuple))
+ assert isinstance(in_index, (list, tuple))
+ assert len(in_channels) == len(in_index)
+ if input_transform == 'resize_concat':
+ self.in_channels = sum(in_channels)
+ else:
+ self.in_channels = in_channels
+ else:
+ assert isinstance(in_channels, int)
+ assert isinstance(in_index, int)
+ self.in_channels = in_channels
+
+ def init_weights(self):
+ """Initialize weights of classification layer."""
+ normal_init(self.conv_seg, mean=0, std=0.01)
+
+ def _transform_inputs(self, inputs):
+ """Transform inputs for decoder.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+
+ Returns:
+ Tensor: The transformed inputs
+ """
+
+ if self.input_transform == 'resize_concat':
+ inputs = [inputs[i] for i in self.in_index]
+ upsampled_inputs = [
+ resize(
+ input=x,
+ size=inputs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners) for x in inputs
+ ]
+ inputs = torch.cat(upsampled_inputs, dim=1)
+ elif self.input_transform == 'multiple_select':
+ inputs = [inputs[i] for i in self.in_index]
+ else:
+ inputs = inputs[self.in_index]
+
+ return inputs
+
+ # @auto_fp16()
+ @abstractmethod
+ def forward(self, inputs):
+ """Placeholder of forward function."""
+ pass
+
+ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+ train_cfg (dict): The training config.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ seg_logits = self.forward(inputs,batch_size, num_clips)
+ losses = self.losses(seg_logits, gt_semantic_seg)
+ return losses
+
+ def forward_test(self, inputs, img_metas, test_cfg, batch_size, num_clips):
+ """Forward function for testing.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ test_cfg (dict): The testing config.
+
+ Returns:
+ Tensor: Output segmentation map.
+ """
+ return self.forward(inputs, batch_size, num_clips)
+
+ def cls_seg(self, feat):
+ """Classify each pixel."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.conv_seg(feat)
+ return output
+
+class BaseDecodeHead_clips_flow(nn.Module, metaclass=ABCMeta):
+ """Base class for BaseDecodeHead_clips_flow.
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ num_classes (int): Number of classes.
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
+ conv_cfg (dict|None): Config of conv layers. Default: None.
+ norm_cfg (dict|None): Config of norm layers. Default: None.
+ act_cfg (dict): Config of activation layers.
+ Default: dict(type='ReLU')
+ in_index (int|Sequence[int]): Input feature index. Default: -1
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ Default: None.
+ loss_decode (dict): Config of decode loss.
+ Default: dict(type='CrossEntropyLoss').
+ ignore_index (int | None): The label index to be ignored. When using
+ masked BCE loss, ignore_index should be set to None. Default: 255
+ sampler (dict|None): The config of segmentation map sampler.
+ Default: None.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ """
+
+ def __init__(self,
+ in_channels,
+ channels,
+ *,
+ num_classes,
+ dropout_ratio=0.1,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ in_index=-1,
+ input_transform=None,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0),
+ decoder_params=None,
+ ignore_index=255,
+ sampler=None,
+ align_corners=False,
+ num_clips=5):
+ super(BaseDecodeHead_clips_flow, self).__init__()
+ self._init_inputs(in_channels, in_index, input_transform)
+ self.channels = channels
+ self.num_classes = num_classes
+ self.dropout_ratio = dropout_ratio
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.in_index = in_index
+ self.ignore_index = ignore_index
+ self.align_corners = align_corners
+ self.num_clips=num_clips
+
+ if sampler is not None:
+ self.sampler = build_pixel_sampler(sampler, context=self)
+ else:
+ self.sampler = None
+
+ self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
+ if dropout_ratio > 0:
+ self.dropout = nn.Dropout2d(dropout_ratio)
+ else:
+ self.dropout = None
+ self.fp16_enabled = False
+
+ def extra_repr(self):
+ """Extra repr."""
+ s = f'input_transform={self.input_transform}, ' \
+ f'ignore_index={self.ignore_index}, ' \
+ f'align_corners={self.align_corners}'
+ return s
+
+ def _init_inputs(self, in_channels, in_index, input_transform):
+ """Check and initialize input transforms.
+
+ The in_channels, in_index and input_transform must match.
+ Specifically, when input_transform is None, only single feature map
+ will be selected. So in_channels and in_index must be of type int.
+ When input_transform
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ in_index (int|Sequence[int]): Input feature index.
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ """
+
+ if input_transform is not None:
+ assert input_transform in ['resize_concat', 'multiple_select']
+ self.input_transform = input_transform
+ self.in_index = in_index
+ if input_transform is not None:
+ assert isinstance(in_channels, (list, tuple))
+ assert isinstance(in_index, (list, tuple))
+ assert len(in_channels) == len(in_index)
+ if input_transform == 'resize_concat':
+ self.in_channels = sum(in_channels)
+ else:
+ self.in_channels = in_channels
+ else:
+ assert isinstance(in_channels, int)
+ assert isinstance(in_index, int)
+ self.in_channels = in_channels
+
+ def init_weights(self):
+ """Initialize weights of classification layer."""
+ normal_init(self.conv_seg, mean=0, std=0.01)
+
+ def _transform_inputs(self, inputs):
+ """Transform inputs for decoder.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+
+ Returns:
+ Tensor: The transformed inputs
+ """
+
+ if self.input_transform == 'resize_concat':
+ inputs = [inputs[i] for i in self.in_index]
+ upsampled_inputs = [
+ resize(
+ input=x,
+ size=inputs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners) for x in inputs
+ ]
+ inputs = torch.cat(upsampled_inputs, dim=1)
+ elif self.input_transform == 'multiple_select':
+ inputs = [inputs[i] for i in self.in_index]
+ else:
+ inputs = inputs[self.in_index]
+
+ return inputs
+
+ # @auto_fp16()
+ @abstractmethod
+ def forward(self, inputs):
+ """Placeholder of forward function."""
+ pass
+
+ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips,img=None):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+ train_cfg (dict): The training config.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ seg_logits = self.forward(inputs,batch_size, num_clips,img)
+ losses = self.losses(seg_logits, gt_semantic_seg)
+ return losses
+
+ def forward_test(self, inputs, img_metas, test_cfg, batch_size=None, num_clips=None, img=None):
+ """Forward function for testing.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ test_cfg (dict): The testing config.
+
+ Returns:
+ Tensor: Output segmentation map.
+ """
+ return self.forward(inputs, batch_size, num_clips,img)
+
+ def cls_seg(self, feat):
+ """Classify each pixel."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.conv_seg(feat)
+ return output
\ No newline at end of file
diff --git a/models/SpaTrackV2/models/depth_refiner/depth_refiner.py b/models/SpaTrackV2/models/depth_refiner/depth_refiner.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a98a8202bb7fcd1cca51b2460a989dca5e5ac00
--- /dev/null
+++ b/models/SpaTrackV2/models/depth_refiner/depth_refiner.py
@@ -0,0 +1,115 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from models.monoD.depth_anything_v2.dinov2_layers.patch_embed import PatchEmbed
+from models.SpaTrackV2.models.depth_refiner.backbone import mit_b3
+from models.SpaTrackV2.models.depth_refiner.stablizer import Stabilization_Network_Cross_Attention
+from einops import rearrange
+class TrackStablizer(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ self.backbone = mit_b3()
+
+ old_conv = self.backbone.patch_embed1.proj
+ new_conv = nn.Conv2d(old_conv.in_channels + 4, old_conv.out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding)
+
+ new_conv.weight[:, :3, :, :].data.copy_(old_conv.weight.clone())
+ self.backbone.patch_embed1.proj = new_conv
+
+ self.Track_Stabilizer = Stabilization_Network_Cross_Attention(in_channels=[64, 128, 320, 512],
+ in_index=[0, 1, 2, 3],
+ feature_strides=[4, 8, 16, 32],
+ channels=128,
+ dropout_ratio=0.1,
+ num_classes=1,
+ align_corners=False,
+ decoder_params=dict(embed_dim=256, depths=4),
+ num_clips=16,
+ norm_cfg = dict(type='SyncBN', requires_grad=True))
+
+ self.edge_conv = nn.Sequential(nn.Conv2d(in_channels=4, out_channels=64, kernel_size=3, padding=1, stride=1, bias=True),\
+ nn.ReLU(inplace=True))
+ self.edge_conv1 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=2, bias=True),\
+ nn.ReLU(inplace=True))
+ self.success = False
+ self.x = None
+
+ def buffer_forward(self, inputs, num_clips=16):
+ """
+ buffer forward for getting the pointmap and image features
+ """
+ B, T, C, H, W = inputs.shape
+ self.x = self.backbone(inputs)
+ scale, shift = self.Track_Stabilizer.buffer_forward(self.x, num_clips=num_clips)
+ self.success = True
+ return scale, shift
+
+ def forward(self, inputs, tracks, tracks_uvd, num_clips=16, imgs=None, vis_track=None):
+
+ """
+ Args:
+ inputs: [B, T, C, H, W], RGB + PointMap + Mask
+ tracks: [B, T, N, 4], 3D tracks in camera coordinate + visibility
+ num_clips: int, number of clips to use
+ """
+ B, T, C, H, W = inputs.shape
+ edge_feat = self.edge_conv(inputs.view(B*T,4,H,W))
+ edge_feat1 = self.edge_conv1(edge_feat)
+
+ if not self.success:
+ scale, shift = self.Track_Stabilizer.buffer_forward(self.x,num_clips=num_clips)
+ self.success = True
+ update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
+ else:
+ update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
+
+ return update
+
+ def reset_success(self):
+ self.success = False
+ self.x = None
+ self.Track_Stabilizer.reset_success()
+
+
+if __name__ == "__main__":
+ # Create test input tensors
+ batch_size = 1
+ seq_len = 16
+ channels = 7 # 3 for RGB + 3 for PointMap + 1 for Mask
+ height = 384
+ width = 512
+
+ # Create random input tensor with shape [B, T, C, H, W]
+ inputs = torch.randn(batch_size, seq_len, channels, height, width)
+
+ # Create random tracks
+ tracks = torch.randn(batch_size, seq_len, 1024, 4)
+
+ # Create random test images
+ test_imgs = torch.randn(batch_size, seq_len, 3, height, width)
+
+ # Initialize model and move to GPU
+ model = TrackStablizer().cuda()
+
+ # Move inputs to GPU and run forward pass
+ inputs = inputs.cuda()
+ tracks = tracks.cuda()
+ outputs = model.buffer_forward(inputs, num_clips=seq_len)
+ import time
+ start_time = time.time()
+ outputs = model(inputs, tracks, num_clips=seq_len)
+ end_time = time.time()
+ print(f"Time taken: {end_time - start_time} seconds")
+ import pdb; pdb.set_trace()
+ # # Print shapes for verification
+ # print(f"Input shape: {inputs.shape}")
+ # print(f"Output shape: {outputs.shape}")
+
+ # # Basic tests
+ # assert outputs.shape[0] == batch_size, "Batch size mismatch"
+ # assert len(outputs.shape) == 4, "Output should be 4D: [B,C,H,W]"
+ # assert torch.all(outputs >= 0), "Output should be non-negative after ReLU"
+
+ # print("All tests passed!")
+
diff --git a/models/SpaTrackV2/models/depth_refiner/network.py b/models/SpaTrackV2/models/depth_refiner/network.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9e70f085e34fb94431191723c1b27bf06e77e1e
--- /dev/null
+++ b/models/SpaTrackV2/models/depth_refiner/network.py
@@ -0,0 +1,429 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+'''
+Author: Ke Xian
+Email: kexian@hust.edu.cn
+Date: 2020/07/20
+'''
+
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+
+# ==============================================================================================================
+
+class FTB(nn.Module):
+ def __init__(self, inchannels, midchannels=512):
+ super(FTB, self).__init__()
+ self.in1 = inchannels
+ self.mid = midchannels
+
+ self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True)
+ self.conv_branch = nn.Sequential(nn.ReLU(inplace=True),\
+ nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True),\
+ #nn.BatchNorm2d(num_features=self.mid),\
+ nn.ReLU(inplace=True),\
+ nn.Conv2d(in_channels=self.mid, out_channels= self.mid, kernel_size=3, padding=1, stride=1, bias=True))
+ self.relu = nn.ReLU(inplace=True)
+
+ self.init_params()
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = x + self.conv_branch(x)
+ x = self.relu(x)
+
+ return x
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ #init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ #init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+class ATA(nn.Module):
+ def __init__(self, inchannels, reduction = 8):
+ super(ATA, self).__init__()
+ self.inchannels = inchannels
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(nn.Linear(self.inchannels*2, self.inchannels // reduction),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.inchannels // reduction, self.inchannels),
+ nn.Sigmoid())
+ self.init_params()
+
+ def forward(self, low_x, high_x):
+ n, c, _, _ = low_x.size()
+ x = torch.cat([low_x, high_x], 1)
+ x = self.avg_pool(x)
+ x = x.view(n, -1)
+ x = self.fc(x).view(n,c,1,1)
+ x = low_x * x + high_x
+
+ return x
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ #init.kaiming_normal_(m.weight, mode='fan_out')
+ #init.normal(m.weight, std=0.01)
+ init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ #init.kaiming_normal_(m.weight, mode='fan_out')
+ #init.normal_(m.weight, std=0.01)
+ init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+
+class FFM(nn.Module):
+ def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
+ super(FFM, self).__init__()
+ self.inchannels = inchannels
+ self.midchannels = midchannels
+ self.outchannels = outchannels
+ self.upfactor = upfactor
+
+ self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels)
+ self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
+
+ self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
+
+ self.init_params()
+ #self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
+ #self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
+ #self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
+
+ def forward(self, low_x, high_x):
+ x = self.ftb1(low_x)
+
+ '''
+ x = torch.cat((x,high_x),1)
+ if x.shape[2] == 12:
+ x = self.p1(x)
+ elif x.shape[2] == 24:
+ x = self.p2(x)
+ elif x.shape[2] == 48:
+ x = self.p3(x)
+ '''
+ x = x + high_x ###high_x
+ x = self.ftb2(x)
+ x = self.upsample(x)
+
+ return x
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ #init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ #init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ #init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ #init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+
+
+class noFFM(nn.Module):
+ def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
+ super(noFFM, self).__init__()
+ self.inchannels = inchannels
+ self.midchannels = midchannels
+ self.outchannels = outchannels
+ self.upfactor = upfactor
+
+ self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
+
+ self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
+
+ self.init_params()
+ #self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
+ #self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
+ #self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
+
+ def forward(self, low_x, high_x):
+
+ #x = self.ftb1(low_x)
+ x = high_x ###high_x
+ x = self.ftb2(x)
+ x = self.upsample(x)
+
+ return x
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ #init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ #init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ #init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ #init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+
+
+
+class AO(nn.Module):
+ # Adaptive output module
+ def __init__(self, inchannels, outchannels, upfactor=2):
+ super(AO, self).__init__()
+ self.inchannels = inchannels
+ self.outchannels = outchannels
+ self.upfactor = upfactor
+
+ """
+ self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
+ nn.BatchNorm2d(num_features=self.inchannels//2),\
+ nn.ReLU(inplace=True),\
+ nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=3, padding=1, stride=1, bias=True),\
+ nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) )#,\
+ #nn.ReLU(inplace=True)) ## get positive values
+ """
+ self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
+ #nn.BatchNorm2d(num_features=self.inchannels//2),\
+ nn.ReLU(inplace=True),\
+ nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True), \
+ nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=1, padding=0, stride=1))
+
+ #nn.ReLU(inplace=True)) ## get positive values
+
+ self.init_params()
+
+ def forward(self, x):
+ x = self.adapt_conv(x)
+ return x
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ #init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ #init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ #init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ #init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+class ASPP(nn.Module):
+ def __init__(self, inchannels=256, planes=128, rates = [1, 6, 12, 18]):
+ super(ASPP, self).__init__()
+ self.inchannels = inchannels
+ self.planes = planes
+ self.rates = rates
+ self.kernel_sizes = []
+ self.paddings = []
+ for rate in self.rates:
+ if rate == 1:
+ self.kernel_sizes.append(1)
+ self.paddings.append(0)
+ else:
+ self.kernel_sizes.append(3)
+ self.paddings.append(rate)
+ self.atrous_0 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[0],
+ stride=1, padding=self.paddings[0], dilation=self.rates[0], bias=True),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm2d(num_features=self.planes)
+ )
+ self.atrous_1 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[1],
+ stride=1, padding=self.paddings[1], dilation=self.rates[1], bias=True),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm2d(num_features=self.planes),
+ )
+ self.atrous_2 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[2],
+ stride=1, padding=self.paddings[2], dilation=self.rates[2], bias=True),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm2d(num_features=self.planes),
+ )
+ self.atrous_3 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[3],
+ stride=1, padding=self.paddings[3], dilation=self.rates[3], bias=True),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm2d(num_features=self.planes),
+ )
+
+ #self.conv = nn.Conv2d(in_channels=self.planes * 4, out_channels=self.inchannels, kernel_size=3, padding=1, stride=1, bias=True)
+ def forward(self, x):
+ x = torch.cat([self.atrous_0(x), self.atrous_1(x), self.atrous_2(x), self.atrous_3(x)],1)
+ #x = self.conv(x)
+
+ return x
+
+# ==============================================================================================================
+
+
+class ResidualConv(nn.Module):
+ def __init__(self, inchannels):
+ super(ResidualConv, self).__init__()
+ #nn.BatchNorm2d
+ self.conv = nn.Sequential(
+ #nn.BatchNorm2d(num_features=inchannels),
+ nn.ReLU(inplace=False),
+ #nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True),
+ #nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True)
+ nn.Conv2d(in_channels=inchannels, out_channels=inchannels//2, kernel_size=3, padding=1, stride=1, bias=False),
+ nn.BatchNorm2d(num_features=inchannels//2),
+ nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=inchannels//2, out_channels=inchannels, kernel_size=3, padding=1, stride=1, bias=False)
+ )
+ self.init_params()
+
+ def forward(self, x):
+ x = self.conv(x)+x
+ return x
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ #init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ #init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ #init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ #init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+
+class FeatureFusion(nn.Module):
+ def __init__(self, inchannels, outchannels):
+ super(FeatureFusion, self).__init__()
+ self.conv = ResidualConv(inchannels=inchannels)
+ #nn.BatchNorm2d
+ self.up = nn.Sequential(ResidualConv(inchannels=inchannels),
+ nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3,stride=2, padding=1, output_padding=1),
+ nn.BatchNorm2d(num_features=outchannels),
+ nn.ReLU(inplace=True))
+
+ def forward(self, lowfeat, highfeat):
+ return self.up(highfeat + self.conv(lowfeat))
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ #init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ #init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ #init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ #init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+
+class SenceUnderstand(nn.Module):
+ def __init__(self, channels):
+ super(SenceUnderstand, self).__init__()
+ self.channels = channels
+ self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
+ nn.ReLU(inplace = True))
+ self.pool = nn.AdaptiveAvgPool2d(8)
+ self.fc = nn.Sequential(nn.Linear(512*8*8, self.channels),
+ nn.ReLU(inplace = True))
+ self.conv2 = nn.Sequential(nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0),
+ nn.ReLU(inplace=True))
+ self.initial_params()
+
+ def forward(self, x):
+ n,c,h,w = x.size()
+ x = self.conv1(x)
+ x = self.pool(x)
+ x = x.view(n,-1)
+ x = self.fc(x)
+ x = x.view(n, self.channels, 1, 1)
+ x = self.conv2(x)
+ x = x.repeat(1,1,h,w)
+ return x
+
+ def initial_params(self, dev=0.01):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ #print torch.sum(m.weight)
+ m.weight.data.normal_(0, dev)
+ if m.bias is not None:
+ m.bias.data.fill_(0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ #print torch.sum(m.weight)
+ m.weight.data.normal_(0, dev)
+ if m.bias is not None:
+ m.bias.data.fill_(0)
+ elif isinstance(m, nn.Linear):
+ m.weight.data.normal_(0, dev)
diff --git a/models/SpaTrackV2/models/depth_refiner/stablilization_attention.py b/models/SpaTrackV2/models/depth_refiner/stablilization_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b2f27a6ebedd000dafb0c64d267094f858523e6
--- /dev/null
+++ b/models/SpaTrackV2/models/depth_refiner/stablilization_attention.py
@@ -0,0 +1,1187 @@
+import math
+import time
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.layers import DropPath, to_2tuple, trunc_normal_
+from einops import rearrange
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+def window_partition_noreshape(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (B, num_windows_h, num_windows_w, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
+ return windows
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+def get_roll_masks(H, W, window_size, shift_size):
+ #####################################
+ # move to top-left
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, H-window_size),
+ slice(H-window_size, H-shift_size),
+ slice(H-shift_size, H))
+ w_slices = (slice(0, W-window_size),
+ slice(W-window_size, W-shift_size),
+ slice(W-shift_size, W))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, window_size * window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask_tl = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ ####################################
+ # move to top right
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, H-window_size),
+ slice(H-window_size, H-shift_size),
+ slice(H-shift_size, H))
+ w_slices = (slice(0, shift_size),
+ slice(shift_size, window_size),
+ slice(window_size, W))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, window_size * window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask_tr = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ ####################################
+ # move to bottom left
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, shift_size),
+ slice(shift_size, window_size),
+ slice(window_size, H))
+ w_slices = (slice(0, W-window_size),
+ slice(W-window_size, W-shift_size),
+ slice(W-shift_size, W))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, window_size * window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask_bl = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ ####################################
+ # move to bottom right
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, shift_size),
+ slice(shift_size, window_size),
+ slice(window_size, H))
+ w_slices = (slice(0, shift_size),
+ slice(shift_size, window_size),
+ slice(window_size, W))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, window_size * window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask_br = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ # append all
+ attn_mask_all = torch.cat((attn_mask_tl, attn_mask_tr, attn_mask_bl, attn_mask_br), -1)
+ return attn_mask_all
+
+def get_relative_position_index(q_windows, k_windows):
+ """
+ Args:
+ q_windows: tuple (query_window_height, query_window_width)
+ k_windows: tuple (key_window_height, key_window_width)
+
+ Returns:
+ relative_position_index: query_window_height*query_window_width, key_window_height*key_window_width
+ """
+ # get pair-wise relative position index for each token inside the window
+ coords_h_q = torch.arange(q_windows[0])
+ coords_w_q = torch.arange(q_windows[1])
+ coords_q = torch.stack(torch.meshgrid([coords_h_q, coords_w_q])) # 2, Wh_q, Ww_q
+
+ coords_h_k = torch.arange(k_windows[0])
+ coords_w_k = torch.arange(k_windows[1])
+ coords_k = torch.stack(torch.meshgrid([coords_h_k, coords_w_k])) # 2, Wh, Ww
+
+ coords_flatten_q = torch.flatten(coords_q, 1) # 2, Wh_q*Ww_q
+ coords_flatten_k = torch.flatten(coords_k, 1) # 2, Wh_k*Ww_k
+
+ relative_coords = coords_flatten_q[:, :, None] - coords_flatten_k[:, None, :] # 2, Wh_q*Ww_q, Wh_k*Ww_k
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh_q*Ww_q, Wh_k*Ww_k, 2
+ relative_coords[:, :, 0] += k_windows[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += k_windows[1] - 1
+ relative_coords[:, :, 0] *= (q_windows[1] + k_windows[1]) - 1
+ relative_position_index = relative_coords.sum(-1) # Wh_q*Ww_q, Wh_k*Ww_k
+ return relative_position_index
+
+def get_relative_position_index3d(q_windows, k_windows, num_clips):
+ """
+ Args:
+ q_windows: tuple (query_window_height, query_window_width)
+ k_windows: tuple (key_window_height, key_window_width)
+
+ Returns:
+ relative_position_index: query_window_height*query_window_width, key_window_height*key_window_width
+ """
+ # get pair-wise relative position index for each token inside the window
+ coords_d_q = torch.arange(num_clips)
+ coords_h_q = torch.arange(q_windows[0])
+ coords_w_q = torch.arange(q_windows[1])
+ coords_q = torch.stack(torch.meshgrid([coords_d_q, coords_h_q, coords_w_q])) # 2, Wh_q, Ww_q
+
+ coords_d_k = torch.arange(num_clips)
+ coords_h_k = torch.arange(k_windows[0])
+ coords_w_k = torch.arange(k_windows[1])
+ coords_k = torch.stack(torch.meshgrid([coords_d_k, coords_h_k, coords_w_k])) # 2, Wh, Ww
+
+ coords_flatten_q = torch.flatten(coords_q, 1) # 2, Wh_q*Ww_q
+ coords_flatten_k = torch.flatten(coords_k, 1) # 2, Wh_k*Ww_k
+
+ relative_coords = coords_flatten_q[:, :, None] - coords_flatten_k[:, None, :] # 2, Wh_q*Ww_q, Wh_k*Ww_k
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh_q*Ww_q, Wh_k*Ww_k, 2
+ relative_coords[:, :, 0] += num_clips - 1 # shift to start from 0
+ relative_coords[:, :, 1] += k_windows[0] - 1
+ relative_coords[:, :, 2] += k_windows[1] - 1
+ relative_coords[:, :, 0] *= (q_windows[0] + k_windows[0] - 1)*(q_windows[1] + k_windows[1] - 1)
+ relative_coords[:, :, 1] *= (q_windows[1] + k_windows[1] - 1)
+ relative_position_index = relative_coords.sum(-1) # Wh_q*Ww_q, Wh_k*Ww_k
+ return relative_position_index
+
+
+class WindowAttention3d3(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+
+ Args:
+ dim (int): Number of input channels.
+ expand_size (int): The expand size at focal level 1.
+ window_size (tuple[int]): The height and width of the window.
+ focal_window (int): Focal region size.
+ focal_level (int): Focal attention level.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ pool_method (str): window pooling method. Default: none
+ """
+
+ def __init__(self, dim, expand_size, window_size, focal_window, focal_level, num_heads,
+ qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., pool_method="none", focal_l_clips=[7,1,2], focal_kernel_clips=[7,5,3]):
+
+ super().__init__()
+ self.dim = dim
+ self.expand_size = expand_size
+ self.window_size = window_size # Wh, Ww
+ self.pool_method = pool_method
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+ self.focal_level = focal_level
+ self.focal_window = focal_window
+
+ # define a parameter table of relative position bias for each window
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ num_clips=4
+ # # define a parameter table of relative position bias
+ # self.relative_position_bias_table = nn.Parameter(
+ # torch.zeros((2 * num_clips - 1) * (2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
+
+ # # get pair-wise relative position index for each token inside the window
+ # coords_d = torch.arange(num_clips)
+ # coords_h = torch.arange(self.window_size[0])
+ # coords_w = torch.arange(self.window_size[1])
+ # coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww
+ # coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww
+ # relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww
+ # relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3
+ # relative_coords[:, :, 0] += num_clips - 1 # shift to start from 0
+ # relative_coords[:, :, 1] += self.window_size[0] - 1
+ # relative_coords[:, :, 2] += self.window_size[1] - 1
+
+ # relative_coords[:, :, 0] *= (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1)
+ # relative_coords[:, :, 1] *= (2 * self.window_size[1] - 1)
+ # relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww
+ # self.register_buffer("relative_position_index", relative_position_index)
+
+
+ if self.expand_size > 0 and focal_level > 0:
+ # define a parameter table of position bias between window and its fine-grained surroundings
+ self.window_size_of_key = self.window_size[0] * self.window_size[1] if self.expand_size == 0 else \
+ (4 * self.window_size[0] * self.window_size[1] - 4 * (self.window_size[0] - self.expand_size) * (self.window_size[0] - self.expand_size))
+ self.relative_position_bias_table_to_neighbors = nn.Parameter(
+ torch.zeros(1, num_heads, self.window_size[0] * self.window_size[1], self.window_size_of_key)) # Wh*Ww, nH, nSurrounding
+ trunc_normal_(self.relative_position_bias_table_to_neighbors, std=.02)
+
+ # get mask for rolled k and rolled v
+ mask_tl = torch.ones(self.window_size[0], self.window_size[1]); mask_tl[:-self.expand_size, :-self.expand_size] = 0
+ mask_tr = torch.ones(self.window_size[0], self.window_size[1]); mask_tr[:-self.expand_size, self.expand_size:] = 0
+ mask_bl = torch.ones(self.window_size[0], self.window_size[1]); mask_bl[self.expand_size:, :-self.expand_size] = 0
+ mask_br = torch.ones(self.window_size[0], self.window_size[1]); mask_br[self.expand_size:, self.expand_size:] = 0
+ mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br), 0).flatten(0)
+ self.register_buffer("valid_ind_rolled", mask_rolled.nonzero().view(-1))
+
+ if pool_method != "none" and focal_level > 1:
+ #self.relative_position_bias_table_to_windows = nn.ParameterList()
+ #self.relative_position_bias_table_to_windows_clips = nn.ParameterList()
+ #self.register_parameter('relative_position_bias_table_to_windows',[])
+ #self.register_parameter('relative_position_bias_table_to_windows_clips',[])
+ self.unfolds = nn.ModuleList()
+ self.unfolds_clips=nn.ModuleList()
+
+ # build relative position bias between local patch and pooled windows
+ for k in range(focal_level-1):
+ stride = 2**k
+ kernel_size = 2*(self.focal_window // 2) + 2**k + (2**k-1)
+ # define unfolding operations
+ self.unfolds += [nn.Unfold(
+ kernel_size=(kernel_size, kernel_size),
+ stride=stride, padding=kernel_size // 2)
+ ]
+
+ # define relative position bias table
+ relative_position_bias_table_to_windows = nn.Parameter(
+ torch.zeros(
+ self.num_heads,
+ (self.window_size[0] + self.focal_window + 2**k - 2) * (self.window_size[1] + self.focal_window + 2**k - 2),
+ )
+ )
+ trunc_normal_(relative_position_bias_table_to_windows, std=.02)
+ #self.relative_position_bias_table_to_windows.append(relative_position_bias_table_to_windows)
+ self.register_parameter('relative_position_bias_table_to_windows_{}'.format(k),relative_position_bias_table_to_windows)
+
+ # define relative position bias index
+ relative_position_index_k = get_relative_position_index(self.window_size, to_2tuple(self.focal_window + 2**k - 1))
+ # relative_position_index_k = get_relative_position_index3d(self.window_size, to_2tuple(self.focal_window + 2**k - 1), num_clips)
+ self.register_buffer("relative_position_index_{}".format(k), relative_position_index_k)
+
+ # define unfolding index for focal_level > 0
+ if k > 0:
+ mask = torch.zeros(kernel_size, kernel_size); mask[(2**k)-1:, (2**k)-1:] = 1
+ self.register_buffer("valid_ind_unfold_{}".format(k), mask.flatten(0).nonzero().view(-1))
+
+ for k in range(len(focal_l_clips)):
+ # kernel_size=focal_kernel_clips[k]
+ focal_l_big_flag=False
+ if focal_l_clips[k]>self.window_size[0]:
+ stride=1
+ padding=0
+ kernel_size=focal_kernel_clips[k]
+ kernel_size_true=kernel_size
+ focal_l_big_flag=True
+ # stride=math.ceil(self.window_size/focal_l_clips[k])
+ # padding=(kernel_size-stride)/2
+ else:
+ stride = focal_l_clips[k]
+ # kernel_size
+ # kernel_size = 2*(focal_kernel_clips[k]// 2) + 2**focal_l_clips[k] + (2**focal_l_clips[k]-1)
+ kernel_size = focal_kernel_clips[k] ## kernel_size must be jishu
+ assert kernel_size%2==1
+ padding=kernel_size // 2
+ # kernel_size_true=focal_kernel_clips[k]+2**focal_l_clips[k]-1
+ kernel_size_true=kernel_size
+ # stride=math.ceil(self.window_size/focal_l_clips[k])
+
+ self.unfolds_clips += [nn.Unfold(
+ kernel_size=(kernel_size, kernel_size),
+ stride=stride,
+ padding=padding)
+ ]
+ relative_position_bias_table_to_windows = nn.Parameter(
+ torch.zeros(
+ self.num_heads,
+ (self.window_size[0] + kernel_size_true - 1) * (self.window_size[0] + kernel_size_true - 1),
+ )
+ )
+ trunc_normal_(relative_position_bias_table_to_windows, std=.02)
+ #self.relative_position_bias_table_to_windows_clips.append(relative_position_bias_table_to_windows)
+ self.register_parameter('relative_position_bias_table_to_windows_clips_{}'.format(k),relative_position_bias_table_to_windows)
+ relative_position_index_k = get_relative_position_index(self.window_size, to_2tuple(kernel_size_true))
+ self.register_buffer("relative_position_index_clips_{}".format(k), relative_position_index_k)
+ # if (not focal_l_big_flag) and focal_l_clips[k]>0:
+ # mask = torch.zeros(kernel_size, kernel_size); mask[(2**focal_l_clips[k])-1:, (2**focal_l_clips[k])-1:] = 1
+ # self.register_buffer("valid_ind_unfold_clips_{}".format(k), mask.flatten(0).nonzero().view(-1))
+
+
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.softmax = nn.Softmax(dim=-1)
+ self.focal_l_clips=focal_l_clips
+ self.focal_kernel_clips=focal_kernel_clips
+
+ def forward(self, x_all, mask_all=None, batch_size=None, num_clips=None):
+ """
+ Args:
+ x_all (list[Tensors]): input features at different granularity
+ mask_all (list[Tensors/None]): masks for input features at different granularity
+ """
+ x = x_all[0][0] #
+
+ B0, nH, nW, C = x.shape
+ # assert B==batch_size*num_clips
+ assert B0==batch_size
+ qkv = self.qkv(x).reshape(B0, nH, nW, 3, C).permute(3, 0, 1, 2, 4).contiguous()
+ q, k, v = qkv[0], qkv[1], qkv[2] # B0, nH, nW, C
+
+ # partition q map
+ # print("x.shape: ", x.shape)
+ # print("q.shape: ", q.shape) # [4, 126, 126, 256]
+ (q_windows, k_windows, v_windows) = map(
+ lambda t: window_partition(t, self.window_size[0]).view(
+ -1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads
+ ).transpose(1, 2),
+ (q, k, v)
+ )
+
+ # q_dim0, q_dim1, q_dim2, q_dim3=q_windows.shape
+ # q_windows=q_windows.view(batch_size, num_clips, (nH//self.window_size[0])*(nW//self.window_size[1]), q_dim1, q_dim2, q_dim3)
+ # q_windows=q_windows[:,-1].contiguous().view(-1, q_dim1, q_dim2, q_dim3) # query for the last frame (target frame)
+
+ # k_windows.shape [1296, 8, 49, 32]
+
+ if self.expand_size > 0 and self.focal_level > 0:
+ (k_tl, v_tl) = map(
+ lambda t: torch.roll(t, shifts=(-self.expand_size, -self.expand_size), dims=(1, 2)), (k, v)
+ )
+ (k_tr, v_tr) = map(
+ lambda t: torch.roll(t, shifts=(-self.expand_size, self.expand_size), dims=(1, 2)), (k, v)
+ )
+ (k_bl, v_bl) = map(
+ lambda t: torch.roll(t, shifts=(self.expand_size, -self.expand_size), dims=(1, 2)), (k, v)
+ )
+ (k_br, v_br) = map(
+ lambda t: torch.roll(t, shifts=(self.expand_size, self.expand_size), dims=(1, 2)), (k, v)
+ )
+
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
+ lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads),
+ (k_tl, k_tr, k_bl, k_br)
+ )
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
+ lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads),
+ (v_tl, v_tr, v_bl, v_br)
+ )
+ k_rolled = torch.cat((k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows), 1).transpose(1, 2)
+ v_rolled = torch.cat((v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows), 1).transpose(1, 2)
+
+ # mask out tokens in current window
+ # print("self.valid_ind_rolled.shape: ", self.valid_ind_rolled.shape) # [132]
+ # print("k_rolled.shape: ", k_rolled.shape) # [1296, 8, 196, 32]
+ k_rolled = k_rolled[:, :, self.valid_ind_rolled]
+ v_rolled = v_rolled[:, :, self.valid_ind_rolled]
+ k_rolled = torch.cat((k_windows, k_rolled), 2)
+ v_rolled = torch.cat((v_windows, v_rolled), 2)
+ else:
+ k_rolled = k_windows; v_rolled = v_windows;
+
+ # print("k_rolled.shape: ", k_rolled.shape) # [1296, 8, 181, 32]
+
+ if self.pool_method != "none" and self.focal_level > 1:
+ k_pooled = []
+ v_pooled = []
+ for k in range(self.focal_level-1):
+ stride = 2**k
+ x_window_pooled = x_all[0][k+1] # B0, nWh, nWw, C
+ nWh, nWw = x_window_pooled.shape[1:3]
+
+ # generate mask for pooled windows
+ # print("x_window_pooled.shape: ", x_window_pooled.shape)
+ mask = x_window_pooled.new(nWh, nWw).fill_(1)
+ # print("here: ",x_window_pooled.shape, self.unfolds[k].kernel_size, self.unfolds[k](mask.unsqueeze(0).unsqueeze(1)).shape)
+ # print(mask.unique())
+ unfolded_mask = self.unfolds[k](mask.unsqueeze(0).unsqueeze(1)).view(
+ 1, 1, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
+ view(nWh*nWw // stride // stride, -1, 1)
+
+ if k > 0:
+ valid_ind_unfold_k = getattr(self, "valid_ind_unfold_{}".format(k))
+ unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
+
+ # print("unfolded_mask.shape: ", unfolded_mask.shape, unfolded_mask.unique())
+ x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
+ # print((x_window_masks == 0).sum(), (x_window_masks > 0).sum(), x_window_masks.unique())
+ x_window_masks = x_window_masks.masked_fill(x_window_masks == 0, float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
+ # print(x_window_masks.shape)
+ mask_all[0][k+1] = x_window_masks
+
+ # generate k and v for pooled windows
+ qkv_pooled = self.qkv(x_window_pooled).reshape(B0, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous()
+ k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B0, C, nWh, nWw
+
+
+ (k_pooled_k, v_pooled_k) = map(
+ lambda t: self.unfolds[k](t).view(
+ B0, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
+ view(-1, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).transpose(1, 2),
+ (k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
+ )
+
+ # print("k_pooled_k.shape: ", k_pooled_k.shape)
+ # print("valid_ind_unfold_k.shape: ", valid_ind_unfold_k.shape)
+
+ if k > 0:
+ (k_pooled_k, v_pooled_k) = map(
+ lambda t: t[:, :, valid_ind_unfold_k], (k_pooled_k, v_pooled_k)
+ )
+
+ # print("k_pooled_k.shape: ", k_pooled_k.shape)
+
+ k_pooled += [k_pooled_k]
+ v_pooled += [v_pooled_k]
+
+ for k in range(len(self.focal_l_clips)):
+ focal_l_big_flag=False
+ if self.focal_l_clips[k]>self.window_size[0]:
+ stride=1
+ focal_l_big_flag=True
+ else:
+ stride = self.focal_l_clips[k]
+ # if self.window_size>=focal_l_clips[k]:
+ # stride=math.ceil(self.window_size/focal_l_clips[k])
+ # # padding=(kernel_size-stride)/2
+ # else:
+ # stride=1
+ # padding=0
+ x_window_pooled = x_all[k+1]
+ nWh, nWw = x_window_pooled.shape[1:3]
+ mask = x_window_pooled.new(nWh, nWw).fill_(1)
+
+ # import pdb; pdb.set_trace()
+ # print(x_window_pooled.shape, self.unfolds_clips[k].kernel_size, self.unfolds_clips[k](mask.unsqueeze(0).unsqueeze(1)).shape)
+
+ unfolded_mask = self.unfolds_clips[k](mask.unsqueeze(0).unsqueeze(1)).view(
+ 1, 1, self.unfolds_clips[k].kernel_size[0], self.unfolds_clips[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
+ view(nWh*nWw // stride // stride, -1, 1)
+
+ # if (not focal_l_big_flag) and self.focal_l_clips[k]>0:
+ # valid_ind_unfold_k = getattr(self, "valid_ind_unfold_clips_{}".format(k))
+ # unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
+
+ # print("unfolded_mask.shape: ", unfolded_mask.shape, unfolded_mask.unique())
+ x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
+ # print((x_window_masks == 0).sum(), (x_window_masks > 0).sum(), x_window_masks.unique())
+ x_window_masks = x_window_masks.masked_fill(x_window_masks == 0, float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
+ # print(x_window_masks.shape)
+ mask_all[k+1] = x_window_masks
+
+ # generate k and v for pooled windows
+ qkv_pooled = self.qkv(x_window_pooled).reshape(B0, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous()
+ k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B0, C, nWh, nWw
+
+ if (not focal_l_big_flag):
+ (k_pooled_k, v_pooled_k) = map(
+ lambda t: self.unfolds_clips[k](t).view(
+ B0, C, self.unfolds_clips[k].kernel_size[0], self.unfolds_clips[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
+ view(-1, self.unfolds_clips[k].kernel_size[0]*self.unfolds_clips[k].kernel_size[1], self.num_heads, C // self.num_heads).transpose(1, 2),
+ (k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
+ )
+ else:
+
+ (k_pooled_k, v_pooled_k) = map(
+ lambda t: self.unfolds_clips[k](t),
+ (k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
+ )
+ LLL=k_pooled_k.size(2)
+ LLL_h=int(LLL**0.5)
+ assert LLL_h**2==LLL
+ k_pooled_k=k_pooled_k.reshape(B0, -1, LLL_h, LLL_h)
+ v_pooled_k=v_pooled_k.reshape(B0, -1, LLL_h, LLL_h)
+
+
+
+ # print("k_pooled_k.shape: ", k_pooled_k.shape)
+ # print("valid_ind_unfold_k.shape: ", valid_ind_unfold_k.shape)
+ # if (not focal_l_big_flag) and self.focal_l_clips[k]:
+ # (k_pooled_k, v_pooled_k) = map(
+ # lambda t: t[:, :, valid_ind_unfold_k], (k_pooled_k, v_pooled_k)
+ # )
+
+ # print("k_pooled_k.shape: ", k_pooled_k.shape)
+
+ k_pooled += [k_pooled_k]
+ v_pooled += [v_pooled_k]
+
+ # qkv_pooled = self.qkv(x_window_pooled).reshape(B0, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous()
+ # k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B0, C, nWh, nWw
+ # (k_pooled_k, v_pooled_k) = map(
+ # lambda t: self.unfolds[k](t).view(
+ # B0, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
+ # view(-1, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).transpose(1, 2),
+ # (k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
+ # )
+ # k_pooled += [k_pooled_k]
+ # v_pooled += [v_pooled_k]
+
+
+ k_all = torch.cat([k_rolled] + k_pooled, 2)
+ v_all = torch.cat([v_rolled] + v_pooled, 2)
+ else:
+ k_all = k_rolled
+ v_all = v_rolled
+
+ N = k_all.shape[-2]
+ q_windows = q_windows * self.scale
+ # print(q_windows.shape, k_all.shape, v_all.shape)
+ # exit()
+ # k_all_dim0, k_all_dim1, k_all_dim2, k_all_dim3=k_all.shape
+ # k_all=k_all.contiguous().view(batch_size, num_clips, (nH//self.window_size[0])*(nW//self.window_size[1]),
+ # k_all_dim1, k_all_dim2, k_all_dim3).permute(0,2,3,4,1,5).contiguous().view(-1, k_all_dim1, k_all_dim2*num_clips, k_all_dim3)
+ # v_all=v_all.contiguous().view(batch_size, num_clips, (nH//self.window_size[0])*(nW//self.window_size[1]),
+ # k_all_dim1, k_all_dim2, k_all_dim3).permute(0,2,3,4,1,5).contiguous().view(-1, k_all_dim1, k_all_dim2*num_clips, k_all_dim3)
+
+ # print(q_windows.shape, k_all.shape, v_all.shape, k_rolled.shape)
+ # exit()
+ attn = (q_windows @ k_all.transpose(-2, -1)) # B0*nW, nHead, window_size*window_size, focal_window_size*focal_window_size
+
+ window_area = self.window_size[0] * self.window_size[1]
+ # window_area_clips= num_clips*self.window_size[0] * self.window_size[1]
+ window_area_rolled = k_rolled.shape[2]
+
+ # add relative position bias for tokens inside window
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ # print(relative_position_bias.shape, attn.shape)
+ attn[:, :, :window_area, :window_area] = attn[:, :, :window_area, :window_area] + relative_position_bias.unsqueeze(0)
+
+ # relative_position_bias = self.relative_position_bias_table[self.relative_position_index[-window_area:, :window_area_clips].reshape(-1)].view(
+ # window_area, window_area_clips, -1) # Wh*Ww,Wd*Wh*Ww,nH
+ # relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().view(self.num_heads,window_area,num_clips,window_area
+ # ).permute(0,1,3,2).contiguous().view(self.num_heads,window_area,window_area_clips).contiguous() # nH, Wh*Ww, Wh*Ww*Wd
+ # # attn_dim0, attn_dim1, attn_dim2, attn_dim3=attn.shape
+ # # attn=attn.view(attn_dim0,attn_dim1,attn_dim2,num_clips,-1)
+ # # print(attn.shape, relative_position_bias.shape)
+ # attn[:,:,:window_area, :window_area_clips]=attn[:,:,:window_area, :window_area_clips] + relative_position_bias.unsqueeze(0)
+ # attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N
+
+ # add relative position bias for patches inside a window
+ if self.expand_size > 0 and self.focal_level > 0:
+ attn[:, :, :window_area, window_area:window_area_rolled] = attn[:, :, :window_area, window_area:window_area_rolled] + self.relative_position_bias_table_to_neighbors
+
+ if self.pool_method != "none" and self.focal_level > 1:
+ # add relative position bias for different windows in an image
+ offset = window_area_rolled
+ # print(offset)
+ for k in range(self.focal_level-1):
+ # add relative position bias
+ relative_position_index_k = getattr(self, 'relative_position_index_{}'.format(k))
+ relative_position_bias_to_windows = getattr(self,'relative_position_bias_table_to_windows_{}'.format(k))[:, relative_position_index_k.view(-1)].view(
+ -1, self.window_size[0] * self.window_size[1], (self.focal_window+2**k-1)**2,
+ ) # nH, NWh*NWw,focal_region*focal_region
+ attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \
+ attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + relative_position_bias_to_windows.unsqueeze(0)
+ # add attentional mask
+ if mask_all[0][k+1] is not None:
+ attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \
+ attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + \
+ mask_all[0][k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[0][k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[0][k+1].shape[-1])
+
+ offset += (self.focal_window+2**k-1)**2
+ # print(offset)
+ for k in range(len(self.focal_l_clips)):
+ focal_l_big_flag=False
+ if self.focal_l_clips[k]>self.window_size[0]:
+ stride=1
+ padding=0
+ kernel_size=self.focal_kernel_clips[k]
+ kernel_size_true=kernel_size
+ focal_l_big_flag=True
+ # stride=math.ceil(self.window_size/focal_l_clips[k])
+ # padding=(kernel_size-stride)/2
+ else:
+ stride = self.focal_l_clips[k]
+ # kernel_size
+ # kernel_size = 2*(self.focal_kernel_clips[k]// 2) + 2**self.focal_l_clips[k] + (2**self.focal_l_clips[k]-1)
+ kernel_size = self.focal_kernel_clips[k]
+ padding=kernel_size // 2
+ # kernel_size_true=self.focal_kernel_clips[k]+2**self.focal_l_clips[k]-1
+ kernel_size_true=kernel_size
+ relative_position_index_k = getattr(self, 'relative_position_index_clips_{}'.format(k))
+ relative_position_bias_to_windows = getattr(self,'relative_position_bias_table_to_windows_clips_{}'.format(k))[:, relative_position_index_k.view(-1)].view(
+ -1, self.window_size[0] * self.window_size[1], (kernel_size_true)**2,
+ )
+ attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] = \
+ attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] + relative_position_bias_to_windows.unsqueeze(0)
+ if mask_all[k+1] is not None:
+ attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] = \
+ attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] + \
+ mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
+ offset += (kernel_size_true)**2
+ # print(offset)
+ # relative_position_index_k = getattr(self, 'relative_position_index_{}'.format(k))
+ # # relative_position_bias_to_windows = self.relative_position_bias_table_to_windows[k][:, relative_position_index_k.view(-1)].view(
+ # # -1, self.window_size[0] * self.window_size[1], (self.focal_window+2**k-1)**2,
+ # # ) # nH, NWh*NWw,focal_region*focal_region
+ # # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \
+ # # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + relative_position_bias_to_windows.unsqueeze(0)
+ # relative_position_bias_to_windows = self.relative_position_bias_table_to_windows[k][:, relative_position_index_k[-window_area:, :].view(-1)].view(
+ # -1, self.window_size[0] * self.window_size[1], num_clips*(self.focal_window+2**k-1)**2,
+ # ).contiguous() # nH, NWh*NWw, num_clips*focal_region*focal_region
+ # relative_position_bias_to_windows = relative_position_bias_to_windows.view(self.num_heads,
+ # window_area,num_clips,-1).permute(0,1,3,2).contiguous().view(self.num_heads,window_area,-1)
+ # attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] = \
+ # attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] + relative_position_bias_to_windows.unsqueeze(0)
+ # # add attentional mask
+ # if mask_all[k+1] is not None:
+ # # print("inside the mask, be careful 1")
+ # # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \
+ # # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + \
+ # # mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
+ # # print("here: ", mask_all[k+1].shape, mask_all[k+1][:, :, None, None, :].shape)
+
+ # attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] = \
+ # attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] + \
+ # mask_all[k+1][:, :, None, None, :,None].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1, num_clips).view(-1, 1, 1, mask_all[k+1].shape[-1]*num_clips)
+ # # print()
+
+ # offset += (self.focal_window+2**k-1)**2
+
+ # print("mask_all[0]: ", mask_all[0])
+ # exit()
+ if mask_all[0][0] is not None:
+ print("inside the mask, be careful 0")
+ nW = mask_all[0].shape[0]
+ attn = attn.view(attn.shape[0] // nW, nW, self.num_heads, window_area, N)
+ attn[:, :, :, :, :window_area] = attn[:, :, :, :, :window_area] + mask_all[0][None, :, None, :, :]
+ attn = attn.view(-1, self.num_heads, window_area, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ # print(x.shape)
+ # x = x.view(B/num_clips, nH, nW, C )
+ # exit()
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+ def flops(self, N, window_size, unfold_size):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += N * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
+ if self.pool_method != "none" and self.focal_level > 1:
+ flops += self.num_heads * N * (self.dim // self.num_heads) * (unfold_size * unfold_size)
+ if self.expand_size > 0 and self.focal_level > 0:
+ flops += self.num_heads * N * (self.dim // self.num_heads) * ((window_size + 2*self.expand_size)**2-window_size**2)
+
+ # x = (attn @ v)
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
+ if self.pool_method != "none" and self.focal_level > 1:
+ flops += self.num_heads * N * (self.dim // self.num_heads) * (unfold_size * unfold_size)
+ if self.expand_size > 0 and self.focal_level > 0:
+ flops += self.num_heads * N * (self.dim // self.num_heads) * ((window_size + 2*self.expand_size)**2-window_size**2)
+
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+
+class CffmTransformerBlock3d3(nn.Module):
+ r""" Focal Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ expand_size (int): expand size at first focal level (finest level).
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ pool_method (str): window pooling method. Default: none, options: [none|fc|conv]
+ focal_level (int): number of focal levels. Default: 1.
+ focal_window (int): region size of focal attention. Default: 1
+ use_layerscale (bool): whether use layer scale for training stability. Default: False
+ layerscale_value (float): scaling value for layer scale. Default: 1e-4
+ """
+
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, expand_size=0, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, pool_method="none",
+ focal_level=1, focal_window=1, use_layerscale=False, layerscale_value=1e-4, focal_l_clips=[7,2,4], focal_kernel_clips=[7,5,3]):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.expand_size = expand_size
+ self.mlp_ratio = mlp_ratio
+ self.pool_method = pool_method
+ self.focal_level = focal_level
+ self.focal_window = focal_window
+ self.use_layerscale = use_layerscale
+ self.focal_l_clips=focal_l_clips
+ self.focal_kernel_clips=focal_kernel_clips
+
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.expand_size = 0
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.window_size_glo = self.window_size
+
+ self.pool_layers = nn.ModuleList()
+ self.pool_layers_clips = nn.ModuleList()
+ if self.pool_method != "none":
+ for k in range(self.focal_level-1):
+ window_size_glo = math.floor(self.window_size_glo / (2 ** k))
+ if self.pool_method == "fc":
+ self.pool_layers.append(nn.Linear(window_size_glo * window_size_glo, 1))
+ self.pool_layers[-1].weight.data.fill_(1./(window_size_glo * window_size_glo))
+ self.pool_layers[-1].bias.data.fill_(0)
+ elif self.pool_method == "conv":
+ self.pool_layers.append(nn.Conv2d(dim, dim, kernel_size=window_size_glo, stride=window_size_glo, groups=dim))
+ for k in range(len(focal_l_clips)):
+ # window_size_glo = math.floor(self.window_size_glo / (2 ** k))
+ if focal_l_clips[k]>self.window_size:
+ window_size_glo = focal_l_clips[k]
+ else:
+ window_size_glo = math.floor(self.window_size_glo / (focal_l_clips[k]))
+ # window_size_glo = focal_l_clips[k]
+ if self.pool_method == "fc":
+ self.pool_layers_clips.append(nn.Linear(window_size_glo * window_size_glo, 1))
+ self.pool_layers_clips[-1].weight.data.fill_(1./(window_size_glo * window_size_glo))
+ self.pool_layers_clips[-1].bias.data.fill_(0)
+ elif self.pool_method == "conv":
+ self.pool_layers_clips.append(nn.Conv2d(dim, dim, kernel_size=window_size_glo, stride=window_size_glo, groups=dim))
+
+ self.norm1 = norm_layer(dim)
+
+ self.attn = WindowAttention3d3(
+ dim, expand_size=self.expand_size, window_size=to_2tuple(self.window_size),
+ focal_window=focal_window, focal_level=focal_level, num_heads=num_heads,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, pool_method=pool_method, focal_l_clips=focal_l_clips, focal_kernel_clips=focal_kernel_clips)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ # print("******self.shift_size: ", self.shift_size)
+
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ H, W = self.input_resolution
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ # print("here mask none")
+ attn_mask = None
+ self.register_buffer("attn_mask", attn_mask)
+
+ if self.use_layerscale:
+ self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
+ self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
+
+ def forward(self, x):
+ H0, W0 = self.input_resolution
+ # B, L, C = x.shape
+ B0, D0, H0, W0, C = x.shape
+ shortcut = x
+ # assert L == H * W, "input feature has wrong size"
+ x=x.reshape(B0*D0,H0,W0,C).reshape(B0*D0,H0*W0,C)
+
+
+ x = self.norm1(x)
+ x = x.reshape(B0*D0, H0, W0, C)
+ # print("here")
+ # exit()
+
+ # pad feature maps to multiples of window size
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W0 % self.window_size) % self.window_size
+ pad_b = (self.window_size - H0 % self.window_size) % self.window_size
+ if pad_r > 0 or pad_b > 0:
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+
+ B, H, W, C = x.shape ## B=B0*D0
+
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # print("shifted_x.shape: ", shifted_x.shape)
+ shifted_x=shifted_x.view(B0,D0,H,W,C)
+ x_windows_all = [shifted_x[:,-1]]
+ x_windows_all_clips=[]
+ x_window_masks_all = [self.attn_mask]
+ x_window_masks_all_clips=[]
+
+ if self.focal_level > 1 and self.pool_method != "none":
+ # if we add coarser granularity and the pool method is not none
+ # pooling_index=0
+ for k in range(self.focal_level-1):
+ window_size_glo = math.floor(self.window_size_glo / (2 ** k))
+ pooled_h = math.ceil(H / self.window_size) * (2 ** k)
+ pooled_w = math.ceil(W / self.window_size) * (2 ** k)
+ H_pool = pooled_h * window_size_glo
+ W_pool = pooled_w * window_size_glo
+
+ x_level_k = shifted_x[:,-1]
+ # trim or pad shifted_x depending on the required size
+ if H > H_pool:
+ trim_t = (H - H_pool) // 2
+ trim_b = H - H_pool - trim_t
+ x_level_k = x_level_k[:, trim_t:-trim_b]
+ elif H < H_pool:
+ pad_t = (H_pool - H) // 2
+ pad_b = H_pool - H - pad_t
+ x_level_k = F.pad(x_level_k, (0,0,0,0,pad_t,pad_b))
+
+ if W > W_pool:
+ trim_l = (W - W_pool) // 2
+ trim_r = W - W_pool - trim_l
+ x_level_k = x_level_k[:, :, trim_l:-trim_r]
+ elif W < W_pool:
+ pad_l = (W_pool - W) // 2
+ pad_r = W_pool - W - pad_l
+ x_level_k = F.pad(x_level_k, (0,0,pad_l,pad_r))
+
+ x_windows_noreshape = window_partition_noreshape(x_level_k.contiguous(), window_size_glo) # B0, nw, nw, window_size, window_size, C
+ nWh, nWw = x_windows_noreshape.shape[1:3]
+ if self.pool_method == "mean":
+ x_windows_pooled = x_windows_noreshape.mean([3, 4]) # B0, nWh, nWw, C
+ elif self.pool_method == "max":
+ x_windows_pooled = x_windows_noreshape.max(-2)[0].max(-2)[0].view(B0, nWh, nWw, C) # B0, nWh, nWw, C
+ elif self.pool_method == "fc":
+ x_windows_noreshape = x_windows_noreshape.view(B0, nWh, nWw, window_size_glo*window_size_glo, C).transpose(3, 4) # B0, nWh, nWw, C, wsize**2
+ x_windows_pooled = self.pool_layers[k](x_windows_noreshape).flatten(-2) # B0, nWh, nWw, C
+ elif self.pool_method == "conv":
+ x_windows_noreshape = x_windows_noreshape.view(-1, window_size_glo, window_size_glo, C).permute(0, 3, 1, 2).contiguous() # B0 * nw * nw, C, wsize, wsize
+ x_windows_pooled = self.pool_layers[k](x_windows_noreshape).view(B0, nWh, nWw, C) # B0, nWh, nWw, C
+
+ x_windows_all += [x_windows_pooled]
+ # print(x_windows_pooled.shape)
+ x_window_masks_all += [None]
+ # pooling_index=pooling_index+1
+
+ x_windows_all_clips += [x_windows_all]
+ x_window_masks_all_clips += [x_window_masks_all]
+ for k in range(len(self.focal_l_clips)):
+ if self.focal_l_clips[k]>self.window_size:
+ window_size_glo = self.focal_l_clips[k]
+ else:
+ window_size_glo = math.floor(self.window_size_glo / (self.focal_l_clips[k]))
+
+ pooled_h = math.ceil(H / self.window_size) * (self.focal_l_clips[k])
+ pooled_w = math.ceil(W / self.window_size) * (self.focal_l_clips[k])
+
+ H_pool = pooled_h * window_size_glo
+ W_pool = pooled_w * window_size_glo
+
+ x_level_k = shifted_x[:,k]
+ if H!=H_pool or W!=W_pool:
+ x_level_k=F.interpolate(x_level_k.permute(0,3,1,2), size=(H_pool, W_pool), mode='bilinear').permute(0,2,3,1)
+
+ # print(x_level_k.shape)
+ x_windows_noreshape = window_partition_noreshape(x_level_k.contiguous(), window_size_glo) # B0, nw, nw, window_size, window_size, C
+ nWh, nWw = x_windows_noreshape.shape[1:3]
+ if self.pool_method == "mean":
+ x_windows_pooled = x_windows_noreshape.mean([3, 4]) # B0, nWh, nWw, C
+ elif self.pool_method == "max":
+ x_windows_pooled = x_windows_noreshape.max(-2)[0].max(-2)[0].view(B0, nWh, nWw, C) # B0, nWh, nWw, C
+ elif self.pool_method == "fc":
+ x_windows_noreshape = x_windows_noreshape.view(B0, nWh, nWw, window_size_glo*window_size_glo, C).transpose(3, 4) # B0, nWh, nWw, C, wsize**2
+ x_windows_pooled = self.pool_layers_clips[k](x_windows_noreshape).flatten(-2) # B0, nWh, nWw, C
+ elif self.pool_method == "conv":
+ x_windows_noreshape = x_windows_noreshape.view(-1, window_size_glo, window_size_glo, C).permute(0, 3, 1, 2).contiguous() # B0 * nw * nw, C, wsize, wsize
+ x_windows_pooled = self.pool_layers_clips[k](x_windows_noreshape).view(B0, nWh, nWw, C) # B0, nWh, nWw, C
+
+ x_windows_all_clips += [x_windows_pooled]
+ # print(x_windows_pooled.shape)
+ x_window_masks_all_clips += [None]
+ # pooling_index=pooling_index+1
+ # exit()
+
+ attn_windows = self.attn(x_windows_all_clips, mask_all=x_window_masks_all_clips, batch_size=B0, num_clips=D0) # nW*B0, window_size*window_size, C
+
+ attn_windows = attn_windows[:, :self.window_size ** 2]
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H(padded) W(padded) C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ # x = x[:, :self.input_resolution[0], :self.input_resolution[1]].contiguous().view(B, -1, C)
+ x = x[:, :H0, :W0].contiguous().view(B0, -1, C)
+
+ # FFN
+ # x = shortcut + self.drop_path(x if (not self.use_layerscale) else (self.gamma_1 * x))
+ # x = x + self.drop_path(self.mlp(self.norm2(x)) if (not self.use_layerscale) else (self.gamma_2 * self.mlp(self.norm2(x))))
+
+ # print(x.shape, shortcut[:,-1].view(B0, -1, C).shape)
+ x = shortcut[:,-1].view(B0, -1, C) + self.drop_path(x if (not self.use_layerscale) else (self.gamma_1 * x))
+ x = x + self.drop_path(self.mlp(self.norm2(x)) if (not self.use_layerscale) else (self.gamma_2 * self.mlp(self.norm2(x))))
+
+ # x=torch.cat([shortcut[:,:-1],x.view(B0,self.input_resolution[0],self.input_resolution[1],C).unsqueeze(1)],1)
+ x=torch.cat([shortcut[:,:-1],x.view(B0,H0,W0,C).unsqueeze(1)],1)
+
+ assert x.shape==shortcut.shape
+
+ # exit()
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+
+ # W-MSA/SW-MSA
+ nW = H * W / self.window_size / self.window_size
+ flops += nW * self.attn.flops(self.window_size * self.window_size, self.window_size, self.focal_window)
+
+ if self.pool_method != "none" and self.focal_level > 1:
+ for k in range(self.focal_level-1):
+ window_size_glo = math.floor(self.window_size_glo / (2 ** k))
+ nW_glo = nW * (2**k)
+ # (sub)-window pooling
+ flops += nW_glo * self.dim * window_size_glo * window_size_glo
+ # qkv for global levels
+ # NOTE: in our implementation, we pass the pooled window embedding to qkv embedding layer,
+ # but theoritically, we only need to compute k and v.
+ flops += nW_glo * self.dim * 3 * self.dim
+
+ # mlp
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+
+class BasicLayer3d3(nn.Module):
+ """ A basic Focal Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ expand_size (int): expand size for focal level 1.
+ expand_layer (str): expand layer. Default: all
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ pool_method (str): Window pooling method. Default: none.
+ focal_level (int): Number of focal levels. Default: 1.
+ focal_window (int): region size at each focal level. Default: 1.
+ use_conv_embed (bool): whether use overlapped convolutional patch embedding layer. Default: False
+ use_shift (bool): Whether use window shift as in Swin Transformer. Default: False
+ use_pre_norm (bool): Whether use pre-norm before patch embedding projection for stability. Default: False
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ use_layerscale (bool): Whether use layer scale for stability. Default: False.
+ layerscale_value (float): Layerscale value. Default: 1e-4.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size, expand_size, expand_layer="all",
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, pool_method="none",
+ focal_level=1, focal_window=1, use_conv_embed=False, use_shift=False, use_pre_norm=False,
+ downsample=None, use_checkpoint=False, use_layerscale=False, layerscale_value=1e-4, focal_l_clips=[16,8,2], focal_kernel_clips=[7,5,3]):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ if expand_layer == "even":
+ expand_factor = 0
+ elif expand_layer == "odd":
+ expand_factor = 1
+ elif expand_layer == "all":
+ expand_factor = -1
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ CffmTransformerBlock3d3(dim=dim, input_resolution=input_resolution,
+ num_heads=num_heads, window_size=window_size,
+ shift_size=(0 if (i % 2 == 0) else window_size // 2) if use_shift else 0,
+ expand_size=0 if (i % 2 == expand_factor) else expand_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ pool_method=pool_method,
+ focal_level=focal_level,
+ focal_window=focal_window,
+ use_layerscale=use_layerscale,
+ layerscale_value=layerscale_value,
+ focal_l_clips=focal_l_clips,
+ focal_kernel_clips=focal_kernel_clips)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(
+ img_size=input_resolution, patch_size=2, in_chans=dim, embed_dim=2*dim,
+ use_conv_embed=use_conv_embed, norm_layer=norm_layer, use_pre_norm=use_pre_norm,
+ is_stem=False
+ )
+ else:
+ self.downsample = None
+
+ def forward(self, x, batch_size=None, num_clips=None, reg_tokens=None):
+ B, D, C, H, W = x.shape
+ x = rearrange(x, 'b d c h w -> b d h w c')
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+
+ if self.downsample is not None:
+ x = x.view(x.shape[0], self.input_resolution[0], self.input_resolution[1], -1).permute(0, 3, 1, 2).contiguous()
+ x = self.downsample(x)
+ x = rearrange(x, 'b d h w c -> b d c h w')
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
diff --git a/models/SpaTrackV2/models/depth_refiner/stablizer.py b/models/SpaTrackV2/models/depth_refiner/stablizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..54656109147dd4a2bb789a447195962c5d7bf167
--- /dev/null
+++ b/models/SpaTrackV2/models/depth_refiner/stablizer.py
@@ -0,0 +1,342 @@
+import numpy as np
+import torch.nn as nn
+import torch
+# from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
+from collections import OrderedDict
+# from mmseg.ops import resize
+from torch.nn.functional import interpolate as resize
+# from builder import HEADS
+from models.SpaTrackV2.models.depth_refiner.decode_head import BaseDecodeHead, BaseDecodeHead_clips, BaseDecodeHead_clips_flow
+# from mmseg.models.utils import *
+import attr
+from IPython import embed
+from models.SpaTrackV2.models.depth_refiner.stablilization_attention import BasicLayer3d3
+import cv2
+from models.SpaTrackV2.models.depth_refiner.network import *
+import warnings
+# from mmcv.utils import Registry, build_from_cfg
+from torch import nn
+from einops import rearrange
+import torch.nn.functional as F
+from models.SpaTrackV2.models.blocks import (
+ AttnBlock, CrossAttnBlock, Mlp
+)
+
+class MLP(nn.Module):
+ """
+ Linear Embedding
+ """
+ def __init__(self, input_dim=2048, embed_dim=768):
+ super().__init__()
+ self.proj = nn.Linear(input_dim, embed_dim)
+
+ def forward(self, x):
+ x = x.flatten(2).transpose(1, 2)
+ x = self.proj(x)
+ return x
+
+
+def scatter_multiscale_fast(
+ track2d: torch.Tensor,
+ trackfeature: torch.Tensor,
+ H: int,
+ W: int,
+ kernel_sizes = [1]
+) -> torch.Tensor:
+ """
+ Scatter sparse track features onto a dense image grid with weighted multi-scale pooling to handle zero-value gaps.
+
+ This function scatters sparse track features into a dense image grid and applies multi-scale average pooling
+ while excluding zero-value holes. The weight mask ensures that only valid feature regions contribute to the pooling,
+ avoiding dilution by empty pixels.
+
+ Args:
+ track2d (torch.Tensor): Float tensor of shape (B, T, N, 2) containing (x, y) pixel coordinates
+ for each track point across batches, frames, and points.
+ trackfeature (torch.Tensor): Float tensor of shape (B, T, N, C) with C-dimensional features
+ for each track point.
+ H (int): Height of the target output image.
+ W (int): Width of the target output image.
+ kernel_sizes (List[int]): List of odd integers for average pooling kernel sizes. Default: [3, 5, 7].
+
+ Returns:
+ torch.Tensor: Multi-scale fused feature map of shape (B, T, C, H, W) with hole-resistant pooling.
+ """
+ B, T, N, C = trackfeature.shape
+ device = trackfeature.device
+
+ # 1. Flatten coordinates and filter valid points within image bounds
+ coords_flat = track2d.round().long().reshape(-1, 2) # (B*T*N, 2)
+ x = coords_flat[:, 0] # x coordinates
+ y = coords_flat[:, 1] # y coordinates
+ feat_flat = trackfeature.reshape(-1, C) # Flatten features
+
+ valid_mask = (x >= 0) & (x < W) & (y >= 0) & (y < H)
+ x = x[valid_mask]
+ y = y[valid_mask]
+ feat_flat = feat_flat[valid_mask]
+ valid_count = x.shape[0]
+
+ if valid_count == 0:
+ return torch.zeros(B, T, C, H, W, device=device) # Handle no-valid-point case
+
+ # 2. Calculate linear indices and batch-frame indices for scattering
+ lin_idx = y * W + x # Linear index within a single frame (H*W range)
+
+ # Generate batch-frame indices (e.g., 0~B*T-1 for each frame in batch)
+ bt_idx_raw = (
+ torch.arange(B * T, device=device)
+ .view(B, T, 1)
+ .expand(B, T, N)
+ .reshape(-1)
+ )
+ bt_idx = bt_idx_raw[valid_mask] # Indices for valid points across batch and frames
+
+ # 3. Create accumulation buffers for features and weights
+ total_space = B * T * H * W
+ img_accum_flat = torch.zeros(total_space, C, device=device) # Feature accumulator
+ weight_accum_flat = torch.zeros(total_space, 1, device=device) # Weight accumulator (counts)
+
+ # 4. Scatter features and weights into accumulation buffers
+ idx_in_accum = bt_idx * (H * W) + lin_idx # Global index: batch_frame * H*W + pixel_index
+
+ # Add features to corresponding indices (index_add_ is efficient for sparse updates)
+ img_accum_flat.index_add_(0, idx_in_accum, feat_flat)
+ weight_accum_flat.index_add_(0, idx_in_accum, torch.ones((valid_count, 1), device=device))
+
+ # 5. Normalize features by valid weights, keep zeros for invalid regions
+ valid_mask_flat = weight_accum_flat > 0 # Binary mask for valid pixels
+ img_accum_flat = img_accum_flat / (weight_accum_flat + 1e-6) # Avoid division by zero
+ img_accum_flat = img_accum_flat * valid_mask_flat.float() # Mask out invalid regions
+
+ # 6. Reshape to (B, T, C, H, W) for further processing
+ img = (
+ img_accum_flat.view(B, T, H, W, C)
+ .permute(0, 1, 4, 2, 3)
+ .contiguous()
+ ) # Shape: (B, T, C, H, W)
+
+ # 7. Multi-scale pooling with weight masking to exclude zero holes
+ blurred_outputs = []
+ for k in kernel_sizes:
+ pad = k // 2
+ img_bt = img.view(B*T, C, H, W) # Flatten batch and time for pooling
+
+ # Create weight mask for valid regions (1 where features exist, 0 otherwise)
+ weight_mask = (
+ weight_accum_flat.view(B, T, 1, H, W) > 0
+ ).float().view(B*T, 1, H, W) # Shape: (B*T, 1, H, W)
+
+ # Calculate number of valid neighbors in each pooling window
+ weight_sum = F.conv2d(
+ weight_mask,
+ torch.ones((1, 1, k, k), device=device),
+ stride=1,
+ padding=pad
+ ) # Shape: (B*T, 1, H, W)
+
+ # Sum features only in valid regions
+ feat_sum = F.conv2d(
+ img_bt * weight_mask, # Mask out invalid regions before summing
+ torch.ones((1, 1, k, k), device=device).expand(C, 1, k, k),
+ stride=1,
+ padding=pad,
+ groups=C
+ ) # Shape: (B*T, C, H, W)
+
+ # Compute average only over valid neighbors
+ feat_avg = feat_sum / (weight_sum + 1e-6)
+ blurred_outputs.append(feat_avg)
+
+ # 8. Fuse multi-scale results by averaging across kernel sizes
+ fused = torch.stack(blurred_outputs).mean(dim=0) # Average over kernel sizes
+ return fused.view(B, T, C, H, W) # Restore original shape
+
+#@HEADS.register_module()
+class Stabilization_Network_Cross_Attention(BaseDecodeHead_clips_flow):
+
+ def __init__(self, feature_strides, **kwargs):
+ super(Stabilization_Network_Cross_Attention, self).__init__(input_transform='multiple_select', **kwargs)
+ self.training = False
+ assert len(feature_strides) == len(self.in_channels)
+ assert min(feature_strides) == feature_strides[0]
+ self.feature_strides = feature_strides
+
+ c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
+
+ decoder_params = kwargs['decoder_params']
+ embedding_dim = decoder_params['embed_dim']
+
+ self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
+ self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
+ self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
+ self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)
+
+ self.linear_fuse = nn.Sequential(nn.Conv2d(embedding_dim*4, embedding_dim, kernel_size=(1, 1), stride=(1, 1), bias=False),\
+ nn.ReLU(inplace=True))
+
+ self.proj_track = nn.Conv2d(100, 128, kernel_size=(1, 1), stride=(1, 1), bias=True)
+
+ depths = decoder_params['depths']
+
+ self.reg_tokens = nn.Parameter(torch.zeros(1, 2, embedding_dim))
+ self.global_patch = nn.Conv2d(embedding_dim, embedding_dim, kernel_size=(8, 8), stride=(8, 8), bias=True)
+
+ self.att_temporal = nn.ModuleList(
+ [
+ AttnBlock(embedding_dim, 8,
+ mlp_ratio=4, flash=True, ckpt_fwd=True)
+ for _ in range(8)
+ ]
+ )
+ self.att_spatial = nn.ModuleList(
+ [
+ AttnBlock(embedding_dim, 8,
+ mlp_ratio=4, flash=True, ckpt_fwd=True)
+ for _ in range(8)
+ ]
+ )
+ self.scale_shift_head = nn.Sequential(nn.Linear(embedding_dim, embedding_dim), nn.GELU(), nn.Linear(embedding_dim, 4))
+
+
+ # Initialize reg tokens
+ nn.init.trunc_normal_(self.reg_tokens, std=0.02)
+
+ self.decoder_focal=BasicLayer3d3(dim=embedding_dim,
+ input_resolution=(96,
+ 96),
+ depth=depths,
+ num_heads=8,
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ pool_method='fc',
+ downsample=None,
+ focal_level=2,
+ focal_window=5,
+ expand_size=3,
+ expand_layer="all",
+ use_conv_embed=False,
+ use_shift=False,
+ use_pre_norm=False,
+ use_checkpoint=False,
+ use_layerscale=False,
+ layerscale_value=1e-4,
+ focal_l_clips=[7,4,2],
+ focal_kernel_clips=[7,5,3])
+
+ self.ffm2 = FFM(inchannels= 256, midchannels= 256, outchannels = 128)
+ self.ffm1 = FFM(inchannels= 128, midchannels= 128, outchannels = 64)
+ self.ffm0 = FFM(inchannels= 64, midchannels= 64, outchannels = 32,upfactor=1)
+ self.AO = AO(32, outchannels=3, upfactor=1)
+ self._c2 = None
+ self._c_further = None
+
+ def buffer_forward(self, inputs, num_clips=None, imgs=None):#,infermode=1):
+
+ # input: B T 7 H W (7 means 3 rgb + 3 pointmap + 1 uncertainty) normalized
+ if self.training:
+ assert self.num_clips==num_clips
+
+ x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32
+ c1, c2, c3, c4 = x
+
+ ############## MLP decoder on C1-C4 ###########
+ n, _, h, w = c4.shape
+ batch_size = n // num_clips
+
+ _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
+ _c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)
+
+ _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
+ _c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)
+
+ _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
+ _c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)
+
+ _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])
+ _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
+
+ _, _, h, w=_c.shape
+ _c_further=_c.reshape(batch_size, num_clips, -1, h, w) #h2w2
+
+ # Expand reg_tokens to match batch size
+ reg_tokens = self.reg_tokens.expand(batch_size*num_clips, -1, -1) # [B, 2, C]
+
+ _c2=self.decoder_focal(_c_further, batch_size=batch_size, num_clips=num_clips, reg_tokens=reg_tokens)
+
+ assert _c_further.shape==_c2.shape
+ self._c2 = _c2
+ self._c_further = _c_further
+
+ # compute the scale and shift of the global patch
+ global_patch = self.global_patch(_c2.view(batch_size*num_clips, -1, h, w)).view(batch_size*num_clips, _c2.shape[2], -1).permute(0,2,1)
+ global_patch = torch.cat([global_patch, reg_tokens], dim=1)
+ for i in range(8):
+ global_patch = self.att_temporal[i](global_patch)
+ global_patch = rearrange(global_patch, '(b t) n c -> (b n) t c', b=batch_size, t=num_clips, c=_c2.shape[2])
+ global_patch = self.att_spatial[i](global_patch)
+ global_patch = rearrange(global_patch, '(b n) t c -> (b t) n c', b=batch_size, t=num_clips, c=_c2.shape[2])
+
+ reg_tokens = global_patch[:, -2:, :]
+ s_ = self.scale_shift_head(reg_tokens)
+ scale = 1 + s_[:, 0, :1].view(batch_size, num_clips, 1, 1, 1)
+ shift = s_[:, 1, 1:].view(batch_size, num_clips, 3, 1, 1)
+ shift[:,:,:2,...] = 0
+ return scale, shift
+
+ def forward(self, inputs, edge_feat, edge_feat1, tracks, tracks_uvd, num_clips=None, imgs=None, vis_track=None):#,infermode=1):
+
+ if self._c2 is None:
+ scale, shift = self.buffer_forward(inputs,num_clips,imgs)
+
+ B, T, N, _ = tracks.shape
+
+ _c2 = self._c2
+ _c_further = self._c_further
+
+ # skip and head
+ _c_further = rearrange(_c_further, 'b t c h w -> (b t) c h w', b=B, t=T)
+ _c2 = rearrange(_c2, 'b t c h w -> (b t) c h w', b=B, t=T)
+
+ outframe = self.ffm2(_c_further, _c2)
+
+ tracks_uv = tracks_uvd[...,:2].clone()
+ track_feature = scatter_multiscale_fast(tracks_uv/2, tracks, outframe.shape[-2], outframe.shape[-1], kernel_sizes=[1, 3, 5])
+ # visualize track_feature as video
+ # import cv2
+ # import imageio
+ # import os
+ # BT, C, H, W = outframe.shape
+ # track_feature_vis = track_feature.view(B, T, 3, H, W).float().detach().cpu().numpy()
+ # track_feature_vis = track_feature_vis.transpose(0,1,3,4,2)
+ # track_feature_vis = (track_feature_vis - track_feature_vis.min()) / (track_feature_vis.max() - track_feature_vis.min() + 1e-6)
+ # track_feature_vis = (track_feature_vis * 255).astype(np.uint8)
+ # imgs =(imgs.detach() + 1) * 127.5
+ # vis_track.visualize(video=imgs, tracks=tracks_uv, filename="test")
+ # for b in range(B):
+ # frames = []
+ # for t in range(T):
+ # frame = track_feature_vis[b,t]
+ # frame = cv2.applyColorMap(frame[...,0], cv2.COLORMAP_JET)
+ # frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ # frames.append(frame)
+ # # Save as gif
+ # imageio.mimsave(f'track_feature_b{b}.gif', frames, duration=0.1)
+ # import pdb; pdb.set_trace()
+ track_feature = rearrange(track_feature, 'b t c h w -> (b t) c h w')
+ track_feature = self.proj_track(track_feature)
+ outframe = self.ffm1(edge_feat1 + track_feature,outframe)
+ outframe = self.ffm0(edge_feat,outframe)
+ outframe = self.AO(outframe)
+
+ return outframe
+
+ def reset_success(self):
+ self._c2 = None
+ self._c_further = None
diff --git a/models/SpaTrackV2/models/predictor.py b/models/SpaTrackV2/models/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..10dd942d42ebaa8465a5ad309f3fd8d59751131a
--- /dev/null
+++ b/models/SpaTrackV2/models/predictor.py
@@ -0,0 +1,153 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+
+from tqdm import tqdm
+from models.SpaTrackV2.models.SpaTrack import SpaTrack2
+from typing import Literal
+import numpy as np
+from pathlib import Path
+from typing import Union, Optional
+import cv2
+import os
+import decord
+
+class Predictor(torch.nn.Module):
+ def __init__(self, args=None):
+ super().__init__()
+ self.args = args
+ self.spatrack = SpaTrack2(loggers=[None, None, None], **args)
+ self.S_wind = args.Track_cfg.s_wind
+ self.overlap = args.Track_cfg.overlap
+
+ def to(self, device: Union[str, torch.device]):
+ self.spatrack.to(device)
+ self.spatrack.base_model.to(device)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Union[str, Path],
+ *,
+ force_download: bool = False,
+ cache_dir: Optional[str] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ model_cfg: Optional[dict] = None,
+ **kwargs,
+ ) -> "SpaTrack2":
+ """
+ Load a pretrained model from a local file or a remote repository.
+
+ Args:
+ pretrained_model_name_or_path (str or Path):
+ - Path to a local model file (e.g., `./model.pth`).
+ - HuggingFace Hub model ID (e.g., `username/model-name`).
+ force_download (bool, optional):
+ Whether to force re-download even if cached. Default: False.
+ cache_dir (str, optional):
+ Custom cache directory. Default: None (use default cache).
+ device (str or torch.device, optional):
+ Target device (e.g., "cuda", "cpu"). Default: None (keep original).
+ **kwargs:
+ Additional config overrides.
+
+ Returns:
+ SpaTrack2: Loaded pretrained model.
+ """
+ # (1) check the path is local or remote
+ if isinstance(pretrained_model_name_or_path, Path):
+ model_path = str(pretrained_model_name_or_path)
+ else:
+ model_path = pretrained_model_name_or_path
+ # (2) if the path is remote, download it
+ if not os.path.exists(model_path):
+ raise NotImplementedError("Remote download not implemented yet. Use a local path.")
+ # (3) load the model weights
+
+ state_dict = torch.load(model_path, map_location="cpu")
+ # (4) initialize the model (can load config.json if exists)
+ config_path = os.path.join(os.path.dirname(model_path), "config.json")
+ config = {}
+ if os.path.exists(config_path):
+ import json
+ with open(config_path, "r") as f:
+ config.update(json.load(f))
+ config.update(kwargs) # allow override the config
+ if model_cfg is not None:
+ config = model_cfg
+ model = cls(config)
+ if "model" in state_dict:
+ model.spatrack.load_state_dict(state_dict["model"], strict=False)
+ else:
+ model.spatrack.load_state_dict(state_dict, strict=False)
+ # (5) device management
+ if device is not None:
+ model.to(device)
+
+ return model
+
+ def forward(self, video: str|torch.Tensor|np.ndarray,
+ depth: str|torch.Tensor|np.ndarray=None,
+ unc_metric: str|torch.Tensor|np.ndarray=None,
+ intrs: str|torch.Tensor|np.ndarray=None,
+ extrs: str|torch.Tensor|np.ndarray=None,
+ queries=None, queries_3d=None, iters_track=4,
+ full_point=False, fps=30, track2d_gt=None,
+ fixed_cam=False, query_no_BA=False, stage=0,
+ support_frame=0, replace_ratio=0.6):
+ """
+ video: this could be a path to a video, a tensor of shape (T, C, H, W) or a numpy array of shape (T, C, H, W)
+ queries: (B, N, 2)
+ """
+
+ if isinstance(video, str):
+ video = decord.VideoReader(video)
+ video = video[::fps].asnumpy() # Convert to numpy array
+ video = np.array(video) # Ensure numpy array
+ video = torch.from_numpy(video).permute(0, 3, 1, 2).float()
+ elif isinstance(video, np.ndarray):
+ video = torch.from_numpy(video).float()
+
+ if isinstance(depth, np.ndarray):
+ depth = torch.from_numpy(depth).float()
+ if isinstance(intrs, np.ndarray):
+ intrs = torch.from_numpy(intrs).float()
+ if isinstance(extrs, np.ndarray):
+ extrs = torch.from_numpy(extrs).float()
+ if isinstance(unc_metric, np.ndarray):
+ unc_metric = torch.from_numpy(unc_metric).float()
+
+ T_, C, H, W = video.shape
+ step_slide = self.S_wind - self.overlap
+ if T_ > self.S_wind:
+
+ num_windows = (T_ - self.S_wind + step_slide) // step_slide
+ T = num_windows * step_slide + self.S_wind
+ pad_len = T - T_
+
+ video = torch.cat([video, video[-1:].repeat(T-video.shape[0], 1, 1, 1)], dim=0)
+ if depth is not None:
+ depth = torch.cat([depth, depth[-1:].repeat(T-depth.shape[0], 1, 1)], dim=0)
+ if intrs is not None:
+ intrs = torch.cat([intrs, intrs[-1:].repeat(T-intrs.shape[0], 1, 1)], dim=0)
+ if extrs is not None:
+ extrs = torch.cat([extrs, extrs[-1:].repeat(T-extrs.shape[0], 1, 1)], dim=0)
+ if unc_metric is not None:
+ unc_metric = torch.cat([unc_metric, unc_metric[-1:].repeat(T-unc_metric.shape[0], 1, 1)], dim=0)
+ with torch.no_grad():
+ ret = self.spatrack.forward_stream(video, queries, T_org=T_,
+ depth=depth, intrs=intrs, unc_metric_in=unc_metric, extrs=extrs, queries_3d=queries_3d,
+ window_len=self.S_wind, overlap_len=self.overlap, track2d_gt=track2d_gt, full_point=full_point, iters_track=iters_track,
+ fixed_cam=fixed_cam, query_no_BA=query_no_BA, stage=stage, support_frame=support_frame, replace_ratio=replace_ratio) + (video[:T_],)
+
+
+ return ret
+
+
+
+
\ No newline at end of file
diff --git a/models/SpaTrackV2/models/tracker3D/TrackRefiner.py b/models/SpaTrackV2/models/tracker3D/TrackRefiner.py
new file mode 100644
index 0000000000000000000000000000000000000000..251e8644b517b09369fb83ee25b0f2bfd68205ea
--- /dev/null
+++ b/models/SpaTrackV2/models/tracker3D/TrackRefiner.py
@@ -0,0 +1,1478 @@
+import os, sys
+import torch
+import torch.amp
+from models.SpaTrackV2.models.tracker3D.co_tracker.cotracker_base import CoTrackerThreeOffline, get_1d_sincos_pos_embed_from_grid
+import torch.nn.functional as F
+from models.SpaTrackV2.utils.visualizer import Visualizer
+from models.SpaTrackV2.utils.model_utils import sample_features5d
+from models.SpaTrackV2.models.blocks import bilinear_sampler
+import torch.nn as nn
+from models.SpaTrackV2.models.tracker3D.co_tracker.utils import (
+ EfficientUpdateFormer, AttnBlock, Attention, CrossAttnBlock,
+ sequence_BCE_loss, sequence_loss, sequence_prob_loss, sequence_dyn_prob_loss, sequence_loss_xyz, balanced_binary_cross_entropy
+)
+from torchvision.io import write_video
+import math
+from models.SpaTrackV2.models.tracker3D.co_tracker.utils import (
+ Mlp, BasicEncoder, EfficientUpdateFormer, GeometryEncoder, NeighborTransformer, CorrPointformer
+)
+from models.SpaTrackV2.utils.embeddings import get_3d_sincos_pos_embed_from_grid
+from einops import rearrange, repeat
+from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import (
+ EfficientUpdateFormer3D, weighted_procrustes_torch, posenc, key_fr_wprocrustes, get_topo_mask,
+ TrackFusion, get_nth_visible_time_index
+)
+from models.SpaTrackV2.models.tracker3D.spatrack_modules.ba import extract_static_from_3DTracks, ba_pycolmap
+from models.SpaTrackV2.models.tracker3D.spatrack_modules.pointmap_updator import PointMapUpdator
+from models.SpaTrackV2.models.depth_refiner.depth_refiner import TrackStablizer
+from models.SpaTrackV2.models.tracker3D.spatrack_modules.alignment import affine_invariant_global_loss
+from models.SpaTrackV2.models.tracker3D.delta_utils.upsample_transformer import UpsampleTransformerAlibi
+
+class TrackRefiner3D(CoTrackerThreeOffline):
+
+ def __init__(self, args=None):
+ super().__init__(**args.base)
+
+ """
+ This is 3D warpper from cotracker, which load the cotracker pretrain and
+ jointly refine the `camera pose`, `3D tracks`, `video depth`, `visibility` and `conf`
+ """
+ self.updateformer3D = EfficientUpdateFormer3D(self.updateformer)
+ self.corr_depth_mlp = Mlp(in_features=256, hidden_features=256, out_features=256)
+ self.rel_pos_mlp = Mlp(in_features=75, hidden_features=128, out_features=128)
+ self.rel_pos_glob_mlp = Mlp(in_features=75, hidden_features=128, out_features=256)
+ self.corr_xyz_mlp = Mlp(in_features=256, hidden_features=128, out_features=128)
+ self.xyz_mlp = Mlp(in_features=126, hidden_features=128, out_features=84)
+ # self.track_feat_mlp = Mlp(in_features=1110, hidden_features=128, out_features=128)
+ self.proj_xyz_embed = Mlp(in_features=1210+50, hidden_features=1110, out_features=1110)
+ # get the anchor point's embedding, and init the pts refiner
+ update_pts = True
+ # self.corr_transformer = nn.ModuleList([
+ # CorrPointformer(
+ # dim=128,
+ # num_heads=8,
+ # head_dim=128 // 8,
+ # mlp_ratio=4.0,
+ # )
+ # for _ in range(self.corr_levels)
+ # ])
+ self.corr_transformer = nn.ModuleList([
+ CorrPointformer(
+ dim=128,
+ num_heads=8,
+ head_dim=128 // 8,
+ mlp_ratio=4.0,
+ )
+ ]
+ )
+ self.fnet = BasicEncoder(input_dim=3,
+ output_dim=self.latent_dim, stride=self.stride)
+ self.corr3d_radius = 3
+
+ if args.stablizer:
+ self.scale_shift_tokens = nn.Parameter(torch.randn(1, 2, self.latent_dim, requires_grad=True))
+ self.upsample_kernel_size = 5
+ self.residual_embedding = nn.Parameter(torch.randn(
+ self.latent_dim, self.model_resolution[0]//16,
+ self.model_resolution[1]//16, requires_grad=True))
+ self.dense_mlp = nn.Conv2d(2*self.latent_dim+63, self.latent_dim, kernel_size=1, stride=1, padding=0)
+ self.upsample_factor = 4
+ self.upsample_transformer = UpsampleTransformerAlibi(
+ kernel_size=self.upsample_kernel_size, # kernel_size=3, #
+ stride=self.stride,
+ latent_dim=self.latent_dim,
+ num_attn_blocks=2,
+ upsample_factor=4,
+ )
+ else:
+ self.update_pointmap = None
+
+ self.mode = args.mode
+ if self.mode == "online":
+ self.s_wind = args.s_wind
+ self.overlap = args.overlap
+
+ def upsample_with_mask(
+ self, inp: torch.Tensor, mask: torch.Tensor
+ ) -> torch.Tensor:
+ """Upsample flow field [H/P, W/P, 2] -> [H, W, 2] using convex combination"""
+ H, W = inp.shape[-2:]
+ up_inp = F.unfold(
+ inp, [self.upsample_kernel_size, self.upsample_kernel_size], padding=(self.upsample_kernel_size - 1) // 2
+ )
+ up_inp = rearrange(up_inp, "b c (h w) -> b c h w", h=H, w=W)
+ up_inp = F.interpolate(up_inp, scale_factor=self.upsample_factor, mode="nearest")
+ up_inp = rearrange(
+ up_inp, "b (c i j) h w -> b c (i j) h w", i=self.upsample_kernel_size, j=self.upsample_kernel_size
+ )
+
+ up_inp = torch.sum(mask * up_inp, dim=2)
+ return up_inp
+
+ def track_from_cam(self, queries, c2w_traj, intrs,
+ rgbs=None, visualize=False):
+ """
+ This function will generate tracks by camera transform
+
+ Args:
+ queries: B T N 4
+ c2w_traj: B T 4 4
+ intrs: B T 3 3
+ """
+ B, T, N, _ = queries.shape
+ query_t = queries[:,0,:,0].to(torch.int64) # B N
+ query_c2w = torch.gather(c2w_traj,
+ dim=1, index=query_t[..., None, None].expand(-1, -1, 4, 4)) # B N 4 4
+ query_intr = torch.gather(intrs,
+ dim=1, index=query_t[..., None, None].expand(-1, -1, 3, 3)) # B N 3 3
+ query_pts = queries[:,0,:,1:4].clone() # B N 3
+ query_d = queries[:,0,:,3:4] # B N 3
+ query_pts[...,2] = 1
+
+ cam_pts = torch.einsum("bnij,bnj->bni", torch.inverse(query_intr), query_pts)*query_d # B N 3
+ # convert to world
+ cam_pts_h = torch.zeros(B, N, 4, device=cam_pts.device)
+ cam_pts_h[..., :3] = cam_pts
+ cam_pts_h[..., 3] = 1
+ world_pts = torch.einsum("bnij,bnj->bni", query_c2w, cam_pts_h)
+ # convert to other frames
+ cam_other_pts_ = torch.einsum("btnij,btnj->btni",
+ torch.inverse(c2w_traj[:,:,None].float().repeat(1,1,N,1,1)),
+ world_pts[:,None].repeat(1,T,1,1))
+ cam_depth = cam_other_pts_[...,2:3]
+ cam_other_pts = cam_other_pts_[...,:3] / (cam_other_pts_[...,2:3].abs()+1e-6)
+ cam_other_pts = torch.einsum("btnij,btnj->btni", intrs[:,:,None].repeat(1,1,N,1,1), cam_other_pts[...,:3])
+ cam_other_pts[..., 2:] = cam_depth
+
+ if visualize:
+ viser = Visualizer(save_dir=".", grayscale=True,
+ fps=10, pad_value=50, tracks_leave_trace=0)
+ cam_other_pts[..., 0] /= self.factor_x
+ cam_other_pts[..., 1] /= self.factor_y
+ viser.visualize(video=rgbs, tracks=cam_other_pts[..., :2], filename="test")
+
+
+ init_xyzs = cam_other_pts
+
+ return init_xyzs, world_pts[..., :3], cam_other_pts_[..., :3]
+
+ def cam_from_track(self, tracks, intrs,
+ dyn_prob=None, metric_unc=None,
+ vis_est=None, only_cam_pts=False,
+ track_feat_concat=None,
+ tracks_xyz=None,
+ query_pts=None,
+ fixed_cam=False,
+ depth_unproj=None,
+ cam_gt=None,
+ init_pose=False,
+ ):
+ """
+ This function will generate tracks by camera transform
+
+ Args:
+ queries: B T N 3
+ scale_est: 1 1
+ shift_est: 1 1
+ intrs: B T 3 3
+ dyn_prob: B T N
+ metric_unc: B N 1
+ query_pts: B T N 3
+ """
+ if tracks_xyz is not None:
+ B, T, N, _ = tracks.shape
+ cam_pts = tracks_xyz
+ intr_repeat = intrs[:,:,None].repeat(1,1,N,1,1)
+ else:
+ B, T, N, _ = tracks.shape
+ # get the pts in cam coordinate
+ tracks_xy = tracks[...,:3].clone().detach() # B T N 3
+ # tracks_z = 1/(tracks[...,2:] * scale_est + shift_est) # B T N 1
+ tracks_z = tracks[...,2:].detach() # B T N 1
+ tracks_xy[...,2] = 1
+ intr_repeat = intrs[:,:,None].repeat(1,1,N,1,1)
+ cam_pts = torch.einsum("bnij,bnj->bni",
+ torch.inverse(intr_repeat.view(B*T,N,3,3)).float(),
+ tracks_xy.view(B*T, N, 3))*(tracks_z.view(B*T,N,1).abs()) # B*T N 3
+ cam_pts[...,2] *= torch.sign(tracks_z.view(B*T,N))
+ # get the normalized cam pts, and pts refiner
+ mask_z = (tracks_z.max(dim=1)[0]<200).squeeze()
+ cam_pts = cam_pts.view(B, T, N, 3)
+
+ if only_cam_pts:
+ return cam_pts
+ dyn_prob = dyn_prob.mean(dim=1)[..., None]
+ # B T N 3 -> local frames coordinates. transformer static points B T N 3 -> B T N 3 static (B T N 3) -> same -> dynamic points @ C2T.inverse()
+ # get the cam pose
+ vis_est_ = vis_est[:,:,None,:]
+ graph_matrix = (vis_est_*vis_est_.permute(0, 2,1,3)).detach()
+ # find the max connected component
+ key_fr_idx = [0]
+ weight_final = (metric_unc) # * vis_est
+
+
+ with torch.amp.autocast(enabled=False, device_type='cuda'):
+ if fixed_cam:
+ c2w_traj_init = self.c2w_est_curr
+ c2w_traj_glob = c2w_traj_init
+ cam_pts_refine = cam_pts
+ intrs_refine = intrs
+ xy_refine = query_pts[...,1:3]
+ world_tracks_init = torch.einsum("btij,btnj->btni", c2w_traj_init[:,:,:3,:3], cam_pts) + c2w_traj_init[:,:,None,:3,3]
+ world_tracks_refined = world_tracks_init
+ # extract the stable static points for refine the camera pose
+ intrs_dn = intrs.clone()
+ intrs_dn[...,0,:] *= self.factor_x
+ intrs_dn[...,1,:] *= self.factor_y
+ _, query_world_pts, _ = self.track_from_cam(query_pts, c2w_traj_init, intrs_dn)
+ world_tracks_static, mask_static, mask_topk, vis_mask_static, tracks2d_static = extract_static_from_3DTracks(world_tracks_init,
+ dyn_prob, query_world_pts,
+ vis_est, tracks, img_size=self.image_size,
+ K=0)
+ world_static_refine = world_tracks_static
+
+ else:
+
+ if (not self.training):
+ # if (self.c2w_est_curr==torch.eye(4, device=cam_pts.device).repeat(B, T, 1, 1)).all():
+ campts_update = torch.einsum("btij,btnj->btni", self.c2w_est_curr[...,:3,:3], cam_pts) + self.c2w_est_curr[...,None,:3,3]
+ # campts_update = cam_pts
+ c2w_traj_init_update = key_fr_wprocrustes(campts_update, graph_matrix,
+ (weight_final*(1-dyn_prob)).permute(0,2,1), vis_est_.permute(0,1,3,2))
+ c2w_traj_init = c2w_traj_init_update@self.c2w_est_curr
+ # else:
+ # c2w_traj_init = self.c2w_est_curr # extract the stable static points for refine the camera pose
+ else:
+ # if (self.c2w_est_curr==torch.eye(4, device=cam_pts.device).repeat(B, T, 1, 1)).all():
+ campts_update = torch.einsum("btij,btnj->btni", self.c2w_est_curr[...,:3,:3], cam_pts) + self.c2w_est_curr[...,None,:3,3]
+ # campts_update = cam_pts
+ c2w_traj_init_update = key_fr_wprocrustes(campts_update, graph_matrix,
+ (weight_final*(1-dyn_prob)).permute(0,2,1), vis_est_.permute(0,1,3,2))
+ c2w_traj_init = c2w_traj_init_update@self.c2w_est_curr
+ # else:
+ # c2w_traj_init = self.c2w_est_curr # extract the stable static points for refine the camera pose
+
+ intrs_dn = intrs.clone()
+ intrs_dn[...,0,:] *= self.factor_x
+ intrs_dn[...,1,:] *= self.factor_y
+ _, query_world_pts, _ = self.track_from_cam(query_pts, c2w_traj_init, intrs_dn)
+ # refine the world tracks
+ world_tracks_init = torch.einsum("btij,btnj->btni", c2w_traj_init[:,:,:3,:3], cam_pts) + c2w_traj_init[:,:,None,:3,3]
+ world_tracks_static, mask_static, mask_topk, vis_mask_static, tracks2d_static = extract_static_from_3DTracks(world_tracks_init,
+ dyn_prob, query_world_pts,
+ vis_est, tracks, img_size=self.image_size,
+ K=150 if self.training else 1500)
+ # calculate the efficient ba
+ cam_tracks_static = cam_pts[:,:,mask_static.squeeze(),:][:,:,mask_topk.squeeze(),:]
+ cam_tracks_static[...,2] = depth_unproj.view(B, T, N)[:,:,mask_static.squeeze()][:,:,mask_topk.squeeze()]
+
+ c2w_traj_glob, world_static_refine, intrs_refine = ba_pycolmap(world_tracks_static, intrs,
+ c2w_traj_init, vis_mask_static,
+ tracks2d_static, self.image_size,
+ cam_tracks_static=cam_tracks_static,
+ training=self.training, query_pts=query_pts)
+ c2w_traj_glob = c2w_traj_glob.view(B, T, 4, 4)
+ world_tracks_refined = world_tracks_init
+
+ #NOTE: merge the index of static points and topk points
+ # merge_idx = torch.where(mask_static.squeeze()>0)[0][mask_topk.squeeze()]
+ # world_tracks_refined[:,:,merge_idx] = world_static_refine
+
+ # test the procrustes
+ w2c_traj_glob = torch.inverse(c2w_traj_init.detach())
+ cam_pts_refine = torch.einsum("btij,btnj->btni", w2c_traj_glob[:,:,:3,:3], world_tracks_refined) + w2c_traj_glob[:,:,None,:3,3]
+ # get the xyz_refine
+ #TODO: refiner
+ cam_pts4_proj = cam_pts_refine.clone()
+ cam_pts4_proj[...,2] *= torch.sign(cam_pts4_proj[...,2:3].view(B*T,N))
+ xy_refine = torch.einsum("btnij,btnj->btni", intrs_refine.view(B,T,1,3,3).repeat(1,1,N,1,1), cam_pts4_proj/cam_pts4_proj[...,2:3].abs())
+ xy_refine[..., 2] = cam_pts4_proj[...,2:3].view(B*T,N)
+ # xy_refine = torch.zeros_like(cam_pts_refine)[...,:2]
+ return c2w_traj_glob, cam_pts_refine, intrs_refine, xy_refine, world_tracks_init, world_tracks_refined, c2w_traj_init
+
+ def extract_img_feat(self, video, fmaps_chunk_size=200):
+ B, T, C, H, W = video.shape
+ dtype = video.dtype
+ H4, W4 = H // self.stride, W // self.stride
+ # Compute convolutional features for the video or for the current chunk in case of online mode
+ if T > fmaps_chunk_size:
+ fmaps = []
+ for t in range(0, T, fmaps_chunk_size):
+ video_chunk = video[:, t : t + fmaps_chunk_size]
+ fmaps_chunk = self.fnet(video_chunk.reshape(-1, C, H, W))
+ T_chunk = video_chunk.shape[1]
+ C_chunk, H_chunk, W_chunk = fmaps_chunk.shape[1:]
+ fmaps.append(fmaps_chunk.reshape(B, T_chunk, C_chunk, H_chunk, W_chunk))
+ fmaps = torch.cat(fmaps, dim=1).reshape(-1, C_chunk, H_chunk, W_chunk)
+ else:
+ fmaps = self.fnet(video.reshape(-1, C, H, W))
+ fmaps = fmaps.permute(0, 2, 3, 1)
+ fmaps = fmaps / torch.sqrt(
+ torch.maximum(
+ torch.sum(torch.square(fmaps), axis=-1, keepdims=True),
+ torch.tensor(1e-12, device=fmaps.device),
+ )
+ )
+ fmaps = fmaps.permute(0, 3, 1, 2).reshape(
+ B, -1, self.latent_dim, H // self.stride, W // self.stride
+ )
+ fmaps = fmaps.to(dtype)
+
+ return fmaps
+
+ def norm_xyz(self, xyz):
+ """
+ xyz can be (B T N 3) or (B T 3 H W) or (B N 3)
+ """
+ if xyz.ndim == 3:
+ min_pts = self.min_pts
+ max_pts = self.max_pts
+ return (xyz - min_pts[None,None,:]) / (max_pts - min_pts)[None,None,:] * 2 - 1
+ elif xyz.ndim == 4:
+ min_pts = self.min_pts
+ max_pts = self.max_pts
+ return (xyz - min_pts[None,None,None,:]) / (max_pts - min_pts)[None,None,None,:] * 2 - 1
+ elif xyz.ndim == 5:
+ if xyz.shape[2] == 3:
+ min_pts = self.min_pts
+ max_pts = self.max_pts
+ return (xyz - min_pts[None,None,:,None,None]) / (max_pts - min_pts)[None,None,:,None,None] * 2 - 1
+ elif xyz.shape[-1] == 3:
+ min_pts = self.min_pts
+ max_pts = self.max_pts
+ return (xyz - min_pts[None,None,None,None,:]) / (max_pts - min_pts)[None,None,None,None,:] * 2 - 1
+
+ def denorm_xyz(self, xyz):
+ """
+ xyz can be (B T N 3) or (B T 3 H W) or (B N 3)
+ """
+ if xyz.ndim == 3:
+ min_pts = self.min_pts
+ max_pts = self.max_pts
+ return (xyz + 1) / 2 * (max_pts - min_pts)[None,None,:] + min_pts[None,None,:]
+ elif xyz.ndim == 4:
+ min_pts = self.min_pts
+ max_pts = self.max_pts
+ return (xyz + 1) / 2 * (max_pts - min_pts)[None,None,None,:] + min_pts[None,None,None,:]
+ elif xyz.ndim == 5:
+ if xyz.shape[2] == 3:
+ min_pts = self.min_pts
+ max_pts = self.max_pts
+ return (xyz + 1) / 2 * (max_pts - min_pts)[None,None,:,None,None] + min_pts[None,None,:,None,None]
+ elif xyz.shape[-1] == 3:
+ min_pts = self.min_pts
+ max_pts = self.max_pts
+ return (xyz + 1) / 2 * (max_pts - min_pts)[None,None,None,None,:] + min_pts[None,None,None,None,:]
+
+ def forward(
+ self,
+ video,
+ metric_depth,
+ metric_unc,
+ point_map,
+ queries,
+ pts_q_3d=None,
+ overlap_d=None,
+ iters=4,
+ add_space_attn=True,
+ fmaps_chunk_size=200,
+ intrs=None,
+ traj3d_gt=None,
+ custom_vid=False,
+ vis_gt=None,
+ prec_fx=None,
+ prec_fy=None,
+ cam_gt=None,
+ init_pose=False,
+ support_pts_q=None,
+ update_pointmap=True,
+ fixed_cam=False,
+ query_no_BA=False,
+ stage=0,
+ cache=None,
+ points_map_gt=None,
+ valid_only=False,
+ replace_ratio=0.6,
+ ):
+ """Predict tracks
+
+ Args:
+ video (FloatTensor[B, T, 3 H W]): input videos.
+ queries (FloatTensor[B, N, 3]): point queries.
+ iters (int, optional): number of updates. Defaults to 4.
+ vdp_feats_cache: last layer's feature of depth
+ tracks_init: B T N 3 the initialization of 3D tracks computed by cam pose
+ Returns:
+ - coords_predicted (FloatTensor[B, T, N, 2]):
+ - vis_predicted (FloatTensor[B, T, N]):
+ - train_data: `None` if `is_train` is false, otherwise:
+ - all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
+ - all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
+ - mask (BoolTensor[B, T, N]):
+ """
+ self.stage = stage
+
+ if cam_gt is not None:
+ cam_gt = cam_gt.clone()
+ cam_gt = torch.inverse(cam_gt[:,:1,...])@cam_gt
+ B, T, C, _, _ = video.shape
+ _, _, H_, W_ = metric_depth.shape
+ _, _, N, __ = queries.shape
+ if (vis_gt is not None)&(queries.shape[1] == T):
+ aug_visb = True
+ if aug_visb:
+ number_visible = vis_gt.sum(dim=1)
+ ratio_rand = torch.rand(B, N, device=vis_gt.device)
+ # first_positive_inds = get_nth_visible_time_index(vis_gt, 1)
+ first_positive_inds = get_nth_visible_time_index(vis_gt, (number_visible*ratio_rand).long().clamp(min=1, max=T))
+
+ assert (torch.gather(vis_gt, 1, first_positive_inds[:, None, :].repeat(1, T, 1)) < 0).sum() == 0
+ else:
+ __, first_positive_inds = torch.max(vis_gt, dim=1)
+ first_positive_inds = first_positive_inds.long()
+ gather = torch.gather(
+ queries, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2)
+ )
+ xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
+ gather_xyz = torch.gather(
+ traj3d_gt, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 3)
+ )
+ z_gt_query = torch.diagonal(gather_xyz, dim1=1, dim2=2).permute(0, 2, 1)[...,2]
+ queries = torch.cat([first_positive_inds[:, :, None], xys], dim=-1)
+ queries = torch.cat([queries, support_pts_q[:,0]], dim=1)
+ else:
+ # Generate the 768 points randomly in the whole video
+ queries = queries.squeeze(1)
+ ba_len = queries.shape[1]
+ z_gt_query = None
+ if support_pts_q is not None:
+ queries = torch.cat([queries, support_pts_q[:,0]], dim=1)
+
+ if (abs(prec_fx-1.0) > 1e-4) & (self.training) & (traj3d_gt is not None):
+ traj3d_gt[..., 0] /= prec_fx
+ traj3d_gt[..., 1] /= prec_fy
+ queries[...,1] /= prec_fx
+ queries[...,2] /= prec_fy
+
+ video_vis = F.interpolate(video.clone().view(B*T, 3, video.shape[-2], video.shape[-1]), (H_, W_), mode="bilinear", align_corners=False).view(B, T, 3, H_, W_)
+
+ self.image_size = torch.tensor([H_, W_])
+ # self.model_resolution = (H_, W_)
+ # resize the queries and intrs
+ self.factor_x = self.model_resolution[1]/W_
+ self.factor_y = self.model_resolution[0]/H_
+ queries[...,1] *= self.factor_x
+ queries[...,2] *= self.factor_y
+ intrs_org = intrs.clone()
+ intrs[...,0,:] *= self.factor_x
+ intrs[...,1,:] *= self.factor_y
+
+ # get the fmaps and color features
+ video = F.interpolate(video.view(B*T, 3, video.shape[-2], video.shape[-1]),
+ (self.model_resolution[0], self.model_resolution[1])).view(B, T, 3, self.model_resolution[0], self.model_resolution[1])
+ _, _, _, H, W = video.shape
+ if cache is not None:
+ T_cache = cache["fmaps"].shape[0]
+ fmaps = self.extract_img_feat(video[:,T_cache:], fmaps_chunk_size=fmaps_chunk_size)
+ fmaps = torch.cat([cache["fmaps"][None], fmaps], dim=1)
+ else:
+ fmaps = self.extract_img_feat(video, fmaps_chunk_size=fmaps_chunk_size)
+ fmaps_org = fmaps.clone()
+
+ metric_depth = F.interpolate(metric_depth.view(B*T, 1, H_, W_),
+ (self.model_resolution[0], self.model_resolution[1]),mode="nearest").view(B*T, 1, self.model_resolution[0], self.model_resolution[1]).clamp(0.01, 200)
+ self.metric_unc_org = metric_unc.clone()
+ metric_unc = F.interpolate(metric_unc.view(B*T, 1, H_, W_),
+ (self.model_resolution[0], self.model_resolution[1]),mode="nearest").view(B*T, 1, self.model_resolution[0], self.model_resolution[1])
+ if (self.stage == 2) & (self.training):
+ scale_rand = (torch.rand(B, T, device=video.device) - 0.5) + 1
+ point_map = scale_rand.view(B*T,1,1,1) * point_map
+
+ point_map_org = point_map.permute(0,3,1,2).view(B*T, 3, H_, W_).clone()
+ point_map = F.interpolate(point_map_org.clone(),
+ (self.model_resolution[0], self.model_resolution[1]),mode="nearest").view(B*T, 3, self.model_resolution[0], self.model_resolution[1])
+ # align the point map
+ point_map_org_train = point_map_org.view(B*T, 3, H_, W_).clone()
+
+ if (stage == 2):
+ # align the point map
+ try:
+ self.pred_points, scale_gt, shift_gt = affine_invariant_global_loss(
+ point_map_org_train.permute(0,2,3,1),
+ points_map_gt,
+ mask=self.metric_unc_org[:,0]>0.5,
+ align_resolution=32,
+ only_align=True
+ )
+ except:
+ scale_gt, shift_gt = torch.ones(B*T).to(video.device), torch.zeros(B*T,3).to(video.device)
+ self.scale_gt, self.shift_gt = scale_gt, shift_gt
+ else:
+ scale_est, shift_est = None, None
+
+ # extract the pts features
+ device = queries.device
+ assert H % self.stride == 0 and W % self.stride == 0
+
+ B, N, __ = queries.shape
+ queries_z = sample_features5d(metric_depth.view(B, T, 1, H, W),
+ queries[:,None], interp_mode="nearest").squeeze(1)
+ queries_z_unc = sample_features5d(metric_unc.view(B, T, 1, H, W),
+ queries[:,None], interp_mode="nearest").squeeze(1)
+
+ queries_rgb = sample_features5d(video.view(B, T, C, H, W),
+ queries[:,None], interp_mode="nearest").squeeze(1)
+ queries_point_map = sample_features5d(point_map.view(B, T, 3, H, W),
+ queries[:,None], interp_mode="nearest").squeeze(1)
+ if ((queries_z > 100)*(queries_z == 0)).sum() > 0:
+ import pdb; pdb.set_trace()
+
+ if overlap_d is not None:
+ queries_z[:,:overlap_d.shape[1],:] = overlap_d[...,None]
+ queries_point_map[:,:overlap_d.shape[1],2:] = overlap_d[...,None]
+
+ if pts_q_3d is not None:
+ scale_factor = (pts_q_3d[...,-1].permute(0,2,1) / queries_z[:,:pts_q_3d.shape[2],:]).squeeze().median()
+ queries_z[:,:pts_q_3d.shape[2],:] = pts_q_3d[...,-1].permute(0,2,1) / scale_factor
+ queries_point_map[:,:pts_q_3d.shape[2],2:] = pts_q_3d[...,-1].permute(0,2,1) / scale_factor
+
+ # normalize the points
+ self.min_pts, self.max_pts = queries_point_map.mean(dim=(0,1)) - 3*queries_point_map.std(dim=(0,1)), queries_point_map.mean(dim=(0,1)) + 3*queries_point_map.std(dim=(0,1))
+ queries_point_map = self.norm_xyz(queries_point_map)
+ queries_point_map_ = queries_point_map.reshape(B, 1, N, 3).expand(B, T, N, 3).clone()
+ point_map = self.norm_xyz(point_map.view(B, T, 3, H, W)).view(B*T, 3, H, W)
+
+ if z_gt_query is not None:
+ queries_z[:,:z_gt_query.shape[1],:] = z_gt_query[:,:,None]
+ mask_traj_gt = ((queries_z[:,:z_gt_query.shape[1],:] - z_gt_query[:,:,None])).abs() < 0.1
+ else:
+ if traj3d_gt is not None:
+ mask_traj_gt = torch.ones_like(queries_z[:, :traj3d_gt.shape[2]]).bool()
+ else:
+ mask_traj_gt = torch.ones_like(queries_z).bool()
+
+ queries_xyz = torch.cat([queries, queries_z], dim=-1)[:,None].repeat(1, T, 1, 1)
+ if cache is not None:
+ cache_T, cache_N = cache["track2d_pred_cache"].shape[0], cache["track2d_pred_cache"].shape[1]
+ cachexy = cache["track2d_pred_cache"].clone()
+ cachexy[...,0] = cachexy[...,0] * self.factor_x
+ cachexy[...,1] = cachexy[...,1] * self.factor_y
+ # initialize the 2d points with cache
+ queries_xyz[:,:cache_T,:cache_N,1:] = cachexy
+ queries_xyz[:,cache_T:,:cache_N,1:] = cachexy[-1:]
+ # initialize the 3d points with cache
+ queries_point_map_[:,:cache_T,:cache_N,:] = self.norm_xyz(cache["track3d_pred_cache"][None])
+ queries_point_map_[:,cache_T:,:cache_N,:] = self.norm_xyz(cache["track3d_pred_cache"][-1:][None])
+
+ if cam_gt is not None:
+ q_static_proj, q_xyz_world, q_xyz_cam = self.track_from_cam(queries_xyz, cam_gt,
+ intrs, rgbs=video_vis, visualize=False)
+ q_static_proj[..., 0] /= self.factor_x
+ q_static_proj[..., 1] /= self.factor_y
+
+
+ assert T >= 1 # A tracker needs at least two frames to track something
+ video = 2 * (video / 255.0) - 1.0
+ dtype = video.dtype
+ queried_frames = queries[:, :, 0].long()
+
+ queried_coords = queries[..., 1:3]
+ queried_coords = queried_coords / self.stride
+
+ # We store our predictions here
+ (all_coords_predictions, all_coords_xyz_predictions,all_vis_predictions,
+ all_confidence_predictions, all_cam_predictions, all_dynamic_prob_predictions,
+ all_cam_pts_predictions, all_world_tracks_predictions, all_world_tracks_refined_predictions,
+ all_scale_est, all_shift_est) = (
+ [],
+ [],
+ [],
+ [],
+ [],
+ [],
+ [],
+ [],
+ [],
+ [],
+ []
+ )
+
+ # We compute track features
+ fmaps_pyramid = []
+ point_map_pyramid = []
+ track_feat_pyramid = []
+ track_feat_support_pyramid = []
+ track_feat3d_pyramid = []
+ track_feat_support3d_pyramid = []
+ track_depth_support_pyramid = []
+ track_point_map_pyramid = []
+ track_point_map_support_pyramid = []
+ fmaps_pyramid.append(fmaps)
+ metric_depth = metric_depth
+ point_map = point_map
+ metric_depth_align = F.interpolate(metric_depth, scale_factor=0.25, mode='nearest')
+ point_map_align = F.interpolate(point_map, scale_factor=0.25, mode='nearest')
+ point_map_pyramid.append(point_map_align.view(B, T, 3, point_map_align.shape[-2], point_map_align.shape[-1]))
+ for i in range(self.corr_levels - 1):
+ fmaps_ = fmaps.reshape(
+ B * T, self.latent_dim, fmaps.shape[-2], fmaps.shape[-1]
+ )
+ fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
+ fmaps = fmaps_.reshape(
+ B, T, self.latent_dim, fmaps_.shape[-2], fmaps_.shape[-1]
+ )
+ fmaps_pyramid.append(fmaps)
+ # downsample the depth
+ metric_depth_ = metric_depth_align.reshape(B*T,1,metric_depth_align.shape[-2],metric_depth_align.shape[-1])
+ metric_depth_ = F.interpolate(metric_depth_, scale_factor=0.5, mode='nearest')
+ metric_depth_align = metric_depth_.reshape(B,T,1,metric_depth_.shape[-2], metric_depth_.shape[-1])
+ # downsample the point map
+ point_map_ = point_map_align.reshape(B*T,3,point_map_align.shape[-2],point_map_align.shape[-1])
+ point_map_ = F.interpolate(point_map_, scale_factor=0.5, mode='nearest')
+ point_map_align = point_map_.reshape(B,T,3,point_map_.shape[-2], point_map_.shape[-1])
+ point_map_pyramid.append(point_map_align)
+
+ for i in range(self.corr_levels):
+ if cache is not None:
+ cache_N = cache["track_feat_pyramid"][i].shape[2]
+ track_feat_cached, track_feat_support_cached = cache["track_feat_pyramid"][i], cache["track_feat_support_pyramid"][i]
+ track_feat3d_cached, track_feat_support3d_cached = cache["track_feat3d_pyramid"][i], cache["track_feat_support3d_pyramid"][i]
+ track_point_map_cached, track_point_map_support_cached = self.norm_xyz(cache["track_point_map_pyramid"][i]), self.norm_xyz(cache["track_point_map_support_pyramid"][i])
+ queried_coords_new = queried_coords[:,cache_N:,:] / 2**i
+ queried_frames_new = queried_frames[:,cache_N:]
+ else:
+ queried_coords_new = queried_coords / 2**i
+ queried_frames_new = queried_frames
+ track_feat, track_feat_support = self.get_track_feat(
+ fmaps_pyramid[i],
+ queried_frames_new,
+ queried_coords_new,
+ support_radius=self.corr_radius,
+ )
+ # get 3d track feat
+ track_point_map, track_point_map_support = self.get_track_feat(
+ point_map_pyramid[i],
+ queried_frames_new,
+ queried_coords_new,
+ support_radius=self.corr3d_radius,
+ )
+ track_feat3d, track_feat_support3d = self.get_track_feat(
+ fmaps_pyramid[i],
+ queried_frames_new,
+ queried_coords_new,
+ support_radius=self.corr3d_radius,
+ )
+ if cache is not None:
+ track_feat = torch.cat([track_feat_cached, track_feat], dim=2)
+ track_point_map = torch.cat([track_point_map_cached, track_point_map], dim=2)
+ track_feat_support = torch.cat([track_feat_support_cached[:,0], track_feat_support], dim=2)
+ track_point_map_support = torch.cat([track_point_map_support_cached[:,0], track_point_map_support], dim=2)
+ track_feat3d = torch.cat([track_feat3d_cached, track_feat3d], dim=2)
+ track_feat_support3d = torch.cat([track_feat_support3d_cached[:,0], track_feat_support3d], dim=2)
+ track_feat_pyramid.append(track_feat.repeat(1, T, 1, 1))
+ track_feat_support_pyramid.append(track_feat_support.unsqueeze(1))
+ track_feat3d_pyramid.append(track_feat3d.repeat(1, T, 1, 1))
+ track_feat_support3d_pyramid.append(track_feat_support3d.unsqueeze(1))
+ track_point_map_pyramid.append(track_point_map.repeat(1, T, 1, 1))
+ track_point_map_support_pyramid.append(track_point_map_support.unsqueeze(1))
+
+
+ D_coords = 2
+ (coord_preds, coords_xyz_preds, vis_preds, confidence_preds,
+ dynamic_prob_preds, cam_preds, pts3d_cam_pred, world_tracks_pred,
+ world_tracks_refined_pred, point_map_preds, scale_ests, shift_ests) = (
+ [], [], [], [], [], [], [], [], [], [], [], []
+ )
+
+ c2w_ests = []
+ vis = torch.zeros((B, T, N), device=device).float()
+ confidence = torch.zeros((B, T, N), device=device).float()
+ dynamic_prob = torch.zeros((B, T, N), device=device).float()
+ pro_analysis_w = torch.zeros((B, T, N), device=device).float()
+
+ coords = queries_xyz[...,1:].clone()
+ coords[...,:2] /= self.stride
+ # coords[...,:2] = queried_coords.reshape(B, 1, N, 2).expand(B, T, N, 2).float()[...,:2]
+ # initialize the 3d points
+ coords_xyz = queries_point_map_.clone()
+
+ # if cache is not None:
+ # viser = Visualizer(save_dir=".", grayscale=True,
+ # fps=10, pad_value=50, tracks_leave_trace=0)
+ # coords_clone = coords.clone()
+ # coords_clone[...,:2] *= self.stride
+ # coords_clone[..., 0] /= self.factor_x
+ # coords_clone[..., 1] /= self.factor_y
+ # viser.visualize(video=video_vis, tracks=coords_clone[..., :2], filename="test")
+ # import pdb; pdb.set_trace()
+
+ if init_pose:
+ q_init_proj, q_xyz_world, q_xyz_cam = self.track_from_cam(queries_xyz, cam_gt,
+ intrs, rgbs=video_vis, visualize=False)
+ q_init_proj[..., 0] /= self.stride
+ q_init_proj[..., 1] /= self.stride
+
+ r = 2 * self.corr_radius + 1
+ r_depth = 2 * self.corr3d_radius + 1
+ anchor_loss = 0
+ # two current states
+ self.c2w_est_curr = torch.eye(4, device=device).repeat(B, T , 1, 1)
+ coords_proj_curr = coords.view(B * T, N, 3)[...,:2]
+ if init_pose:
+ self.c2w_est_curr = cam_gt.to(coords_proj_curr.device).to(coords_proj_curr.dtype)
+ sync_loss = 0
+ if stage == 2:
+ extra_sparse_tokens = self.scale_shift_tokens[:,:,None,:].repeat(B, 1, T, 1)
+ extra_dense_tokens = self.residual_embedding[None,None].repeat(B, T, 1, 1, 1)
+ xyz_pos_enc = posenc(point_map_pyramid[-2].permute(0,1,3,4,2), min_deg=0, max_deg=10).permute(0,1,4,2,3)
+ extra_dense_tokens = torch.cat([xyz_pos_enc, extra_dense_tokens, fmaps_pyramid[-2]], dim=2)
+ extra_dense_tokens = rearrange(extra_dense_tokens, 'b t c h w -> (b t) c h w')
+ extra_dense_tokens = self.dense_mlp(extra_dense_tokens)
+ extra_dense_tokens = rearrange(extra_dense_tokens, '(b t) c h w -> b t c h w', b=B, t=T)
+ else:
+ extra_sparse_tokens = None
+ extra_dense_tokens = None
+
+ scale_est, shift_est = torch.ones(B, T, 1, 1, device=device), torch.zeros(B, T, 1, 3, device=device)
+ residual_point = torch.zeros(B, T, 3, self.model_resolution[0]//self.stride,
+ self.model_resolution[1]//self.stride, device=device)
+
+ for it in range(iters):
+ # query points scale and shift
+ scale_est_query = torch.gather(scale_est, dim=1, index=queries[:,:,None,:1].long())
+ shift_est_query = torch.gather(shift_est, dim=1, index=queries[:,:,None,:1].long().repeat(1, 1, 1, 3))
+
+ coords = coords.detach() # B T N 3
+ coords_xyz = coords_xyz.detach()
+ vis = vis.detach()
+ confidence = confidence.detach()
+ dynamic_prob = dynamic_prob.detach()
+ pro_analysis_w = pro_analysis_w.detach()
+ coords_init = coords.view(B * T, N, 3)
+ coords_xyz_init = coords_xyz.view(B * T, N, 3)
+ corr_embs = []
+ corr_depth_embs = []
+ corr_feats = []
+ for i in range(self.corr_levels):
+ # K_level = int(32*0.8**(i))
+ K_level = 16
+ corr_feat = self.get_correlation_feat(
+ fmaps_pyramid[i], coords_init[...,:2] / 2**i
+ )
+ #NOTE: update the point map
+ residual_point_i = F.interpolate(residual_point.view(B*T,3,residual_point.shape[-2],residual_point.shape[-1]),
+ size=(point_map_pyramid[i].shape[-2], point_map_pyramid[i].shape[-1]), mode='nearest')
+ point_map_pyramid_i = (self.denorm_xyz(point_map_pyramid[i]) * scale_est[...,None]
+ + shift_est.permute(0,1,3,2)[...,None] + residual_point_i.view(B,T,3,point_map_pyramid[i].shape[-2], point_map_pyramid[i].shape[-1])).clone().detach()
+
+ corr_point_map = self.get_correlation_feat(
+ self.norm_xyz(point_map_pyramid_i), coords_proj_curr / 2**i, radius=self.corr3d_radius
+ )
+
+ corr_point_feat = self.get_correlation_feat(
+ fmaps_pyramid[i], coords_proj_curr / 2**i, radius=self.corr3d_radius
+ )
+ track_feat_support = (
+ track_feat_support_pyramid[i]
+ .view(B, 1, r, r, N, self.latent_dim)
+ .squeeze(1)
+ .permute(0, 3, 1, 2, 4)
+ )
+ track_feat_support3d = (
+ track_feat_support3d_pyramid[i]
+ .view(B, 1, r_depth, r_depth, N, self.latent_dim)
+ .squeeze(1)
+ .permute(0, 3, 1, 2, 4)
+ )
+ #NOTE: update the point map
+ track_point_map_support_pyramid_i = (self.denorm_xyz(track_point_map_support_pyramid[i]) * scale_est_query.view(B,1,1,N,1)
+ + shift_est_query.view(B,1,1,N,3)).clone().detach()
+
+ track_point_map_support = (
+ self.norm_xyz(track_point_map_support_pyramid_i)
+ .view(B, 1, r_depth, r_depth, N, 3)
+ .squeeze(1)
+ .permute(0, 3, 1, 2, 4)
+ )
+ corr_volume = torch.einsum(
+ "btnhwc,bnijc->btnhwij", corr_feat, track_feat_support
+ )
+ corr_emb = self.corr_mlp(corr_volume.reshape(B, T, N, r * r * r * r))
+
+ with torch.no_grad():
+ rel_pos_query_ = track_point_map_support - track_point_map_support[:,:,self.corr3d_radius,self.corr3d_radius,:][...,None,None,:]
+ rel_pos_target_ = corr_point_map - coords_xyz_init.view(B, T, N, 1, 1, 3)
+ # select the top 9 points
+ rel_pos_query_idx = rel_pos_query_.norm(dim=-1).view(B, N, -1).topk(K_level+1, dim=-1, largest=False)[1][...,1:,None]
+ rel_pos_target_idx = rel_pos_target_.norm(dim=-1).view(B, T, N, -1).topk(K_level+1, dim=-1, largest=False)[1][...,1:,None]
+ rel_pos_query_ = torch.gather(rel_pos_query_.view(B, N, -1, 3), dim=-2, index=rel_pos_query_idx.expand(B, N, K_level, 3))
+ rel_pos_target_ = torch.gather(rel_pos_target_.view(B, T, N, -1, 3), dim=-2, index=rel_pos_target_idx.expand(B, T, N, K_level, 3))
+ rel_pos_query = rel_pos_query_
+ rel_pos_target = rel_pos_target_
+ rel_pos_query = posenc(rel_pos_query, min_deg=0, max_deg=12)
+ rel_pos_target = posenc(rel_pos_target, min_deg=0, max_deg=12)
+ rel_pos_target = self.rel_pos_mlp(rel_pos_target)
+ rel_pos_query = self.rel_pos_mlp(rel_pos_query)
+ with torch.no_grad():
+ # integrate with feature
+ track_feat_support_ = rearrange(track_feat_support3d, 'b n r k c -> b n (r k) c', r=r_depth, k=r_depth, n=N, b=B)
+ track_feat_support_ = torch.gather(track_feat_support_, dim=-2, index=rel_pos_query_idx.expand(B, N, K_level, 128))
+ queried_feat = torch.cat([rel_pos_query, track_feat_support_], dim=-1)
+ corr_feat_ = rearrange(corr_point_feat, 'b t n r k c -> b t n (r k) c', t=T, n=N, b=B)
+ corr_feat_ = torch.gather(corr_feat_, dim=-2, index=rel_pos_target_idx.expand(B, T, N, K_level, 128))
+ target_feat = torch.cat([rel_pos_target, corr_feat_], dim=-1)
+
+ # 3d attention
+ queried_feat = self.corr_xyz_mlp(queried_feat)
+ target_feat = self.corr_xyz_mlp(target_feat)
+ queried_feat = repeat(queried_feat, 'b n k c -> b t n k c', k=K_level, t=T, n=N, b=B)
+ corr_depth_emb = self.corr_transformer[0](queried_feat.reshape(B*T*N,-1,128),
+ target_feat.reshape(B*T*N,-1,128),
+ target_rel_pos=rel_pos_target.reshape(B*T*N,-1,128))
+ corr_depth_emb = rearrange(corr_depth_emb, '(b t n) 1 c -> b t n c', t=T, n=N, b=B)
+ corr_depth_emb = self.corr_depth_mlp(corr_depth_emb)
+ valid_mask = self.denorm_xyz(coords_xyz_init).view(B, T, N, -1)[...,2:3] > 0
+ corr_depth_embs.append(corr_depth_emb*valid_mask)
+
+ corr_embs.append(corr_emb)
+ corr_embs = torch.cat(corr_embs, dim=-1)
+ corr_embs = corr_embs.view(B, T, N, corr_embs.shape[-1])
+ corr_depth_embs = torch.cat(corr_depth_embs, dim=-1)
+ corr_depth_embs = corr_depth_embs.view(B, T, N, corr_depth_embs.shape[-1])
+ transformer_input = [vis[..., None], confidence[..., None], corr_embs]
+ transformer_input_depth = [vis[..., None], confidence[..., None], corr_depth_embs]
+
+ rel_coords_forward = coords[:,:-1,...,:2] - coords[:,1:,...,:2]
+ rel_coords_backward = coords[:, 1:,...,:2] - coords[:, :-1,...,:2]
+
+ rel_xyz_forward = coords_xyz[:,:-1,...,:3] - coords_xyz[:,1:,...,:3]
+ rel_xyz_backward = coords_xyz[:, 1:,...,:3] - coords_xyz[:, :-1,...,:3]
+
+ rel_coords_forward = torch.nn.functional.pad(
+ rel_coords_forward, (0, 0, 0, 0, 0, 1)
+ )
+ rel_coords_backward = torch.nn.functional.pad(
+ rel_coords_backward, (0, 0, 0, 0, 1, 0)
+ )
+ rel_xyz_forward = torch.nn.functional.pad(
+ rel_xyz_forward, (0, 0, 0, 0, 0, 1)
+ )
+ rel_xyz_backward = torch.nn.functional.pad(
+ rel_xyz_backward, (0, 0, 0, 0, 1, 0)
+ )
+
+ scale = (
+ torch.tensor(
+ [self.model_resolution[1], self.model_resolution[0]],
+ device=coords.device,
+ )
+ / self.stride
+ )
+ rel_coords_forward = rel_coords_forward / scale
+ rel_coords_backward = rel_coords_backward / scale
+
+ rel_pos_emb_input = posenc(
+ torch.cat([rel_coords_forward, rel_coords_backward], dim=-1),
+ min_deg=0,
+ max_deg=10,
+ ) # batch, num_points, num_frames, 84
+ rel_xyz_emb_input = posenc(
+ torch.cat([rel_xyz_forward, rel_xyz_backward], dim=-1),
+ min_deg=0,
+ max_deg=10,
+ ) # batch, num_points, num_frames, 126
+ rel_xyz_emb_input = self.xyz_mlp(rel_xyz_emb_input)
+ transformer_input.append(rel_pos_emb_input)
+ transformer_input_depth.append(rel_xyz_emb_input)
+ # get the queries world
+ with torch.no_grad():
+ # update the query points with scale and shift
+ queries_xyz_i = queries_xyz.clone().detach()
+ queries_xyz_i[..., -1] = queries_xyz_i[..., -1] * scale_est_query.view(B,1,N) + shift_est_query.view(B,1,N,3)[...,2]
+ _, _, q_xyz_cam = self.track_from_cam(queries_xyz_i, self.c2w_est_curr,
+ intrs, rgbs=None, visualize=False)
+ q_xyz_cam = self.norm_xyz(q_xyz_cam)
+
+ query_t = queries[:,None,:,:1].repeat(B, T, 1, 1)
+ q_xyz_cam = torch.cat([query_t/T, q_xyz_cam], dim=-1)
+ T_all = torch.arange(T, device=device)[None,:,None,None].repeat(B, 1, N, 1)
+ current_xyzt = torch.cat([T_all/T, coords_xyz_init.view(B, T, N, -1)], dim=-1)
+ rel_pos_query_glob = q_xyz_cam - current_xyzt
+ # embed the confidence and dynamic probability
+ confidence_curr = torch.sigmoid(confidence[...,None])
+ dynamic_prob_curr = torch.sigmoid(dynamic_prob[...,None]).mean(dim=1, keepdim=True).repeat(1,T,1,1)
+ # embed the confidence and dynamic probability
+ rel_pos_query_glob = torch.cat([rel_pos_query_glob, confidence_curr, dynamic_prob_curr], dim=-1)
+ rel_pos_query_glob = posenc(rel_pos_query_glob, min_deg=0, max_deg=12)
+ transformer_input_depth.append(rel_pos_query_glob)
+
+ x = (
+ torch.cat(transformer_input, dim=-1)
+ .permute(0, 2, 1, 3)
+ .reshape(B * N, T, -1)
+ )
+ x_depth = (
+ torch.cat(transformer_input_depth, dim=-1)
+ .permute(0, 2, 1, 3)
+ .reshape(B * N, T, -1)
+ )
+ x_depth = self.proj_xyz_embed(x_depth)
+
+ x = x + self.interpolate_time_embed(x, T)
+ x = x.view(B, N, T, -1) # (B N) T D -> B N T D
+ x_depth = x_depth + self.interpolate_time_embed(x_depth, T)
+ x_depth = x_depth.view(B, N, T, -1) # (B N) T D -> B N T D
+ delta, delta_depth, delta_dynamic_prob, delta_pro_analysis_w, scale_shift_out, dense_res_out = self.updateformer3D(
+ x,
+ x_depth,
+ self.updateformer,
+ add_space_attn=add_space_attn,
+ extra_sparse_tokens=extra_sparse_tokens,
+ extra_dense_tokens=extra_dense_tokens,
+ )
+ # update the scale and shift
+ if scale_shift_out is not None:
+ extra_sparse_tokens = extra_sparse_tokens + scale_shift_out[...,:128]
+ scale_update = scale_shift_out[:,:1,:,-1].permute(0,2,1)[...,None]
+ shift_update = scale_shift_out[:,1:,:,-1].permute(0,2,1)[...,None]
+ scale_est = scale_est + scale_update
+ shift_est[...,2:] = shift_est[...,2:] + shift_update / 10
+ # dense tokens update
+ extra_dense_tokens = extra_dense_tokens + dense_res_out[:,:,-128:]
+ res_low = dense_res_out[:,:,:3]
+ up_mask = self.upsample_transformer(extra_dense_tokens.mean(dim=1), res_low)
+ up_mask = repeat(up_mask, "b k h w -> b s k h w", s=T)
+ up_mask = rearrange(up_mask, "b s c h w -> (b s) 1 c h w")
+ res_up = self.upsample_with_mask(
+ rearrange(res_low, 'b t c h w -> (b t) c h w'),
+ up_mask,
+ )
+ res_up = rearrange(res_up, "(b t) c h w -> b t c h w", b=B, t=T)
+ # residual_point = residual_point + res_up
+
+ delta_coords = delta[..., :D_coords].permute(0, 2, 1, 3)
+ delta_vis = delta[..., D_coords].permute(0, 2, 1)
+ delta_confidence = delta[..., D_coords + 1].permute(0, 2, 1)
+
+ vis = vis + delta_vis
+ confidence = confidence + delta_confidence
+ dynamic_prob = dynamic_prob + delta_dynamic_prob[...,0].permute(0, 2, 1)
+ pro_analysis_w = pro_analysis_w + delta_pro_analysis_w[...,0].permute(0, 2, 1)
+ # update the depth
+ vis_est = torch.sigmoid(vis.detach())
+
+ delta_xyz = delta_depth[...,:3].permute(0,2,1,3)
+ denorm_delta_depth = (self.denorm_xyz(coords_xyz+delta_xyz)-self.denorm_xyz(coords_xyz))[...,2:3]
+
+
+ delta_depth_ = denorm_delta_depth.detach()
+ delta_coords = torch.cat([delta_coords, delta_depth_],dim=-1)
+ coords = coords + delta_coords
+ coords_append = coords.clone()
+ coords_xyz_append = self.denorm_xyz(coords_xyz + delta_xyz).clone()
+
+ coords_append[..., :2] = coords_append[..., :2] * float(self.stride)
+ coords_append[..., 0] /= self.factor_x
+ coords_append[..., 1] /= self.factor_y
+
+ # get the camera pose from tracks
+ dynamic_prob_curr = torch.sigmoid(dynamic_prob.detach())*torch.sigmoid(pro_analysis_w)
+ mask_out = (coords_append[...,0]0)&(coords_append[...,1]0)
+ if query_no_BA:
+ dynamic_prob_curr[:,:,:ba_len] = torch.ones_like(dynamic_prob_curr[:,:,:ba_len])
+ point_map_org_i = scale_est.view(B*T,1,1,1)*point_map_org.clone().detach() + shift_est.view(B*T,3,1,1)
+ # depth_unproj = bilinear_sampler(point_map_org_i, coords_append[...,:2].view(B*T, N, 1, 2), mode="nearest")[:,2,:,0].detach()
+
+ depth_unproj_neg = self.get_correlation_feat(
+ point_map_org_i.view(B,T,3,point_map_org_i.shape[-2], point_map_org_i.shape[-1]),
+ coords_append[...,:2].view(B*T, N, 2), radius=self.corr3d_radius
+ )[..., 2]
+ depth_diff = (depth_unproj_neg.view(B,T,N,-1) - coords_append[...,2:]).abs()
+ idx_neg = torch.argmin(depth_diff, dim=-1)
+ depth_unproj = depth_unproj_neg.view(B,T,N,-1)[torch.arange(B)[:, None, None, None],
+ torch.arange(T)[None, :, None, None],
+ torch.arange(N)[None, None, :, None],
+ idx_neg.view(B,T,N,1)].view(B*T, N)
+
+ unc_unproj = bilinear_sampler(self.metric_unc_org, coords_append[...,:2].view(B*T, N, 1, 2), mode="nearest")[:,0,:,0].detach()
+ depth_unproj[unc_unproj<0.5] = 0.0
+
+ # replace the depth for visible and solid points
+ conf_est = torch.sigmoid(confidence.detach())
+ replace_mask = (depth_unproj.view(B,T,N)>0.0) * (vis_est>0.5) # * (conf_est>0.5)
+ #NOTE: way1: find the jitter points
+ depth_rel = (depth_unproj.view(B, T, N) - queries_z.permute(0, 2, 1))
+ depth_ddt1 = depth_rel[:, 1:, :] - depth_rel[:, :-1, :]
+ depth_ddt2 = depth_rel[:, 2:, :] - 2 * depth_rel[:, 1:-1, :] + depth_rel[:, :-2, :]
+ jitter_mask = torch.zeros_like(depth_rel, dtype=torch.bool)
+ if depth_ddt2.abs().max()>0:
+ thre2 = torch.quantile(depth_ddt2.abs()[depth_ddt2.abs()>0], replace_ratio)
+ jitter_mask[:, 1:-1, :] = (depth_ddt2.abs() < thre2)
+ thre1 = torch.quantile(depth_ddt1.abs()[depth_ddt1.abs()>0], replace_ratio)
+ jitter_mask[:, :-1, :] *= (depth_ddt1.abs() < thre1)
+ replace_mask = replace_mask * jitter_mask
+
+ #NOTE: way2: top k topological change detection
+ # coords_2d_lift = coords_append.clone()
+ # coords_2d_lift[...,2][replace_mask] = depth_unproj.view(B,T,N)[replace_mask]
+ # coords_2d_lift = self.cam_from_track(coords_2d_lift.clone(), intrs_org, only_cam_pts=True)
+ # coords_2d_lift[~replace_mask] = coords_xyz_append[~replace_mask]
+ # import pdb; pdb.set_trace()
+ # jitter_mask = get_topo_mask(coords_xyz_append, coords_2d_lift, replace_ratio)
+ # replace_mask = replace_mask * jitter_mask
+
+ # replace the depth
+ if self.training:
+ replace_mask = torch.zeros_like(replace_mask)
+ coords_append[...,2][replace_mask] = depth_unproj.view(B,T,N)[replace_mask]
+ coords_xyz_unproj = self.cam_from_track(coords_append.clone(), intrs_org, only_cam_pts=True)
+ coords[...,2][replace_mask] = depth_unproj.view(B,T,N)[replace_mask]
+ # coords_xyz_append[replace_mask] = coords_xyz_unproj[replace_mask]
+ coords_xyz_append_refine = coords_xyz_append.clone()
+ coords_xyz_append_refine[replace_mask] = coords_xyz_unproj[replace_mask]
+
+ c2w_traj_est, cam_pts_est, intrs_refine, coords_refine, world_tracks, world_tracks_refined, c2w_traj_init = self.cam_from_track(coords_append.clone(),
+ intrs_org, dynamic_prob_curr, queries_z_unc, conf_est*vis_est*mask_out.float(),
+ track_feat_concat=x_depth, tracks_xyz=coords_xyz_append_refine, init_pose=init_pose,
+ query_pts=queries_xyz_i, fixed_cam=fixed_cam, depth_unproj=depth_unproj, cam_gt=cam_gt)
+ intrs_org = intrs_refine.view(B, T, 3, 3).to(intrs_org.dtype)
+
+ # get the queries world
+ self.c2w_est_curr = c2w_traj_est.detach()
+
+ # update coords and coords_append
+ coords[..., 2] = (cam_pts_est)[...,2]
+ coords_append[..., 2] = (cam_pts_est)[...,2]
+
+ # update coords_xyz_append
+ # coords_xyz_append = cam_pts_est
+ coords_xyz = self.norm_xyz(cam_pts_est)
+
+
+ # proj
+ coords_xyz_de = coords_xyz_append.clone()
+ coords_xyz_de[coords_xyz_de[...,2].abs()<1e-6] = -1e-4
+ mask_nan = coords_xyz_de[...,2].abs()<1e-2
+ coords_proj = torch.einsum("btij,btnj->btni", intrs_org, coords_xyz_de/coords_xyz_de[...,2:3].abs())[...,:2]
+ coords_proj[...,0] *= self.factor_x
+ coords_proj[...,1] *= self.factor_y
+ coords_proj[...,:2] /= float(self.stride)
+ # make sure it is aligned with 2d tracking
+ coords_proj_curr = coords[...,:2].view(B*T, N, 2).detach()
+ vis_est = (vis_est>0.5).float()
+ sync_loss += (vis_est.detach()[...,None]*(coords_proj_curr - coords_proj).norm(dim=-1, keepdim=True)*(1-mask_nan[...,None].float())).mean()
+ # coords_proj_curr[~mask_nan.view(B*T, N)] = coords_proj.view(B*T, N, 2)[~mask_nan.view(B*T, N)].to(coords_proj_curr.dtype)
+ # if torch.isnan(coords_proj_curr).sum()>0:
+ # import pdb; pdb.set_trace()
+
+ if False:
+ point_map_resize = point_map.clone().view(B, T, 3, H, W)
+ update_input = torch.cat([point_map_resize, metric_unc.view(B,T,1,H,W)], dim=2)
+ coords_append_resize = coords.clone().detach()
+ coords_append_resize[..., :2] = coords_append_resize[..., :2] * float(self.stride)
+ update_track_input = self.norm_xyz(cam_pts_est)*5
+ update_track_input = torch.cat([update_track_input, vis_est[...,None]], dim=-1)
+ update_track_input = posenc(update_track_input, min_deg=0, max_deg=12)
+ update = self.update_pointmap.stablizer(update_input,
+ update_track_input, coords_append_resize)#, imgs=video, vis_track=viser)
+ #NOTE: update the point map
+ point_map_resize += update
+ point_map_refine_out = F.interpolate(point_map_resize.view(B*T, -1, H, W),
+ size=(self.image_size[0].item(), self.image_size[1].item()), mode='nearest')
+ point_map_refine_out = rearrange(point_map_refine_out, '(b t) c h w -> b t c h w', t=T, b=B)
+ point_map_preds.append(self.denorm_xyz(point_map_refine_out))
+ point_map_org = self.denorm_xyz(point_map_refine_out).view(B*T, 3, H_, W_)
+
+ # if torch.isnan(coords).sum()>0:
+ # import pdb; pdb.set_trace()
+ #NOTE: the 2d tracking + unproject depth
+ fix_cam_est = coords_append.clone()
+ fix_cam_est[...,2] = depth_unproj
+ fix_cam_pts = self.cam_from_track(
+ fix_cam_est, intrs_org, only_cam_pts=True
+ )
+
+ coord_preds.append(coords_append)
+ coords_xyz_preds.append(coords_xyz_append)
+ vis_preds.append(vis)
+ cam_preds.append(c2w_traj_init)
+ pts3d_cam_pred.append(cam_pts_est)
+ world_tracks_pred.append(world_tracks)
+ world_tracks_refined_pred.append(world_tracks_refined)
+ confidence_preds.append(confidence)
+ dynamic_prob_preds.append(dynamic_prob)
+ scale_ests.append(scale_est)
+ shift_ests.append(shift_est)
+
+ if stage!=0:
+ all_coords_predictions.append([coord for coord in coord_preds])
+ all_coords_xyz_predictions.append([coord_xyz for coord_xyz in coords_xyz_preds])
+ all_vis_predictions.append(vis_preds)
+ all_confidence_predictions.append(confidence_preds)
+ all_dynamic_prob_predictions.append(dynamic_prob_preds)
+ all_cam_predictions.append([cam for cam in cam_preds])
+ all_cam_pts_predictions.append([pts for pts in pts3d_cam_pred])
+ all_world_tracks_predictions.append([world_tracks for world_tracks in world_tracks_pred])
+ all_world_tracks_refined_predictions.append([world_tracks_refined for world_tracks_refined in world_tracks_refined_pred])
+ all_scale_est.append(scale_ests)
+ all_shift_est.append(shift_ests)
+ if stage!=0:
+ train_data = (
+ all_coords_predictions,
+ all_coords_xyz_predictions,
+ all_vis_predictions,
+ all_confidence_predictions,
+ all_dynamic_prob_predictions,
+ all_cam_predictions,
+ all_cam_pts_predictions,
+ all_world_tracks_predictions,
+ all_world_tracks_refined_predictions,
+ all_scale_est,
+ all_shift_est,
+ torch.ones_like(vis_preds[-1], device=vis_preds[-1].device),
+ )
+ else:
+ train_data = None
+ # resize back
+ # init the trajectories by camera motion
+
+ # if cache is not None:
+ # viser = Visualizer(save_dir=".", grayscale=True,
+ # fps=10, pad_value=50, tracks_leave_trace=0)
+ # coords_clone = coords.clone()
+ # coords_clone[...,:2] *= self.stride
+ # coords_clone[..., 0] /= self.factor_x
+ # coords_clone[..., 1] /= self.factor_y
+ # viser.visualize(video=video_vis, tracks=coords_clone[..., :2], filename="test_refine")
+ # import pdb; pdb.set_trace()
+
+ if train_data is not None:
+ # get the gt pts in the world coordinate
+ self_supervised = False
+ if (traj3d_gt is not None):
+ if traj3d_gt[...,2].abs().max()>0:
+ gt_cam_pts = self.cam_from_track(
+ traj3d_gt, intrs_org, only_cam_pts=True
+ )
+ else:
+ self_supervised = True
+ else:
+ self_supervised = True
+
+ if self_supervised:
+ gt_cam_pts = self.cam_from_track(
+ coord_preds[-1].detach(), intrs_org, only_cam_pts=True
+ )
+
+ if cam_gt is not None:
+ gt_world_pts = torch.einsum(
+ "btij,btnj->btni",
+ cam_gt[...,:3,:3],
+ gt_cam_pts
+ ) + cam_gt[...,None, :3,3] # B T N 3
+ else:
+ gt_world_pts = torch.einsum(
+ "btij,btnj->btni",
+ self.c2w_est_curr[...,:3,:3],
+ gt_cam_pts
+ ) + self.c2w_est_curr[...,None, :3,3] # B T N 3
+ # update the query points with scale and shift
+ queries_xyz_i = queries_xyz.clone().detach()
+ queries_xyz_i[..., -1] = queries_xyz_i[..., -1] * scale_est_query.view(B,1,N) + shift_est_query.view(B,1,N,3)[...,2]
+ q_static_proj, q_xyz_world, q_xyz_cam = self.track_from_cam(queries_xyz_i,
+ self.c2w_est_curr,
+ intrs, rgbs=video_vis, visualize=False)
+
+ q_static_proj[..., 0] /= self.factor_x
+ q_static_proj[..., 1] /= self.factor_y
+ cam_gt = self.c2w_est_curr[:,:,:3,:]
+
+ if traj3d_gt is not None:
+ ret_loss = self.loss(train_data, traj3d_gt,
+ vis_gt, None, cam_gt, queries_z_unc,
+ q_xyz_world, q_static_proj, anchor_loss=anchor_loss, fix_cam_pts=fix_cam_pts, video_vis=video_vis, stage=stage,
+ gt_world_pts=gt_world_pts, mask_traj_gt=mask_traj_gt, intrs=intrs_org, custom_vid=custom_vid, valid_only=valid_only,
+ c2w_ests=c2w_ests, point_map_preds=point_map_preds, points_map_gt=points_map_gt, metric_unc=metric_unc, scale_est=scale_est,
+ shift_est=shift_est, point_map_org_train=point_map_org_train)
+ else:
+ ret_loss = self.loss(train_data, traj3d_gt,
+ vis_gt, None, cam_gt, queries_z_unc,
+ q_xyz_world, q_static_proj, anchor_loss=anchor_loss, fix_cam_pts=fix_cam_pts, video_vis=video_vis, stage=stage,
+ gt_world_pts=gt_world_pts, mask_traj_gt=mask_traj_gt, intrs=intrs_org, custom_vid=custom_vid, valid_only=valid_only,
+ c2w_ests=c2w_ests, point_map_preds=point_map_preds, points_map_gt=points_map_gt, metric_unc=metric_unc, scale_est=scale_est,
+ shift_est=shift_est, point_map_org_train=point_map_org_train)
+ if custom_vid:
+ sync_loss = 0*sync_loss
+ if (sync_loss > 50) and (stage==1):
+ ret_loss = (0*sync_loss, 0*sync_loss, 0*sync_loss, 0*sync_loss, 0*sync_loss, 0*sync_loss, 0*sync_loss) + (0*sync_loss,)
+ else:
+ ret_loss = ret_loss+(10*sync_loss,)
+
+ else:
+ ret_loss = None
+
+ color_pts = torch.cat([pts3d_cam_pred[-1], queries_rgb[:,None].repeat(1, T, 1, 1)], dim=-1)
+
+ #TODO: For evaluation. We found our model have some bias on invisible points after training. (to be fixed)
+ vis_pred_out = torch.sigmoid(vis_preds[-1]) + 0.2
+
+ ret = {"preds": coord_preds[-1], "vis_pred": vis_pred_out,
+ "conf_pred": torch.sigmoid(confidence_preds[-1]),
+ "cam_pred": self.c2w_est_curr,"loss": ret_loss}
+
+ cache = {
+ "fmaps": fmaps_org[0].detach(),
+ "track_feat_support3d_pyramid": [track_feat_support3d_pyramid[i].detach() for i in range(len(track_feat_support3d_pyramid))],
+ "track_point_map_support_pyramid": [self.denorm_xyz(track_point_map_support_pyramid[i].detach()) for i in range(len(track_point_map_support_pyramid))],
+ "track_feat3d_pyramid": [track_feat3d_pyramid[i].detach() for i in range(len(track_feat3d_pyramid))],
+ "track_point_map_pyramid": [self.denorm_xyz(track_point_map_pyramid[i].detach()) for i in range(len(track_point_map_pyramid))],
+ "track_feat_pyramid": [track_feat_pyramid[i].detach() for i in range(len(track_feat_pyramid))],
+ "track_feat_support_pyramid": [track_feat_support_pyramid[i].detach() for i in range(len(track_feat_support_pyramid))],
+ "track2d_pred_cache": coord_preds[-1][0].clone().detach(),
+ "track3d_pred_cache": pts3d_cam_pred[-1][0].clone().detach(),
+ }
+ #NOTE: update the point map
+ point_map_org = scale_est.view(B*T,1,1,1)*point_map_org + shift_est.view(B*T,3,1,1)
+ point_map_org_refined = point_map_org
+ return ret, torch.sigmoid(dynamic_prob_preds[-1])*queries_z_unc[:,None,:,0], coord_preds[-1], color_pts, intrs_org, point_map_org_refined, cache
+
+ def track_d2_loss(self, tracks3d, stride=[1,2,3], dyn_prob=None, mask=None):
+ """
+ tracks3d: B T N 3
+ dyn_prob: B T N 1
+ """
+ r = 0.8
+ t_diff_total = 0.0
+ for i, s_ in enumerate(stride):
+ w_ = r**i
+ tracks3d_stride = tracks3d[:, ::s_, :, :] # B T//s_ N 3
+ t_diff_tracks3d = (tracks3d_stride[:, 1:, :, :] - tracks3d_stride[:, :-1, :, :])
+ t_diff2 = (t_diff_tracks3d[:, 1:, :, :] - t_diff_tracks3d[:, :-1, :, :])
+ t_diff_total += w_*(t_diff2.norm(dim=-1).mean())
+
+ return 1e2*t_diff_total
+
+ def loss(self, train_data, traj3d_gt=None,
+ vis_gt=None, static_tracks_gt=None, cam_gt=None,
+ z_unc=None, q_xyz_world=None, q_static_proj=None, anchor_loss=0, valid_only=False,
+ gt_world_pts=None, mask_traj_gt=None, intrs=None, c2w_ests=None, custom_vid=False, video_vis=None, stage=0,
+ fix_cam_pts=None, point_map_preds=None, points_map_gt=None, metric_unc=None, scale_est=None, shift_est=None, point_map_org_train=None):
+ """
+ Compute the loss of 3D tracking problem
+
+ """
+
+ (
+ coord_predictions, coords_xyz_predictions, vis_predictions, confidence_predicitons,
+ dynamic_prob_predictions, camera_predictions, cam_pts_predictions, world_tracks_predictions,
+ world_tracks_refined_predictions, scale_ests, shift_ests, valid_mask
+ ) = train_data
+ B, T, _, _ = cam_gt.shape
+ if (stage == 2) and self.training:
+ # get the scale and shift gt
+ self.metric_unc_org[:,0] = self.metric_unc_org[:,0] * (points_map_gt.norm(dim=-1)>0).float() * (self.metric_unc_org[:,0]>0.5).float()
+ if not (self.scale_gt==torch.ones(B*T).to(self.scale_gt.device)).all():
+ scale_gt, shift_gt = self.scale_gt, self.shift_gt
+ scale_re = scale_gt[:4].mean()
+ scale_loss = 0.0
+ shift_loss = 0.0
+ for i_scale in range(len(scale_ests[0])):
+ scale_loss += 0.8**(len(scale_ests[0])-i_scale-1)*10*(scale_gt - scale_re*scale_ests[0][i_scale].view(-1)).abs().mean()
+ shift_loss += 0.8**(len(shift_ests[0])-i_scale-1)*10*(shift_gt - scale_re*shift_ests[0][i_scale].view(-1,3)).abs().mean()
+ else:
+ scale_loss = 0.0 * scale_ests[0][0].mean()
+ shift_loss = 0.0 * shift_ests[0][0].mean()
+ scale_re = 1.0
+ else:
+ scale_loss = 0.0
+ shift_loss = 0.0
+
+ if len(point_map_preds)>0:
+ point_map_loss = 0.0
+ for i in range(len(point_map_preds)):
+ point_map_preds_i = point_map_preds[i]
+ point_map_preds_i = rearrange(point_map_preds_i, 'b t c h w -> (b t) c h w', b=B, t=T)
+ base_loss = ((self.pred_points - points_map_gt).norm(dim=-1) * self.metric_unc_org[:,0]).mean()
+ point_map_loss_i = ((point_map_preds_i - points_map_gt.permute(0,3,1,2)).norm(dim=1) * self.metric_unc_org[:,0]).mean()
+ point_map_loss += point_map_loss_i
+ # point_map_loss += ((point_map_org_train - points_map_gt.permute(0,3,1,2)).norm(dim=1) * self.metric_unc_org[:,0]).mean()
+ if scale_loss == 0.0:
+ point_map_loss = 0*point_map_preds_i.sum()
+ else:
+ point_map_loss = 0.0
+
+ # camera loss
+ cam_loss = 0.0
+ dyn_loss = 0.0
+ N_gt = gt_world_pts.shape[2]
+
+ # self supervised dynamic mask
+ H_org, W_org = self.image_size[0], self.image_size[1]
+ q_static_proj[torch.isnan(q_static_proj)] = -200
+ in_view_mask = (q_static_proj[...,0]>0) & (q_static_proj[...,0]0) & (q_static_proj[...,1] 6
+
+ for iter_, cam_pred_i in enumerate(camera_predictions[0]):
+ # points loss
+ pts_i_world = world_tracks_predictions[0][iter_].view(B, T, -1, 3)
+
+ coords_xyz_i_world = coords_xyz_predictions[0][iter_].view(B, T, -1, 3)
+ coords_i = coord_predictions[0][iter_].view(B, T, -1, 3)[..., :2]
+ pts_i_world_refined = torch.einsum(
+ "btij,btnj->btni",
+ cam_gt[...,:3,:3],
+ coords_xyz_i_world
+ ) + cam_gt[...,None, :3,3] # B T N 3
+
+ # pts_i_world_refined = world_tracks_refined_predictions[0][iter_].view(B, T, -1, 3)
+ pts_world = pts_i_world
+ dyn_prob_i_logits = dynamic_prob_predictions[0][iter_].mean(dim=1)
+ dyn_prob_i = torch.sigmoid(dyn_prob_i_logits).detach()
+ mask = pts_world.norm(dim=-1) < 200
+
+ # general
+ vis_i_logits = vis_predictions[0][iter_]
+ vis_i = torch.sigmoid(vis_i_logits).detach()
+ if mask_traj_gt is not None:
+ try:
+ N_gt_mask = mask_traj_gt.shape[1]
+ align_loss = (gt_world_pts - q_xyz_world[:,None,:N_gt,:,]).norm(dim=-1)[...,:N_gt_mask] * (mask_traj_gt.permute(0,2,1))
+ visb_traj = (align_loss * vis_i[:,:,:N_gt_mask]).sum(dim=1)/vis_i[:,:,:N_gt_mask].sum(dim=1)
+ except:
+ import pdb; pdb.set_trace()
+ else:
+ visb_traj = ((gt_world_pts - q_xyz_world[:,None,:N_gt,:,]).norm(dim=-1) * vis_i[:,:,:N_gt]).sum(dim=1)/vis_i[:,:,:N_gt].sum(dim=1)
+
+ # pts_loss = ((q_xyz_world[:,None,...] - pts_world)[:,:,:N_gt,:].norm(dim=-1)*(1-dyn_prob_i[:,None,:N_gt])) # - 0.1*(1-dyn_prob_i[:,None,:N_gt]).log()
+ pts_loss = 0
+ static_mask = ~dyn_mask_final # more strict for static points
+ dyn_mask = dyn_mask_final
+ pts_loss_refined = ((q_xyz_world[:,None,...] - pts_i_world_refined).norm(dim=-1)*static_mask[:,None,:]).sum()/static_mask.sum() # - 0.1*(1-dyn_prob_i[:,None,:N_gt]).log()
+ vis_logits_final = vis_predictions[0][-1].detach()
+ vis_final = torch.sigmoid(vis_logits_final)+0.2 > 0.5 # more strict for visible points
+ dyn_vis_mask = dyn_mask*vis_final * (fix_cam_pts[...,2] > 0.1)
+ pts_loss_dynamic = ((fix_cam_pts - coords_xyz_i_world).norm(dim=-1)*dyn_vis_mask[:,None,:]).sum()/dyn_vis_mask.sum()
+
+ # pts_loss_refined = 0
+ if traj3d_gt is not None:
+ tap_traj = (gt_world_pts[:,:-1,...] - gt_world_pts[:,1:,...]).norm(dim=-1).sum(dim=1)[...,:N_gt_mask]
+ mask_dyn = tap_traj>0.5
+ if mask_traj_gt.sum() > 0:
+ dyn_loss_i = 20*balanced_binary_cross_entropy(dyn_prob_i_logits[:,:N_gt_mask][mask_traj_gt.squeeze(-1)],
+ mask_dyn.float()[mask_traj_gt.squeeze(-1)])
+ else:
+ dyn_loss_i = 0
+ else:
+ dyn_loss_i = 10*balanced_binary_cross_entropy(dyn_prob_i_logits, dyn_mask_final.float())
+
+ dyn_loss += dyn_loss_i
+
+ # visible loss for out of view points
+ vis_i_train = torch.sigmoid(vis_i_logits)
+ out_of_view_mask = (coords_i[...,0]<0)|(coords_i[...,0]>self.image_size[1])|(coords_i[...,1]<0)|(coords_i[...,1]>self.image_size[0])
+ vis_loss_out_of_view = vis_i_train[out_of_view_mask].sum() / out_of_view_mask.sum()
+
+
+ if traj3d_gt is not None:
+ world_pts_loss = (((gt_world_pts - pts_i_world_refined[:,:,:gt_world_pts.shape[2],...]).norm(dim=-1))[...,:N_gt_mask] * mask_traj_gt.permute(0,2,1)).sum() / mask_traj_gt.sum()
+ # world_pts_init_loss = (((gt_world_pts - pts_i_world[:,:,:gt_world_pts.shape[2],...]).norm(dim=-1))[...,:N_gt_mask] * mask_traj_gt.permute(0,2,1)).sum() / mask_traj_gt.sum()
+ else:
+ world_pts_loss = 0
+
+ # cam regress
+ t_err = (cam_pred_i[...,:3,3] - cam_gt[...,:3,3]).norm(dim=-1).sum()
+
+ # xyz loss
+ in_view_mask_large = (q_static_proj[...,0]>-50) & (q_static_proj[...,0]-50) & (q_static_proj[...,1]0.05).float() * static_mask[:,None,:] * in_view_mask_large
+ xyz_loss = ((coord_predictions[0][iter_] - q_static_proj)).abs()[...,:2].norm(dim=-1)*static_vis_mask
+ xyz_loss = xyz_loss.sum()/static_vis_mask.sum()
+
+ # visualize the q_static_proj
+ # viser = Visualizer(save_dir=".", grayscale=True,
+ # fps=10, pad_value=50, tracks_leave_trace=0)
+ # video_vis_ = F.interpolate(video_vis.view(B*T,3,video_vis.shape[-2],video_vis.shape[-1]), (H_org, W_org), mode='bilinear', align_corners=False)
+ # viser.visualize(video=video_vis_, tracks=q_static_proj[:,:,dyn_mask_final.squeeze(), :2], filename="test")
+ # viser.visualize(video=video_vis_, tracks=coord_predictions[0][-1][:,:,dyn_mask_final.squeeze(), :2], filename="test_pred")
+ # import pdb; pdb.set_trace()
+
+ # temporal loss
+ t_loss = self.track_d2_loss(pts_i_world_refined, [1,2,3], dyn_prob=dyn_prob_i, mask=mask)
+ R_err = (cam_pred_i[...,:3,:3] - cam_gt[...,:3,:3]).abs().sum(dim=-1).mean()
+ if self.stage == 1:
+ cam_loss += 0.8**(len(camera_predictions[0])-iter_-1)*(10*t_err + 500*R_err + 20*pts_loss_refined + 10*xyz_loss + 20*pts_loss_dynamic + 10*vis_loss_out_of_view) #+ 5*(pts_loss + pts_loss_refined + world_pts_loss) + t_loss)
+ elif self.stage == 3:
+ cam_loss += 0.8**(len(camera_predictions[0])-iter_-1)*(10*t_err + 500*R_err + 10*vis_loss_out_of_view) #+ 5*(pts_loss + pts_loss_refined + world_pts_loss) + t_loss)
+ else:
+ cam_loss += 0*vis_loss_out_of_view
+
+ if (cam_loss > 20000)|(torch.isnan(cam_loss)):
+ cam_loss = torch.zeros_like(cam_loss)
+
+
+ if traj3d_gt is None:
+ # ================ Condition 1: The self-supervised signals from the self-consistency ===================
+ return cam_loss, train_data[0][0][0].mean()*0, dyn_loss, train_data[0][0][0].mean()*0, point_map_loss, scale_loss, shift_loss
+
+
+ # ================ Condition 2: The supervision signal given by the ground truth trajectories ===================
+ if (
+ (torch.isnan(traj3d_gt).any()
+ or traj3d_gt.abs().max() > 2000) and (custom_vid==False)
+ ):
+ return cam_loss, train_data[0][0][0].mean()*0, dyn_loss, train_data[0][0][0].mean()*0, point_map_loss, scale_loss, shift_loss
+
+
+ vis_gts = [vis_gt.float()]
+ invis_gts = [1-vis_gt.float()]
+ traj_gts = [traj3d_gt]
+ valids_gts = [valid_mask]
+ seq_loss_all = sequence_loss(
+ coord_predictions,
+ traj_gts,
+ valids_gts,
+ vis=vis_gts,
+ gamma=0.8,
+ add_huber_loss=False,
+ loss_only_for_visible=False if custom_vid==False else True,
+ z_unc=z_unc,
+ mask_traj_gt=mask_traj_gt
+ )
+
+ confidence_loss = sequence_prob_loss(
+ coord_predictions, confidence_predicitons, traj_gts, vis_gts
+ )
+
+ seq_loss_xyz = sequence_loss_xyz(
+ coords_xyz_predictions,
+ traj_gts,
+ valids_gts,
+ intrs=intrs,
+ vis=vis_gts,
+ gamma=0.8,
+ add_huber_loss=False,
+ loss_only_for_visible=False,
+ mask_traj_gt=mask_traj_gt
+ )
+
+ # filter the blinking points
+ mask_vis = vis_gts[0].clone() # B T N
+ mask_vis[mask_vis==0] = -1
+ blink_mask = mask_vis[:,:-1,:] * mask_vis[:,1:,:] # first derivative B (T-1) N
+ mask_vis[:,:-1,:], mask_vis[:,-1,:] = (blink_mask == 1), 0
+
+ vis_loss = sequence_BCE_loss(vis_predictions, vis_gts, mask=[mask_vis])
+
+ track_loss_out = (seq_loss_all+2*seq_loss_xyz + cam_loss)
+ if valid_only:
+ vis_loss = 0.0*vis_loss
+ if custom_vid:
+ return seq_loss_all, 0.0*seq_loss_all, 0.0*seq_loss_all, 10*vis_loss, 0.0*seq_loss_all, 0.0*seq_loss_all, 0.0*seq_loss_all
+
+ return track_loss_out, confidence_loss, dyn_loss, 10*vis_loss, point_map_loss, scale_loss, shift_loss
+
+
+
+
diff --git a/models/SpaTrackV2/models/tracker3D/co_tracker/cotracker_base.py b/models/SpaTrackV2/models/tracker3D/co_tracker/cotracker_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aca4f353d5e8336d1939479e8b33dd90dae6c17
--- /dev/null
+++ b/models/SpaTrackV2/models/tracker3D/co_tracker/cotracker_base.py
@@ -0,0 +1,418 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from models.SpaTrackV2.utils.model_utils import sample_features5d, bilinear_sampler
+
+from models.SpaTrackV2.models.tracker3D.co_tracker.utils import (
+ Mlp, BasicEncoder, EfficientUpdateFormer
+)
+
+torch.manual_seed(0)
+
+
+def get_1d_sincos_pos_embed_from_grid(
+ embed_dim: int, pos: torch.Tensor
+) -> torch.Tensor:
+ """
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - pos: The position to generate the embedding from.
+
+ Returns:
+ - emb: The generated 1D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb[None].float()
+
+def posenc(x, min_deg, max_deg):
+ """Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
+ Instead of computing [sin(x), cos(x)], we use the trig identity
+ cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
+ Args:
+ x: torch.Tensor, variables to be encoded. Note that x should be in [-pi, pi].
+ min_deg: int, the minimum (inclusive) degree of the encoding.
+ max_deg: int, the maximum (exclusive) degree of the encoding.
+ legacy_posenc_order: bool, keep the same ordering as the original tf code.
+ Returns:
+ encoded: torch.Tensor, encoded variables.
+ """
+ if min_deg == max_deg:
+ return x
+ scales = torch.tensor(
+ [2**i for i in range(min_deg, max_deg)], dtype=x.dtype, device=x.device
+ )
+
+ xb = (x[..., None, :] * scales[:, None]).reshape(list(x.shape[:-1]) + [-1])
+ four_feat = torch.sin(torch.cat([xb, xb + 0.5 * torch.pi], dim=-1))
+ return torch.cat([x] + [four_feat], dim=-1)
+
+class CoTrackerThreeBase(nn.Module):
+ def __init__(
+ self,
+ window_len=8,
+ stride=4,
+ corr_radius=3,
+ corr_levels=4,
+ num_virtual_tracks=64,
+ model_resolution=(384, 512),
+ add_space_attn=True,
+ linear_layer_for_vis_conf=True,
+ ):
+ super(CoTrackerThreeBase, self).__init__()
+ self.window_len = window_len
+ self.stride = stride
+ self.corr_radius = corr_radius
+ self.corr_levels = corr_levels
+ self.hidden_dim = 256
+ self.latent_dim = 128
+
+ self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
+ self.fnet = BasicEncoder(input_dim=3, output_dim=self.latent_dim, stride=stride)
+
+ highres_dim = 128
+ lowres_dim = 256
+
+ self.num_virtual_tracks = num_virtual_tracks
+ self.model_resolution = model_resolution
+
+ self.input_dim = 1110
+
+ self.updateformer = EfficientUpdateFormer(
+ space_depth=3,
+ time_depth=3,
+ input_dim=self.input_dim,
+ hidden_size=384,
+ output_dim=4,
+ mlp_ratio=4.0,
+ num_virtual_tracks=num_virtual_tracks,
+ add_space_attn=add_space_attn,
+ linear_layer_for_vis_conf=linear_layer_for_vis_conf,
+ )
+ self.corr_mlp = Mlp(in_features=49 * 49, hidden_features=384, out_features=256)
+
+ time_grid = torch.linspace(0, window_len - 1, window_len).reshape(
+ 1, window_len, 1
+ )
+
+ self.register_buffer(
+ "time_emb", get_1d_sincos_pos_embed_from_grid(self.input_dim, time_grid[0])
+ )
+
+ def get_support_points(self, coords, r, reshape_back=True):
+ B, _, N, _ = coords.shape
+ device = coords.device
+ centroid_lvl = coords.reshape(B, N, 1, 1, 3)
+
+ dx = torch.linspace(-r, r, 2 * r + 1, device=device)
+ dy = torch.linspace(-r, r, 2 * r + 1, device=device)
+
+ xgrid, ygrid = torch.meshgrid(dy, dx, indexing="ij")
+ zgrid = torch.zeros_like(xgrid, device=device)
+ delta = torch.stack([zgrid, xgrid, ygrid], axis=-1)
+ delta_lvl = delta.view(1, 1, 2 * r + 1, 2 * r + 1, 3)
+ coords_lvl = centroid_lvl + delta_lvl
+
+ if reshape_back:
+ return coords_lvl.reshape(B, N, (2 * r + 1) ** 2, 3).permute(0, 2, 1, 3)
+ else:
+ return coords_lvl
+
+ def get_track_feat(self, fmaps, queried_frames, queried_coords, support_radius=0):
+
+ sample_frames = queried_frames[:, None, :, None]
+ sample_coords = torch.cat(
+ [
+ sample_frames,
+ queried_coords[:, None],
+ ],
+ dim=-1,
+ )
+ support_points = self.get_support_points(sample_coords, support_radius)
+ support_track_feats = sample_features5d(fmaps, support_points)
+ return (
+ support_track_feats[:, None, support_track_feats.shape[1] // 2],
+ support_track_feats,
+ )
+
+ def get_correlation_feat(self, fmaps, queried_coords, radius=None, padding_mode="border"):
+ B, T, D, H_, W_ = fmaps.shape
+ N = queried_coords.shape[1]
+ if radius is None:
+ r = self.corr_radius
+ else:
+ r = radius
+ sample_coords = torch.cat(
+ [torch.zeros_like(queried_coords[..., :1]), queried_coords], dim=-1
+ )[:, None]
+ support_points = self.get_support_points(sample_coords, r, reshape_back=False)
+ correlation_feat = bilinear_sampler(
+ fmaps.reshape(B * T, D, 1, H_, W_), support_points, padding_mode=padding_mode
+ )
+ return correlation_feat.view(B, T, D, N, (2 * r + 1), (2 * r + 1)).permute(
+ 0, 1, 3, 4, 5, 2
+ )
+
+ def interpolate_time_embed(self, x, t):
+ previous_dtype = x.dtype
+ T = self.time_emb.shape[1]
+
+ if t == T:
+ return self.time_emb
+
+ time_emb = self.time_emb.float()
+ time_emb = F.interpolate(
+ time_emb.permute(0, 2, 1), size=t, mode="linear"
+ ).permute(0, 2, 1)
+ return time_emb.to(previous_dtype)
+
+class CoTrackerThreeOffline(CoTrackerThreeBase):
+ def __init__(self, **args):
+ super(CoTrackerThreeOffline, self).__init__(**args)
+
+ def forward(
+ self,
+ video,
+ queries,
+ iters=4,
+ is_train=False,
+ add_space_attn=True,
+ fmaps_chunk_size=200,
+ ):
+ """Predict tracks
+
+ Args:
+ video (FloatTensor[B, T, 3]): input videos.
+ queries (FloatTensor[B, N, 3]): point queries.
+ iters (int, optional): number of updates. Defaults to 4.
+ is_train (bool, optional): enables training mode. Defaults to False.
+ Returns:
+ - coords_predicted (FloatTensor[B, T, N, 2]):
+ - vis_predicted (FloatTensor[B, T, N]):
+ - train_data: `None` if `is_train` is false, otherwise:
+ - all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
+ - all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
+ - mask (BoolTensor[B, T, N]):
+ """
+
+ B, T, C, H, W = video.shape
+ device = queries.device
+ assert H % self.stride == 0 and W % self.stride == 0
+
+ B, N, __ = queries.shape
+ # B = batch size
+ # S_trimmed = actual number of frames in the window
+ # N = number of tracks
+ # C = color channels (3 for RGB)
+ # E = positional embedding size
+ # LRR = local receptive field radius
+ # D = dimension of the transformer input tokens
+
+ # video = B T C H W
+ # queries = B N 3
+ # coords_init = B T N 2
+ # vis_init = B T N 1
+
+ assert T >= 1 # A tracker needs at least two frames to track something
+
+ video = 2 * (video / 255.0) - 1.0
+ dtype = video.dtype
+ queried_frames = queries[:, :, 0].long()
+
+ queried_coords = queries[..., 1:3]
+ queried_coords = queried_coords / self.stride
+
+ # We store our predictions here
+ all_coords_predictions, all_vis_predictions, all_confidence_predictions = (
+ [],
+ [],
+ [],
+ )
+ C_ = C
+ H4, W4 = H // self.stride, W // self.stride
+ # Compute convolutional features for the video or for the current chunk in case of online mode
+
+ if T > fmaps_chunk_size:
+ fmaps = []
+ for t in range(0, T, fmaps_chunk_size):
+ video_chunk = video[:, t : t + fmaps_chunk_size]
+ fmaps_chunk = self.fnet(video_chunk.reshape(-1, C_, H, W))
+ T_chunk = video_chunk.shape[1]
+ C_chunk, H_chunk, W_chunk = fmaps_chunk.shape[1:]
+ fmaps.append(fmaps_chunk.reshape(B, T_chunk, C_chunk, H_chunk, W_chunk))
+ fmaps = torch.cat(fmaps, dim=1).reshape(-1, C_chunk, H_chunk, W_chunk)
+ else:
+ fmaps = self.fnet(video.reshape(-1, C_, H, W))
+ fmaps = fmaps.permute(0, 2, 3, 1)
+ fmaps = fmaps / torch.sqrt(
+ torch.maximum(
+ torch.sum(torch.square(fmaps), axis=-1, keepdims=True),
+ torch.tensor(1e-12, device=fmaps.device),
+ )
+ )
+ fmaps = fmaps.permute(0, 3, 1, 2).reshape(
+ B, -1, self.latent_dim, H // self.stride, W // self.stride
+ )
+ fmaps = fmaps.to(dtype)
+
+ # We compute track features
+ fmaps_pyramid = []
+ track_feat_pyramid = []
+ track_feat_support_pyramid = []
+ fmaps_pyramid.append(fmaps)
+ for i in range(self.corr_levels - 1):
+ fmaps_ = fmaps.reshape(
+ B * T, self.latent_dim, fmaps.shape[-2], fmaps.shape[-1]
+ )
+ fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
+ fmaps = fmaps_.reshape(
+ B, T, self.latent_dim, fmaps_.shape[-2], fmaps_.shape[-1]
+ )
+ fmaps_pyramid.append(fmaps)
+
+ for i in range(self.corr_levels):
+ track_feat, track_feat_support = self.get_track_feat(
+ fmaps_pyramid[i],
+ queried_frames,
+ queried_coords / 2**i,
+ support_radius=self.corr_radius,
+ )
+ track_feat_pyramid.append(track_feat.repeat(1, T, 1, 1))
+ track_feat_support_pyramid.append(track_feat_support.unsqueeze(1))
+
+ D_coords = 2
+
+ coord_preds, vis_preds, confidence_preds = [], [], []
+
+ vis = torch.zeros((B, T, N), device=device).float()
+ confidence = torch.zeros((B, T, N), device=device).float()
+ coords = queried_coords.reshape(B, 1, N, 2).expand(B, T, N, 2).float()
+
+ r = 2 * self.corr_radius + 1
+
+ for it in range(iters):
+ coords = coords.detach() # B T N 2
+ coords_init = coords.view(B * T, N, 2)
+ corr_embs = []
+ corr_feats = []
+ for i in range(self.corr_levels):
+ corr_feat = self.get_correlation_feat(
+ fmaps_pyramid[i], coords_init / 2**i
+ )
+ track_feat_support = (
+ track_feat_support_pyramid[i]
+ .view(B, 1, r, r, N, self.latent_dim)
+ .squeeze(1)
+ .permute(0, 3, 1, 2, 4)
+ )
+ corr_volume = torch.einsum(
+ "btnhwc,bnijc->btnhwij", corr_feat, track_feat_support
+ )
+ corr_emb = self.corr_mlp(corr_volume.reshape(B * T * N, r * r * r * r))
+ corr_embs.append(corr_emb)
+ corr_embs = torch.cat(corr_embs, dim=-1)
+ corr_embs = corr_embs.view(B, T, N, corr_embs.shape[-1])
+
+ transformer_input = [vis[..., None], confidence[..., None], corr_embs]
+
+ rel_coords_forward = coords[:, :-1] - coords[:, 1:]
+ rel_coords_backward = coords[:, 1:] - coords[:, :-1]
+
+ rel_coords_forward = torch.nn.functional.pad(
+ rel_coords_forward, (0, 0, 0, 0, 0, 1)
+ )
+ rel_coords_backward = torch.nn.functional.pad(
+ rel_coords_backward, (0, 0, 0, 0, 1, 0)
+ )
+ scale = (
+ torch.tensor(
+ [self.model_resolution[1], self.model_resolution[0]],
+ device=coords.device,
+ )
+ / self.stride
+ )
+ rel_coords_forward = rel_coords_forward / scale
+ rel_coords_backward = rel_coords_backward / scale
+
+ rel_pos_emb_input = posenc(
+ torch.cat([rel_coords_forward, rel_coords_backward], dim=-1),
+ min_deg=0,
+ max_deg=10,
+ ) # batch, num_points, num_frames, 84
+ transformer_input.append(rel_pos_emb_input)
+
+ x = (
+ torch.cat(transformer_input, dim=-1)
+ .permute(0, 2, 1, 3)
+ .reshape(B * N, T, -1)
+ )
+
+ x = x + self.interpolate_time_embed(x, T)
+ x = x.view(B, N, T, -1) # (B N) T D -> B N T D
+
+ delta = self.updateformer(
+ x,
+ add_space_attn=add_space_attn,
+ )
+
+ delta_coords = delta[..., :D_coords].permute(0, 2, 1, 3)
+ delta_vis = delta[..., D_coords].permute(0, 2, 1)
+ delta_confidence = delta[..., D_coords + 1].permute(0, 2, 1)
+
+ vis = vis + delta_vis
+ confidence = confidence + delta_confidence
+
+ coords = coords + delta_coords
+ coords_append = coords.clone()
+ coords_append[..., :2] = coords_append[..., :2] * float(self.stride)
+ coord_preds.append(coords_append)
+ vis_preds.append(torch.sigmoid(vis))
+ confidence_preds.append(torch.sigmoid(confidence))
+
+ if is_train:
+ all_coords_predictions.append([coord[..., :2] for coord in coord_preds])
+ all_vis_predictions.append(vis_preds)
+ all_confidence_predictions.append(confidence_preds)
+
+ if is_train:
+ train_data = (
+ all_coords_predictions,
+ all_vis_predictions,
+ all_confidence_predictions,
+ torch.ones_like(vis_preds[-1], device=vis_preds[-1].device),
+ )
+ else:
+ train_data = None
+
+ return coord_preds[-1][..., :2], vis_preds[-1], confidence_preds[-1], train_data
+
+
+if __name__ == "__main__":
+ cotrack_cktp = "/data0/xyx/scaled_offline.pth"
+ cotracker = CoTrackerThreeOffline(
+ stride=4, corr_radius=3, window_len=60
+ )
+ with open(cotrack_cktp, "rb") as f:
+ state_dict = torch.load(f, map_location="cpu")
+ if "model" in state_dict:
+ state_dict = state_dict["model"]
+ cotracker.load_state_dict(state_dict)
+ import pdb; pdb.set_trace()
\ No newline at end of file
diff --git a/models/SpaTrackV2/models/tracker3D/co_tracker/utils.py b/models/SpaTrackV2/models/tracker3D/co_tracker/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..728785138666591edd4650a506ad5c694d2237fe
--- /dev/null
+++ b/models/SpaTrackV2/models/tracker3D/co_tracker/utils.py
@@ -0,0 +1,929 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from typing import Callable, List
+import collections
+from torch import Tensor
+from itertools import repeat
+from models.SpaTrackV2.utils.model_utils import bilinear_sampler
+from models.SpaTrackV2.models.blocks import CrossAttnBlock as CrossAttnBlock_F
+from torch.nn.functional import scaled_dot_product_attention
+from torch.nn.attention import sdpa_kernel, SDPBackend
+# import flash_attn
+EPS = 1e-6
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_planes,
+ planes,
+ kernel_size=3,
+ padding=1,
+ stride=stride,
+ padding_mode="zeros",
+ )
+ self.conv2 = nn.Conv2d(
+ planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
+ )
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == "group":
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == "batch":
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == "instance":
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == "none":
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not stride == 1:
+ self.norm3 = nn.Sequential()
+
+ if stride == 1:
+ self.downsample = None
+
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
+ )
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x + y)
+
+def reduce_masked_mean(input, mask, dim=None, keepdim=False):
+ r"""Masked mean
+
+ `reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
+ over a mask :attr:`mask`, returning
+
+ .. math::
+ \text{output} =
+ \frac
+ {\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
+ {\epsilon + \sum_{i=1}^N \text{mask}_i}
+
+ where :math:`N` is the number of elements in :attr:`input` and
+ :attr:`mask`, and :math:`\epsilon` is a small constant to avoid
+ division by zero.
+
+ `reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
+ :attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
+ Optionally, the dimension can be kept in the output by setting
+ :attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
+ the same dimension as :attr:`input`.
+
+ The interface is similar to `torch.mean()`.
+
+ Args:
+ inout (Tensor): input tensor.
+ mask (Tensor): mask.
+ dim (int, optional): Dimension to sum over. Defaults to None.
+ keepdim (bool, optional): Keep the summed dimension. Defaults to False.
+
+ Returns:
+ Tensor: mean tensor.
+ """
+
+ mask = mask.expand_as(input)
+
+ prod = input * mask
+
+ if dim is None:
+ numer = torch.sum(prod)
+ denom = torch.sum(mask)
+ else:
+ numer = torch.sum(prod, dim=dim, keepdim=keepdim)
+ denom = torch.sum(mask, dim=dim, keepdim=keepdim)
+
+ mean = numer / (EPS + denom)
+ return mean
+
+class GeometryEncoder(nn.Module):
+ def __init__(self, input_dim=3, output_dim=128, stride=4):
+ super(GeometryEncoder, self).__init__()
+ self.stride = stride
+ self.norm_fn = "instance"
+ self.in_planes = output_dim // 2
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
+ self.conv1 = nn.Conv2d(
+ input_dim,
+ self.in_planes,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ padding_mode="zeros",
+ )
+ self.relu1 = nn.ReLU(inplace=True)
+ self.layer1 = self._make_layer(output_dim // 2, stride=1)
+ self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
+
+ self.conv2 = nn.Conv2d(
+ output_dim * 5 // 4,
+ output_dim,
+ kernel_size=3,
+ padding=1,
+ padding_mode="zeros",
+ )
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = nn.Conv2d(output_dim, output_dim, kernel_size=1)
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.InstanceNorm2d)):
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, dim, stride=1):
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ _, _, H, W = x.shape
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu1(x)
+ a = self.layer1(x)
+ b = self.layer2(a)
+ def _bilinear_intepolate(x):
+ return F.interpolate(
+ x,
+ (H // self.stride, W // self.stride),
+ mode="bilinear",
+ align_corners=True,
+ )
+ a = _bilinear_intepolate(a)
+ b = _bilinear_intepolate(b)
+ x = self.conv2(torch.cat([a, b], dim=1))
+ x = self.norm2(x)
+ x = self.relu2(x)
+ x = self.conv3(x)
+ return x
+
+class BasicEncoder(nn.Module):
+ def __init__(self, input_dim=3, output_dim=128, stride=4):
+ super(BasicEncoder, self).__init__()
+ self.stride = stride
+ self.norm_fn = "instance"
+ self.in_planes = output_dim // 2
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
+
+ self.conv1 = nn.Conv2d(
+ input_dim,
+ self.in_planes,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ padding_mode="zeros",
+ )
+ self.relu1 = nn.ReLU(inplace=True)
+ self.layer1 = self._make_layer(output_dim // 2, stride=1)
+ self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
+ self.layer3 = self._make_layer(output_dim, stride=2)
+ self.layer4 = self._make_layer(output_dim, stride=2)
+
+ self.conv2 = nn.Conv2d(
+ output_dim * 3 + output_dim // 4,
+ output_dim * 2,
+ kernel_size=3,
+ padding=1,
+ padding_mode="zeros",
+ )
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.InstanceNorm2d)):
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, dim, stride=1):
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ _, _, H, W = x.shape
+
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu1(x)
+
+ a = self.layer1(x)
+ b = self.layer2(a)
+ c = self.layer3(b)
+ d = self.layer4(c)
+
+ def _bilinear_intepolate(x):
+ return F.interpolate(
+ x,
+ (H // self.stride, W // self.stride),
+ mode="bilinear",
+ align_corners=True,
+ )
+
+ a = _bilinear_intepolate(a)
+ b = _bilinear_intepolate(b)
+ c = _bilinear_intepolate(c)
+ d = _bilinear_intepolate(d)
+
+ x = self.conv2(torch.cat([a, b, c, d], dim=1))
+ x = self.norm2(x)
+ x = self.relu2(x)
+ x = self.conv3(x)
+ return x
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ return val if exists(val) else d
+
+
+to_2tuple = _ntuple(2)
+
+
+class Mlp(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.0,
+ use_conv=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.norm = (
+ norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
+ )
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False
+ ):
+ super().__init__()
+ inner_dim = dim_head * num_heads
+ self.inner_dim = inner_dim
+ context_dim = default(context_dim, query_dim)
+ self.scale = dim_head**-0.5
+ self.heads = num_heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
+ self.to_out = nn.Linear(inner_dim, query_dim)
+
+ def forward(self, x, context=None, attn_bias=None, flash=True):
+ B, N1, C = x.shape
+ h = self.heads
+
+ q = self.to_q(x).reshape(B, N1, h, self.inner_dim // h).permute(0, 2, 1, 3)
+ context = default(context, x)
+ k, v = self.to_kv(context).chunk(2, dim=-1)
+
+ N2 = context.shape[1]
+ k = k.reshape(B, N2, h, self.inner_dim // h).permute(0, 2, 1, 3)
+ v = v.reshape(B, N2, h, self.inner_dim // h).permute(0, 2, 1, 3)
+
+ if (
+ (N1 < 64 and N2 < 64) or
+ (B > 1e4) or
+ (q.shape[1] != k.shape[1]) or
+ (q.shape[1] % k.shape[1] != 0)
+ ):
+ flash = False
+
+
+ if flash == False:
+ sim = (q @ k.transpose(-2, -1)) * self.scale
+ if attn_bias is not None:
+ sim = sim + attn_bias
+ if sim.abs().max() > 1e2:
+ import pdb; pdb.set_trace()
+ attn = sim.softmax(dim=-1)
+ x = (attn @ v).transpose(1, 2).reshape(B, N1, self.inner_dim)
+ else:
+
+ input_args = [x.contiguous() for x in [q, k, v]]
+ try:
+ # print(f"q.shape: {q.shape}, dtype: {q.dtype}, device: {q.device}")
+ # print(f"Flash SDP available: {torch.backends.cuda.flash_sdp_enabled()}")
+ # print(f"Flash SDP allowed: {torch.backends.cuda.enable_flash_sdp}")
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
+ x = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).reshape(B,N1,-1) # type: ignore
+ except Exception as e:
+ print(e)
+
+ if self.to_out.bias.dtype != x.dtype:
+ x = x.to(self.to_out.bias.dtype)
+
+ return self.to_out(x)
+
+class CrossAttnBlock(nn.Module):
+ def __init__(
+ self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs
+ ):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.norm_context = nn.LayerNorm(context_dim)
+ self.cross_attn = Attention(
+ hidden_size,
+ context_dim=context_dim,
+ num_heads=num_heads,
+ qkv_bias=True,
+ **block_kwargs
+ )
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(
+ in_features=hidden_size,
+ hidden_features=mlp_hidden_dim,
+ act_layer=approx_gelu,
+ drop=0,
+ )
+
+ def forward(self, x, context, mask=None):
+ attn_bias = None
+ if mask is not None:
+ if mask.shape[1] == x.shape[1]:
+ mask = mask[:, None, :, None].expand(
+ -1, self.cross_attn.heads, -1, context.shape[1]
+ )
+ else:
+ mask = mask[:, None, None].expand(
+ -1, self.cross_attn.heads, x.shape[1], -1
+ )
+
+ max_neg_value = -torch.finfo(x.dtype).max
+ attn_bias = (~mask) * max_neg_value
+ x = x + self.cross_attn(
+ self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias
+ )
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+class AttnBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ attn_class: Callable[..., nn.Module] = Attention,
+ mlp_ratio=4.0,
+ **block_kwargs
+ ):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.attn = attn_class(
+ hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
+ )
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(
+ in_features=hidden_size,
+ hidden_features=mlp_hidden_dim,
+ act_layer=approx_gelu,
+ drop=0,
+ )
+
+ def forward(self, x, mask=None):
+ attn_bias = mask
+ if mask is not None:
+ mask = (
+ (mask[:, None] * mask[:, :, None])
+ .unsqueeze(1)
+ .expand(-1, self.attn.num_heads, -1, -1)
+ )
+ max_neg_value = -torch.finfo(x.dtype).max
+ attn_bias = (~mask) * max_neg_value
+ x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+class EfficientUpdateFormer(nn.Module):
+ """
+ Transformer model that updates track estimates.
+ """
+
+ def __init__(
+ self,
+ space_depth=6,
+ time_depth=6,
+ input_dim=320,
+ hidden_size=384,
+ num_heads=8,
+ output_dim=130,
+ mlp_ratio=4.0,
+ num_virtual_tracks=64,
+ add_space_attn=True,
+ linear_layer_for_vis_conf=False,
+ patch_feat=False,
+ patch_dim=128,
+ ):
+ super().__init__()
+ self.out_channels = 2
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
+ if linear_layer_for_vis_conf:
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True)
+ self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
+ else:
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
+
+ if patch_feat==False:
+ self.virual_tracks = nn.Parameter(
+ torch.randn(1, num_virtual_tracks, 1, hidden_size)
+ )
+ self.num_virtual_tracks = num_virtual_tracks
+ else:
+ self.patch_proj = nn.Linear(patch_dim, hidden_size, bias=True)
+
+ self.add_space_attn = add_space_attn
+ self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
+ self.time_blocks = nn.ModuleList(
+ [
+ AttnBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ attn_class=Attention,
+ )
+ for _ in range(time_depth)
+ ]
+ )
+
+ if add_space_attn:
+ self.space_virtual_blocks = nn.ModuleList(
+ [
+ AttnBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ attn_class=Attention,
+ )
+ for _ in range(space_depth)
+ ]
+ )
+ self.space_point2virtual_blocks = nn.ModuleList(
+ [
+ CrossAttnBlock(
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
+ )
+ for _ in range(space_depth)
+ ]
+ )
+ self.space_virtual2point_blocks = nn.ModuleList(
+ [
+ CrossAttnBlock(
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
+ )
+ for _ in range(space_depth)
+ ]
+ )
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
+ if self.linear_layer_for_vis_conf:
+ torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001)
+
+ def _trunc_init(module):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ torch.nn.init.trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ self.apply(_basic_init)
+
+ def forward(self, input_tensor, mask=None, add_space_attn=True, patch_feat=None):
+ tokens = self.input_transform(input_tensor)
+
+ B, _, T, _ = tokens.shape
+ if patch_feat is None:
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
+ else:
+ patch_feat = self.patch_proj(patch_feat.detach())
+ tokens = torch.cat([tokens, patch_feat], dim=1)
+ self.num_virtual_tracks = patch_feat.shape[1]
+
+ _, N, _, _ = tokens.shape
+ j = 0
+ layers = []
+ for i in range(len(self.time_blocks)):
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
+ time_tokens = torch.utils.checkpoint.checkpoint(
+ self.time_blocks[i],
+ time_tokens
+ )
+
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
+ if (
+ add_space_attn
+ and hasattr(self, "space_virtual_blocks")
+ and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0)
+ ):
+ space_tokens = (
+ tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
+ ) # B N T C -> (B T) N C
+
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
+
+ virtual_tokens = torch.utils.checkpoint.checkpoint(
+ self.space_virtual2point_blocks[j],
+ virtual_tokens, point_tokens, mask
+ )
+
+ virtual_tokens = torch.utils.checkpoint.checkpoint(
+ self.space_virtual_blocks[j],
+ virtual_tokens
+ )
+
+ point_tokens = torch.utils.checkpoint.checkpoint(
+ self.space_point2virtual_blocks[j],
+ point_tokens, virtual_tokens, mask
+ )
+
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
+ tokens = space_tokens.view(B, T, N, -1).permute(
+ 0, 2, 1, 3
+ ) # (B T) N C -> B N T C
+ j += 1
+ tokens = tokens[:, : N - self.num_virtual_tracks]
+
+ flow = self.flow_head(tokens)
+ if self.linear_layer_for_vis_conf:
+ vis_conf = self.vis_conf_head(tokens)
+ flow = torch.cat([flow, vis_conf], dim=-1)
+
+ return flow
+
+def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
+ probs = torch.sigmoid(logits)
+ ce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
+ p_t = probs * targets + (1 - probs) * (1 - targets)
+ loss = alpha * (1 - p_t) ** gamma * ce_loss
+ return loss.mean()
+
+def balanced_binary_cross_entropy(logits, targets, balance_weight=1.0, eps=1e-6, reduction="mean", pos_bias=0.0, mask=None):
+ """
+ logits: Tensor of arbitrary shape
+ targets: same shape as logits
+ balance_weight: scaling the loss
+ reduction: 'mean', 'sum', or 'none'
+ """
+ targets = targets.float()
+ positive = (targets == 1).float().sum()
+ total = targets.numel()
+ positive_ratio = positive / (total + eps)
+
+ pos_weight = (1 - positive_ratio) / (positive_ratio + eps)
+ pos_weight = pos_weight.clamp(min=0.1, max=10.0)
+ loss = F.binary_cross_entropy_with_logits(
+ logits,
+ targets,
+ pos_weight=pos_weight+pos_bias,
+ reduction=reduction
+ )
+ if mask is not None:
+ loss = (loss * mask).sum() / (mask.sum() + eps)
+ return balance_weight * loss
+
+def sequence_loss(
+ flow_preds,
+ flow_gt,
+ valids,
+ vis=None,
+ gamma=0.8,
+ add_huber_loss=False,
+ loss_only_for_visible=False,
+ depth_sample=None,
+ z_unc=None,
+ mask_traj_gt=None
+):
+ """Loss function defined over sequence of flow predictions"""
+ total_flow_loss = 0.0
+ for j in range(len(flow_gt)):
+ B, S, N, D = flow_gt[j].shape
+ B, S2, N = valids[j].shape
+ assert S == S2
+ n_predictions = len(flow_preds[j])
+ flow_loss = 0.0
+ for i in range(n_predictions):
+ i_weight = gamma ** (n_predictions - i - 1)
+ flow_pred = flow_preds[j][i][:,:,:flow_gt[j].shape[2]]
+ if flow_pred.shape[-1] == 3:
+ flow_pred[...,2] = flow_pred[...,2]
+ if add_huber_loss:
+ i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0)
+ else:
+ if flow_gt[j][...,2].abs().max() != 0:
+ track_z_loss = (flow_pred- flow_gt[j])[...,2].abs().mean()
+ if mask_traj_gt is not None:
+ track_z_loss = ((flow_pred- flow_gt[j])[...,2].abs() * mask_traj_gt.permute(0,2,1)).sum() / (mask_traj_gt.sum(dim=1)+1e-6)
+ else:
+ track_z_loss = 0
+ i_loss = (flow_pred[...,:2] - flow_gt[j][...,:2]).abs() # B, S, N, 2
+ # print((flow_pred - flow_gt[j])[...,2].abs()[vis[j].bool()].mean())
+ i_loss = torch.mean(i_loss, dim=3) # B, S, N
+ valid_ = valids[j].clone()[:,:, :flow_gt[j].shape[2]] # Ensure valid_ has the same shape as i_loss
+ valid_ = valid_ * (flow_gt[j][...,:2].norm(dim=-1) > 0).float()
+ if loss_only_for_visible:
+ valid_ = valid_ * vis[j]
+ # print(reduce_masked_mean(i_loss, valid_).item(), track_z_loss.item()/16)
+ flow_loss += i_weight * (reduce_masked_mean(i_loss, valid_) + track_z_loss + 10*reduce_masked_mean(i_loss, valid_* vis[j]))
+ # if flow_loss > 5e2:
+ # import pdb; pdb.set_trace()
+ flow_loss = flow_loss / n_predictions
+ total_flow_loss += flow_loss
+ return total_flow_loss / len(flow_gt)
+
+def sequence_loss_xyz(
+ flow_preds,
+ flow_gt,
+ valids,
+ intrs,
+ vis=None,
+ gamma=0.8,
+ add_huber_loss=False,
+ loss_only_for_visible=False,
+ mask_traj_gt=None
+):
+ """Loss function defined over sequence of flow predictions"""
+ total_flow_loss = 0.0
+ for j in range(len(flow_gt)):
+ B, S, N, D = flow_gt[j].shape
+ B, S2, N = valids[j].shape
+ assert S == S2
+ n_predictions = len(flow_preds[j])
+ flow_loss = 0.0
+ for i in range(n_predictions):
+ i_weight = gamma ** (n_predictions - i - 1)
+ flow_pred = flow_preds[j][i][:,:,:flow_gt[j].shape[2]]
+ flow_gt_ = flow_gt[j]
+ flow_gt_one = torch.cat([flow_gt_[...,:2], torch.ones_like(flow_gt_[:,:,:,:1])], dim=-1)
+ flow_gt_cam = torch.einsum('btsc,btnc->btns', torch.inverse(intrs), flow_gt_one)
+ flow_gt_cam *= flow_gt_[...,2:3].abs()
+ flow_gt_cam[...,2] *= torch.sign(flow_gt_cam[...,2])
+
+ if add_huber_loss:
+ i_loss = huber_loss(flow_pred, flow_gt_cam, delta=6.0)
+ else:
+ i_loss = (flow_pred- flow_gt_cam).norm(dim=-1,keepdim=True) # B, S, N, 2
+
+ # print((flow_pred - flow_gt[j])[...,2].abs()[vis[j].bool()].mean())
+ i_loss = torch.mean(i_loss, dim=3) # B, S, N
+ valid_ = valids[j].clone()[:,:, :flow_gt[j].shape[2]] # Ensure valid_ has the same shape as i_loss
+ if loss_only_for_visible:
+ valid_ = valid_ * vis[j]
+ # print(reduce_masked_mean(i_loss, valid_).item(), track_z_loss.item()/16)
+ flow_loss += i_weight * (reduce_masked_mean(i_loss, valid_)) * 1000
+ # if flow_loss > 5e2:
+ # import pdb; pdb.set_trace()
+ flow_loss = flow_loss / n_predictions
+ total_flow_loss += flow_loss
+ return total_flow_loss / len(flow_gt)
+
+def huber_loss(x, y, delta=1.0):
+ """Calculate element-wise Huber loss between x and y"""
+ diff = x - y
+ abs_diff = diff.abs()
+ flag = (abs_diff <= delta).float()
+ return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta)
+
+
+def sequence_BCE_loss(vis_preds, vis_gts, mask=None):
+ total_bce_loss = 0.0
+ for j in range(len(vis_preds)):
+ n_predictions = len(vis_preds[j])
+ bce_loss = 0.0
+ for i in range(n_predictions):
+ N_gt = vis_gts[j].shape[-1]
+ if mask is not None:
+ vis_loss = balanced_binary_cross_entropy(vis_preds[j][i][...,:N_gt], vis_gts[j], mask=mask[j], reduction="none")
+ else:
+ vis_loss = balanced_binary_cross_entropy(vis_preds[j][i][...,:N_gt], vis_gts[j]) + focal_loss(vis_preds[j][i][...,:N_gt], vis_gts[j])
+ # print(vis_loss, ((torch.sigmoid(vis_preds[j][i][...,:N_gt])>0.5).float() - vis_gts[j]).abs().sum())
+ bce_loss += vis_loss
+ bce_loss = bce_loss / n_predictions
+ total_bce_loss += bce_loss
+ return total_bce_loss / len(vis_preds)
+
+
+def sequence_prob_loss(
+ tracks: torch.Tensor,
+ confidence: torch.Tensor,
+ target_points: torch.Tensor,
+ visibility: torch.Tensor,
+ expected_dist_thresh: float = 12.0,
+):
+ """Loss for classifying if a point is within pixel threshold of its target."""
+ # Points with an error larger than 12 pixels are likely to be useless; marking
+ # them as occluded will actually improve Jaccard metrics and give
+ # qualitatively better results.
+ total_logprob_loss = 0.0
+ for j in range(len(tracks)):
+ n_predictions = len(tracks[j])
+ logprob_loss = 0.0
+ for i in range(n_predictions):
+ N_gt = target_points[j].shape[2]
+ err = torch.sum((tracks[j][i].detach()[:,:,:N_gt,:2] - target_points[j][...,:2]) ** 2, dim=-1)
+ valid = (err <= expected_dist_thresh**2).float()
+ logprob = balanced_binary_cross_entropy(confidence[j][i][...,:N_gt], valid, reduction="none")
+ logprob *= visibility[j]
+ logprob = torch.mean(logprob, dim=[1, 2])
+ logprob_loss += logprob
+ logprob_loss = logprob_loss / n_predictions
+ total_logprob_loss += logprob_loss
+ return total_logprob_loss / len(tracks)
+
+
+def sequence_dyn_prob_loss(
+ tracks: torch.Tensor,
+ confidence: torch.Tensor,
+ target_points: torch.Tensor,
+ visibility: torch.Tensor,
+ expected_dist_thresh: float = 6.0,
+):
+ """Loss for classifying if a point is within pixel threshold of its target."""
+ # Points with an error larger than 12 pixels are likely to be useless; marking
+ # them as occluded will actually improve Jaccard metrics and give
+ # qualitatively better results.
+ total_logprob_loss = 0.0
+ for j in range(len(tracks)):
+ n_predictions = len(tracks[j])
+ logprob_loss = 0.0
+ for i in range(n_predictions):
+ err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=-1)
+ valid = (err <= expected_dist_thresh**2).float()
+ valid = (valid.sum(dim=1) > 0).float()
+ logprob = balanced_binary_cross_entropy(confidence[j][i].mean(dim=1), valid, reduction="none")
+ # logprob *= visibility[j]
+ logprob = torch.mean(logprob, dim=[0, 1])
+ logprob_loss += logprob
+ logprob_loss = logprob_loss / n_predictions
+ total_logprob_loss += logprob_loss
+ return total_logprob_loss / len(tracks)
+
+
+def masked_mean(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
+ if mask is None:
+ return data.mean(dim=dim, keepdim=True)
+ mask = mask.float()
+ mask_sum = torch.sum(mask, dim=dim, keepdim=True)
+ mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
+ mask_sum, min=1.0
+ )
+ return mask_mean
+
+
+def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
+ if mask is None:
+ return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
+ mask = mask.float()
+ mask_sum = torch.sum(mask, dim=dim, keepdim=True)
+ mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
+ mask_sum, min=1.0
+ )
+ mask_var = torch.sum(
+ mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
+ ) / torch.clamp(mask_sum, min=1.0)
+ return mask_mean.squeeze(dim), mask_var.squeeze(dim)
+
+class NeighborTransformer(nn.Module):
+ def __init__(self, dim: int, num_heads: int, head_dim: int, mlp_ratio: float):
+ super().__init__()
+ self.dim = dim
+ self.output_token_1 = nn.Parameter(torch.randn(1, dim))
+ self.output_token_2 = nn.Parameter(torch.randn(1, dim))
+ self.xblock1_2 = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
+ self.xblock2_1 = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
+ self.aggr1 = Attention(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim)
+ self.aggr2 = Attention(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim)
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ from einops import rearrange, repeat
+ import torch.utils.checkpoint as checkpoint
+
+ assert len (x.shape) == 3, "x should be of shape (B, N, D)"
+ assert len (y.shape) == 3, "y should be of shape (B, N, D)"
+
+ # not work so well ...
+
+ def forward_chunk(x, y):
+ new_x = self.xblock1_2(x, y)
+ new_y = self.xblock2_1(y, x)
+ out1 = self.aggr1(repeat(self.output_token_1, 'n d -> b n d', b=x.shape[0]), context=new_x)
+ out2 = self.aggr2(repeat(self.output_token_2, 'n d -> b n d', b=x.shape[0]), context=new_y)
+ return out1 + out2
+
+ return checkpoint.checkpoint(forward_chunk, x, y)
+
+
+class CorrPointformer(nn.Module):
+ def __init__(self, dim: int, num_heads: int, head_dim: int, mlp_ratio: float):
+ super().__init__()
+ self.dim = dim
+ self.xblock1_2 = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
+ # self.xblock2_1 = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
+ self.aggr = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
+ self.out_proj = nn.Linear(dim, 2*dim)
+
+ def forward(self, query: torch.Tensor, target: torch.Tensor, target_rel_pos: torch.Tensor) -> torch.Tensor:
+ from einops import rearrange, repeat
+ import torch.utils.checkpoint as checkpoint
+
+ def forward_chunk(query, target, target_rel_pos):
+ new_query = self.xblock1_2(query, target).mean(dim=1, keepdim=True)
+ # new_target = self.xblock2_1(target, query).mean(dim=1, keepdim=True)
+ # new_aggr = new_query + new_target
+ out = self.aggr(new_query, target+target_rel_pos) # (potential delta xyz) (target - center)
+ out = self.out_proj(out)
+ return out
+
+ return checkpoint.checkpoint(forward_chunk, query, target, target_rel_pos)
\ No newline at end of file
diff --git a/models/SpaTrackV2/models/tracker3D/delta_utils/__init__.py b/models/SpaTrackV2/models/tracker3D/delta_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/SpaTrackV2/models/tracker3D/delta_utils/blocks.py b/models/SpaTrackV2/models/tracker3D/delta_utils/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..029ade67ba6561550ffd3a6330fc1e0e0b9de734
--- /dev/null
+++ b/models/SpaTrackV2/models/tracker3D/delta_utils/blocks.py
@@ -0,0 +1,842 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import collections
+from functools import partial
+from itertools import repeat
+from typing import Callable
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from models.SpaTrackV2.models.blocks import bilinear_sampler
+from einops import rearrange
+from torch import Tensor, einsum
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ return val if exists(val) else d
+
+
+to_2tuple = _ntuple(2)
+
+
+class Mlp(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.0,
+ use_conv=False,
+ zero_init=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ if zero_init:
+ self.zero_init()
+
+ def zero_init(self):
+ nn.init.constant_(self.fc2.weight, 0)
+ if self.fc2.bias is not None:
+ nn.init.constant_(self.fc2.bias, 0)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x, mode="nearest"):
+ x = F.interpolate(x, scale_factor=2.0, mode=mode)
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_planes,
+ planes,
+ kernel_size=3,
+ padding=1,
+ stride=stride,
+ padding_mode="zeros",
+ )
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == "group":
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == "batch":
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == "instance":
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == "none":
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not stride == 1:
+ self.norm3 = nn.Sequential()
+
+ if stride == 1:
+ self.downsample = None
+
+ else:
+ self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x + y)
+
+
+class BasicEncoder(nn.Module):
+ def __init__(self, input_dim=3, output_dim=128, stride=4):
+ super(BasicEncoder, self).__init__()
+ self.stride = stride
+ self.norm_fn = "instance"
+ self.in_planes = output_dim // 2
+
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
+
+ self.conv1 = nn.Conv2d(
+ input_dim,
+ self.in_planes,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ padding_mode="zeros",
+ )
+ self.relu1 = nn.ReLU(inplace=True)
+ self.layer1 = self._make_layer(output_dim // 2, stride=1)
+ self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
+ self.layer3 = self._make_layer(output_dim, stride=2)
+ self.layer4 = self._make_layer(output_dim, stride=2)
+
+ self.conv2 = nn.Conv2d(
+ output_dim * 3 + output_dim // 4,
+ output_dim * 2,
+ kernel_size=3,
+ padding=1,
+ padding_mode="zeros",
+ )
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.InstanceNorm2d)):
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, dim, stride=1):
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x, return_intermediate=False):
+ _, _, H, W = x.shape
+
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu1(x)
+
+ a = self.layer1(x)
+ b = self.layer2(a)
+ c = self.layer3(b)
+ d = self.layer4(c)
+
+ def _bilinear_intepolate(x):
+ return F.interpolate(
+ x,
+ (H // self.stride, W // self.stride),
+ mode="bilinear",
+ align_corners=True,
+ )
+
+ # a = _bilinear_intepolate(a)
+ # b = _bilinear_intepolate(b)
+ # c = _bilinear_intepolate(c)
+ # d = _bilinear_intepolate(d)
+
+ cat_feat = torch.cat(
+ [_bilinear_intepolate(a), _bilinear_intepolate(b), _bilinear_intepolate(c), _bilinear_intepolate(d)], dim=1
+ )
+ x = self.conv2(cat_feat)
+ x = self.norm2(x)
+ x = self.relu2(x)
+ x = self.conv3(x)
+
+ # breakpoint()
+ if return_intermediate:
+ if self.stride == 4:
+ return x, a, c # 128, h/4, w/4, - 64, h/2, w/2 - 128, h/8, w/8
+ elif self.stride == 8:
+ return x, b, d
+ else:
+ raise NotImplementedError
+ return x
+
+
+class CorrBlockFP16:
+ def __init__(
+ self,
+ fmaps,
+ num_levels=4,
+ radius=4,
+ multiple_track_feats=False,
+ padding_mode="zeros",
+ ):
+ B, S, C, H, W = fmaps.shape
+ self.S, self.C, self.H, self.W = S, C, H, W
+ self.padding_mode = padding_mode
+ self.num_levels = num_levels
+ self.radius = radius
+ self.fmaps_pyramid = []
+ self.multiple_track_feats = multiple_track_feats
+
+ self.fmaps_pyramid.append(fmaps)
+ for i in range(self.num_levels - 1):
+ fmaps_ = fmaps.reshape(B * S, C, H, W)
+ fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
+ _, _, H, W = fmaps_.shape
+ fmaps = fmaps_.reshape(B, S, C, H, W)
+ self.fmaps_pyramid.append(fmaps)
+
+ def sample(self, coords):
+ r = self.radius
+ B, S, N, D = coords.shape
+ assert D == 2
+
+ H, W = self.H, self.W
+ out_pyramid = []
+ for i in range(self.num_levels):
+ corrs = self.corrs_pyramid[i] # B, S, N, H, W
+ *_, H, W = corrs.shape
+
+ dx = torch.linspace(-r, r, 2 * r + 1)
+ dy = torch.linspace(-r, r, 2 * r + 1)
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
+
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
+ coords_lvl = centroid_lvl + delta_lvl
+
+ # breakpoint()
+ corrs = bilinear_sampler(
+ corrs.reshape(B * S * N, 1, H, W),
+ coords_lvl,
+ padding_mode=self.padding_mode,
+ )
+ corrs = corrs.view(B, S, N, -1)
+ out_pyramid.append(corrs)
+
+ del self.corrs_pyramid
+
+ out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
+ out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
+ return out
+
+ def corr(self, targets):
+ B, S, N, C = targets.shape
+ if self.multiple_track_feats:
+ targets_split = targets.split(C // self.num_levels, dim=-1)
+ B, S, N, C = targets_split[0].shape
+
+ assert C == self.C
+ assert S == self.S
+
+ fmap1 = targets
+
+ self.corrs_pyramid = []
+ for i, fmaps in enumerate(self.fmaps_pyramid):
+ *_, H, W = fmaps.shape
+ fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
+ if self.multiple_track_feats:
+ fmap1 = targets_split[i]
+ corrs = torch.matmul(fmap1, fmap2s)
+ corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
+ corrs = corrs / torch.sqrt(torch.tensor(C).float())
+ # breakpoint()
+ self.corrs_pyramid.append(corrs)
+
+
+class CorrBlock:
+ def __init__(
+ self,
+ fmaps,
+ num_levels=4,
+ radius=4,
+ multiple_track_feats=False,
+ padding_mode="zeros",
+ ):
+ B, S, C, H, W = fmaps.shape
+ self.S, self.C, self.H, self.W = S, C, H, W
+ self.padding_mode = padding_mode
+ self.num_levels = num_levels
+ self.radius = radius
+ self.fmaps_pyramid = []
+ self.multiple_track_feats = multiple_track_feats
+
+ self.fmaps_pyramid.append(fmaps)
+ for i in range(self.num_levels - 1):
+ fmaps_ = fmaps.reshape(B * S, C, H, W)
+ fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
+ _, _, H, W = fmaps_.shape
+ fmaps = fmaps_.reshape(B, S, C, H, W)
+ self.fmaps_pyramid.append(fmaps)
+
+ def sample(self, coords, delete=True):
+ r = self.radius
+ B, S, N, D = coords.shape
+ assert D == 2
+
+ H, W = self.H, self.W
+ out_pyramid = []
+ for i in range(self.num_levels):
+ corrs = self.corrs_pyramid[i] # B, S, N, H, W
+ *_, H, W = corrs.shape
+
+ dx = torch.linspace(-r, r, 2 * r + 1)
+ dy = torch.linspace(-r, r, 2 * r + 1)
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
+
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
+ coords_lvl = centroid_lvl + delta_lvl
+
+ # breakpoint()
+
+ # t1 = time.time()
+ corrs = bilinear_sampler(
+ corrs.reshape(B * S * N, 1, H, W),
+ coords_lvl,
+ padding_mode=self.padding_mode,
+ )
+ # t2 = time.time()
+
+ # print(coords_lvl.shape, t2 - t1)
+ corrs = corrs.view(B, S, N, -1)
+ out_pyramid.append(corrs)
+
+ if delete:
+ del self.corrs_pyramid
+
+ out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
+ out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
+ return out
+
+ def corr(self, targets):
+ B, S, N, C = targets.shape
+ if self.multiple_track_feats:
+ targets_split = targets.split(C // self.num_levels, dim=-1)
+ B, S, N, C = targets_split[0].shape
+
+ assert C == self.C
+ assert S == self.S
+
+ fmap1 = targets
+
+ self.corrs_pyramid = []
+ for i, fmaps in enumerate(self.fmaps_pyramid):
+ *_, H, W = fmaps.shape
+ fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
+ if self.multiple_track_feats:
+ fmap1 = targets_split[i]
+ corrs = torch.matmul(fmap1, fmap2s)
+ corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
+ corrs = corrs / torch.sqrt(torch.tensor(C).float())
+ # breakpoint()
+ self.corrs_pyramid.append(corrs)
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ query_dim,
+ context_dim=None,
+ num_heads=8,
+ dim_head=48,
+ qkv_bias=False,
+ flash=False,
+ alibi=False,
+ zero_init=False,
+ ):
+ super().__init__()
+ inner_dim = dim_head * num_heads
+ context_dim = default(context_dim, query_dim)
+ self.scale = dim_head**-0.5
+ self.heads = num_heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
+ self.to_out = nn.Linear(inner_dim, query_dim)
+
+ self.flash = flash
+ self.alibi = alibi
+
+ if zero_init:
+ self.zero_init()
+ # if self.alibi:
+ # self.training_length = 24
+
+ # bias_forward = get_alibi_slope(self.heads // 2) * get_relative_positions(self.training_length)
+ # bias_forward = bias_forward + torch.triu(torch.full_like(bias_forward, -1e9), diagonal=1)
+ # bias_backward = get_alibi_slope(self.heads // 2) * get_relative_positions(self.training_length, reverse=True)
+ # bias_backward = bias_backward + torch.tril(torch.full_like(bias_backward, -1e9), diagonal=-1)
+
+ # self.precomputed_attn_bias = self.register_buffer("precomputed_attn_bias", torch.cat([bias_forward, bias_backward], dim=0), persistent=False)
+
+ def zero_init(self):
+ nn.init.constant_(self.to_out.weight, 0)
+ nn.init.constant_(self.to_out.bias, 0)
+
+ # breakpoint()
+
+ def forward(self, x, context=None, attn_bias=None):
+ B, N1, C = x.shape
+ h = self.heads
+
+ q = self.to_q(x).reshape(B, N1, h, C // h)
+ context = default(context, x)
+ N2 = context.shape[1]
+ k, v = self.to_kv(context).chunk(2, dim=-1)
+ k = k.reshape(B, N2, h, C // h)
+ v = v.reshape(B, N2, h, C // h)
+
+ if self.flash:
+ with torch.autocast(device_type="cuda", enabled=True):
+ x = flash_attn_func(q.half(), k.half(), v.half())
+ x = x.reshape(B, N1, C)
+ x = x.float()
+ else:
+ q = q.permute(0, 2, 1, 3)
+ k = k.permute(0, 2, 1, 3)
+ v = v.permute(0, 2, 1, 3)
+
+ sim = (q @ k.transpose(-2, -1)) * self.scale
+
+ if attn_bias is not None:
+ sim = sim + attn_bias
+ attn = sim.softmax(dim=-1)
+
+ x = attn @ v
+ x = x.transpose(1, 2).reshape(B, N1, C)
+ x = self.to_out(x)
+ return x
+
+ def forward_noattn(self, x):
+ # B, N1, C = x.shape
+ # h = self.heads
+ _, x = self.to_kv(x).chunk(2, dim=-1)
+ # x = x.reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
+ # x = x.transpose(1, 2).reshape(B, N1, C)
+ x = self.to_out(x)
+
+ return x
+
+
+def get_relative_positions(seq_len, reverse=False, device="cpu"):
+ x = torch.arange(seq_len, device=device)[None, :]
+ y = torch.arange(seq_len, device=device)[:, None]
+ return torch.tril(x - y) if not reverse else torch.triu(y - x)
+
+
+def get_alibi_slope(num_heads, device="cpu"):
+ x = (24) ** (1 / num_heads)
+ return torch.tensor([1 / x ** (i + 1) for i in range(num_heads)], device=device, dtype=torch.float32).view(
+ -1, 1, 1
+ )
+
+
+class RelativeAttention(nn.Module):
+ """Multi-headed attention (MHA) module."""
+
+ def __init__(self, query_dim, num_heads=8, qkv_bias=True, model_size=None, flash=False):
+ super(RelativeAttention, self).__init__()
+
+ query_dim = query_dim // num_heads
+ self.num_heads = num_heads
+ self.query_dim = query_dim
+ self.value_size = query_dim
+ self.model_size = query_dim * num_heads
+
+ self.qkv_bias = qkv_bias
+
+ self.query_proj = nn.Linear(num_heads * query_dim, num_heads * query_dim, bias=qkv_bias)
+ self.key_proj = nn.Linear(num_heads * query_dim, num_heads * query_dim, bias=qkv_bias)
+ self.value_proj = nn.Linear(num_heads * self.value_size, num_heads * self.value_size, bias=qkv_bias)
+ self.final_proj = nn.Linear(num_heads * self.value_size, self.model_size, bias=qkv_bias)
+
+ self.training_length = 24
+
+ bias_forward = get_alibi_slope(self.num_heads // 2) * get_relative_positions(self.training_length)
+ bias_forward = bias_forward + torch.triu(torch.full_like(bias_forward, -1e9), diagonal=1)
+ bias_backward = get_alibi_slope(self.num_heads // 2) * get_relative_positions(
+ self.training_length, reverse=True
+ )
+ bias_backward = bias_backward + torch.tril(torch.full_like(bias_backward, -1e9), diagonal=-1)
+
+ self.register_buffer(
+ "precomputed_attn_bias", torch.cat([bias_forward, bias_backward], dim=0), persistent=False
+ )
+
+ def forward(self, x, attn_bias=None):
+ batch_size, sequence_length, _ = x.size()
+
+ query_heads = self._linear_projection(x, self.query_dim, self.query_proj) # [T', H, Q=K]
+ key_heads = self._linear_projection(x, self.query_dim, self.key_proj) # [T, H, K]
+ value_heads = self._linear_projection(x, self.value_size, self.value_proj) # [T, H, V]
+
+ if self.training_length == sequence_length:
+ new_attn_bias = self.precomputed_attn_bias
+ else:
+ device = x.device
+ bias_forward = get_alibi_slope(self.num_heads // 2, device=device) * get_relative_positions(
+ sequence_length, device=device
+ )
+ bias_forward = bias_forward + torch.triu(torch.full_like(bias_forward, -1e9), diagonal=1)
+ bias_backward = get_alibi_slope(self.num_heads // 2, device=device) * get_relative_positions(
+ sequence_length, reverse=True, device=device
+ )
+ bias_backward = bias_backward + torch.tril(torch.full_like(bias_backward, -1e9), diagonal=-1)
+ new_attn_bias = torch.cat([bias_forward, bias_backward], dim=0)
+
+ if attn_bias is not None:
+ attn_bias = attn_bias + new_attn_bias
+ else:
+ attn_bias = new_attn_bias
+
+ attn = F.scaled_dot_product_attention(
+ query_heads, key_heads, value_heads, attn_mask=new_attn_bias, scale=1 / np.sqrt(self.query_dim)
+ )
+ attn = attn.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, -1)
+
+ return self.final_proj(attn) # [T', D']
+
+ # attn_logits = torch.einsum("...thd,...Thd->...htT", query_heads, key_heads)
+ # attn_logits = attn_logits / np.sqrt(self.query_dim) + new_attn_bias
+
+ # # breakpoint()
+ # if attn_bias is not None:
+ # if attn_bias.ndim != attn_logits.ndim:
+ # raise ValueError(f"Mask dimensionality {attn_bias.ndim} must match logits dimensionality {attn_logits.ndim}.")
+ # attn_logits = torch.where(attn_bias, attn_logits, torch.tensor(-1e30))
+
+ # attn_weights = F.softmax(attn_logits, dim=-1) # [H, T', T]
+
+ # attn = torch.einsum("...htT,...Thd->...thd", attn_weights, value_heads)
+ # attn = attn.reshape(batch_size, sequence_length, -1) # [T', H*V]
+
+ # return self.final_proj(attn) # [T', D']
+
+ # def _linear_projection(self, x, head_size, proj_layer):
+ # y = proj_layer(x)
+ # *leading_dims, _ = x.shape
+ # return y.reshape((*leading_dims, self.num_heads, head_size))
+
+ def _linear_projection(self, x, head_size, proj_layer):
+ y = proj_layer(x)
+ batch_size, sequence_length, _ = x.shape
+ return y.reshape((batch_size, sequence_length, self.num_heads, head_size)).permute(0, 2, 1, 3)
+
+
+class AttnBlock(nn.Module):
+ def __init__(
+ self, hidden_size, num_heads, attn_class: Callable[..., nn.Module] = Attention, mlp_ratio=4.0, **block_kwargs
+ ):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(
+ in_features=hidden_size,
+ hidden_features=mlp_hidden_dim,
+ act_layer=approx_gelu,
+ drop=0,
+ )
+
+ def forward(self, x, mask=None):
+ attn_bias = mask
+ if mask is not None:
+ mask = (mask[:, None] * mask[:, :, None]).unsqueeze(1).expand(-1, self.attn.heads, -1, -1)
+ max_neg_value = -torch.finfo(x.dtype).max
+ attn_bias = (~mask) * max_neg_value
+ x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+ def forward_noattn(self, x):
+ x = x + self.attn.forward_noattn(self.norm1(x))
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+def pix2cam(coords, intr, detach=True):
+ """
+ Args:
+ coords: [B, T, N, 3]
+ intr: [B, T, 3, 3]
+ """
+ if detach:
+ coords = coords.detach()
+
+ (
+ B,
+ S,
+ N,
+ _,
+ ) = coords.shape
+ xy_src = coords.reshape(B * S * N, 3)
+ intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B * S * N, 3, 3)
+ xy_src = torch.cat([xy_src[..., :2], torch.ones_like(xy_src[..., :1])], dim=-1)
+ xyz_src = (torch.inverse(intr) @ xy_src[..., None])[..., 0]
+ dp_pred = coords[..., 2]
+ xyz_src_ = xyz_src * (dp_pred.reshape(S * N, 1))
+ xyz_src_ = xyz_src_.reshape(B, S, N, 3)
+ return xyz_src_
+
+
+def cam2pix(coords, intr):
+ """
+ Args:
+ coords: [B, T, N, 3]
+ intr: [B, T, 3, 3]
+ """
+ coords = coords.detach()
+ (
+ B,
+ S,
+ N,
+ _,
+ ) = coords.shape
+ xy_src = coords.reshape(B * S * N, 3).clone()
+ intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B * S * N, 3, 3)
+ xy_src = xy_src / (xy_src[..., 2:] + 1e-5)
+ xyz_src = (intr @ xy_src[..., None])[..., 0]
+ dp_pred = coords[..., 2]
+ xyz_src[..., 2] *= dp_pred.reshape(S * N)
+ xyz_src = xyz_src.reshape(B, S, N, 3)
+ return xyz_src
+
+
+class BroadMultiHeadAttention(nn.Module):
+ def __init__(self, dim, heads):
+ super(BroadMultiHeadAttention, self).__init__()
+ self.dim = dim
+ self.heads = heads
+ self.scale = (dim / heads) ** -0.5
+ self.attend = nn.Softmax(dim=-1)
+
+ def attend_with_rpe(self, Q, K):
+ Q = rearrange(Q.squeeze(), "i (heads d) -> heads i d", heads=self.heads)
+ K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads)
+
+ dots = einsum("hid, bhjd -> bhij", Q, K) * self.scale # (b hw) heads 1 pointnum
+
+ return self.attend(dots)
+
+ def forward(self, Q, K, V):
+ attn = self.attend_with_rpe(Q, K)
+ B, _, _ = K.shape
+ _, N, _ = Q.shape
+
+ V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
+
+ out = einsum("bhij, bhjd -> bhid", attn, V)
+ out = rearrange(out, "b heads n d -> b n (heads d)", b=B, n=N)
+
+ return out
+
+
+class CrossAttentionLayer(nn.Module):
+ def __init__(
+ self,
+ qk_dim,
+ v_dim,
+ query_token_dim,
+ tgt_token_dim,
+ num_heads=8,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ drop_path=0.0,
+ dropout=0.0,
+ ):
+ super(CrossAttentionLayer, self).__init__()
+ assert qk_dim % num_heads == 0, f"dim {qk_dim} should be divided by num_heads {num_heads}."
+ assert v_dim % num_heads == 0, f"dim {v_dim} should be divided by num_heads {num_heads}."
+ """
+ Query Token: [N, C] -> [N, qk_dim] (Q)
+ Target Token: [M, D] -> [M, qk_dim] (K), [M, v_dim] (V)
+ """
+ self.num_heads = num_heads
+ head_dim = qk_dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.norm1 = nn.LayerNorm(query_token_dim)
+ self.norm2 = nn.LayerNorm(query_token_dim)
+ self.multi_head_attn = BroadMultiHeadAttention(qk_dim, num_heads)
+ self.q, self.k, self.v = (
+ nn.Linear(query_token_dim, qk_dim, bias=True),
+ nn.Linear(tgt_token_dim, qk_dim, bias=True),
+ nn.Linear(tgt_token_dim, v_dim, bias=True),
+ )
+
+ self.proj = nn.Linear(v_dim, query_token_dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.ffn = nn.Sequential(
+ nn.Linear(query_token_dim, query_token_dim),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(query_token_dim, query_token_dim),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, query, tgt_token):
+ """
+ x: [BH1W1, H3W3, D]
+ """
+ short_cut = query
+ query = self.norm1(query)
+
+ q, k, v = self.q(query), self.k(tgt_token), self.v(tgt_token)
+
+ x = self.multi_head_attn(q, k, v)
+
+ x = short_cut + self.proj_drop(self.proj(x))
+
+ x = x + self.drop_path(self.ffn(self.norm2(x)))
+
+ return x
+
+
+class LayerNormProxy(nn.Module):
+ def __init__(self, dim):
+
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+
+ def forward(self, x):
+
+ x = rearrange(x, "b c h w -> b h w c")
+ x = self.norm(x)
+ return rearrange(x, "b h w c -> b c h w")
+
+
+def posenc(x, min_deg, max_deg, legacy_posenc_order=False):
+ """Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
+
+ Instead of computing [sin(x), cos(x)], we use the trig identity
+ cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
+
+ Args:
+ x: torch.Tensor, variables to be encoded. Note that x should be in [-pi, pi].
+ min_deg: int, the minimum (inclusive) degree of the encoding.
+ max_deg: int, the maximum (exclusive) degree of the encoding.
+ legacy_posenc_order: bool, keep the same ordering as the original tf code.
+
+ Returns:
+ encoded: torch.Tensor, encoded variables.
+ """
+ if min_deg == max_deg:
+ return x
+ scales = torch.tensor([2**i for i in range(min_deg, max_deg)], dtype=x.dtype, device=x.device)
+ if legacy_posenc_order:
+ xb = x[..., None, :] * scales[:, None]
+ four_feat = torch.reshape(torch.sin(torch.stack([xb, xb + 0.5 * np.pi], dim=-2)), list(x.shape[:-1]) + [-1])
+ else:
+ xb = torch.reshape((x[..., None, :] * scales[:, None]), list(x.shape[:-1]) + [-1])
+ four_feat = torch.sin(torch.cat([xb, xb + 0.5 * np.pi], dim=-1))
+ return torch.cat([x] + [four_feat], dim=-1)
+
+
+def gaussian2D2(shape, sigma=(1, 1), rho=0):
+ if not isinstance(sigma, tuple):
+ sigma = (sigma, sigma)
+ sigma_x, sigma_y = sigma
+
+ m, n = [(ss - 1.0) / 2.0 for ss in shape]
+ y, x = np.ogrid[-m : m + 1, -n : n + 1]
+
+ energy = (x * x) / (sigma_x * sigma_x) - 2 * rho * x * y / (sigma_x * sigma_y) + (y * y) / (sigma_y * sigma_y)
+ h = np.exp(-energy / (2 * (1 - rho * rho)))
+ h[h < np.finfo(h.dtype).eps * h.max()] = 0
+ return h / h.sum()
diff --git a/models/SpaTrackV2/models/tracker3D/delta_utils/upsample_transformer.py b/models/SpaTrackV2/models/tracker3D/delta_utils/upsample_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b2f0b4ccd8839c0fcd6b5a360b7783c7b88bd85
--- /dev/null
+++ b/models/SpaTrackV2/models/tracker3D/delta_utils/upsample_transformer.py
@@ -0,0 +1,438 @@
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import einsum, rearrange, repeat
+from jaxtyping import Float, Int64
+from torch import Tensor, nn
+
+from models.SpaTrackV2.models.tracker3D.delta_utils.blocks import (
+ Attention,
+ AttnBlock,
+ BasicEncoder,
+ CorrBlock,
+ Mlp,
+ ResidualBlock,
+ Upsample,
+ cam2pix,
+ pix2cam,
+)
+
+from models.SpaTrackV2.models.blocks import bilinear_sampler
+
+def get_grid(height, width, shape=None, dtype="torch", device="cpu", align_corners=True, normalize=True):
+ H, W = height, width
+ S = shape if shape else []
+ if align_corners:
+ x = torch.linspace(0, 1, W, device=device)
+ y = torch.linspace(0, 1, H, device=device)
+ if not normalize:
+ x = x * (W - 1)
+ y = y * (H - 1)
+ else:
+ x = torch.linspace(0.5 / W, 1.0 - 0.5 / W, W, device=device)
+ y = torch.linspace(0.5 / H, 1.0 - 0.5 / H, H, device=device)
+ if not normalize:
+ x = x * W
+ y = y * H
+ x_view, y_view, exp = [1 for _ in S] + [1, -1], [1 for _ in S] + [-1, 1], S + [H, W]
+ x = x.view(*x_view).expand(*exp)
+ y = y.view(*y_view).expand(*exp)
+ grid = torch.stack([x, y], dim=-1)
+ if dtype == "numpy":
+ grid = grid.numpy()
+ return grid
+
+class RelativeAttention(nn.Module):
+ """Multi-headed attention (MHA) module."""
+
+ def __init__(self, query_dim, num_heads=8, qkv_bias=True, model_size=None, flash=False):
+ super(RelativeAttention, self).__init__()
+
+ query_dim = query_dim // num_heads
+ self.num_heads = num_heads
+ self.query_dim = query_dim
+ self.value_size = query_dim
+ self.model_size = query_dim * num_heads
+
+ self.qkv_bias = qkv_bias
+
+ self.flash = flash
+
+ self.query_proj = nn.Linear(num_heads * query_dim, num_heads * query_dim, bias=qkv_bias)
+ self.key_proj = nn.Linear(num_heads * query_dim, num_heads * query_dim, bias=qkv_bias)
+ self.value_proj = nn.Linear(num_heads * self.value_size, num_heads * self.value_size, bias=qkv_bias)
+ self.final_proj = nn.Linear(num_heads * self.value_size, self.model_size, bias=qkv_bias)
+
+ self.scale = 1.0 / math.sqrt(self.query_dim)
+ # self.training_length = 24
+
+ # bias_forward = get_alibi_slope(self.num_heads // 2) * get_relative_positions(self.training_length)
+ # bias_forward = bias_forward + torch.triu(torch.full_like(bias_forward, -1e9), diagonal=1)
+ # bias_backward = get_alibi_slope(self.num_heads // 2) * get_relative_positions(self.training_length, reverse=True)
+ # bias_backward = bias_backward + torch.tril(torch.full_like(bias_backward, -1e9), diagonal=-1)
+
+ # self.register_buffer("precomputed_attn_bias", torch.cat([bias_forward, bias_backward], dim=0), persistent=False)
+
+ def forward(self, x, context, attn_bias=None):
+ B, N1, C = x.size()
+
+ q = self._linear_projection(x, self.query_dim, self.query_proj) # [T', H, Q=K]
+ k = self._linear_projection(context, self.query_dim, self.key_proj) # [T, H, K]
+ v = self._linear_projection(context, self.value_size, self.value_proj) # [T, H, V]
+
+ if self.flash:
+ with torch.autocast(device_type="cuda", enabled=True):
+ x = flash_attn_func(q.half(), k.half(), v.half())
+ x = x.reshape(B, N1, C)
+ x = x.float()
+ else:
+ q = q.permute(0, 2, 1, 3)
+ k = k.permute(0, 2, 1, 3)
+ v = v.permute(0, 2, 1, 3)
+
+ sim = (q @ k.transpose(-2, -1)) * self.scale
+
+ if attn_bias is not None:
+ sim = sim + attn_bias
+ attn = sim.softmax(dim=-1)
+
+ x = attn @ v
+ x = x.transpose(1, 2).reshape(B, N1, C)
+
+ # with torch.autocast(device_type="cuda", dtype=torch.float32):
+ # attn = F.scaled_dot_product_attention(query_heads, key_heads, value_heads, attn_mask=attn_bias, scale=1.0 / math.sqrt(self.query_dim))
+ # else:
+
+ # sim = (query_heads @ key_heads.transpose(-2, -1)) * self.scale
+
+ # if attn_bias is not None:
+ # sim = sim + attn_bias
+ # attn = sim.softmax(dim=-1)
+
+ # attn = (attn @ value_heads)
+ # attn = attn.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, -1)
+
+ return self.final_proj(x) # [T', D']
+
+ def _linear_projection(self, x, head_size, proj_layer):
+ batch_size, sequence_length, _ = x.shape
+ y = proj_layer(x)
+ y = y.reshape((batch_size, sequence_length, self.num_heads, head_size))
+
+ return y
+
+
+class UpsampleCrossAttnBlock(nn.Module):
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.norm_context = nn.LayerNorm(hidden_size)
+ self.cross_attn = RelativeAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(
+ in_features=hidden_size,
+ hidden_features=mlp_hidden_dim,
+ act_layer=approx_gelu,
+ drop=0,
+ )
+
+ def forward(self, x, context, attn_bias=None):
+ x = x + self.cross_attn(x=self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias)
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class DecoderUpsampler(nn.Module):
+ def __init__(self, in_channels: int, middle_channels: int, out_channels: int = None, stride: int = 4):
+ super().__init__()
+
+ self.stride = stride
+
+ if out_channels is None:
+ out_channels = middle_channels
+
+ self.conv_in = nn.Conv2d(in_channels, middle_channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
+ self.norm1 = nn.GroupNorm(num_groups=middle_channels // 8, num_channels=middle_channels, eps=1e-6)
+
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+
+ for i in range(int(math.log2(self.stride))):
+ self.res_blocks.append(ResidualBlock(middle_channels, middle_channels))
+ self.upsample_blocks.append(Upsample(middle_channels, with_conv=True))
+
+ # in_channels = middle_channels
+
+ self.norm2 = nn.GroupNorm(num_groups=middle_channels // 8, num_channels=middle_channels, eps=1e-6)
+ self.conv_out = nn.Conv2d(middle_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
+
+ self.initialize_weight()
+
+ def initialize_weight(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Conv2d):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.res_blocks.apply(_basic_init)
+ self.conv_in.apply(_basic_init)
+ self.conv_out.apply(_basic_init)
+
+ def forward(
+ self,
+ x: Float[Tensor, "b c1 h_down w_down"],
+ mode: str = "nearest",
+ ) -> Float[Tensor, "b c1 h_up w_up"]:
+
+ x = F.relu(self.norm1(self.conv_in(x)))
+
+ for i in range(len(self.res_blocks)):
+ x = self.res_blocks[i](x)
+ x = self.upsample_blocks[i](x, mode=mode)
+
+ x = self.conv_out(F.relu(self.norm2(x)))
+ return x
+
+
+class UpsampleTransformer(nn.Module):
+ def __init__(
+ self,
+ kernel_size: int = 3,
+ stride: int = 4,
+ latent_dim: int = 128,
+ n_heads: int = 4,
+ num_attn_blocks: int = 2,
+ use_rel_emb: bool = True,
+ flash: bool = False,
+ ):
+ super().__init__()
+
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.latent_dim = latent_dim
+
+ self.n_heads = n_heads
+
+ self.attnup_feat_cnn = DecoderUpsampler(
+ in_channels=self.latent_dim, middle_channels=self.latent_dim, out_channels=self.latent_dim
+ )
+
+ self.cross_blocks = nn.ModuleList(
+ [
+ UpsampleCrossAttnBlock(latent_dim + 64, latent_dim + 64, num_heads=n_heads, mlp_ratio=4, flash=flash)
+ for _ in range(num_attn_blocks)
+ ]
+ )
+
+ self.flow_mlp = nn.Sequential(
+ nn.Conv2d(2 * 16, 128, 7, padding=3),
+ nn.ReLU(),
+ nn.Conv2d(128, 64, 3, padding=1),
+ nn.ReLU(),
+ )
+
+ self.out = nn.Linear(latent_dim + 64, kernel_size * kernel_size, bias=True)
+
+ if use_rel_emb:
+ self.rpb_attnup = nn.Parameter(torch.zeros(kernel_size * kernel_size))
+ torch.nn.init.trunc_normal_(self.rpb_attnup, std=0.1, mean=0.0, a=-2.0, b=2.0)
+ else:
+ self.rpb_attnup = None
+
+ def forward(
+ self,
+ feat_map: Float[Tensor, "b c1 h w"],
+ flow_map: Float[Tensor, "b c2 h w"],
+ ):
+ B = feat_map.shape[0]
+ H_down, W_down = feat_map.shape[-2:]
+ # x0, y0 = x0y0
+
+ feat_map_up = self.attnup_feat_cnn(feat_map) # learnable upsample by 4
+ # feat_map_down = F.interpolate(feat_map_up, scale_factor=1/self.stride, mode='nearest') # B C H*4 W*4
+ feat_map_down = feat_map
+ # depths_down = F.interpolate(depths, scale_factor=1/self.stride, mode='nearest')
+
+ # NOTE prepare attention bias
+ # depths_down_ = torch.stack([depths_down[b, :, y0_:y0_+H_down, x0_:x0_+W_down] for b, (x0_,y0_) in enumerate(zip(x0, y0))], dim=0)
+ # depths_ = torch.stack([depths[b, :, y0_*4:y0_*4+H_down*4, x0_*4:x0_*4+W_down*4] for b, (x0_,y0_) in enumerate(zip(x0, y0))], dim=0)
+ # guidance_downsample = F.interpolate(guidance, size=(H, W), mode='nearest')
+ pad_val = (self.kernel_size - 1) // 2
+ # depths_down_padded = F.pad(depths_down_, (pad_val, pad_val, pad_val, pad_val), "replicate")
+
+ if self.rpb_attnup is not None:
+ relative_pos_attn_map = self.rpb_attnup.view(1, 1, -1, 1, 1).repeat(
+ B, self.n_heads, 1, H_down * 4, W_down * 4
+ )
+ relative_pos_attn_map = rearrange(relative_pos_attn_map, "b k n h w -> (b h w) k 1 n")
+ attn_bias = relative_pos_attn_map
+ else:
+ attn_bias = None
+
+ # NOTE prepare context (low-reso feat)
+ context = feat_map_down
+ context = F.unfold(context, kernel_size=self.kernel_size, padding=pad_val) # B C*kernel**2 H W
+ context = rearrange(context, "b c (h w) -> b c h w", h=H_down, w=W_down)
+ context = F.interpolate(context, scale_factor=self.stride, mode="nearest") # B C*kernel**2 H*4 W*4
+ context = rearrange(context, "b (c i j) h w -> (b h w) (i j) c", i=self.kernel_size, j=self.kernel_size)
+
+ # NOTE prepare queries (high-reso feat)
+ x = feat_map_up
+ x = rearrange(x, "b c h w -> (b h w) 1 c")
+
+ assert flow_map.shape[-2:] == feat_map.shape[-2:]
+
+ flow_map = rearrange(flow_map, "b t c h w -> b (t c) h w")
+ flow_map = self.flow_mlp(flow_map)
+
+ nn_flow_map = F.unfold(flow_map, kernel_size=self.kernel_size, padding=pad_val) # B C*kernel**2 H W
+ nn_flow_map = rearrange(nn_flow_map, "b c (h w) -> b c h w", h=H_down, w=W_down)
+ nn_flow_map = F.interpolate(nn_flow_map, scale_factor=self.stride, mode="nearest") # B C*kernel**2 H*4 W*4
+ nn_flow_map = rearrange(
+ nn_flow_map, "b (c i j) h w -> (b h w) (i j) c", i=self.kernel_size, j=self.kernel_size
+ )
+
+ up_flow_map = F.interpolate(flow_map, scale_factor=4, mode="nearest") # NN up # b 2 h w
+ up_flow_map = rearrange(up_flow_map, "b c h w -> (b h w) 1 c")
+
+ context = torch.cat([context, nn_flow_map], dim=-1)
+ x = torch.cat([x, up_flow_map], dim=-1)
+
+ for lvl in range(len(self.cross_blocks)):
+ x = self.cross_blocks[lvl](x, context, attn_bias)
+
+ mask_out = self.out(x)
+ mask_out = F.softmax(mask_out, dim=-1)
+ mask_out = rearrange(mask_out, "(b h w) 1 c -> b c h w", h=H_down * self.stride, w=W_down * self.stride)
+
+ return mask_out
+
+
+def get_alibi_slope(num_heads):
+ x = (24) ** (1 / num_heads)
+ return torch.tensor([1 / x ** (i + 1) for i in range(num_heads)]).float()
+
+
+class UpsampleTransformerAlibi(nn.Module):
+ def __init__(
+ self,
+ kernel_size: int = 3,
+ stride: int = 4,
+ latent_dim: int = 128,
+ n_heads: int = 4,
+ num_attn_blocks: int = 2,
+ upsample_factor: int = 4,
+ ):
+ super().__init__()
+
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.latent_dim = latent_dim
+ self.upsample_factor = upsample_factor
+
+ self.n_heads = n_heads
+
+ self.attnup_feat_cnn = DecoderUpsampler(
+ in_channels=self.latent_dim,
+ middle_channels=self.latent_dim,
+ out_channels=self.latent_dim,
+ # stride=self.upsample_factor
+ )
+
+ self.cross_blocks = nn.ModuleList(
+ [
+ UpsampleCrossAttnBlock(
+ latent_dim+64,
+ latent_dim+64,
+ num_heads=n_heads,
+ mlp_ratio=4,
+ flash=False
+ )
+ for _ in range(num_attn_blocks)
+ ]
+ )
+
+ self.flow_mlp = nn.Sequential(
+ nn.Conv2d(3*32, 128, 7, padding=3),
+ nn.ReLU(),
+ nn.Conv2d(128, 64, 3, padding=1),
+ nn.ReLU(),
+ )
+
+ self.out = nn.Linear(latent_dim+64, kernel_size*kernel_size, bias=True)
+
+
+ alibi_slope = get_alibi_slope(n_heads // 2)
+ grid_kernel = get_grid(kernel_size, kernel_size, normalize=False).reshape(kernel_size, kernel_size, 2)
+ grid_kernel = grid_kernel - (kernel_size - 1) / 2
+ grid_kernel = -torch.abs(grid_kernel)
+ alibi_bias = torch.cat([
+ alibi_slope.view(-1,1,1) * grid_kernel[..., 0].view(1,kernel_size,kernel_size),
+ alibi_slope.view(-1,1,1) * grid_kernel[..., 1].view(1,kernel_size,kernel_size)
+ ]) # n_heads, kernel_size, kernel_size
+
+ self.register_buffer("alibi_bias", alibi_bias)
+
+
+ def forward(
+ self,
+ feat_map: Float[Tensor, "b c1 h w"],
+ flow_map: Float[Tensor, "b c2 h w"],
+ ):
+ B = feat_map.shape[0]
+ H_down, W_down = feat_map.shape[-2:]
+
+ feat_map_up = self.attnup_feat_cnn(feat_map) # learnable upsample by 4
+ if self.upsample_factor != 4:
+ additional_scale = float(self.upsample_factor / 4)
+ if additional_scale > 1:
+ feat_map_up = F.interpolate(feat_map_up, scale_factor=additional_scale, mode='bilinear', align_corners=False)
+ else:
+ feat_map_up = F.interpolate(feat_map_up, scale_factor=additional_scale, mode='nearest')
+
+ feat_map_down = feat_map
+
+ pad_val = (self.kernel_size - 1) // 2
+
+ attn_bias = self.alibi_bias.view(1,self.n_heads,self.kernel_size**2,1,1).repeat(B,1,1,H_down*self.upsample_factor,W_down*self.upsample_factor)
+ attn_bias = rearrange(attn_bias, "b k n h w -> (b h w) k 1 n")
+
+ # NOTE prepare context (low-reso feat)
+ context = feat_map_down
+ context = F.unfold(context, kernel_size=self.kernel_size, padding=pad_val) # B C*kernel**2 H W
+ context = rearrange(context, 'b c (h w) -> b c h w', h=H_down, w=W_down)
+ context = F.interpolate(context, scale_factor=self.upsample_factor, mode='nearest') # B C*kernel**2 H*4 W*4
+ context = rearrange(context, 'b (c i j) h w -> (b h w) (i j) c', i=self.kernel_size, j=self.kernel_size)
+
+ # NOTE prepare queries (high-reso feat)
+ x = feat_map_up
+ x = rearrange(x, 'b c h w -> (b h w) 1 c')
+
+ assert flow_map.shape[-2:] == feat_map.shape[-2:]
+
+ flow_map = rearrange(flow_map, 'b t c h w -> b (t c) h w')
+ flow_map = self.flow_mlp(flow_map)
+
+ nn_flow_map = F.unfold(flow_map, kernel_size=self.kernel_size, padding=pad_val) # B C*kernel**2 H W
+ nn_flow_map = rearrange(nn_flow_map, 'b c (h w) -> b c h w', h=H_down, w=W_down)
+ nn_flow_map = F.interpolate(nn_flow_map, scale_factor=self.upsample_factor, mode='nearest') # B C*kernel**2 H*4 W*4
+ nn_flow_map = rearrange(nn_flow_map, 'b (c i j) h w -> (b h w) (i j) c', i=self.kernel_size, j=self.kernel_size)
+ up_flow_map = F.interpolate(flow_map, scale_factor=self.upsample_factor, mode="nearest") # NN up # b 2 h w
+ up_flow_map = rearrange(up_flow_map, 'b c h w -> (b h w) 1 c')
+ context = torch.cat([context, nn_flow_map], dim=-1)
+ x = torch.cat([x, up_flow_map], dim=-1)
+ for lvl in range(len(self.cross_blocks)):
+ x = self.cross_blocks[lvl](x, context, attn_bias)
+
+ mask_out = self.out(x)
+ mask_out = F.softmax(mask_out, dim=-1)
+ mask_out = rearrange(mask_out, '(b h w) 1 c -> b c h w', h=H_down*self.upsample_factor, w=W_down*self.upsample_factor)
+
+ return mask_out
\ No newline at end of file
diff --git a/models/SpaTrackV2/models/tracker3D/spatrack_modules/alignment.py b/models/SpaTrackV2/models/tracker3D/spatrack_modules/alignment.py
new file mode 100644
index 0000000000000000000000000000000000000000..42a9f24320b066793d3aa3f7e0a4cf3483bed4ae
--- /dev/null
+++ b/models/SpaTrackV2/models/tracker3D/spatrack_modules/alignment.py
@@ -0,0 +1,471 @@
+from typing import *
+import math
+from collections import namedtuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.types
+import utils3d
+
+from models.SpaTrackV2.models.tracker3D.spatrack_modules.geometry_torch import (
+ weighted_mean,
+ harmonic_mean,
+ geometric_mean,
+ mask_aware_nearest_resize,
+ normalized_view_plane_uv,
+ angle_diff_vec3
+)
+
+def scatter_min(size: int, dim: int, index: torch.LongTensor, src: torch.Tensor) -> torch.return_types.min:
+ "Scatter the minimum value along the given dimension of `input` into `src` at the indices specified in `index`."
+ shape = src.shape[:dim] + (size,) + src.shape[dim + 1:]
+ minimum = torch.full(shape, float('inf'), dtype=src.dtype, device=src.device).scatter_reduce(dim=dim, index=index, src=src, reduce='amin', include_self=False)
+ minimum_where = torch.where(src == torch.gather(minimum, dim=dim, index=index))
+ indices = torch.full(shape, -1, dtype=torch.long, device=src.device)
+ indices[(*minimum_where[:dim], index[minimum_where], *minimum_where[dim + 1:])] = minimum_where[dim]
+ return torch.return_types.min((minimum, indices))
+
+
+def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs):
+ batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0]
+ n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0)
+ splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args)
+ splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()}
+ results = []
+ for i in range(n_chunks):
+ chunk_args = tuple(arg[i] for arg in splited_args)
+ chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()}
+ results.append(fn(*chunk_args, **chunk_kwargs))
+
+ if isinstance(results[0], tuple):
+ return tuple(torch.cat(r, dim=0) for r in zip(*results))
+ else:
+ return torch.cat(results, dim=0)
+
+
+def _pad_inf(x_: torch.Tensor):
+ return torch.cat([torch.full_like(x_[..., :1], -torch.inf), x_, torch.full_like(x_[..., :1], torch.inf)], dim=-1)
+
+
+def _pad_cumsum(cumsum: torch.Tensor):
+ return torch.cat([torch.zeros_like(cumsum[..., :1]), cumsum, cumsum[..., -1:]], dim=-1)
+
+
+def _compute_residual(a: torch.Tensor, xyw: torch.Tensor, trunc: float):
+ return a.mul(xyw[..., 0]).sub_(xyw[..., 1]).abs_().mul_(xyw[..., 2]).clamp_max_(trunc).sum(dim=-1)
+
+
+def align(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, trunc: Optional[Union[float, torch.Tensor]] = None, eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
+ """
+ If trunc is None, solve `min sum_i w_i * |a * x_i - y_i|`, otherwise solve `min sum_i min(trunc, w_i * |a * x_i - y_i|)`.
+
+ w_i must be >= 0.
+
+ ### Parameters:
+ - `x`: tensor of shape (..., n)
+ - `y`: tensor of shape (..., n)
+ - `w`: tensor of shape (..., n)
+ - `trunc`: optional, float or tensor of shape (..., n) or None
+
+ ### Returns:
+ - `a`: tensor of shape (...), differentiable
+ - `loss`: tensor of shape (...), value of loss function at `a`, detached
+ - `index`: tensor of shape (...), where a = y[idx] / x[idx]
+ """
+ if trunc is None:
+ x, y, w = torch.broadcast_tensors(x, y, w)
+ sign = torch.sign(x)
+ x, y = x * sign, y * sign
+ y_div_x = y / x.clamp_min(eps)
+ y_div_x, argsort = y_div_x.sort(dim=-1)
+
+ wx = torch.gather(x * w, dim=-1, index=argsort)
+ derivatives = 2 * wx.cumsum(dim=-1) - wx.sum(dim=-1, keepdim=True)
+ search = torch.searchsorted(derivatives, torch.zeros_like(derivatives[..., :1]), side='left').clamp_max(derivatives.shape[-1] - 1)
+
+ a = y_div_x.gather(dim=-1, index=search).squeeze(-1)
+ index = argsort.gather(dim=-1, index=search).squeeze(-1)
+ loss = (w * (a[..., None] * x - y).abs()).sum(dim=-1)
+
+ else:
+ # Reshape to (batch_size, n) for simplicity
+ x, y, w = torch.broadcast_tensors(x, y, w)
+ batch_shape = x.shape[:-1]
+ batch_size = math.prod(batch_shape)
+ x, y, w = x.reshape(-1, x.shape[-1]), y.reshape(-1, y.shape[-1]), w.reshape(-1, w.shape[-1])
+
+ sign = torch.sign(x)
+ x, y = x * sign, y * sign
+ wx, wy = w * x, w * y
+ xyw = torch.stack([x, y, w], dim=-1) # Stacked for convenient gathering
+
+ y_div_x = A = y / x.clamp_min(eps)
+ B = (wy - trunc) / wx.clamp_min(eps)
+ C = (wy + trunc) / wx.clamp_min(eps)
+ with torch.no_grad():
+ # Caculate prefix sum by orders of A, B, C
+ A, A_argsort = A.sort(dim=-1)
+ Q_A = torch.cumsum(torch.gather(wx, dim=-1, index=A_argsort), dim=-1)
+ A, Q_A = _pad_inf(A), _pad_cumsum(Q_A) # Pad [-inf, A1, ..., An, inf] and [0, Q1, ..., Qn, Qn] to handle edge cases.
+
+ B, B_argsort = B.sort(dim=-1)
+ Q_B = torch.cumsum(torch.gather(wx, dim=-1, index=B_argsort), dim=-1)
+ B, Q_B = _pad_inf(B), _pad_cumsum(Q_B)
+
+ C, C_argsort = C.sort(dim=-1)
+ Q_C = torch.cumsum(torch.gather(wx, dim=-1, index=C_argsort), dim=-1)
+ C, Q_C = _pad_inf(C), _pad_cumsum(Q_C)
+
+ # Caculate left and right derivative of A
+ j_A = torch.searchsorted(A, y_div_x, side='left').sub_(1)
+ j_B = torch.searchsorted(B, y_div_x, side='left').sub_(1)
+ j_C = torch.searchsorted(C, y_div_x, side='left').sub_(1)
+ left_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
+ j_A = torch.searchsorted(A, y_div_x, side='right').sub_(1)
+ j_B = torch.searchsorted(B, y_div_x, side='right').sub_(1)
+ j_C = torch.searchsorted(C, y_div_x, side='right').sub_(1)
+ right_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
+
+ # Find extrema
+ is_extrema = (left_derivative < 0) & (right_derivative >= 0)
+ is_extrema[..., 0] |= ~is_extrema.any(dim=-1) # In case all derivatives are zero, take the first one as extrema.
+ where_extrema_batch, where_extrema_index = torch.where(is_extrema)
+
+ # Calculate objective value at extrema
+ extrema_a = y_div_x[where_extrema_batch, where_extrema_index] # (num_extrema,)
+ MAX_ELEMENTS = 4096 ** 2 # Split into small batches to avoid OOM in case there are too many extrema.(~1G)
+ SPLIT_SIZE = MAX_ELEMENTS // x.shape[-1]
+ extrema_value = torch.cat([
+ _compute_residual(extrema_a_split[:, None], xyw[extrema_i_split, :, :], trunc)
+ for extrema_a_split, extrema_i_split in zip(extrema_a.split(SPLIT_SIZE), where_extrema_batch.split(SPLIT_SIZE))
+ ]) # (num_extrema,)
+
+ # Find minima among corresponding extrema
+ minima, indices = scatter_min(size=batch_size, dim=0, index=where_extrema_batch, src=extrema_value) # (batch_size,)
+ index = where_extrema_index[indices]
+
+ a = torch.gather(y, dim=-1, index=index[..., None]) / torch.gather(x, dim=-1, index=index[..., None]).clamp_min(eps)
+ a = a.reshape(batch_shape)
+ loss = minima.reshape(batch_shape)
+ index = index.reshape(batch_shape)
+
+ return a, loss, index
+
+
+def align_depth_scale(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
+ """
+ Align `depth_src` to `depth_tgt` with given constant weights.
+
+ ### Parameters:
+ - `depth_src: torch.Tensor` of shape (..., N)
+ - `depth_tgt: torch.Tensor` of shape (..., N)
+
+ """
+ scale, _, _ = align(depth_src, depth_tgt, weight, trunc)
+
+ return scale
+
+
+def align_depth_affine(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
+ """
+ Align `depth_src` to `depth_tgt` with given constant weights.
+
+ ### Parameters:
+ - `depth_src: torch.Tensor` of shape (..., N)
+ - `depth_tgt: torch.Tensor` of shape (..., N)
+ - `weight: torch.Tensor` of shape (..., N)
+ - `trunc: float` or tensor of shape (..., N) or None
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (...).
+ """
+ dtype, device = depth_src.dtype, depth_src.device
+
+ # Flatten batch dimensions for simplicity
+ batch_shape, n = depth_src.shape[:-1], depth_src.shape[-1]
+ batch_size = math.prod(batch_shape)
+ depth_src, depth_tgt, weight = depth_src.reshape(batch_size, n), depth_tgt.reshape(batch_size, n), weight.reshape(batch_size, n)
+
+ # Here, we take anchors only for non-zero weights.
+ # Although the results will be still correct even anchor points have zero weight,
+ # it is wasting computation and may cause instability in some cases, e.g. too many extrema.
+ anchors_where_batch, anchors_where_n = torch.where(weight > 0)
+
+ # Stop gradient when solving optimal anchors
+ with torch.no_grad():
+ depth_src_anchor = depth_src[anchors_where_batch, anchors_where_n] # (anchors)
+ depth_tgt_anchor = depth_tgt[anchors_where_batch, anchors_where_n] # (anchors)
+
+ depth_src_anchored = depth_src[anchors_where_batch, :] - depth_src_anchor[..., None] # (anchors, n)
+ depth_tgt_anchored = depth_tgt[anchors_where_batch, :] - depth_tgt_anchor[..., None] # (anchors, n)
+ weight_anchored = weight[anchors_where_batch, :] # (anchors, n)
+
+ scale, loss, index = align(depth_src_anchored, depth_tgt_anchored, weight_anchored, trunc) # (anchors)
+
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchors_where_batch, src=loss) # (batch_size,)
+
+ # Reproduce by indexing for shorter compute graph
+ index_1 = anchors_where_n[index_anchor] # (batch_size,)
+ index_2 = index[index_anchor] # (batch_size,)
+
+ tgt_1, src_1 = torch.gather(depth_tgt, dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_1[..., None]).squeeze(-1)
+ tgt_2, src_2 = torch.gather(depth_tgt, dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_2[..., None]).squeeze(-1)
+
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1e-7)
+ shift = tgt_1 - scale * src_1
+
+ scale, shift = scale.reshape(batch_shape), shift.reshape(batch_shape)
+
+ return scale, shift
+
+def align_depth_affine_irls(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], max_iter: int = 100, eps: float = 1e-12):
+ """
+ Align `depth_src` to `depth_tgt` with given constant weights using IRLS.
+ """
+ dtype, device = depth_src.dtype, depth_src.device
+
+ w = weight
+ x = torch.stack([depth_src, torch.ones_like(depth_src)], dim=-1)
+ y = depth_tgt
+
+ for i in range(max_iter):
+ beta = (x.transpose(-1, -2) @ (w * y)) @ (x.transpose(-1, -2) @ (w[..., None] * x)).inverse().transpose(-2, -1)
+ w = 1 / (y - (x @ beta[..., None])[..., 0]).abs().clamp_min(eps)
+
+ return beta[..., 0], beta[..., 1]
+
+
+def align_points_scale(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
+ """
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weight: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `a: torch.Tensor` of shape (...). Only positive solutions are garunteed. You should filter out negative scales before using it.
+ - `b: torch.Tensor` of shape (...)
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ scale, _, _ = align(points_src.flatten(-2), points_tgt.flatten(-2), weight[..., None].expand_as(points_src).flatten(-2), trunc)
+
+ return scale
+
+
+def align_points_scale_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
+ """
+ Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
+ It is similar to `align_affine` but scale and shift are applied to different dimensions.
+
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weights: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (..., 3). x and y shifts are zeros.
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ # Flatten batch dimensions for simplicity
+ batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
+ batch_size = math.prod(batch_shape)
+ points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
+
+ # Take anchors
+ anchor_where_batch, anchor_where_n = torch.where(weight > 0)
+ with torch.no_grad():
+ zeros = torch.zeros(anchor_where_batch.shape[0], device=device, dtype=dtype)
+ points_src_anchor = torch.stack([zeros, zeros, points_src[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
+ points_tgt_anchor = torch.stack([zeros, zeros, points_tgt[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
+
+ points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
+ points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
+ weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
+
+ # Solve optimal scale and shift for each anchor
+ MAX_ELEMENTS = 2 ** 20
+ scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // n, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
+
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
+
+ # Reproduce by indexing for shorter compute graph
+ index_2 = index[index_anchor] # (batch_size,) [0, 3n)
+ index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
+
+ zeros = torch.zeros((batch_size, n), device=device, dtype=dtype)
+ points_tgt_00z, points_src_00z = torch.stack([zeros, zeros, points_tgt[..., 2]], dim=-1), torch.stack([zeros, zeros, points_src[..., 2]], dim=-1)
+ tgt_1, src_1 = torch.gather(points_tgt_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_src_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
+ tgt_2, src_2 = torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
+
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
+ shift = torch.gather(points_tgt_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
+ scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
+
+ return scale, shift
+
+
+def align_points_scale_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
+ """
+ Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
+ It is similar to `align_affine` but scale and shift are applied to different dimensions.
+
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weights: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (..., 3)
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ # Flatten batch dimensions for simplicity
+ batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
+ batch_size = math.prod(batch_shape)
+ points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
+
+ # Take anchors
+ anchor_where_batch, anchor_where_n = torch.where(weight > 0)
+
+ with torch.no_grad():
+ points_src_anchor = points_src[anchor_where_batch, anchor_where_n] # (anchors, 3)
+ points_tgt_anchor = points_tgt[anchor_where_batch, anchor_where_n] # (anchors, 3)
+
+ points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
+ points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
+ weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
+
+ # Solve optimal scale and shift for each anchor
+ MAX_ELEMENTS = 2 ** 20
+ scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // 2, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
+
+ # Get optimal scale and shift for each batch element
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
+
+ index_2 = index[index_anchor] # (batch_size,) [0, 3n)
+ index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
+
+ src_1, tgt_1 = torch.gather(points_src.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
+ src_2, tgt_2 = torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
+
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
+ shift = torch.gather(points_tgt, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
+
+ scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
+
+ return scale, shift
+
+
+def align_points_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
+ """
+ Align `points_src` to `points_tgt` with respect to a Z-axis shift.
+
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weights: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (..., 3)
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ shift, _, _ = align(torch.ones_like(points_src[..., 2]), points_tgt[..., 2] - points_src[..., 2], weight, trunc)
+ shift = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)
+
+ return shift
+
+
+def align_points_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
+ """
+ Align `points_src` to `points_tgt` with respect to a Z-axis shift.
+
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weights: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (..., 3)
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ shift, _, _ = align(torch.ones_like(points_src).swapaxes(-2, -1), (points_tgt - points_src).swapaxes(-2, -1), weight[..., None, :], trunc)
+
+ return shift
+
+
+def align_affine_lstsq(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Solve `min sum_i w_i * (a * x_i + b - y_i ) ^ 2`, where `a` and `b` are scalars, with respect to `a` and `b` using least squares.
+
+ ### Parameters:
+ - `x: torch.Tensor` of shape (..., N)
+ - `y: torch.Tensor` of shape (..., N)
+ - `w: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `a: torch.Tensor` of shape (...,)
+ - `b: torch.Tensor` of shape (...,)
+ """
+ w_sqrt = torch.ones_like(x) if w is None else w.sqrt()
+ A = torch.stack([w_sqrt * x, torch.ones_like(x)], dim=-1)
+ B = (w_sqrt * y)[..., None]
+ a, b = torch.linalg.lstsq(A, B)[0].squeeze(-1).unbind(-1)
+ return a, b
+
+def _smooth(err: torch.FloatTensor, beta: float = 0.0) -> torch.FloatTensor:
+ if beta == 0:
+ return err
+ else:
+ return torch.where(err < beta, 0.5 * err.square() / beta, err - 0.5 * beta)
+
+def affine_invariant_global_loss(
+ pred_points: torch.Tensor,
+ gt_points: torch.Tensor,
+ mask: torch.Tensor,
+ align_resolution: int = 64,
+ beta: float = 0.0,
+ trunc: float = 1.0,
+ sparsity_aware: bool = False,
+ only_align: bool = False
+):
+ device = pred_points.device
+
+ # Align
+ (pred_points_lr, gt_points_lr), lr_mask = mask_aware_nearest_resize((pred_points, gt_points), mask=mask, size=(align_resolution, align_resolution))
+ scale, shift = align_points_scale_z_shift(pred_points_lr.flatten(-3, -2), gt_points_lr.flatten(-3, -2), lr_mask.flatten(-2, -1) / gt_points_lr[..., 2].flatten(-2, -1).clamp_min(1e-2), trunc=trunc)
+ valid = scale > 0
+ scale, shift = torch.where(valid, scale, 0), torch.where(valid[..., None], shift, 0)
+
+ pred_points = scale[..., None, None, None] * pred_points + shift[..., None, None, :]
+ if only_align:
+ return pred_points, scale, shift
+ # Compute loss
+ weight = (valid[..., None, None] & mask).float() / gt_points[..., 2].clamp_min(1e-5)
+ weight = weight.clamp_max(10.0 * weighted_mean(weight, mask, dim=(-2, -1), keepdim=True)) # In case your data contains extremely small depth values
+ loss = _smooth((pred_points - gt_points).abs() * weight[..., None], beta=beta).mean(dim=(-3, -2, -1))
+
+ if sparsity_aware:
+ # Reweighting improves performance on sparse depth data. NOTE: this is not used in MoGe-1.
+ sparsity = mask.float().mean(dim=(-2, -1)) / lr_mask.float().mean(dim=(-2, -1))
+ loss = loss / (sparsity + 1e-7)
+
+ err = (pred_points.detach() - gt_points).norm(dim=-1) / gt_points[..., 2]
+
+ # Record any scalar metric
+ misc = {
+ 'truncated_error': weighted_mean(err.clamp_max(1.0), mask).item(),
+ 'delta': weighted_mean((err < 1).float(), mask).item()
+ }
+
+ return loss, misc, scale.detach(), shift.detach()
\ No newline at end of file
diff --git a/models/SpaTrackV2/models/tracker3D/spatrack_modules/ba.py b/models/SpaTrackV2/models/tracker3D/spatrack_modules/ba.py
new file mode 100644
index 0000000000000000000000000000000000000000..503ce93dff57f7a19b07884f3042f2502b454e32
--- /dev/null
+++ b/models/SpaTrackV2/models/tracker3D/spatrack_modules/ba.py
@@ -0,0 +1,538 @@
+import pycolmap
+import torch
+import numpy as np
+import pyceres
+from pyceres import SolverOptions, LinearSolverType, PreconditionerType, TrustRegionStrategyType, LoggingType
+import logging
+from scipy.spatial.transform import Rotation as R
+
+# config logging and make sure it print to the console
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+
+def extract_static_from_3DTracks(world_tracks, dyn_prob,
+ query_3d_pts, vis_est, tracks2d, img_size, K=100, maintain_invisb=False):
+ """
+ world_tracks: B T N 3 this is the coarse 3d tracks in world coordinate (coarse 3d tracks)
+ dyn_prob: B T N this is the dynamic probability of the 3d tracks
+ query_3d_pts: B T N 3 this is the query 3d points in world coordinate (coarse by camera pose)
+ vis_est: B T N this is the visibility of the 3d tracks
+ tracks2d: B T N 2 this is the 2d tracks
+ K: int top K static points
+ """
+ B, T, N, _ = world_tracks.shape
+ static_msk = (dyn_prob<0.5).bool()
+ world_tracks_static = world_tracks[:,:,static_msk.squeeze(),:]
+ query_3d_pts_static = query_3d_pts[:,static_msk.squeeze(),:]
+ if maintain_invisb:
+ vis = (tracks2d[...,0] > 0).bool() * (tracks2d[...,1] > 0).bool()
+ vis_mask = vis * (img_size[1] > tracks2d[...,0]) * (img_size[0] > tracks2d[...,1])
+ vis_mask = vis_mask[:,:,static_msk.squeeze()]
+ else:
+ vis_mask = (vis_est>0.5).bool()[:,:,static_msk.squeeze()]
+ tracks2d_static = tracks2d[:,:,static_msk.squeeze(),:]
+ world_tracks_static = (world_tracks_static*vis_mask[...,None]).sum(dim=1)/(vis_mask.sum(dim=1)[...,None]+1e-6)
+ # get the distance between the query_3d_pts_static and the world_tracks_static
+ dist = (query_3d_pts_static-world_tracks_static).norm(dim=-1)
+ # get the top K static points, which have the smallest distance
+ topk_idx = torch.argsort(dist,dim=-1)[:,:K]
+ world_tracks_static = world_tracks_static[torch.arange(B)[:,None,None],topk_idx]
+ query_3d_pts_static = query_3d_pts_static[torch.arange(B)[:,None,None],topk_idx]
+ # get the visible selected
+ vis_mask_static = vis_mask[:,:,topk_idx.squeeze()]
+ tracks2d_static = tracks2d_static[:, :, topk_idx.squeeze(), :]
+
+ return world_tracks_static, static_msk, topk_idx, vis_mask_static, tracks2d_static
+
+def log_ba_summary(summary):
+ logging.info(f"Residuals : {summary.num_residuals_reduced}")
+ if summary.num_residuals_reduced > 0:
+ logging.info(f"Parameters : {summary.num_effective_parameters_reduced}")
+ logging.info(
+ f"Iterations : {summary.num_successful_steps + summary.num_unsuccessful_steps}"
+ )
+ logging.info(f"Time : {summary.total_time_in_seconds} [s]")
+ logging.info(
+ f"Initial cost : {np.sqrt(summary.initial_cost / summary.num_residuals_reduced)} [px]"
+ )
+ logging.info(
+ f"Final cost : {np.sqrt(summary.final_cost / summary.num_residuals_reduced)} [px]"
+ )
+ return True
+ else:
+ print("No residuals reduced")
+ return False
+
+# def solve_bundle_adjustment(reconstruction, ba_options, ba_config):
+# bundle_adjuster = pycolmap.BundleAdjuster(ba_options, ba_config)
+# bundle_adjuster.set_up_problem(
+# reconstruction, ba_options.create_loss_function()
+# )
+# solver_options = bundle_adjuster.set_up_solver_options(
+# bundle_adjuster.problem, ba_options.solver_options
+# )
+# summary = pyceres.SolverSummary()
+# pyceres.solve(solver_options, bundle_adjuster.problem, summary)
+# return summary
+
+def efficient_solver(solver_options, stability_mode=True):
+ # Set linear solver to ITERATIVE_SCHUR (using PCG to solve Schur complement)
+ solver_options.linear_solver_type = LinearSolverType.ITERATIVE_SCHUR
+
+ # Set preconditioner (critical for PCG)
+ solver_options.preconditioner_type = PreconditionerType.SCHUR_JACOBI
+
+ # Optimize trust region strategy
+ solver_options.trust_region_strategy_type = TrustRegionStrategyType.LEVENBERG_MARQUARDT
+
+ # Enable multi-threading acceleration
+ solver_options.num_threads = 32 # Adjust based on CPU cores
+
+ if stability_mode:
+ # Stability-first configuration
+ solver_options.initial_trust_region_radius = 1.0 # Reduce initial step size
+ solver_options.max_trust_region_radius = 10.0 # Limit max step size
+ solver_options.min_trust_region_radius = 1e-6 # Allow small step convergence
+
+ # Increase regularization parameters
+ solver_options.use_nonmonotonic_steps = True # Allow non-monotonic steps
+ solver_options.max_consecutive_nonmonotonic_steps = 10
+
+ # Adjust iteration termination conditions
+ solver_options.max_num_iterations = 100 # Increase max iterations
+ solver_options.function_tolerance = 1e-8 # Stricter function convergence
+ solver_options.gradient_tolerance = 1e-12 # Stricter gradient convergence
+ solver_options.parameter_tolerance = 1e-10 # Stricter parameter convergence
+
+ # Control PCG iterations and precision
+ solver_options.min_linear_solver_iterations = 10
+ solver_options.max_linear_solver_iterations = 100
+ solver_options.inner_iteration_tolerance = 0.01 # Higher inner iteration precision
+
+ # Increase damping factor
+ solver_options.min_lm_diagonal = 1e-3 # Increase min LM diagonal
+ solver_options.max_lm_diagonal = 1e+10 # Limit max LM diagonal
+
+ # Enable parameter change limits
+ solver_options.update_state_every_iteration = True # Update state each iteration
+
+ else:
+ # Efficiency-first configuration (original settings)
+ solver_options.initial_trust_region_radius = 10000.0
+ solver_options.max_trust_region_radius = 1e+16
+ solver_options.max_num_iterations = 50
+ solver_options.function_tolerance = 1e-6
+ solver_options.gradient_tolerance = 1e-10
+ solver_options.parameter_tolerance = 1e-8
+ solver_options.min_linear_solver_iterations = 5
+ solver_options.max_linear_solver_iterations = 50
+ solver_options.inner_iteration_tolerance = 0.1
+
+ # Enable Jacobi scaling for better numerical stability
+ solver_options.jacobi_scaling = True
+
+ # Disable verbose logging for better performance (enable for debugging)
+ solver_options.logging_type = LoggingType.SILENT
+ solver_options.minimizer_progress_to_stdout = False
+
+ return solver_options
+
+class SpatTrackCost_static(pyceres.CostFunction):
+ def __init__(self, observed_depth):
+ """
+ observed_depth: float
+ """
+ super().__init__()
+ self.observed_depth = float(observed_depth)
+ self.set_num_residuals(1)
+ self.set_parameter_block_sizes([4, 3, 3]) # [rotation_quat, translation, xyz]
+
+ def Evaluate(self, parameters, residuals, jacobians):
+ # Unpack parameters
+ quat = parameters[0] # shape: (4,) [w, x, y, z]
+ t = parameters[1] # shape: (3,)
+ point = parameters[2] # shape: (3,)
+
+ # Convert COLMAP-style quat [w, x, y, z] to scipy format [x, y, z, w]
+ r = R.from_quat([quat[1], quat[2], quat[3], quat[0]])
+ R_mat = r.as_matrix() # (3, 3)
+
+ # Transform point to camera frame
+ X_cam = R_mat @ point + t
+ z = X_cam[2]
+
+ # Compute residual (normalized depth error)
+ residuals[0] = 20.0 * (z - self.observed_depth) / self.observed_depth
+
+ if jacobians is not None:
+ if jacobians[2] is not None:
+ # dr/d(point3D): only z-axis matters, so only 3rd row of R
+ jacobians[2][0] = 20.0 * R_mat[2, 0] / self.observed_depth
+ jacobians[2][1] = 20.0 * R_mat[2, 1] / self.observed_depth
+ jacobians[2][2] = 20.0 * R_mat[2, 2] / self.observed_depth
+
+ if jacobians[1] is not None:
+ # dr/dt = ∂residual/∂translation = d(z)/dt = [0, 0, 1]
+ jacobians[1][0] = 0.0
+ jacobians[1][1] = 0.0
+ jacobians[1][2] = 20.0 / self.observed_depth
+
+ if jacobians[0] is not None:
+ # Optional: dr/d(quat) — not trivial to derive, can be left for autodiff if needed
+ # Set zero for now (not ideal but legal)
+ jacobians[0][:] = 0.0
+
+ return True
+
+
+class SpatTrackCost_dynamic(pyceres.CostFunction):
+
+ def __init__(self, observed_uv, image, point3D, camera):
+ """
+ observed_uv: 1 1 K 2 this is the 2d tracks
+ image: pycolmap.Image object
+ point3D: pycolmap.Point3D object
+ camera: pycolmap.Camera object
+ """
+ sizes = [image.cam_from_world.params.shape[0], point3D.xyz.shape[0], camera.params.shape[0]]
+ super().__init__(self, residual_size=2, parameter_block_sizes=sizes)
+ self.observed_uv = observed_uv
+ self.image = image
+ self.point3D = point3D
+ self.camera = camera
+
+def solve_bundle_adjustment(reconstruction, ba_options,
+ ba_config=None, extra_residual=None):
+ """
+ Perform bundle adjustment optimization (compatible with pycolmap 0.5+)
+
+ Args:
+ reconstruction: pycolmap.Reconstruction object
+ ba_options: pycolmap.BundleAdjustmentOptions object
+ ba_config: pycolmap.BundleAdjustmentConfig object (optional)
+ """
+ # Alternatively, you can customize the existing problem or options as:
+ # import pyceres
+ bundle_adjuster = pycolmap.create_default_bundle_adjuster(
+ ba_options, ba_config, reconstruction
+ )
+ solver_options = ba_options.create_solver_options(
+ ba_config, bundle_adjuster.problem
+ )
+ summary = pyceres.SolverSummary()
+ solver_options = efficient_solver(solver_options)
+ problem = bundle_adjuster.problem
+ # problem = pyceres.Problem()
+ # if (extra_residual is not None):
+ # observed_depths = []
+ # quaternions = []
+ # translations = []
+ # points3d = []
+ # for res_ in extra_residual:
+ # point_id_i = res_["point3D_id"]
+ # for img_id_i, obs_depth_i in zip(res_["image_ids"], res_["observed_depth"]):
+ # if obs_depth_i > 0:
+ # observed_depths.append(obs_depth_i)
+ # quaternions.append(reconstruction.images[img_id_i].cam_from_world.rotation.quat)
+ # translations.append(reconstruction.images[img_id_i].cam_from_world.translation)
+ # points3d.append(reconstruction.points3D[point_id_i].xyz)
+ # pyceres.add_spatrack_static_problem(
+ # problem,
+ # observed_depths,
+ # quaternions,
+ # translations,
+ # points3d,
+ # huber_loss_delta=5.0
+ # )
+
+ pyceres.solve(solver_options, problem, summary)
+
+ return summary
+
+def batch_matrix_to_pycolmap(
+ points3d,
+ extrinsics,
+ intrinsics,
+ tracks,
+ masks,
+ image_size,
+ max_points3D_val=3000,
+ shared_camera=False,
+ camera_type="SIMPLE_PINHOLE",
+ extra_params=None,
+ cam_tracks_static=None,
+ query_pts=None,
+):
+ """
+ Convert Batched Pytorch Tensors to PyCOLMAP
+
+ Check https://github.com/colmap/pycolmap for more details about its format
+ """
+ # points3d: Px3
+ # extrinsics: Nx3x4
+ # intrinsics: Nx3x3
+ # tracks: NxPx2
+ # masks: NxP
+ # image_size: 2, assume all the frames have been padded to the same size
+ # where N is the number of frames and P is the number of tracks
+
+ N, P, _ = tracks.shape
+ assert len(extrinsics) == N
+ assert len(intrinsics) == N
+ assert len(points3d) == P
+ assert image_size.shape[0] == 2
+
+ extrinsics = extrinsics.cpu().numpy()
+ intrinsics = intrinsics.cpu().numpy()
+
+ if extra_params is not None:
+ extra_params = extra_params.cpu().numpy()
+
+ tracks = tracks.cpu().numpy()
+ masks = masks.cpu().numpy()
+ points3d = points3d.cpu().numpy()
+ image_size = image_size.cpu().numpy()
+ if cam_tracks_static is not None:
+ cam_tracks_static = cam_tracks_static.cpu().numpy()
+
+ # Reconstruction object, following the format of PyCOLMAP/COLMAP
+ reconstruction = pycolmap.Reconstruction()
+
+ inlier_num = masks.sum(0)
+ valid_mask = inlier_num >= 2 # a track is invalid if without two inliers
+ valid_idx = np.nonzero(valid_mask)[0]
+
+ # Only add 3D points that have sufficient 2D points
+ point3d_ids = []
+ for vidx in valid_idx:
+ point3d_id = reconstruction.add_point3D(
+ points3d[vidx], pycolmap.Track(), np.zeros(3)
+ )
+ point3d_ids.append(point3d_id)
+
+ # add the residual pair
+ if cam_tracks_static is not None:
+ extra_residual = []
+ for id_x, vidx in enumerate(valid_idx):
+ points_3d_id = point3d_ids[id_x]
+ point_residual = {
+ "point3D_id": points_3d_id,
+ "image_ids": [],
+ "observed_depth": [],
+ }
+ query_i = query_pts[:,:,vidx]
+ point_residual["image_ids"].append(int(query_i[0,0,0]))
+ point_residual["observed_depth"].append(query_i[0,0,-1])
+ extra_residual.append(point_residual)
+ else:
+ extra_residual = None
+
+ num_points3D = len(valid_idx)
+
+ camera = None
+ # frame idx
+ for fidx in range(N):
+ # set camera
+ if camera is None or (not shared_camera):
+ if camera_type == "SIMPLE_RADIAL":
+ pycolmap_intri = np.array(
+ [
+ intrinsics[fidx][0, 0],
+ intrinsics[fidx][0, 2],
+ intrinsics[fidx][1, 2],
+ extra_params[fidx][0],
+ ]
+ )
+ elif camera_type == "SIMPLE_PINHOLE":
+ pycolmap_intri = np.array(
+ [
+ intrinsics[fidx][0, 0],
+ intrinsics[fidx][0, 2],
+ intrinsics[fidx][1, 2],
+ ]
+ )
+ else:
+ raise ValueError(
+ f"Camera type {camera_type} is not supported yet"
+ )
+
+ camera = pycolmap.Camera(
+ model=camera_type,
+ width=image_size[0],
+ height=image_size[1],
+ params=pycolmap_intri,
+ camera_id=fidx,
+ )
+
+ # add camera
+ reconstruction.add_camera(camera)
+
+ # set image
+ cam_from_world = pycolmap.Rigid3d(
+ pycolmap.Rotation3d(extrinsics[fidx][:3, :3]),
+ extrinsics[fidx][:3, 3],
+ ) # Rot and Trans
+ image = pycolmap.Image(
+ id=fidx,
+ name=f"image_{fidx}",
+ camera_id=camera.camera_id,
+ cam_from_world=cam_from_world,
+ )
+
+ points2D_list = []
+
+ point2D_idx = 0
+ # NOTE point3D_id start by 1
+ for point3D_id in range(1, num_points3D + 1):
+ original_track_idx = valid_idx[point3D_id - 1]
+
+ if (
+ reconstruction.points3D[point3D_id].xyz < max_points3D_val
+ ).all():
+ if masks[fidx][original_track_idx]:
+ # It seems we don't need +0.5 for BA
+ point2D_xy = tracks[fidx][original_track_idx]
+ # Please note when adding the Point2D object
+ # It not only requires the 2D xy location, but also the id to 3D point
+ points2D_list.append(
+ pycolmap.Point2D(point2D_xy, point3D_id)
+ )
+
+ # add element
+ track = reconstruction.points3D[point3D_id].track
+ track.add_element(fidx, point2D_idx)
+ point2D_idx += 1
+
+ assert point2D_idx == len(points2D_list)
+ try:
+ image.points2D = pycolmap.ListPoint2D(points2D_list)
+ except Exception as e:
+ print(f"frame {fidx} is out of BA: {e}")
+
+ # add image
+ reconstruction.add_image(image)
+
+ return reconstruction, valid_idx, extra_residual
+
+def pycolmap_to_batch_matrix(
+ reconstruction, device="cuda", camera_type="SIMPLE_PINHOLE"
+):
+ """
+ Convert a PyCOLMAP Reconstruction Object to batched PyTorch tensors.
+
+ Args:
+ reconstruction (pycolmap.Reconstruction): The reconstruction object from PyCOLMAP.
+ device (str): The device to place the tensors on (default: "cuda").
+ camera_type (str): The type of camera model used (default: "SIMPLE_PINHOLE").
+
+ Returns:
+ tuple: A tuple containing points3D, extrinsics, intrinsics, and optionally extra_params.
+ """
+
+ num_images = len(reconstruction.images)
+ max_points3D_id = max(reconstruction.point3D_ids())
+ points3D = np.zeros((max_points3D_id, 3))
+
+ for point3D_id in reconstruction.points3D:
+ points3D[point3D_id - 1] = reconstruction.points3D[point3D_id].xyz
+ points3D = torch.from_numpy(points3D).to(device)
+
+ extrinsics = []
+ intrinsics = []
+
+ extra_params = [] if camera_type == "SIMPLE_RADIAL" else None
+
+ for i in range(num_images):
+ # Extract and append extrinsics
+ pyimg = reconstruction.images[i]
+ pycam = reconstruction.cameras[pyimg.camera_id]
+ matrix = pyimg.cam_from_world.matrix()
+ extrinsics.append(matrix)
+
+ # Extract and append intrinsics
+ calibration_matrix = pycam.calibration_matrix()
+ intrinsics.append(calibration_matrix)
+
+ if camera_type == "SIMPLE_RADIAL":
+ extra_params.append(pycam.params[-1])
+
+ # Convert lists to torch tensors
+ extrinsics = torch.from_numpy(np.stack(extrinsics)).to(device)
+
+ intrinsics = torch.from_numpy(np.stack(intrinsics)).to(device)
+
+ if camera_type == "SIMPLE_RADIAL":
+ extra_params = torch.from_numpy(np.stack(extra_params)).to(device)
+ extra_params = extra_params[:, None]
+
+ return points3D, extrinsics, intrinsics, extra_params
+
+def ba_pycolmap(world_tracks, intrs, c2w_traj, visb, tracks2d, image_size, cam_tracks_static=None, training=True, query_pts=None):
+ """
+ world_tracks: 1 1 K 3 this is the coarse 3d tracks in world coordinate (coarse 3d tracks)
+ intrs: B T 3 3 this is the intrinsic matrix
+ c2w_traj: B T 4 4 this is the camera trajectory
+ visb: B T K this is the visibility of the 3d tracks
+ tracks2d: B T K 2 this is the 2d tracks
+ """
+ with torch.no_grad():
+ B, _, K, _ = world_tracks.shape
+ T = c2w_traj.shape[1]
+ world_tracks = world_tracks.view(K, 3).detach()
+ world_tracks_refine = world_tracks.view(K, 3).detach().clone()
+ c2w_traj_glob = c2w_traj.view(B*T, 4, 4).detach().clone()
+ c2w_traj = c2w_traj.view(B*T, 4, 4).detach()
+ intrs = intrs.view(B*T, 3, 3).detach()
+ visb = visb.view(B*T, K).detach()
+ tracks2d = tracks2d[...,:2].view(B*T, K, 2).detach()
+
+ rec, valid_idx_pts, extra_residual = batch_matrix_to_pycolmap(
+ world_tracks,
+ torch.inverse(c2w_traj)[:,:3,:],
+ intrs,
+ tracks2d,
+ visb,
+ image_size,
+ cam_tracks_static=cam_tracks_static,
+ query_pts=query_pts,
+ )
+ # NOTE It is window_size + 1 instead of window_size
+ ba_options = pycolmap.BundleAdjustmentOptions()
+ ba_options.refine_focal_length = False
+ ba_options.refine_principal_point = False
+ ba_options.refine_extra_params = False
+ ba_config = pycolmap.BundleAdjustmentConfig()
+ for image_id in rec.reg_image_ids():
+ ba_config.add_image(image_id)
+ # Fix frame 0, i.e, the end frame of the last window
+ ba_config.set_constant_cam_pose(0)
+
+ # fix the 3d points
+ for point3D_id in rec.points3D:
+ if training:
+ # ba_config.add_constant_point(point3D_id)
+ ba_config.add_variable_point(point3D_id)
+ else:
+ ba_config.add_variable_point(point3D_id)
+ # ba_config.add_constant_point(point3D_id)
+ if (len(ba_config.variable_point3D_ids) < 50) and (len(ba_config.constant_point3D_ids) < 50):
+ return c2w_traj_glob, world_tracks_refine, intrs
+ summary = solve_bundle_adjustment(rec, ba_options, ba_config, extra_residual=extra_residual)
+ # free the 3d points
+ # for point3D_id in rec.points3D:
+ # ba_config.remove_constant_point(point3D_id)
+ # ba_config.add_variable_point(point3D_id)
+ # summary = solve_bundle_adjustment(rec, ba_options, ba_config)
+ if not training:
+ ba_success = log_ba_summary(summary)
+ # get the refined results
+ points3D, extrinsics, intrinsics, extra_params = pycolmap_to_batch_matrix(rec, device="cuda", camera_type="SIMPLE_PINHOLE")
+ c2w_traj_glob[:, :3, :] = extrinsics
+ c2w_traj_glob = torch.inverse(c2w_traj_glob)
+ world_tracks_refine[valid_idx_pts] = points3D.to(world_tracks_refine.device).to(world_tracks_refine.dtype)
+ intrinsics = intrinsics.to(world_tracks_refine.device).to(world_tracks_refine.dtype)
+ # import pdb; pdb.set_trace()
+ return c2w_traj_glob, world_tracks_refine, intrinsics
+
+
+
+
diff --git a/models/SpaTrackV2/models/tracker3D/spatrack_modules/blocks.py b/models/SpaTrackV2/models/tracker3D/spatrack_modules/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..e849ec97e8b3dd9f254b96a6a12662900bec3c37
--- /dev/null
+++ b/models/SpaTrackV2/models/tracker3D/spatrack_modules/blocks.py
@@ -0,0 +1,15 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+
+class PointDinoV2(nn.Module):
+ """
+ PointDinoV2 is a 3D point tracking model that uses a backbone and head to extract features from points and track them.
+ """
+ def __init__(self, ):
+ super(PointDinoV2, self).__init__()
+ # self.backbone = PointDinoV2Backbone()
+ # self.head = PointDinoV2Head()
+
diff --git a/models/SpaTrackV2/models/tracker3D/spatrack_modules/dynamic_point_refine.py b/models/SpaTrackV2/models/tracker3D/spatrack_modules/dynamic_point_refine.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/SpaTrackV2/models/tracker3D/spatrack_modules/geometry_numpy.py b/models/SpaTrackV2/models/tracker3D/spatrack_modules/geometry_numpy.py
new file mode 100644
index 0000000000000000000000000000000000000000..87064378d7db7873ba2b8b5f269f1c34663e74c2
--- /dev/null
+++ b/models/SpaTrackV2/models/tracker3D/spatrack_modules/geometry_numpy.py
@@ -0,0 +1,401 @@
+from typing import *
+from functools import partial
+import math
+
+import cv2
+import numpy as np
+from scipy.signal import fftconvolve
+import numpy as np
+import utils3d
+
+from .tools import timeit
+
+
+def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
+ if w is None:
+ return np.mean(x, axis=axis)
+ else:
+ w = w.astype(x.dtype)
+ return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None)
+
+
+def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
+ if w is None:
+ return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis)
+ else:
+ w = w.astype(x.dtype)
+ return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps)
+
+
+def normalized_view_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray:
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
+ if aspect_ratio is None:
+ aspect_ratio = width / height
+
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
+
+ u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype)
+ v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype)
+ u, v = np.meshgrid(u, v, indexing='xy')
+ uv = np.stack([u, v], axis=-1)
+ return uv
+
+
+def focal_to_fov_numpy(focal: np.ndarray):
+ return 2 * np.arctan(0.5 / focal)
+
+
+def fov_to_focal_numpy(fov: np.ndarray):
+ return 0.5 / np.tan(fov / 2)
+
+
+def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0])
+ fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1])
+ return fov_x, fov_y
+
+
+def point_map_to_depth_legacy_numpy(points: np.ndarray):
+ height, width = points.shape[-3:-1]
+ diagonal = (height ** 2 + width ** 2) ** 0.5
+ uv = normalized_view_plane_uv_numpy(width, height, dtype=points.dtype) # (H, W, 2)
+ _, uv = np.broadcast_arrays(points[..., :2], uv)
+
+ # Solve least squares problem
+ b = (uv * points[..., 2:]).reshape(*points.shape[:-3], -1) # (..., H * W * 2)
+ A = np.stack([points[..., :2], -uv], axis=-1).reshape(*points.shape[:-3], -1, 2) # (..., H * W * 2, 2)
+
+ M = A.swapaxes(-2, -1) @ A
+ solution = (np.linalg.inv(M + 1e-6 * np.eye(2)) @ (A.swapaxes(-2, -1) @ b[..., None])).squeeze(-1)
+ focal, shift = solution
+
+ depth = points[..., 2] + shift[..., None, None]
+ fov_x = np.arctan(width / diagonal / focal) * 2
+ fov_y = np.arctan(height / diagonal / focal) * 2
+ return depth, fov_x, fov_y, shift
+
+
+def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray):
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
+ from scipy.optimize import least_squares
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
+
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
+ xy_proj = xy / (z + shift)[: , None]
+ f = (xy_proj * uv).sum() / np.square(xy_proj).sum()
+ err = (f * xy_proj - uv).ravel()
+ return err
+
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
+ optim_shift = solution['x'].squeeze().astype(np.float32)
+
+ xy_proj = xy / (z + optim_shift)[: , None]
+ optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum()
+
+ return optim_shift, optim_focal
+
+
+def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float):
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift"
+ from scipy.optimize import least_squares
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
+
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
+ xy_proj = xy / (z + shift)[: , None]
+ err = (focal * xy_proj - uv).ravel()
+ return err
+
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
+ optim_shift = solution['x'].squeeze().astype(np.float32)
+
+ return optim_shift
+
+
+def recover_focal_shift_numpy(points: np.ndarray, mask: np.ndarray = None, focal: float = None, downsample_size: Tuple[int, int] = (64, 64)):
+ import cv2
+ assert points.shape[-1] == 3, "Points should (H, W, 3)"
+
+ height, width = points.shape[-3], points.shape[-2]
+ diagonal = (height ** 2 + width ** 2) ** 0.5
+
+ uv = normalized_view_plane_uv_numpy(width=width, height=height)
+
+ if mask is None:
+ points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3)
+ uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2)
+ else:
+ (points_lr, uv_lr), mask_lr = mask_aware_nearest_resize_numpy((points, uv), mask, downsample_size)
+
+ if points_lr.size < 2:
+ return 1., 0.
+
+ if focal is None:
+ focal, shift = solve_optimal_focal_shift(uv_lr, points_lr)
+ else:
+ shift = solve_optimal_shift(uv_lr, points_lr, focal)
+
+ return focal, shift
+
+
+def mask_aware_nearest_resize_numpy(
+ inputs: Union[np.ndarray, Tuple[np.ndarray, ...], None],
+ mask: np.ndarray,
+ size: Tuple[int, int],
+ return_index: bool = False
+) -> Tuple[Union[np.ndarray, Tuple[np.ndarray, ...], None], np.ndarray, Tuple[np.ndarray, ...]]:
+ """
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
+
+ ### Parameters
+ - `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...).
+ - `mask`: input 2D mask of shape (..., H, W)
+ - `size`: target size (width, height)
+
+ ### Returns
+ - `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...).
+ - `resized_mask`: mask of the resized map of shape (..., target_height, target_width)
+ - `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension.
+ """
+ height, width = mask.shape[-2:]
+ target_width, target_height = size
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
+ filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
+ filter_size = filter_h_i * filter_w_i
+ padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
+
+ # Window the original mask and uv
+ uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
+ indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
+ padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
+ padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
+ padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
+ windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
+ windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
+ windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
+
+ # Gather the target pixels's local window
+ target_centers = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
+ target_lefttop = target_centers - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
+ target_window = np.round(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
+
+ target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
+ target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
+ target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(*([-1] * (mask.ndim - 2)), target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
+
+ # Compute nearest neighbor in the local window for each pixel
+ dist = np.square(target_window_centers - target_centers[..., None])
+ dist = dist[..., 0, :] + dist[..., 1, :]
+ dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size)
+ nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1)
+ nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width)
+ nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
+ target_mask = np.any(target_window_mask, axis=-1)
+ batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])]
+
+ index = (*batch_indices, nearest_i, nearest_j)
+
+ if inputs is None:
+ outputs = None
+ elif isinstance(inputs, np.ndarray):
+ outputs = inputs[index]
+ elif isinstance(inputs, Sequence):
+ outputs = tuple(x[index] for x in inputs)
+ else:
+ raise ValueError(f'Invalid input type: {type(inputs)}')
+
+ if return_index:
+ return outputs, target_mask, index
+ else:
+ return outputs, target_mask
+
+
+def mask_aware_area_resize_numpy(image: np.ndarray, mask: np.ndarray, target_width: int, target_height: int) -> Tuple[Tuple[np.ndarray, ...], np.ndarray]:
+ """
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
+
+ ### Parameters
+ - `image`: Input 2D image of shape (..., H, W, C)
+ - `mask`: Input 2D mask of shape (..., H, W)
+ - `target_width`: target width of the resized map
+ - `target_height`: target height of the resized map
+
+ ### Returns
+ - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width).
+ - `target_mask`: Mask of the resized map of shape (..., target_height, target_width)
+ """
+ height, width = mask.shape[-2:]
+
+ if image.shape[-2:] == (height, width):
+ omit_channel_dim = True
+ else:
+ omit_channel_dim = False
+ if omit_channel_dim:
+ image = image[..., None]
+
+ image = np.where(mask[..., None], image, 0)
+
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
+ filter_h_i, filter_w_i = math.ceil(filter_h_f) + 1, math.ceil(filter_w_f) + 1
+ filter_size = filter_h_i * filter_w_i
+ padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
+
+ # Window the original mask and uv (non-copy)
+ uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
+ indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
+ padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
+ padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
+ padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
+ windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
+ windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
+ windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
+
+ # Gather the target pixels's local window
+ target_center = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
+ target_lefttop = target_center - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
+ target_bottomright = target_center + np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
+ target_window = np.floor(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
+
+ target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
+ target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
+ target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
+
+ # Compute pixel area in the local windows
+ target_window_lefttop = np.maximum(target_window_centers - 0.5, target_lefttop[..., None])
+ target_window_bottomright = np.minimum(target_window_centers + 0.5, target_bottomright[..., None])
+ target_window_area = (target_window_bottomright - target_window_lefttop).clip(0, None)
+ target_window_area = np.where(target_window_mask, target_window_area[..., 0, :] * target_window_area[..., 1, :], 0)
+
+ # Weighted sum by area
+ target_window_image = image.reshape(*image.shape[:-3], height * width, -1)[..., target_window_indices, :].swapaxes(-2, -1)
+ target_mask = np.sum(target_window_area, axis=-1) >= 0.25
+ target_image = weighted_mean_numpy(target_window_image, target_window_area[..., None, :], axis=-1)
+
+ if omit_channel_dim:
+ target_image = target_image[..., 0]
+
+ return target_image, target_mask
+
+
+def norm3d(x: np.ndarray) -> np.ndarray:
+ "Faster `np.linalg.norm(x, axis=-1)` for 3D vectors"
+ return np.sqrt(np.square(x[..., 0]) + np.square(x[..., 1]) + np.square(x[..., 2]))
+
+
+def depth_occlusion_edge_numpy(depth: np.ndarray, mask: np.ndarray, kernel_size: int = 3, tol: float = 0.1):
+ disp = np.where(mask, 1 / depth, 0)
+ disp_pad = np.pad(disp, (kernel_size // 2, kernel_size // 2), constant_values=0)
+ mask_pad = np.pad(mask, (kernel_size // 2, kernel_size // 2), constant_values=False)
+ disp_window = utils3d.numpy.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2]
+ mask_window = utils3d.numpy.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2]
+
+ disp_mean = weighted_mean_numpy(disp_window, mask_window, axis=(-2, -1))
+ fg_edge_mask = mask & (disp > (1 + tol) * disp_mean)
+ bg_edge_mask = mask & (disp_mean > (1 + tol) * disp)
+ return fg_edge_mask, bg_edge_mask
+
+
+def disk_kernel(radius: int) -> np.ndarray:
+ """
+ Generate disk kernel with given radius.
+
+ Args:
+ radius (int): Radius of the disk (in pixels).
+
+ Returns:
+ np.ndarray: (2*radius+1, 2*radius+1) normalized convolution kernel.
+ """
+ # Create coordinate grid centered at (0,0)
+ L = np.arange(-radius, radius + 1)
+ X, Y = np.meshgrid(L, L)
+ # Generate disk: region inside circle with radius R is 1
+ kernel = ((X**2 + Y**2) <= radius**2).astype(np.float32)
+ # Normalize the kernel
+ kernel /= np.sum(kernel)
+ return kernel
+
+
+def disk_blur(image: np.ndarray, radius: int) -> np.ndarray:
+ """
+ Apply disk blur to an image using FFT convolution.
+
+ Args:
+ image (np.ndarray): Input image, can be grayscale or color.
+ radius (int): Blur radius (in pixels).
+
+ Returns:
+ np.ndarray: Blurred image.
+ """
+ if radius == 0:
+ return image
+ kernel = disk_kernel(radius)
+ if image.ndim == 2:
+ blurred = fftconvolve(image, kernel, mode='same')
+ elif image.ndim == 3:
+ channels = []
+ for i in range(image.shape[2]):
+ blurred_channel = fftconvolve(image[..., i], kernel, mode='same')
+ channels.append(blurred_channel)
+ blurred = np.stack(channels, axis=-1)
+ else:
+ raise ValueError("Image must be 2D or 3D.")
+ return blurred
+
+
+def depth_of_field(
+ img: np.ndarray,
+ disp: np.ndarray,
+ focus_disp : float,
+ max_blur_radius : int = 10,
+) -> np.ndarray:
+ """
+ Apply depth of field effect to an image.
+
+ Args:
+ img (numpy.ndarray): (H, W, 3) input image.
+ depth (numpy.ndarray): (H, W) depth map of the scene.
+ focus_depth (float): Focus depth of the lens.
+ strength (float): Strength of the depth of field effect.
+ max_blur_radius (int): Maximum blur radius (in pixels).
+
+ Returns:
+ numpy.ndarray: (H, W, 3) output image with depth of field effect applied.
+ """
+ # Precalculate dialated depth map for each blur radius
+ max_disp = np.max(disp)
+ disp = disp / max_disp
+ focus_disp = focus_disp / max_disp
+ dilated_disp = []
+ for radius in range(max_blur_radius + 1):
+ dilated_disp.append(cv2.dilate(disp, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*radius+1, 2*radius+1)), iterations=1))
+
+ # Determine the blur radius for each pixel based on the depth map
+ blur_radii = np.clip(abs(disp - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
+ for radius in range(max_blur_radius + 1):
+ dialted_blur_radii = np.clip(abs(dilated_disp[radius] - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
+ mask = (dialted_blur_radii >= radius) & (dialted_blur_radii >= blur_radii) & (dilated_disp[radius] > disp)
+ blur_radii[mask] = dialted_blur_radii[mask]
+ blur_radii = np.clip(blur_radii, 0, max_blur_radius)
+ blur_radii = cv2.blur(blur_radii, (5, 5))
+
+ # Precalculate the blured image for each blur radius
+ unique_radii = np.unique(blur_radii)
+ precomputed = {}
+ for radius in range(max_blur_radius + 1):
+ if radius not in unique_radii:
+ continue
+ precomputed[radius] = disk_blur(img, radius)
+
+ # Composit the blured image for each pixel
+ output = np.zeros_like(img)
+ for r in unique_radii:
+ mask = blur_radii == r
+ output[mask] = precomputed[r][mask]
+
+ return output
\ No newline at end of file
diff --git a/models/SpaTrackV2/models/tracker3D/spatrack_modules/geometry_torch.py b/models/SpaTrackV2/models/tracker3D/spatrack_modules/geometry_torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..3536ab7749ac61c3b50ccbd07d1f8e2f4077c7bc
--- /dev/null
+++ b/models/SpaTrackV2/models/tracker3D/spatrack_modules/geometry_torch.py
@@ -0,0 +1,323 @@
+from typing import *
+import math
+from collections import namedtuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.types
+import utils3d
+
+from .tools import timeit
+from .geometry_numpy import solve_optimal_focal_shift, solve_optimal_shift
+
+
+def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
+ if w is None:
+ return x.mean(dim=dim, keepdim=keepdim)
+ else:
+ w = w.to(x.dtype)
+ return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps)
+
+
+def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
+ if w is None:
+ return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal()
+ else:
+ w = w.to(x.dtype)
+ return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal()
+
+
+def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
+ if w is None:
+ return x.add(eps).log().mean(dim=dim).exp()
+ else:
+ w = w.to(x.dtype)
+ return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp()
+
+
+def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
+ if aspect_ratio is None:
+ aspect_ratio = width / height
+
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
+
+ u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
+ v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
+ u, v = torch.meshgrid(u, v, indexing='xy')
+ uv = torch.stack([u, v], dim=-1)
+ return uv
+
+
+def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
+ kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2))
+ kernel = kernel / kernel.sum()
+ kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size)
+ input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate')
+ input = F.conv2d(input, kernel, groups=input.shape[1])
+ return input
+
+
+def focal_to_fov(focal: torch.Tensor):
+ return 2 * torch.atan(0.5 / focal)
+
+
+def fov_to_focal(fov: torch.Tensor):
+ return 0.5 / torch.tan(fov / 2)
+
+
+def angle_diff_vec3(v1: torch.Tensor, v2: torch.Tensor, eps: float = 1e-12):
+ return torch.atan2(torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps, (v1 * v2).sum(dim=-1))
+
+def intrinsics_to_fov(intrinsics: torch.Tensor):
+ """
+ Returns field of view in radians from normalized intrinsics matrix.
+ ### Parameters:
+ - intrinsics: torch.Tensor of shape (..., 3, 3)
+
+ ### Returns:
+ - fov_x: torch.Tensor of shape (...)
+ - fov_y: torch.Tensor of shape (...)
+ """
+ focal_x = intrinsics[..., 0, 0]
+ focal_y = intrinsics[..., 1, 1]
+ return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y)
+
+
+def point_map_to_depth_legacy(points: torch.Tensor):
+ height, width = points.shape[-3:-1]
+ diagonal = (height ** 2 + width ** 2) ** 0.5
+ uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
+
+ # Solve least squares problem
+ b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2)
+ A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2)
+
+ M = A.transpose(-2, -1) @ A
+ solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1)
+ focal, shift = solution.unbind(-1)
+
+ depth = points[..., 2] + shift[..., None, None]
+ fov_x = torch.atan(width / diagonal / focal) * 2
+ fov_y = torch.atan(height / diagonal / focal) * 2
+ return depth, fov_x, fov_y, shift
+
+
+def view_plane_uv_to_focal(uv: torch.Tensor):
+ normed_uv = normalized_view_plane_uv(width=uv.shape[-2], height=uv.shape[-3], device=uv.device, dtype=uv.dtype)
+ focal = (uv * normed_uv).sum() / uv.square().sum().add(1e-12)
+ return focal
+
+
+def recover_focal_shift(points: torch.Tensor, mask: torch.Tensor = None, focal: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)):
+ """
+ Recover the depth map and FoV from a point map with unknown z shift and focal.
+
+ Note that it assumes:
+ - the optical center is at the center of the map
+ - the map is undistorted
+ - the map is isometric in the x and y directions
+
+ ### Parameters:
+ - `points: torch.Tensor` of shape (..., H, W, 3)
+ - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
+
+ ### Returns:
+ - `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map
+ - `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space
+ """
+ shape = points.shape
+ height, width = points.shape[-3], points.shape[-2]
+ diagonal = (height ** 2 + width ** 2) ** 0.5
+
+ points = points.reshape(-1, *shape[-3:])
+ mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
+ focal = focal.reshape(-1) if focal is not None else None
+ uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
+
+ points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1)
+ uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0)
+ mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0
+
+ uv_lr_np = uv_lr.cpu().numpy()
+ points_lr_np = points_lr.detach().cpu().numpy()
+ focal_np = focal.cpu().numpy() if focal is not None else None
+ mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
+ optim_shift, optim_focal = [], []
+ for i in range(points.shape[0]):
+ points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
+ uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
+ if uv_lr_i_np.shape[0] < 2:
+ optim_focal.append(1)
+ optim_shift.append(0)
+ continue
+ if focal is None:
+ optim_shift_i, optim_focal_i = solve_optimal_focal_shift(uv_lr_i_np, points_lr_i_np)
+ optim_focal.append(float(optim_focal_i))
+ else:
+ optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i])
+ optim_shift.append(float(optim_shift_i))
+ optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])
+
+ if focal is None:
+ optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3])
+ else:
+ optim_focal = focal.reshape(shape[:-3])
+
+ return optim_focal, optim_shift
+
+
+def mask_aware_nearest_resize(
+ inputs: Union[torch.Tensor, Sequence[torch.Tensor], None],
+ mask: torch.BoolTensor,
+ size: Tuple[int, int],
+ return_index: bool = False
+) -> Tuple[Union[torch.Tensor, Sequence[torch.Tensor], None], torch.BoolTensor, Tuple[torch.LongTensor, ...]]:
+ """
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
+
+ ### Parameters
+ - `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...).
+ - `mask`: input 2D mask of shape (..., H, W)
+ - `size`: target size (target_width, target_height)
+
+ ### Returns
+ - `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...).
+ - `resized_mask`: mask of the resized map of shape (..., target_height, target_width)
+ - `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension, .
+ """
+ height, width = mask.shape[-2:]
+ target_width, target_height = size
+ device = mask.device
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
+ filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
+ filter_size = filter_h_i * filter_w_i
+ padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
+
+ # Window the original mask and uv
+ uv = utils3d.torch.image_pixel_center(width=width, height=height, dtype=torch.float32, device=device)
+ indices = torch.arange(height * width, dtype=torch.long, device=device).reshape(height, width)
+ padded_uv = torch.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=torch.float32, device=device)
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
+ padded_mask = torch.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=torch.bool, device=device)
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
+ padded_indices = torch.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=torch.long, device=device)
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
+ windowed_uv = utils3d.torch.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, dim=(0, 1))
+ windowed_mask = utils3d.torch.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, dim=(-2, -1))
+ windowed_indices = utils3d.torch.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, dim=(0, 1))
+
+ # Gather the target pixels's local window
+ target_uv = utils3d.torch.image_uv(width=target_width, height=target_height, dtype=torch.float32, device=device) * torch.tensor([width, height], dtype=torch.float32, device=device)
+ target_lefttop = target_uv - torch.tensor((filter_w_f / 2, filter_h_f / 2), dtype=torch.float32, device=device)
+ target_window = torch.round(target_lefttop).long() + torch.tensor((padding_w, padding_h), dtype=torch.long, device=device)
+
+ target_window_uv = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
+ target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
+ target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
+ target_window_indices = target_window_indices.expand_as(target_window_mask)
+
+ # Compute nearest neighbor in the local window for each pixel
+ dist = torch.where(target_window_mask, torch.norm(target_window_uv - target_uv[..., None], dim=-2), torch.inf) # (..., target_height, tgt_width, filter_size)
+ nearest = torch.argmin(dist, dim=-1, keepdim=True) # (..., target_height, tgt_width, 1)
+ nearest_idx = torch.gather(target_window_indices, index=nearest, dim=-1).squeeze(-1) # (..., target_height, tgt_width)
+ target_mask = torch.any(target_window_mask, dim=-1)
+ nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
+ batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])]
+
+ index = (*batch_indices, nearest_i, nearest_j)
+
+ if inputs is None:
+ outputs = None
+ elif isinstance(inputs, torch.Tensor):
+ outputs = inputs[index]
+ elif isinstance(inputs, Sequence):
+ outputs = tuple(x[index] for x in inputs)
+ else:
+ raise ValueError(f'Invalid input type: {type(inputs)}')
+
+ if return_index:
+ return outputs, target_mask, index
+ else:
+ return outputs, target_mask
+
+
+def theshold_depth_change(depth: torch.Tensor, mask: torch.Tensor, pooler: Literal['min', 'max'], rtol: float = 0.2, kernel_size: int = 3):
+ *batch_shape, height, width = depth.shape
+ depth = depth.reshape(-1, 1, height, width)
+ mask = mask.reshape(-1, 1, height, width)
+ if pooler =='max':
+ pooled_depth = F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2)
+ output_mask = pooled_depth > depth * (1 + rtol)
+ elif pooler =='min':
+ pooled_depth = -F.max_pool2d(-torch.where(mask, depth, torch.inf), kernel_size, stride=1, padding=kernel_size // 2)
+ output_mask = pooled_depth < depth * (1 - rtol)
+ else:
+ raise ValueError(f'Unsupported pooler: {pooler}')
+ output_mask = output_mask.reshape(*batch_shape, height, width)
+ return output_mask
+
+
+def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1):
+ device, dtype = depth.device, depth.dtype
+
+ disp = torch.where(mask, 1 / depth, 0)
+ disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0)
+ mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False)
+ disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2]
+ mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2]
+
+ x = torch.linspace(-kernel_size // 2, kernel_size // 2, kernel_size, device=device, dtype=dtype)
+ A = torch.stack([*torch.meshgrid(x, x, indexing='xy'), torch.ones((kernel_size, kernel_size), device=device, dtype=dtype)], dim=-1).reshape(kernel_size ** 2, 3) # [kernel_size ** 2, 3]
+ A = mask_window[..., None] * A
+ I = torch.eye(3, device=device, dtype=dtype)
+
+ affine_disp_window = (disp_window[..., None, :] @ A @ torch.inverse(A.mT @ A + 1e-5 * I) @ A.mT).clamp_min(1e-12)[..., 0, :] # [..., H, W, kernel_size ** 2]
+ diff = torch.where(mask_window, torch.maximum(affine_disp_window, disp_window) / torch.minimum(affine_disp_window, disp_window) - 1, 0)
+
+ edge_mask = mask & (diff > tol).any(dim=-1)
+
+ disp_mean = weighted_mean(disp_window, mask_window, dim=-1)
+ fg_edge_mask = edge_mask & (disp > disp_mean)
+ # fg_edge_mask = edge_mask & theshold_depth_change(depth, mask, pooler='max', rtol=tol, kernel_size=kernel_size)
+ bg_edge_mask = edge_mask & ~fg_edge_mask
+ return fg_edge_mask, bg_edge_mask
+
+
+def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1):
+ device, dtype = depth.device, depth.dtype
+
+ disp = torch.where(mask, 1 / depth, 0)
+ disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0)
+ mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False)
+ disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2]
+ mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2]
+
+ disp_mean = weighted_mean(disp_window, mask_window, dim=(-2, -1))
+ fg_edge_mask = mask & (disp / disp_mean > 1 + tol)
+ bg_edge_mask = mask & (disp_mean / disp > 1 + tol)
+
+ fg_edge_mask = fg_edge_mask & F.max_pool2d(bg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool()
+ bg_edge_mask = bg_edge_mask & F.max_pool2d(fg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool()
+
+ return fg_edge_mask, bg_edge_mask
+
+
+def dilate_with_mask(input: torch.Tensor, mask: torch.BoolTensor, filter: Literal['min', 'max', 'mean', 'median'] = 'mean', iterations: int = 1) -> torch.Tensor:
+ kernel = torch.tensor([[False, True, False], [True, True, True], [False, True, False]], device=input.device, dtype=torch.bool)
+ for _ in range(iterations):
+ input_window = utils3d.torch.sliding_window_2d(F.pad(input, (1, 1, 1, 1), mode='constant', value=0), window_size=3, stride=1, dim=(-2, -1))
+ mask_window = kernel & utils3d.torch.sliding_window_2d(F.pad(mask, (1, 1, 1, 1), mode='constant', value=False), window_size=3, stride=1, dim=(-2, -1))
+ if filter =='min':
+ input = torch.where(mask, input, torch.where(mask_window, input_window, torch.inf).min(dim=(-2, -1)).values)
+ elif filter =='max':
+ input = torch.where(mask, input, torch.where(mask_window, input_window, -torch.inf).max(dim=(-2, -1)).values)
+ elif filter == 'mean':
+ input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).nanmean(dim=(-2, -1)))
+ elif filter =='median':
+ input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).flatten(-2).nanmedian(dim=-1).values)
+ mask = mask_window.any(dim=(-2, -1))
+ return input, mask
\ No newline at end of file
diff --git a/models/SpaTrackV2/models/tracker3D/spatrack_modules/pointmap_updator.py b/models/SpaTrackV2/models/tracker3D/spatrack_modules/pointmap_updator.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab15fa78606d9e177df6898fa83a1b80a73bbe1a
--- /dev/null
+++ b/models/SpaTrackV2/models/tracker3D/spatrack_modules/pointmap_updator.py
@@ -0,0 +1,104 @@
+import torch
+import torch.nn as nn
+from models.SpaTrackV2.models.blocks import bilinear_sampler
+from models.SpaTrackV2.models.tracker3D.spatrack_modules.alignment import align_points_scale, align_points_scale_xyz_shift
+
+def compute_affine_scale_and_shift(points, pointmap, mask, weights=None, eps=1e-6):
+ """
+ Compute global affine transform (scale * pointmap + shift = points)
+ using least-squares fitting with optional weights and mask.
+
+ Args:
+ points (BT, N, 3): Target points
+ pointmap (BT, N, 3): Source points
+ mask (BT, N): Binary mask indicating valid points
+ weights (BT, N): Optional weights per point
+ eps (float): Numerical stability
+
+ Returns:
+ scale (BT, 1): Scalar scale per batch
+ shift (BT, 3): Shift vector per batch
+ """
+ if weights is None:
+ weights = mask.float()
+ else:
+ weights = weights * mask # combine mask
+
+ # Sum of weights
+ weight_sum = weights.sum(dim=1, keepdim=True) + eps # (BT, 1)
+
+ # Compute weighted centroids
+ centroid_p = (points * weights.unsqueeze(-1)).sum(dim=1) / weight_sum # (BT, 3)
+ centroid_m = (pointmap * weights.unsqueeze(-1)).sum(dim=1) / weight_sum # (BT, 3)
+
+ # Center the point sets
+ p_centered = points - centroid_p.unsqueeze(1) # (BT, N, 3)
+ m_centered = pointmap - centroid_m.unsqueeze(1) # (BT, N, 3)
+
+ # Compute scale: ratio of dot products
+ numerator = (weights.unsqueeze(-1) * (p_centered * m_centered)).sum(dim=1).sum(dim=-1) # (BT,)
+ denominator = (weights.unsqueeze(-1) * (m_centered ** 2)).sum(dim=1).sum(dim=-1) + eps # (BT,)
+ scale = (numerator / denominator).unsqueeze(-1) # (BT, 1)
+
+ # Compute shift: t = c_p - s * c_m
+ shift = centroid_p - scale * centroid_m # (BT, 3)
+
+ return scale, shift
+
+def compute_weighted_std(track2d, vis_est, eps=1e-6):
+ """
+ Compute the weighted standard deviation of 2D tracks across time.
+
+ Args:
+ track2d (Tensor): shape (B, T, N, 2), 2D tracked points.
+ vis_est (Tensor): shape (B, T, N), visibility weights (0~1).
+ eps (float): small epsilon to avoid division by zero.
+
+ Returns:
+ std (Tensor): shape (B, N, 2), weighted standard deviation for each point.
+ """
+ B, T, N, _ = track2d.shape
+
+ # Compute weighted mean
+ weighted_sum = (track2d * vis_est[..., None]).sum(dim=1) # (B, N, 2)
+ weight_sum = vis_est.sum(dim=1)[..., None] + eps # (B, N, 1)
+ track_mean = weighted_sum / weight_sum # (B, N, 2)
+
+ # Compute squared residuals
+ residuals = track2d - track_mean[:, None, :, :] # (B, T, N, 2)
+ weighted_sq_res = (residuals ** 2) * vis_est[..., None] # (B, T, N, 2)
+
+ # Compute weighted variance and std
+ var = weighted_sq_res.sum(dim=1) / (weight_sum + eps) # (B, N, 2)
+ std = var.sqrt() # (B, N, 2)
+
+ return std
+
+class PointMapUpdator(nn.Module):
+ def __init__(self, stablizer):
+ super(PointMapUpdator, self).__init__()
+ self.stablizer = stablizer()
+
+ def init_pointmap(self, points_map):
+
+ pass
+
+ def scale_update_from_tracks(self, cam_pts_est, coords_append, point_map_org, vis_est, reproj_loss):
+ B, T, N, _ = coords_append.shape
+ track2d = coords_append[...,:2].view(B*T, N, 2)
+
+ track_len_std = compute_weighted_std(track2d.view(B, T, N, 2), vis_est.view(B, T, N)).norm(dim=-1)
+
+ point_samp = bilinear_sampler(point_map_org, track2d[:,None], mode="nearest")
+ point_samp = point_samp.permute(0,3,1,2).view(B*T, N, 3)
+ cam_pts_est = cam_pts_est.view(B*T, N, 3)
+ # mask
+ mask = vis_est.view(B*T, N)
+ # using gaussian weights, mean is 2 pixels
+ nm_reproj_loss = (reproj_loss.view(B*T, N) / (track_len_std.view(B, N) + 1e-6)).clamp(0, 5)
+ std = nm_reproj_loss.std(dim=-1).view(B*T, 1) # B*T 1
+ weights = torch.exp(-(0.5-nm_reproj_loss.view(B*T, N))**2 / (2*std**2))
+ mask = mask*(point_samp[...,2]>0)*(cam_pts_est[...,2]>0)*weights
+ scales, shift = align_points_scale_xyz_shift(point_samp, cam_pts_est, mask)
+
+ return scales, shift
\ No newline at end of file
diff --git a/models/SpaTrackV2/models/tracker3D/spatrack_modules/simple_vit_1d.py b/models/SpaTrackV2/models/tracker3D/spatrack_modules/simple_vit_1d.py
new file mode 100644
index 0000000000000000000000000000000000000000..233f83441e6bb846c9aae299003db169931aeca2
--- /dev/null
+++ b/models/SpaTrackV2/models/tracker3D/spatrack_modules/simple_vit_1d.py
@@ -0,0 +1,125 @@
+import torch
+from torch import nn
+
+from einops import rearrange
+from einops.layers.torch import Rearrange
+
+# helpers
+
+def posemb_sincos_1d(patches, temperature = 10000, dtype = torch.float32):
+ _, n, dim, device, dtype = *patches.shape, patches.device, patches.dtype
+
+ n = torch.arange(n, device = device)
+ assert (dim % 2) == 0, 'feature dimension must be multiple of 2 for sincos emb'
+ omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1)
+ omega = 1. / (temperature ** omega)
+
+ n = n.flatten()[:, None] * omega[None, :]
+ pe = torch.cat((n.sin(), n.cos()), dim = 1)
+ return pe.type(dtype)
+
+# classes
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, hidden_dim):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, hidden_dim),
+ nn.GELU(),
+ nn.Linear(hidden_dim, dim),
+ )
+ def forward(self, x):
+ return self.net(x)
+
+class Attention(nn.Module):
+ def __init__(self, dim, heads = 8, dim_head = 64):
+ super().__init__()
+ inner_dim = dim_head * heads
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+ self.norm = nn.LayerNorm(dim)
+
+ self.attend = nn.Softmax(dim = -1)
+
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
+ self.to_out = nn.Linear(inner_dim, dim, bias = False)
+
+ def forward(self, x):
+ x = self.norm(x)
+
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
+
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
+
+ attn = self.attend(dots)
+
+ out = torch.matmul(attn, v)
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ return self.to_out(out)
+
+class Transformer(nn.Module):
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(nn.ModuleList([
+ Attention(dim, heads = heads, dim_head = dim_head),
+ FeedForward(dim, mlp_dim)
+ ]))
+ def forward(self, x):
+ for attn, ff in self.layers:
+ x = attn(x) + x
+ x = ff(x) + x
+ return self.norm(x)
+
+class SimpleViT(nn.Module):
+ def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
+ super().__init__()
+
+ assert seq_len % patch_size == 0
+
+ num_patches = seq_len // patch_size
+ patch_dim = channels * patch_size
+
+ self.to_patch_embedding = nn.Sequential(
+ Rearrange('b c (n p) -> b n (p c)', p = patch_size),
+ nn.LayerNorm(patch_dim),
+ nn.Linear(patch_dim, dim),
+ nn.LayerNorm(dim),
+ )
+
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
+
+ self.to_latent = nn.Identity()
+ self.linear_head = nn.Linear(dim, num_classes)
+
+ def forward(self, series):
+ *_, n, dtype = *series.shape, series.dtype
+
+ x = self.to_patch_embedding(series)
+ pe = posemb_sincos_1d(x)
+ x = rearrange(x, 'b ... d -> b (...) d') + pe
+
+ x = self.transformer(x)
+ x = x.mean(dim = 1)
+
+ x = self.to_latent(x)
+ return self.linear_head(x)
+
+if __name__ == '__main__':
+
+ v = SimpleViT(
+ seq_len = 256,
+ patch_size = 16,
+ num_classes = 1000,
+ dim = 1024,
+ depth = 6,
+ heads = 8,
+ mlp_dim = 2048
+ )
+
+ time_series = torch.randn(4, 3, 256)
+ logits = v(time_series) # (4, 1000)
diff --git a/models/SpaTrackV2/models/tracker3D/spatrack_modules/tools.py b/models/SpaTrackV2/models/tracker3D/spatrack_modules/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..3687f6938fe34433d149a1a8405be7eed5f23c37
--- /dev/null
+++ b/models/SpaTrackV2/models/tracker3D/spatrack_modules/tools.py
@@ -0,0 +1,289 @@
+from typing import *
+import time
+from pathlib import Path
+from numbers import Number
+from functools import wraps
+import warnings
+import math
+import json
+import os
+import importlib
+import importlib.util
+
+
+def catch_exception(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ try:
+ return fn(*args, **kwargs)
+ except Exception as e:
+ import traceback
+ print(f"Exception in {fn.__name__}", end='r')
+ # print({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())})
+ traceback.print_exc(chain=False)
+ time.sleep(0.1)
+ return None
+ return wrapper
+
+
+class CallbackOnException:
+ def __init__(self, callback: Callable, exception: type):
+ self.exception = exception
+ self.callback = callback
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if isinstance(exc_val, self.exception):
+ self.callback()
+ return True
+ return False
+
+def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]:
+ for k, v in d.items():
+ if isinstance(v, dict):
+ for sub_key in traverse_nested_dict_keys(v):
+ yield (k, ) + sub_key
+ else:
+ yield (k, )
+
+
+def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None):
+ for k in keys:
+ d = d.get(k, default)
+ if d is None:
+ break
+ return d
+
+def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any):
+ for k in keys[:-1]:
+ d = d.setdefault(k, {})
+ d[keys[-1]] = value
+
+
+def key_average(list_of_dicts: list) -> Dict[str, Any]:
+ """
+ Returns a dictionary with the average value of each key in the input list of dictionaries.
+ """
+ _nested_dict_keys = set()
+ for d in list_of_dicts:
+ _nested_dict_keys.update(traverse_nested_dict_keys(d))
+ _nested_dict_keys = sorted(_nested_dict_keys)
+ result = {}
+ for k in _nested_dict_keys:
+ values = []
+ for d in list_of_dicts:
+ v = get_nested_dict(d, k)
+ if v is not None and not math.isnan(v):
+ values.append(v)
+ avg = sum(values) / len(values) if values else float('nan')
+ set_nested_dict(result, k, avg)
+ return result
+
+
+def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]:
+ """
+ Flattens a nested dictionary into a single-level dictionary, with keys as tuples.
+ """
+ items = []
+ if parent_key is None:
+ parent_key = ()
+ for k, v in d.items():
+ new_key = parent_key + (k, )
+ if isinstance(v, MutableMapping):
+ items.extend(flatten_nested_dict(v, new_key).items())
+ else:
+ items.append((new_key, v))
+ return dict(items)
+
+
+def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Unflattens a single-level dictionary into a nested dictionary, with keys as tuples.
+ """
+ result = {}
+ for k, v in d.items():
+ sub_dict = result
+ for k_ in k[:-1]:
+ if k_ not in sub_dict:
+ sub_dict[k_] = {}
+ sub_dict = sub_dict[k_]
+ sub_dict[k[-1]] = v
+ return result
+
+
+def read_jsonl(file):
+ import json
+ with open(file, 'r') as f:
+ data = f.readlines()
+ return [json.loads(line) for line in data]
+
+
+def write_jsonl(data: List[dict], file):
+ import json
+ with open(file, 'w') as f:
+ for item in data:
+ f.write(json.dumps(item) + '\n')
+
+
+def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]):
+ import pandas as pd
+ data = [flatten_nested_dict(d) for d in data]
+ df = pd.DataFrame(data)
+ df = df.sort_index(axis=1)
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
+ return df
+
+
+def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]):
+ if isinstance(d, str):
+ for old, new in mapping.items():
+ d = d.replace(old, new)
+ elif isinstance(d, list):
+ for i, item in enumerate(d):
+ d[i] = recursive_replace(item, mapping)
+ elif isinstance(d, dict):
+ for k, v in d.items():
+ d[k] = recursive_replace(v, mapping)
+ return d
+
+
+class timeit:
+ _history: Dict[str, List['timeit']] = {}
+
+ def __init__(self, name: str = None, verbose: bool = True, average: bool = False):
+ self.name = name
+ self.verbose = verbose
+ self.start = None
+ self.end = None
+ self.average = average
+ if average and name not in timeit._history:
+ timeit._history[name] = []
+
+ def __call__(self, func: Callable):
+ import inspect
+ if inspect.iscoroutinefunction(func):
+ async def wrapper(*args, **kwargs):
+ with timeit(self.name or func.__qualname__):
+ ret = await func(*args, **kwargs)
+ return ret
+ return wrapper
+ else:
+ def wrapper(*args, **kwargs):
+ with timeit(self.name or func.__qualname__):
+ ret = func(*args, **kwargs)
+ return ret
+ return wrapper
+
+ def __enter__(self):
+ self.start = time.time()
+ return self
+
+ @property
+ def time(self) -> float:
+ assert self.start is not None, "Time not yet started."
+ assert self.end is not None, "Time not yet ended."
+ return self.end - self.start
+
+ @property
+ def average_time(self) -> float:
+ assert self.average, "Average time not available."
+ return sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name])
+
+ @property
+ def history(self) -> List['timeit']:
+ return timeit._history.get(self.name, [])
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.end = time.time()
+ if self.average:
+ timeit._history[self.name].append(self)
+ if self.verbose:
+ if self.average:
+ avg = self.average_time
+ print(f"{self.name or 'It'} took {avg:.6f} seconds in average.")
+ else:
+ print(f"{self.name or 'It'} took {self.time:.6f} seconds.")
+
+
+def strip_common_prefix_suffix(strings: List[str]) -> List[str]:
+ first = strings[0]
+
+ for start in range(len(first)):
+ if any(s[start] != strings[0][start] for s in strings):
+ break
+
+ for end in range(1, min(len(s) for s in strings)):
+ if any(s[-end] != first[-end] for s in strings):
+ break
+
+ return [s[start:len(s) - end + 1] for s in strings]
+
+
+def multithead_execute(inputs: List[Any], num_workers: int, pbar = None):
+ from concurrent.futures import ThreadPoolExecutor
+ from contextlib import nullcontext
+ from tqdm import tqdm
+
+ if pbar is not None:
+ pbar.total = len(inputs) if hasattr(inputs, '__len__') else None
+ else:
+ pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None)
+
+ def decorator(fn: Callable):
+ with (
+ ThreadPoolExecutor(max_workers=num_workers) as executor,
+ pbar
+ ):
+ pbar.refresh()
+ @catch_exception
+ @suppress_traceback
+ def _fn(input):
+ ret = fn(input)
+ pbar.update()
+ return ret
+ executor.map(_fn, inputs)
+ executor.shutdown(wait=True)
+
+ return decorator
+
+
+def suppress_traceback(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ try:
+ return fn(*args, **kwargs)
+ except Exception as e:
+ e.__traceback__ = e.__traceback__.tb_next.tb_next
+ raise
+ return wrapper
+
+
+class no_warnings:
+ def __init__(self, action: str = 'ignore', **kwargs):
+ self.action = action
+ self.filter_kwargs = kwargs
+
+ def __call__(self, fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ with warnings.catch_warnings():
+ warnings.simplefilter(self.action, **self.filter_kwargs)
+ return fn(*args, **kwargs)
+ return wrapper
+
+ def __enter__(self):
+ self.warnings_manager = warnings.catch_warnings()
+ self.warnings_manager.__enter__()
+ warnings.simplefilter(self.action, **self.filter_kwargs)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
+
+
+def import_file_as_module(file_path: Union[str, os.PathLike], module_name: str):
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ return module
\ No newline at end of file
diff --git a/models/SpaTrackV2/models/tracker3D/spatrack_modules/utils.py b/models/SpaTrackV2/models/tracker3D/spatrack_modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..465c53337e01ae3041ac87691237c27bfe78b48c
--- /dev/null
+++ b/models/SpaTrackV2/models/tracker3D/spatrack_modules/utils.py
@@ -0,0 +1,1006 @@
+import os, sys
+import torch
+import torch.amp
+import torch.nn.functional as F
+import torch.nn as nn
+from models.SpaTrackV2.models.tracker3D.co_tracker.utils import (
+ EfficientUpdateFormer, AttnBlock, Attention, CrossAttnBlock,
+ sequence_BCE_loss, sequence_loss, sequence_prob_loss, sequence_dyn_prob_loss
+)
+import math
+from models.SpaTrackV2.models.tracker3D.co_tracker.utils import (
+ Mlp, BasicEncoder, EfficientUpdateFormer, GeometryEncoder, NeighborTransformer
+)
+import numpy as np
+from models.SpaTrackV2.models.tracker3D.spatrack_modules.simple_vit_1d import Transformer,posemb_sincos_1d
+from einops import rearrange
+
+def self_grid_pos_embedding(B, T, H, W, level=None):
+ import pdb; pdb.set_trace()
+
+def random_se3_transformation(
+ batch_size: int = 1,
+ max_rotation_angle: float = math.pi,
+ max_translation: float = 1.0,
+ device: str = "cpu",
+ dtype: torch.dtype = torch.float32,
+) -> torch.Tensor:
+ """
+ 随机生成刚体变换矩阵(SE(3) Transformation Matrix)。
+
+ Args:
+ batch_size (int): 批大小,默认为 1。
+ max_rotation_angle (float): 最大旋转角度(弧度),默认 π(180°)。
+ max_translation (float): 最大平移量,默认 1.0。
+ device (str): 设备('cpu' 或 'cuda')。
+ dtype (torch.dtype): 数据类型(推荐 float32)。
+
+ Returns:
+ torch.Tensor: 形状为 (batch_size, 4, 4) 的齐次变换矩阵。
+ """
+ # 随机生成旋转矩阵 R (batch_size, 3, 3)
+ # 方法 1:使用轴角表示(Axis-Angle)转换为旋转矩阵
+ axis = torch.randn(batch_size, 3, device=device, dtype=dtype) # 随机旋转轴
+ axis = axis / torch.norm(axis, dim=1, keepdim=True) # 归一化
+ angle = torch.rand(batch_size, 1, device=device, dtype=dtype) * max_rotation_angle # 随机角度 [0, max_angle]
+
+ # 计算旋转矩阵(Rodrigues' rotation formula)
+ K = torch.zeros(batch_size, 3, 3, device=device, dtype=dtype)
+ K[:, 0, 1] = -axis[:, 2]
+ K[:, 0, 2] = axis[:, 1]
+ K[:, 1, 0] = axis[:, 2]
+ K[:, 1, 2] = -axis[:, 0]
+ K[:, 2, 0] = -axis[:, 1]
+ K[:, 2, 1] = axis[:, 0]
+
+ I = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand(batch_size, -1, -1)
+ R = I + torch.sin(angle).unsqueeze(-1) * K + (1 - torch.cos(angle).unsqueeze(-1)) * (K @ K)
+
+ # 随机生成平移向量 t (batch_size, 3)
+ t = (torch.rand(batch_size, 3, device=device, dtype=dtype) - 0.5) * 2 * max_translation
+
+ # 组合成齐次变换矩阵 T (batch_size, 4, 4)
+ T = torch.eye(4, device=device, dtype=dtype).unsqueeze(0).expand(batch_size, -1, -1)
+ T[:, :3, :3] = R
+ T[:, :3, 3] = t
+
+ return T
+
+def weighted_procrustes_torch(X, Y, W=None, RT=None):
+ """
+ Weighted Procrustes Analysis in PyTorch (batched).
+
+ Args:
+ X: (B, 1, N, 3), source point cloud.
+ Y: (B, T, N, 3), target point cloud.
+ W: (B, T, N) or (B, 1, N), optional weights for each point.
+
+ Returns:
+ t: (B, T, 3), optimal translation vectors.
+ R: (B, T, 3, 3), optimal rotation matrices.
+ """
+ device = X.device
+ B, T, N, _ = Y.shape
+
+ # Default weights: uniform
+ if W is None:
+ W = torch.ones(B, 1, N, device=device)
+ elif W.dim() == 3: # (B, T, N) -> expand to match Y
+ W = W.unsqueeze(-1) # (B, T, N, 1)
+ else: # (B, 1, N)
+ W = W.unsqueeze(-1).expand(B, T, N, 1)
+
+ # Reshape X to (B, T, N, 3) by broadcasting
+ X = X.expand(B, T, N, 3)
+
+ # Compute weighted centroids
+ sum_W = torch.sum(W, dim=2, keepdim=True) # (B, T, 1, 1)
+ centroid_X = torch.sum(W * X, dim=2) / sum_W.squeeze(-1) # (B, T, 3)
+ centroid_Y = torch.sum(W * Y, dim=2) / sum_W.squeeze(-1) # (B, T, 3)
+
+ # Center the point clouds
+ X_centered = X - centroid_X.unsqueeze(2) # (B, T, N, 3)
+ Y_centered = Y - centroid_Y.unsqueeze(2) # (B, T, N, 3)
+
+ # Compute weighted covariance matrix H = X^T W Y
+ X_weighted = X_centered * W # (B, T, N, 3)
+ H = torch.matmul(X_weighted.transpose(2, 3), Y_centered) # (B, T, 3, 3)
+
+ # SVD decomposition
+ U, S, Vt = torch.linalg.svd(H) # U/Vt: (B, T, 3, 3)
+
+ # Ensure right-handed rotation (det(R) = +1)
+ det = torch.det(torch.matmul(U, Vt)) # (B, T)
+ Vt_corrected = Vt.clone()
+ mask = det < 0
+ B_idx, T_idx = torch.nonzero(mask, as_tuple=True)
+ Vt_corrected[B_idx, T_idx, -1, :] *= -1 # Flip last row for those needing correction
+
+ # Optimal rotation and translation
+ R = torch.matmul(U, Vt_corrected).inverse() # (B, T, 3, 3)
+ t = centroid_Y - torch.matmul(R, centroid_X.unsqueeze(-1)).squeeze(-1) # (B, T, 3)
+ w2c = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).repeat(B, T, 1, 1)
+ if (torch.det(R) - 1).abs().max() < 1e-3:
+ w2c[:, :, :3, :3] = R
+ else:
+ import pdb; pdb.set_trace()
+ w2c[:, :, :3, 3] = t
+ try:
+ c2w_traj = torch.inverse(w2c) # or torch.linalg.inv()
+ except:
+ c2w_traj = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).repeat(B, T, 1, 1)
+
+ return c2w_traj
+
+def key_fr_wprocrustes(cam_pts, graph_matrix, dyn_weight, vis_mask,slide_len=16, overlap=8, K=3, mode="keyframe"):
+ """
+ cam_pts: (B, T, N, 3)
+ graph_matrix: (B, 1, N)
+ dyn_weight: (B, T, N)
+ K: number of keyframes to select (including start and end)
+
+ Returns:
+ c2w_traj: (B, T, 4, 4)
+ """
+ B, T, N, _ = cam_pts.shape
+ device = cam_pts.device
+
+ if mode == "keyframe":
+ # Step 1: Keyframe selection
+ ky_fr_idx = [0, T - 1]
+ graph_sum = torch.sum(graph_matrix, dim=-1) # (B, T, T)
+ dist = torch.max(graph_sum[:, 0, :], graph_sum[:, T - 1, :]) # (B, T)
+ dist[:, [0, T - 1]] = float('inf')
+ for _ in range(K - 2): # already have 2
+ last_idx = ky_fr_idx[-1]
+ dist = torch.max(dist, graph_sum[:, last_idx, :])
+ dist[:, last_idx] = float('inf')
+ next_id = torch.argmin(dist, dim=1)[0].item() # Assuming batch=1 or shared
+ ky_fr_idx.append(next_id)
+
+ ky_fr_idx = sorted(ky_fr_idx)
+ elif mode == "slide":
+ id_slide = torch.arange(0, T)
+ id_slide = id_slide.unfold(0, slide_len, overlap)
+ vis_mask_slide = vis_mask.unfold(1, slide_len, overlap)
+ cam_pts_slide = cam_pts.unfold(1, slide_len, overlap)
+ ky_fr_idx = torch.arange(0, T - slide_len + 1, overlap)
+ if ky_fr_idx[-1] + slide_len < T:
+ # if the last keyframe does not cover the whole sequence, add one more keyframe
+ ky_fr_idx = torch.cat([ky_fr_idx, ky_fr_idx[-1:] + overlap])
+ id_add = torch.arange(ky_fr_idx[-1], ky_fr_idx[-1] + slide_len).clamp(max=T-1)
+ id_slide = torch.cat([id_slide, id_add[None, :]], dim=0)
+ cam_pts_add = cam_pts[:, id_add, :, :]
+ cam_pts_slide = torch.cat([cam_pts_slide, cam_pts_add.permute(0,2,3,1)[:, None, ...]], dim=1)
+ vis_mask_add = vis_mask[:, id_add, :]
+ vis_mask_slide = torch.cat([vis_mask_slide, vis_mask_add.permute(0,2,3,1)[:, None, ...]], dim=1)
+
+ if mode == "keyframe":
+ # Step 2: Weighted Procrustes in windows
+ base_pose = torch.eye(4, device=cam_pts.device).view(1, 1, 4, 4).repeat(B, 1, 1, 1) # (B, 1, 4, 4)
+ c2w_traj_out = []
+ for i in range(len(ky_fr_idx) - 1):
+ start_idx = ky_fr_idx[i]
+ end_idx = ky_fr_idx[i + 1]
+
+ # Visibility mask
+ vis_mask_i = graph_matrix[:, start_idx, end_idx, :] # (B, N) or (N,)
+ if vis_mask_i.dim() == 1:
+ vis_mask_i = vis_mask_i.unsqueeze(0) # (1, N)
+
+ # Broadcast cam_pts and dyn_weight
+ cam_ref = cam_pts[:, start_idx:start_idx+1, :, :] # (B, 1, M, 3)
+ cam_win = cam_pts[:, start_idx:end_idx+1, :, :] # (B, W, M, 3)
+ weight = dyn_weight[:, :, :] * vis_mask_i[:, None, :] # (B, W, M)
+
+ # Compute relative transformations
+ if weight.sum() < 50:
+ weight = weight.clamp(min=5e-2)
+ relative_tfms = weighted_procrustes_torch(cam_ref, cam_win, weight) # (B, W, 4, 4)
+
+ # Apply to original c2w_traj
+ updated_pose = base_pose.detach() @ relative_tfms # (B, W, 4, 4)
+ base_pose = relative_tfms[:, -1:, :, :].detach() # (B, 1, 4, 4)
+
+ # Assign to output trajectory (avoid in-place on autograd path)
+ c2w_traj_out.append(updated_pose[:, 1:, ...])
+
+ c2w_traj_out = torch.cat(c2w_traj_out, dim=1)
+ c2w_traj_out = torch.cat([torch.eye(4, device=device).repeat(B, 1, 1, 1), c2w_traj_out], dim=1)
+ elif mode == "slide":
+ c2w_traj_out = torch.eye(4, device=device).repeat(B, T, 1, 1)
+ for i in range(cam_pts_slide.shape[1]):
+ cam_pts_slide_i = cam_pts_slide[:, i, :, :].permute(0,3,1,2)
+ id_slide_i = id_slide[i, :]
+ vis_mask_i = vis_mask_slide[:, i, :, 0, :].permute(0,2,1) # (B, N) or (N,)
+ vis_mask_i = vis_mask_i[:,:1] * vis_mask_i
+ weight_i = dyn_weight * vis_mask_i
+ if weight_i.sum() < 50:
+ weight_i = weight_i.clamp(min=5e-2)
+ if i == 0:
+ c2w_traj_out[:, id_slide_i, :, :] = weighted_procrustes_torch(cam_pts_slide_i[:,:1], cam_pts_slide_i, weight_i)
+ else:
+ campts_update = torch.einsum("btij,btnj->btni", c2w_traj_out[:,id_slide_i][...,:3,:3], cam_pts_slide_i) + c2w_traj_out[:,id_slide_i][...,None,:3,3]
+ c2w_traj_update = weighted_procrustes_torch(campts_update[:,:1], campts_update, weight_i)
+ c2w_traj_out[:, id_slide_i, :, :] = c2w_traj_update@c2w_traj_out[:,id_slide_i]
+
+ return c2w_traj_out
+
+def posenc(x, min_deg, max_deg):
+ """Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
+ Instead of computing [sin(x), cos(x)], we use the trig identity
+ cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
+ Args:
+ x: torch.Tensor, variables to be encoded. Note that x should be in [-pi, pi].
+ min_deg: int, the minimum (inclusive) degree of the encoding.
+ max_deg: int, the maximum (exclusive) degree of the encoding.
+ legacy_posenc_order: bool, keep the same ordering as the original tf code.
+ Returns:
+ encoded: torch.Tensor, encoded variables.
+ """
+ if min_deg == max_deg:
+ return x
+ scales = torch.tensor(
+ [2**i for i in range(min_deg, max_deg)], dtype=x.dtype, device=x.device
+ )
+
+ xb = (x[..., None, :] * scales[:, None]).reshape(list(x.shape[:-1]) + [-1])
+ four_feat = torch.sin(torch.cat([xb, xb + 0.5 * torch.pi], dim=-1))
+ return torch.cat([x] + [four_feat], dim=-1)
+
+
+class EfficientUpdateFormer3D(nn.Module):
+ """
+ Transformer model that updates track in 3D
+ """
+
+ def __init__(
+ self,
+ EFormer: EfficientUpdateFormer,
+ update_points=True
+ ):
+ super().__init__()
+
+ hidden_size = EFormer.hidden_size
+ num_virtual_tracks = EFormer.num_virtual_tracks
+ num_heads = EFormer.num_heads
+ mlp_ratio = 4.0
+
+ #NOTE: we design a switcher to bridege the camera pose, 3d tracks and 2d tracks
+
+ # iteract with pretrained 2d tracking
+ self.switcher_tokens = nn.Parameter(
+ torch.randn(1, num_virtual_tracks, 1, hidden_size)
+ )
+ # cross attention
+ space_depth=len(EFormer.space_virtual_blocks)
+ self.space_switcher_blocks = nn.ModuleList(
+ [
+ AttnBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ attn_class=Attention,
+ )
+ for _ in range(space_depth)
+ ]
+ )
+
+ # config 3d tracks blocks
+ self.space_track3d2switcher_blocks = nn.ModuleList(
+ [
+ CrossAttnBlock(
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
+ )
+ for _ in range(space_depth)
+ ]
+ )
+ self.space_switcher2track3d_blocks = nn.ModuleList(
+ [
+ CrossAttnBlock(
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
+ )
+ for _ in range(space_depth)
+ ]
+ )
+ # config switcher blocks
+ self.space_virtual2switcher_blocks = nn.ModuleList(
+ [
+ CrossAttnBlock(
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
+ )
+ for _ in range(space_depth)
+ ]
+ )
+ self.space_switcher2virtual_blocks = nn.ModuleList(
+ [
+ CrossAttnBlock(
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
+ )
+ for _ in range(space_depth)
+ ]
+ )
+ # config the temporal blocks
+ self.time_blocks_new = nn.ModuleList(
+ [
+ AttnBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ attn_class=Attention,
+ )
+ for _ in range(len(EFormer.time_blocks))
+ ]
+ )
+ # scale and shift cross attention
+ self.scale_shift_cross_attn = nn.ModuleList(
+ [
+ CrossAttnBlock(
+ 128, hidden_size, num_heads, mlp_ratio=mlp_ratio
+ )
+ for _ in range(len(EFormer.time_blocks))
+ ]
+ )
+ self.scale_shift_self_attn = nn.ModuleList(
+ [
+ AttnBlock(
+ 128, num_heads, mlp_ratio=mlp_ratio, attn_class=Attention
+ )
+ for _ in range(len(EFormer.time_blocks))
+ ]
+ )
+ self.scale_shift_dec = torch.nn.Linear(128, 128+1, bias=True)
+
+ # dense cross attention
+ self.dense_res_cross_attn = nn.ModuleList(
+ [
+ CrossAttnBlock(
+ 128, hidden_size, num_heads, mlp_ratio=mlp_ratio
+ )
+ for _ in range(len(EFormer.time_blocks))
+ ]
+ )
+ self.dense_res_self_attn = nn.ModuleList(
+ [
+ AttnBlock(
+ 128, num_heads, mlp_ratio=mlp_ratio, attn_class=Attention
+ )
+ for _ in range(len(EFormer.time_blocks))
+ ]
+ )
+ self.dense_res_dec = torch.nn.Conv2d(128, 3+128, kernel_size=1, stride=1, padding=0)
+
+ # set different heads
+ self.update_points = update_points
+ if update_points:
+ self.point_head = torch.nn.Linear(hidden_size, 4, bias=True)
+ else:
+ self.depth_head = torch.nn.Linear(hidden_size, 1, bias=True)
+ self.pro_analysis_w_head = torch.nn.Linear(hidden_size, 1, bias=True)
+ self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
+ self.residual_head = torch.nn.Linear(hidden_size,
+ hidden_size, bias=True)
+
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ if getattr(self, "point_head", None) is not None:
+ torch.nn.init.trunc_normal_(self.point_head.weight, std=1e-6)
+ torch.nn.init.constant_(self.point_head.bias, 0)
+ if getattr(self, "depth_head", None) is not None:
+ torch.nn.init.trunc_normal_(self.depth_head.weight, std=0.001)
+ if getattr(self, "vis_conf_head", None) is not None:
+ torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=1e-6)
+ if getattr(self, "scale_shift_dec", None) is not None:
+ torch.nn.init.trunc_normal_(self.scale_shift_dec.weight, std=0.001)
+ if getattr(self, "residual_head", None) is not None:
+ torch.nn.init.trunc_normal_(self.residual_head.weight, std=0.001)
+
+
+ def _trunc_init(module):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ torch.nn.init.trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ self.apply(_basic_init)
+
+ def forward(self, input_tensor, input_tensor3d, EFormer: EfficientUpdateFormer,
+ mask=None, add_space_attn=True, extra_sparse_tokens=None, extra_dense_tokens=None):
+
+ #NOTE: prepare the pose and 3d tracks features
+ tokens3d = EFormer.input_transform(input_tensor3d)
+
+ tokens = EFormer.input_transform(input_tensor)
+ B, _, T, _ = tokens.shape
+ virtual_tokens = EFormer.virual_tracks.repeat(B, 1, T, 1)
+ switcher_tokens = self.switcher_tokens.repeat(B, 1, T, 1)
+
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
+ tokens3d = torch.cat([tokens3d, switcher_tokens], dim=1)
+
+ _, N, _, _ = tokens.shape
+ j = 0
+ layers = []
+
+
+ for i in range(len(EFormer.time_blocks)):
+ if extra_sparse_tokens is not None:
+ extra_sparse_tokens = rearrange(extra_sparse_tokens, 'b n t c -> (b t) n c')
+ extra_sparse_tokens = self.scale_shift_cross_attn[i](extra_sparse_tokens, rearrange(tokens3d, 'b n t c -> (b t) n c'))
+ extra_sparse_tokens = rearrange(extra_sparse_tokens, '(b t) n c -> (b n) t c', b=B, t=T)
+ extra_sparse_tokens = self.scale_shift_self_attn[i](extra_sparse_tokens)
+ extra_sparse_tokens = rearrange(extra_sparse_tokens, '(b n) t c -> b n t c', b=B, n=2, t=T)
+
+ if extra_dense_tokens is not None:
+ h_p, w_p = extra_dense_tokens.shape[-2:]
+ extra_dense_tokens = rearrange(extra_dense_tokens, 'b t c h w -> (b t) (h w) c')
+ extra_dense_tokens = self.dense_res_cross_attn[i](extra_dense_tokens, rearrange(tokens3d, 'b n t c -> (b t) n c'))
+ extra_dense_tokens = rearrange(extra_dense_tokens, '(b t) n c -> (b n) t c', b=B, t=T)
+ extra_dense_tokens = self.dense_res_self_attn[i](extra_dense_tokens)
+ extra_dense_tokens = rearrange(extra_dense_tokens, '(b h w) t c -> b t c h w', b=B, h=h_p, w=w_p)
+
+ # temporal
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
+ time_tokens = EFormer.time_blocks[i](time_tokens)
+
+ # temporal 3d
+ time_tokens3d = tokens3d.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
+ time_tokens3d = self.time_blocks_new[i](time_tokens3d)
+
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
+ tokens3d = time_tokens3d.view(B, N, T, -1)
+
+ if (
+ add_space_attn
+ and hasattr(EFormer, "space_virtual_blocks")
+ and (i % (len(EFormer.time_blocks) // len(EFormer.space_virtual_blocks)) == 0)
+ ):
+ space_tokens = (
+ tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
+ ) # B N T C -> (B T) N C
+ space_tokens3d = (
+ tokens3d.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
+ ) # B N T C -> (B T) N C
+
+ point_tokens = space_tokens[:, : N - EFormer.num_virtual_tracks]
+ virtual_tokens = space_tokens[:, N - EFormer.num_virtual_tracks :]
+ # get the 3d relevant tokens
+ track3d_tokens = space_tokens3d[:, : N - EFormer.num_virtual_tracks]
+ switcher_tokens = space_tokens[:, N - EFormer.num_virtual_tracks + 1:]
+
+ # iteract switcher with pose and tracks3d
+ switcher_tokens = self.space_track3d2switcher_blocks[j](
+ switcher_tokens, track3d_tokens, mask=mask
+ )
+
+
+ virtual_tokens = EFormer.space_virtual2point_blocks[j](
+ virtual_tokens, point_tokens, mask=mask
+ )
+
+ # get the switcher_tokens
+ switcher_tokens = self.space_virtual2switcher_blocks[j](
+ switcher_tokens, virtual_tokens
+ )
+ virtual_tokens_res = self.residual_head(
+ self.space_switcher2virtual_blocks[j](
+ virtual_tokens, switcher_tokens
+ )
+ )
+ switcher_tokens_res = self.residual_head(
+ self.space_switcher2virtual_blocks[j](
+ switcher_tokens, virtual_tokens
+ )
+ )
+ # add residual
+ virtual_tokens = virtual_tokens + virtual_tokens_res
+ switcher_tokens = switcher_tokens + switcher_tokens_res
+
+ virtual_tokens = EFormer.space_virtual_blocks[j](virtual_tokens)
+ switcher_tokens = self.space_switcher_blocks[j](switcher_tokens)
+ # decode
+ point_tokens = EFormer.space_point2virtual_blocks[j](
+ point_tokens, virtual_tokens, mask=mask
+ )
+ track3d_tokens = self.space_switcher2track3d_blocks[j](
+ track3d_tokens, switcher_tokens, mask=mask
+ )
+
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
+ space_tokens3d = torch.cat([track3d_tokens, virtual_tokens], dim=1)
+ tokens = space_tokens.view(B, T, N, -1).permute(
+ 0, 2, 1, 3
+ ) # (B T) N C -> B N T C
+ tokens3d = space_tokens3d.view(B, T, N, -1).permute(
+ 0, 2, 1, 3
+ ) # (B T) N C -> B N T C
+
+ j += 1
+
+ tokens = tokens[:, : N - EFormer.num_virtual_tracks]
+ track3d_tokens = tokens3d[:, : N - EFormer.num_virtual_tracks]
+
+ if self.update_points:
+ depth_update, dynamic_prob_update = self.point_head(track3d_tokens)[..., :3], self.point_head(track3d_tokens)[..., 3:]
+ else:
+ depth_update, dynamic_prob_update = self.depth_head(track3d_tokens)[..., :1], self.depth_head(track3d_tokens)[..., 1:]
+ pro_analysis_w = self.pro_analysis_w_head(track3d_tokens)
+ flow = EFormer.flow_head(tokens)
+ if EFormer.linear_layer_for_vis_conf:
+ vis_conf = EFormer.vis_conf_head(tokens)
+ flow = torch.cat([flow, vis_conf], dim=-1)
+ if extra_sparse_tokens is not None:
+ scale_shift_out = self.scale_shift_dec(extra_sparse_tokens)
+ dense_res_out = self.dense_res_dec(extra_dense_tokens.view(B*T, -1, h_p, w_p)).view(B, T, -1, h_p, w_p)
+ return flow, depth_update, dynamic_prob_update, pro_analysis_w, scale_shift_out, dense_res_out
+ else:
+ return flow, depth_update, dynamic_prob_update, pro_analysis_w, None, None
+
+def recover_global_translations_batch(global_rot, c2w_traj, graph_weight):
+ B, T = global_rot.shape[:2]
+ device = global_rot.device
+
+ # Compute R_i @ t_ij
+ t_rel = c2w_traj[:, :, :, :3, 3] # (B, T, T, 3)
+ R_i = global_rot[:, :, None, :, :] # (B, T, 1, 3, 3)
+ t_rhs = torch.matmul(R_i, t_rel.unsqueeze(-1)).squeeze(-1) # (B, T, T, 3)
+
+ # Mask: exclude self-loops and small weights
+ valid_mask = (graph_weight > 1e-5) & (~torch.eye(T, dtype=bool, device=device)[None, :, :]) # (B, T, T)
+
+ # Get all valid (i, j) edge indices
+ i_idx, j_idx = torch.meshgrid(
+ torch.arange(T, device=device),
+ torch.arange(T, device=device),
+ indexing="ij"
+ )
+ i_idx = i_idx.reshape(-1) # (T*T,)
+ j_idx = j_idx.reshape(-1)
+
+ # Expand to batch (B, T*T)
+ i_idx = i_idx[None, :].repeat(B, 1)
+ j_idx = j_idx[None, :].repeat(B, 1)
+
+ # Flatten everything
+ valid_mask_flat = valid_mask.view(B, -1) # (B, T*T)
+ w_flat = graph_weight.view(B, -1) # (B, T*T)
+ rhs_flat = t_rhs.view(B, -1, 3) # (B, T*T, 3)
+
+ # Initialize output translations
+ global_translations = torch.zeros(B, T, 3, device=device)
+
+ for b_id in range(B):
+ mask = valid_mask_flat[b_id]
+ i_valid = i_idx[b_id][mask]
+ j_valid = j_idx[b_id][mask]
+ w_valid = w_flat[b_id][mask]
+ rhs_valid = rhs_flat[b_id][mask]
+
+ n_edges = i_valid.shape[0]
+
+ # Build A matrix: (n_edges*3, T*3)
+ A = torch.zeros(n_edges*3, T*3, device=device)
+
+ # Build b vector: (n_edges*3,)
+ b = torch.zeros(n_edges*3, device=device)
+
+ for k in range(n_edges):
+ i, j = i_valid[k], j_valid[k]
+ weight = w_valid[k]
+
+ # Fill A matrix for x,y,z components
+ for dim in range(3):
+ row = k*3 + dim
+ A[row, i*3 + dim] = -weight
+ A[row, j*3 + dim] = weight
+
+ # Fill b vector
+ b[row] = rhs_valid[k, dim] * weight
+
+ # Solve least squares
+ try:
+ # Add small regularization for stability
+ AtA = A.transpose(-1, -2) @ A + 1e-4 * torch.eye(A.shape[-1], device=A.device)
+ Atb = A.transpose(-1, -2) @ b.unsqueeze(-1)
+
+ solution = torch.linalg.solve(AtA, Atb).squeeze(-1) # (3*T,)
+ t_batch = solution.view(T, 3)
+
+ # Fix scale by setting first frame to origin
+ t_batch = t_batch - t_batch[0:1]
+ global_translations[b_id] = t_batch
+
+ except RuntimeError as e:
+ print(f"Error in batch {b_id}: {e}")
+ global_translations[b_id] = torch.zeros(T, 3, device=device)
+ return global_translations
+
+
+def global_graph_motion_average(c2w_traj, graph_weight):
+ """
+ This function will average the c2w_traj by the graph_weight
+ """
+ B, T, T, _, _ = c2w_traj.shape
+ mask = graph_weight[..., 0, 0] < 1e-5 # (B, T, T)
+ mask = mask.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, 4, 4) # (B, T, T, 4, 4)
+ identity = torch.eye(4, device=c2w_traj.device).view(1, 1, 1, 4, 4).expand(B, T, T, 4, 4)
+ c2w_traj = torch.where(mask, identity, c2w_traj)
+
+ Rot_rel_weighted = c2w_traj[:,:,:,:3,:3].contiguous() * graph_weight # B T T 3 3
+ Rot_big = Rot_rel_weighted.permute(0, 1, 3, 2, 4).reshape(B, 3*T, 3*T) # B 3T 3T
+ epsilon = 1e-8
+ I_big = torch.eye(3*T, device=Rot_big.device).unsqueeze(0) # (1, 3T, 3T)
+ Rot_big_reg = Rot_big + epsilon * I_big # (B, 3T, 3T)
+ #NOTE: cal the global rotation
+ # Step 1: batch eigendecomposition
+ try:
+ eigvals, eigvecs = torch.linalg.eigh(Rot_big_reg) # eigvecs: (B, 3T, 3)
+ except:
+ import pdb; pdb.set_trace()
+ # Step 2: get the largest 3 eigenvectors
+ X = eigvecs[:, :, -3:] # (B, 3T, 3)
+ # Step 3: split into (B, T, 3, 3)
+ X = X.view(B, T, 3, 3) # each frame's rotation block (non-orthogonal)
+ # Step 4: project to SO(3), using SVD
+ U, _, Vh = torch.linalg.svd(X) # (B, T, 3, 3)
+ R = U @ Vh
+ # Step 5: ensure det(R)=1 (right-handed coordinate system)
+ det = torch.linalg.det(R) # (B, T)
+ neg_det_mask = det < 0
+ # if det<0, reverse the last column and multiply
+ U_flip = U.clone()
+ U_flip[neg_det_mask, :, -1] *= -1
+ R = U_flip @ Vh
+ # global rotation
+ Rot_glob = R[:,:1].inverse() @ R
+ # global translation
+ t_glob = recover_global_translations_batch(Rot_glob,
+ c2w_traj, graph_weight[...,0,0])
+ c2w_traj_final = torch.eye(4, device=c2w_traj.device)[None,None].repeat(B, T, 1, 1)
+ c2w_traj_final[:,:,:3,:3] = Rot_glob
+ c2w_traj_final[:,:,:3,3] = t_glob
+
+ return c2w_traj_final
+
+
+def depth_to_points_colmap(metric_depth: torch.Tensor,
+ intrinsics: torch.Tensor) -> torch.Tensor:
+ """
+ Unproject a depth map to a point cloud in COLMAP convention.
+
+ Args:
+ metric_depth: (B, H, W) depth map, meters.
+ intrinsics: (B, 3, 3) COLMAP-style K matrix.
+ Returns:
+ points_map: (B, H, W, 3) point cloud in camera coordinates.
+ """
+ # 因为输入的 metric_depth 维度是 (B, H, W)
+ B, H, W = metric_depth.shape
+
+ # 因为需要每个像素的 [u, v, 1] 齐次坐标
+ u = torch.arange(W, device=metric_depth.device, dtype=metric_depth.dtype)
+ v = torch.arange(H, device=metric_depth.device, dtype=metric_depth.dtype)
+ uu, vv = torch.meshgrid(u, v, indexing='xy')
+ pix = torch.stack([uu, vv, torch.ones_like(uu)], dim=-1)
+ pix = pix.reshape(-1, 3) # (H*W, 3)
+ # 因为要对 B 张图做相同操作
+ pix = pix.unsqueeze(0).expand(B, -1, -1) # (B, H*W, 3)
+ # import pdb; pdb.set_trace()
+ # 因为 K 是 (B, 3, 3)
+ K_inv = torch.inverse(intrinsics) # (B, 3, 3)
+
+ # 因为反投影方向是 X_cam = K^{-1} * pix
+ dirs = torch.einsum('bij,bkj->bki', K_inv, pix) # (B, H*W, 3)
+
+ # 因为要按深度伸缩
+ depths = metric_depth.reshape(B, -1) # (B, H*W)
+ pts = dirs * depths.unsqueeze(-1) # (B, H*W, 3)
+
+ # 因为希望输出 (B, H, W, 3)
+ points_map = pts.view(B, H, W, 3) # (B, H, W, 3)
+
+ return points_map
+
+def vec6d_to_R(vector_6D):
+ v1=vector_6D[:,:3]/vector_6D[:,:3].norm(dim=-1,keepdim=True)
+ v2=vector_6D[:,3:]-(vector_6D[:,3:]*v1).sum(dim=-1,keepdim=True)*v1
+ v2=v2/v2.norm(dim=-1,keepdim=True)
+ v3=torch.cross(v1,v2,dim=-1)
+ return torch.concatenate((v1.unsqueeze(1),v2.unsqueeze(1),v3.unsqueeze(1)),dim=1)
+
+class MyTransformerHead(nn.Module):
+ def __init__(self,input_dim,dim,use_positional_encoding_transformer):
+ super(MyTransformerHead,self).__init__()
+
+ patch_dim=input_dim+1
+ self.layers=3
+ # dim=128
+ self.use_positional_encoding_transformer=use_positional_encoding_transformer
+ self.to_patch_embedding = nn.Sequential(
+ nn.LayerNorm(patch_dim),
+ nn.Linear(patch_dim, dim),
+ nn.LayerNorm(dim),
+ )
+ self.transformer_frames=[]
+ self.transformer_points=[]
+
+ for i in range(self.layers):
+ self.transformer_frames.append(Transformer(dim, 1, 16, 64, 2048))
+ self.transformer_points.append(Transformer(dim, 1, 16, 64, 2048))
+ self.transformer_frames=nn.ModuleList(self.transformer_frames)
+ self.transformer_points=nn.ModuleList(self.transformer_points)
+
+ def forward(self, x):
+
+
+ x=torch.cat((x,torch.ones(x.shape[0],x.shape[1],1,x.shape[3]).cuda()),dim=2)
+
+ x=x.transpose(2,3)
+
+ b,n,f,c=x.shape
+ x=self.to_patch_embedding(x)
+
+ x=x.view(b*n,f,-1) # x.shape [390, 33, 256]
+ if self.use_positional_encoding_transformer:
+ pe = posemb_sincos_1d(x) #pe.shape= [33,256] (33 frame, 256 embedding dim)
+ x=pe.unsqueeze(0)+x
+ for i in range(self.layers):
+ #frames aggregation
+ x=self.transformer_frames[i](x)
+
+ #point sets aggregation
+ x=x.view(b,n,f,-1).transpose(1,2).reshape(b*f,n,-1)
+
+ x=self.transformer_points[i](x)
+
+ x=x.view(b,f,n,-1)
+ x=x.transpose(1,2).reshape(b*n,f,-1)
+
+ x=x.view(b,n,f,-1)
+ x=x.transpose(2,3)
+
+
+ return x
+
+def positionalEncoding_vec(in_tensor, b):
+ proj = torch.einsum('ij, k -> ijk', in_tensor, b)
+ mapped_coords = torch.cat((torch.sin(proj), torch.cos(proj)), dim=1)
+ output = mapped_coords.transpose(2, 1).contiguous().view(mapped_coords.size(0), -1)
+ return output
+
+class TrackFusion(nn.Module):
+ def __init__(self,width1=320,conv2_kernel_size=31,K=12,
+ conv_kernel_size=3,inputdim=2,use_positionl_encoding=True,
+ positional_dim=4,use_transformer=True,detach_cameras_dynamic=True,
+ use_positional_encoding_transformer=True,use_set_of_sets=False,predict_focal_length=False):
+ super(TrackFusion, self).__init__()
+ self.predict_focal_length=predict_focal_length
+ self.inputdim = inputdim
+ self.n1 = width1
+
+ self.K=K
+ self.n2 = 6+3+1+self.K+2
+ self.detach_cameras_dynamic=detach_cameras_dynamic
+ l=conv_kernel_size
+ # layers
+ self.use_set_of_sets=use_set_of_sets
+ self.use_positionl_encoding=use_positionl_encoding
+ self.positional_dim=positional_dim
+ actual_input_dim=inputdim
+ if self.use_positionl_encoding:
+ actual_input_dim=2 * inputdim * self.positional_dim+inputdim
+
+ self.use_transformer=use_transformer
+
+ if self.use_positionl_encoding:
+ self.b = torch.tensor([(2 ** j) * np.pi for j in range(self.positional_dim)],requires_grad = False)
+
+ if True:
+ if self.use_transformer:
+ self.transformer_my=MyTransformerHead(actual_input_dim,width1,use_positional_encoding_transformer)
+
+ self.conv_final = nn.Conv1d(self.n1, self.n2, kernel_size=conv2_kernel_size,stride=1, padding=conv2_kernel_size//2, padding_mode='circular')
+
+ self.fc1 = nn.Linear(self.n1,3*self.K+1)
+
+
+
+ torch.nn.init.xavier_uniform_(self.conv_final.weight)
+
+ torch.nn.init.xavier_uniform_(self.fc1.weight)
+
+ def forward(self, x, pts_miu=None, pts_radis=None, simple_return=True):
+
+ B, N, C, T = x.shape
+ if self.use_positionl_encoding:
+ x_original_shape=x.shape
+ x=x.transpose(2,3)
+ x=x.reshape(-1,x.shape[-1])
+ if self.b.device!=x.device:
+ self.b=self.b.to(x.device)
+ pos = positionalEncoding_vec(x,self.b)
+ x=torch.cat((x,pos),dim=1)
+ x=x.view(x_original_shape[0],x_original_shape[1],x_original_shape[3],x.shape[-1]).transpose(2,3)
+
+ b = len(x)
+ n= x.shape[1]
+ l= x.shape[-1]
+ if self.use_set_of_sets:
+ cameras,perpoint_features=self.set_of_sets_my(x)
+ else:
+ if self.use_transformer:
+ x=self.transformer_my(x)
+ else:
+ for i in range(len( self.conv1)):
+ if i==0:
+ x = x.reshape(n*b, x.shape[2],l)
+ else:
+ x = x.view(n * b, self.n1, l)
+ x1 = self.bn1[i](self.conv1[i](x)).view(b,n,self.n1,l)
+ x2 = self.bn1s[i](self.conv1s[i](x)).view(b,n,self.n1,l).mean(dim=1).view(b,1,self.n1,l).repeat(1,n,1,1)
+ x = F.relu(x1 + x2)
+
+ cameras=torch.mean(x,dim=1)
+ cameras=self.conv_final(cameras)
+ perpoint_features = torch.mean(x,dim=3)
+ perpoint_features = self.fc1(perpoint_features.view(n*b,self.n1))
+
+ B=perpoint_features[:,:self.K*3].view(b,n,3,self.K) # motion basis
+ NR=F.elu(perpoint_features[:,-1].view(b,n))+1+0.00001
+
+ position_params=cameras[:,:3,:]
+ if self.predict_focal_length:
+ focal_params=1+0.05*cameras[:,3:4,:].clone().transpose(1,2)
+ else:
+ focal_params=1.0
+ basis_params=cameras[:,4:4+self.K]
+ basis_params[:,0,:]=torch.clamp(basis_params[:,0,:].clone(),min=1.0,max=1.0)
+ basis_params.transpose(1,2).unsqueeze(1).unsqueeze(1)
+ rotation_params=cameras[:,4+self.K:4+self.K+6]
+ # Converting rotation parameters into a valid rotation matrix (probably better to move to 6d representation)
+ rotation_params=vec6d_to_R(rotation_params.transpose(1,2).reshape(b*l,6)).view(b,l,3,3)
+
+ # Transfering global 3D into each camera coordinates (using per camera roation and translation)
+ points3D_static=((basis_params.transpose(1,2).unsqueeze(1).unsqueeze(1))[:,:,:,:,:1]*B.unsqueeze(-2)[:,:,:,:,:1]).sum(-1)
+
+ if self.detach_cameras_dynamic==False:
+ points3D=((basis_params.transpose(1,2).unsqueeze(1).unsqueeze(1))[:,:,:,:,1:]*B.unsqueeze(-2)[:,:,:,:,1:]).sum(-1)+points3D_static
+ else:
+ points3D=((basis_params.transpose(1,2).unsqueeze(1).unsqueeze(1))[:,:,:,:,1:]*B.unsqueeze(-2)[:,:,:,:,1:]).sum(-1)+points3D_static.detach()
+
+ points3D=points3D.transpose(1,3)
+ points3D_static=points3D_static.transpose(1,3)
+ position_params=position_params.transpose(1,2)
+ if pts_miu is not None:
+ position_params=position_params*pts_radis.squeeze(-1)+pts_miu.squeeze(-2)
+ points3D_static = points3D_static*pts_radis.squeeze(-1)+pts_miu.permute(0,1,3,2)
+ points3D = points3D*pts_radis.squeeze(-1)+pts_miu.permute(0,1,3,2)
+
+ if self.detach_cameras_dynamic==False:
+ points3D_camera=(torch.bmm(rotation_params.view(b*l,3,3).transpose(1,2),points3D.reshape(b*l,3,n)-position_params.reshape(b*l,3).unsqueeze(-1)))
+ points3D_camera=points3D_camera.view(b,l,3,n)
+ else:
+ points3D_camera=(torch.bmm(rotation_params.view(b*l,3,3).transpose(1,2).detach(),points3D.reshape(b*l,3,n)-position_params.detach().reshape(b*l,3).unsqueeze(-1)))
+ points3D_camera=points3D_camera.view(b,l,3,n)
+ points3D_static_camera=(torch.bmm(rotation_params.view(b*l,3,3).transpose(1,2),points3D_static.reshape(b*l,3,n)-position_params.reshape(b*l,3).unsqueeze(-1)))
+ points3D_static_camera=points3D_static_camera.view(b,l,3,n)
+
+ # Projecting from 3D to 2D
+ projections=points3D_camera.clone()
+ projections_static=points3D_static_camera.clone()
+
+ depths=projections[:,:,2,:]
+ depths_static=projections_static[:,:,2,:]
+
+ projectionx=focal_params*projections[:,:,0,:]/torch.clamp(projections[:,:,2,:].clone(),min=0.01)
+ projectiony=focal_params*projections[:,:,1,:]/torch.clamp(projections[:,:,2,:].clone(),min=0.01)
+
+ projectionx_static=focal_params*projections_static[:,:,0,:]/torch.clamp(projections_static[:,:,2,:].clone(),min=0.01)
+ projectiony_static=focal_params*projections_static[:,:,1,:]/torch.clamp(projections_static[:,:,2,:].clone(),min=0.01)
+
+ projections2=torch.cat((projectionx.unsqueeze(2),projectiony.unsqueeze(2)),dim=2)
+ projections2_static=torch.cat((projectionx_static.unsqueeze(2),projectiony_static.unsqueeze(2)),dim=2)
+
+ if simple_return:
+ c2w_traj = torch.eye(4, device=x.device)[None,None].repeat(b,T,1,1)
+ c2w_traj[:,:,:3,:3] = rotation_params
+ c2w_traj[:,:,:3,3] = position_params
+ return c2w_traj, points3D, points3D_camera
+ else:
+ return focal_params,projections2,projections2_static,rotation_params,position_params,B,points3D,points3D_static,depths,depths_static,0,basis_params,0,0,points3D_camera,NR
+
+
+def get_nth_visible_time_index(vis_gt: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
+ """
+ vis_gt: [B, T, N] 0/1 binary tensor
+ n: [B, N] int tensor, the n-th visible time index to get (1-based)
+ Returns: [B, N] tensor of time indices into T, or -1 if not enough visible steps
+ """
+ B, T, N = vis_gt.shape
+
+ # Create a tensor [0, 1, ..., T-1] for time indices
+ time_idx = torch.arange(T, device=vis_gt.device).view(1, T, 1).expand(B, T, N) # [B, T, N]
+
+ # Mask invisible steps with a large number (T)
+ masked_time = torch.where(vis_gt.bool(), time_idx, torch.full_like(time_idx, T))
+
+ # Sort along time dimension
+ sorted_time, _ = masked_time.sort(dim=1) # [B, T, N]
+
+ # Prepare index tensor for gather: [B, N] -> [B, 1, N]
+ gather_idx = (n - 1).clamp(min=0, max=T-1).unsqueeze(1) # shape: [B, 1, N]
+ assert gather_idx.shape == sorted_time.shape[:1] + (1, sorted_time.shape[2]) # [B, 1, N]
+
+ # Gather from sorted_time: result is [B, 1, N]
+ nth_time = sorted_time.gather(dim=1, index=gather_idx).squeeze(1) # [B, N]
+
+ # If value is T (i.e., masked), then not enough visible → set to -1
+ nth_time = torch.where(nth_time == T, torch.full_like(nth_time, -1), nth_time)
+
+ return nth_time # [B, N]
+
+def knn_torch(x, k):
+ """
+ x: (B, T, N, 2)
+ return: indices of k-NN, shape (B, T, N, k)
+ """
+ B, T, N, C = x.shape
+ # Reshape to (B*T, N, 2)
+ x = x.view(B*T, N, C) # Merge the first two dimensions for easier processing
+ # Calculate pairwise distance: (B*T, N, N)
+ dist = torch.cdist(x, x, p=2) # Euclidean distance
+
+ # Exclude self: set diagonal to a large number (to prevent self from being a neighbor)
+ mask = torch.eye(N, device=x.device).bool()[None, :, :] # (1, N, N)
+ dist.masked_fill_(mask, float('inf'))
+
+ # Get indices of top k smallest distances
+ knn_idx = dist.topk(k, largest=False).indices # (B*T, N, k)
+ # Restore dimensions (B, T, N, k)
+ knn_idx = knn_idx.view(B, T, N, k)
+ return knn_idx
+
+def get_topo_mask(coords_xyz_append: torch.Tensor,
+ coords_2d_lift: torch.Tensor, replace_ratio: float = 0.6) -> torch.Tensor:
+ """
+ coords_xyz_append: [B, T, N, 3] 3d coordinates
+ coords_2d_lift: [B*T, N] depth map
+ replace_ratio: float, the ratio of the depth change to be considered as a topological change
+ """
+ B, T, N, _ = coords_xyz_append.shape
+ # if N > 1024:
+ # pick_idx = torch.randperm(N)[:1024]
+ # else:
+ pick_idx = torch.arange(N, device=coords_xyz_append.device)
+ coords_xyz_append = coords_xyz_append[:,:,pick_idx,:]
+ knn_idx = knn_torch(coords_xyz_append, 49)
+ knn_idx = pick_idx[knn_idx]
+ # raw topology
+ raw_depth = coords_xyz_append[...,2:] # B T N 1 knn_idx B T N K
+ knn_depth = torch.gather(
+ raw_depth.expand(-1, -1, -1, knn_idx.shape[-1]), # (B, T, N, K)
+ dim=2,
+ index=knn_idx # (B, T, N, K)
+ ).squeeze(-1) # → (B, T, N, K)
+ depth_rel_neg_raw = (knn_depth - raw_depth)
+ # unproj depth
+ knn_depth_unproj = torch.gather(
+ depth_unproj.view(B,T,N,1).expand(-1, -1, -1, knn_idx.shape[-1]), # (B, T, N, K)
+ dim=2,
+ index=knn_idx # (B, T, N, K)
+ ).squeeze(-1) # → (B, T, N, K)
+ depth_rel_neg_unproj = (knn_depth_unproj - depth_unproj.view(B,T,N,1))
+ # topological change threshold
+ mask_topo = (depth_rel_neg_raw.abs() / (depth_rel_neg_unproj.abs()+1e-8) - 1).abs() < 0.4
+ mask_topo = mask_topo.sum(dim=-1) > 9
+
+ return mask_topo
+
+
diff --git a/models/SpaTrackV2/models/utils.py b/models/SpaTrackV2/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5335574731a05f171138d85800f4a2d529eccf60
--- /dev/null
+++ b/models/SpaTrackV2/models/utils.py
@@ -0,0 +1,1221 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from https://github.com/facebookresearch/PoseDiffusion
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Optional, Tuple, Union, List
+from einops import rearrange, repeat
+
+import cv2
+import numpy as np
+
+# from torchmetrics.functional.regression import pearson_corrcoef
+from easydict import EasyDict as edict
+from enum import Enum
+import torch.utils.data.distributed as dist
+from typing import Literal, Union, List, Tuple, Dict
+from models.monoD.depth_anything_v2.util.transform import Resize
+from models.SpaTrackV2.utils.model_utils import sample_features5d
+EPS = 1e-9
+
+class Summary(Enum):
+ NONE = 0
+ AVERAGE = 1
+ SUM = 2
+ COUNT = 3
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
+ self.name = name
+ self.fmt = fmt
+ self.summary_type = summary_type
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+ def all_reduce(self):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ if isinstance(self.sum, np.ndarray):
+ total = torch.tensor(
+ self.sum.tolist()
+ + [
+ self.count,
+ ],
+ dtype=torch.float32,
+ device=device,
+ )
+ else:
+ total = torch.tensor(
+ [self.sum, self.count], dtype=torch.float32, device=device
+ )
+
+ dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
+ if total.shape[0] > 2:
+ self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item()
+ else:
+ self.sum, self.count = total.tolist()
+ self.avg = self.sum / (self.count + 1e-5)
+
+ def __str__(self):
+ fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
+ return fmtstr.format(**self.__dict__)
+
+ def summary(self):
+ fmtstr = ""
+ if self.summary_type is Summary.NONE:
+ fmtstr = ""
+ elif self.summary_type is Summary.AVERAGE:
+ fmtstr = "{name} {avg:.3f}"
+ elif self.summary_type is Summary.SUM:
+ fmtstr = "{name} {sum:.3f}"
+ elif self.summary_type is Summary.COUNT:
+ fmtstr = "{name} {count:.3f}"
+ else:
+ raise ValueError("invalid summary type %r" % self.summary_type)
+
+ return fmtstr.format(**self.__dict__)
+
+
+def procrustes_analysis(X0,X1): # [N,3]
+ # translation
+ t0 = X0.mean(dim=0,keepdim=True)
+ t1 = X1.mean(dim=0,keepdim=True)
+ X0c = X0-t0
+ X1c = X1-t1
+ # scale
+ s0 = (X0c**2).sum(dim=-1).mean().sqrt()
+ s1 = (X1c**2).sum(dim=-1).mean().sqrt()
+ X0cs = X0c/s0
+ X1cs = X1c/s1
+ # rotation (use double for SVD, float loses precision)
+ U,S,V = (X0cs.t()@X1cs).double().svd(some=True)
+ R = (U@V.t()).float()
+ if R.det()<0: R[2] *= -1
+ # align X1 to X0: X1to0 = (X1-t1)/s1@R.t()*s0+t0
+ sim3 = edict(t0=t0[0],t1=t1[0],s0=s0,s1=s1,R=R)
+ return sim3
+
+def create_intri_matrix(focal_length, principal_point):
+ """
+ Creates a intri matrix from focal length and principal point.
+
+ Args:
+ focal_length (torch.Tensor): A Bx2 or BxSx2 tensor containing the focal lengths (fx, fy) for each image.
+ principal_point (torch.Tensor): A Bx2 or BxSx2 tensor containing the principal point coordinates (cx, cy) for each image.
+
+ Returns:
+ torch.Tensor: A Bx3x3 or BxSx3x3 tensor containing the camera matrix for each image.
+ """
+
+ if len(focal_length.shape) == 2:
+ B = focal_length.shape[0]
+ intri_matrix = torch.zeros(B, 3, 3, dtype=focal_length.dtype, device=focal_length.device)
+ intri_matrix[:, 0, 0] = focal_length[:, 0]
+ intri_matrix[:, 1, 1] = focal_length[:, 1]
+ intri_matrix[:, 2, 2] = 1.0
+ intri_matrix[:, 0, 2] = principal_point[:, 0]
+ intri_matrix[:, 1, 2] = principal_point[:, 1]
+ else:
+ B, S = focal_length.shape[0], focal_length.shape[1]
+ intri_matrix = torch.zeros(B, S, 3, 3, dtype=focal_length.dtype, device=focal_length.device)
+ intri_matrix[:, :, 0, 0] = focal_length[:, :, 0]
+ intri_matrix[:, :, 1, 1] = focal_length[:, :, 1]
+ intri_matrix[:, :, 2, 2] = 1.0
+ intri_matrix[:, :, 0, 2] = principal_point[:, :, 0]
+ intri_matrix[:, :, 1, 2] = principal_point[:, :, 1]
+
+ return intri_matrix
+
+
+def closed_form_inverse_OpenCV(se3, R=None, T=None):
+ """
+ Computes the inverse of each 4x4 SE3 matrix in the batch.
+
+ Args:
+ - se3 (Tensor): Nx4x4 tensor of SE3 matrices.
+
+ Returns:
+ - Tensor: Nx4x4 tensor of inverted SE3 matrices.
+
+
+ | R t |
+ | 0 1 |
+ -->
+ | R^T -R^T t|
+ | 0 1 |
+ """
+ if R is None:
+ R = se3[:, :3, :3]
+
+ if T is None:
+ T = se3[:, :3, 3:]
+
+ # Compute the transpose of the rotation
+ R_transposed = R.transpose(1, 2)
+
+ # -R^T t
+ top_right = -R_transposed.bmm(T)
+
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(se3), 1, 1)
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
+
+ inverted_matrix[:, :3, :3] = R_transposed
+ inverted_matrix[:, :3, 3:] = top_right
+
+ return inverted_matrix
+
+
+def get_EFP(pred_cameras, image_size, B, S, default_focal=False):
+ """
+ Converting PyTorch3D cameras to extrinsics, intrinsics matrix
+
+ Return extrinsics, intrinsics, focal_length, principal_point
+ """
+ scale = image_size.min()
+
+ focal_length = pred_cameras.focal_length
+
+ principal_point = torch.zeros_like(focal_length)
+
+ focal_length = focal_length * scale / 2
+ principal_point = (image_size[None] - principal_point * scale) / 2
+
+ Rots = pred_cameras.R.clone()
+ Trans = pred_cameras.T.clone()
+
+ extrinsics = torch.cat([Rots, Trans[..., None]], dim=-1)
+
+ # reshape
+ extrinsics = extrinsics.reshape(B, S, 3, 4)
+ focal_length = focal_length.reshape(B, S, 2)
+ principal_point = principal_point.reshape(B, S, 2)
+
+ # only one dof focal length
+ if default_focal:
+ focal_length[:] = scale
+ else:
+ focal_length = focal_length.mean(dim=-1, keepdim=True).expand(-1, -1, 2)
+ focal_length = focal_length.clamp(0.2 * scale, 5 * scale)
+
+ intrinsics = create_intri_matrix(focal_length, principal_point)
+ return extrinsics, intrinsics
+
+def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as quaternions to rotation matrices.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ r, i, j, k = torch.unbind(quaternions, -1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+def pose_encoding_to_camera(
+ pose_encoding,
+ pose_encoding_type="absT_quaR_logFL",
+ log_focal_length_bias=1.8,
+ min_focal_length=0.1,
+ max_focal_length=30,
+ return_dict=False,
+ to_OpenCV=True,
+):
+ """
+ Args:
+ pose_encoding: A tensor of shape `BxNxC`, containing a batch of
+ `BxN` `C`-dimensional pose encodings.
+ pose_encoding_type: The type of pose encoding,
+ """
+ pose_encoding_reshaped = pose_encoding.reshape(-1, pose_encoding.shape[-1]) # Reshape to BNxC
+
+ if pose_encoding_type == "absT_quaR_logFL":
+ # 3 for absT, 4 for quaR, 2 for absFL
+ abs_T = pose_encoding_reshaped[:, :3]
+ quaternion_R = pose_encoding_reshaped[:, 3:7]
+ R = quaternion_to_matrix(quaternion_R)
+ log_focal_length = pose_encoding_reshaped[:, 7:9]
+ # log_focal_length_bias was the hyperparameter
+ # to ensure the mean of logFL close to 0 during training
+ # Now converted back
+ focal_length = (log_focal_length + log_focal_length_bias).exp()
+ # clamp to avoid weird fl values
+ focal_length = torch.clamp(focal_length,
+ min=min_focal_length, max=max_focal_length)
+ elif pose_encoding_type == "absT_quaR_OneFL":
+ # 3 for absT, 4 for quaR, 1 for absFL
+ # [absolute translation, quaternion rotation, normalized focal length]
+ abs_T = pose_encoding_reshaped[:, :3]
+ quaternion_R = pose_encoding_reshaped[:, 3:7]
+ R = quaternion_to_matrix(quaternion_R)
+ focal_length = pose_encoding_reshaped[:, 7:8]
+ focal_length = torch.clamp(focal_length,
+ min=min_focal_length, max=max_focal_length)
+ else:
+ raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
+
+ if to_OpenCV:
+ ### From Pytorch3D coordinate to OpenCV coordinate:
+ # I hate coordinate conversion
+ R = R.clone()
+ abs_T = abs_T.clone()
+ R[:, :, :2] *= -1
+ abs_T[:, :2] *= -1
+ R = R.permute(0, 2, 1)
+
+ extrinsics_4x4 = torch.eye(4, 4).to(R.dtype).to(R.device)
+ extrinsics_4x4 = extrinsics_4x4[None].repeat(len(R), 1, 1)
+
+ extrinsics_4x4[:, :3, :3] = R.clone()
+ extrinsics_4x4[:, :3, 3] = abs_T.clone()
+
+ rel_transform = closed_form_inverse_OpenCV(extrinsics_4x4[0:1])
+ rel_transform = rel_transform.expand(len(extrinsics_4x4), -1, -1)
+
+ # relative to the first camera
+ # NOTE it is extrinsics_4x4 x rel_transform instead of rel_transform x extrinsics_4x4
+ extrinsics_4x4 = torch.bmm(extrinsics_4x4, rel_transform)
+
+ R = extrinsics_4x4[:, :3, :3].clone()
+ abs_T = extrinsics_4x4[:, :3, 3].clone()
+
+ if return_dict:
+ return {"focal_length": focal_length, "R": R, "T": abs_T}
+
+ pred_cameras = PerspectiveCameras(focal_length=focal_length,
+ R=R, T=abs_T, device=R.device, in_ndc=False)
+ return pred_cameras
+
+
+def camera_to_pose_encoding(
+ camera, pose_encoding_type="absT_quaR_logFL",
+ log_focal_length_bias=1.8, min_focal_length=0.1, max_focal_length=30
+):
+ """
+ Inverse to pose_encoding_to_camera
+ """
+ if pose_encoding_type == "absT_quaR_logFL":
+ # Convert rotation matrix to quaternion
+ quaternion_R = matrix_to_quaternion(camera.R)
+
+ # Calculate log_focal_length
+ log_focal_length = (
+ torch.log(torch.clamp(camera.focal_length,
+ min=min_focal_length, max=max_focal_length))
+ - log_focal_length_bias
+ )
+
+ # Concatenate to form pose_encoding
+ pose_encoding = torch.cat([camera.T, quaternion_R, log_focal_length], dim=-1)
+
+ elif pose_encoding_type == "absT_quaR_OneFL":
+ # [absolute translation, quaternion rotation, normalized focal length]
+ quaternion_R = matrix_to_quaternion(camera.R)
+ focal_length = (torch.clamp(camera.focal_length,
+ min=min_focal_length,
+ max=max_focal_length))[..., 0:1]
+ pose_encoding = torch.cat([camera.T, quaternion_R, focal_length], dim=-1)
+ else:
+ raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
+
+ return pose_encoding
+
+
+def init_pose_enc(B: int,
+ S: int, pose_encoding_type: str="absT_quaR_logFL",
+ device: Optional[torch.device]=None):
+ """
+ Initialize the pose encoding tensor
+ args:
+ B: batch size
+ S: number of frames
+ pose_encoding_type: the type of pose encoding
+ device: device to put the tensor
+ return:
+ pose_enc: [B S C]
+ """
+ if pose_encoding_type == "absT_quaR_logFL":
+ C = 9
+ elif pose_encoding_type == "absT_quaR_OneFL":
+ C = 8
+ else:
+ raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
+
+ pose_enc = torch.zeros(B, S, C, device=device)
+ pose_enc[..., :3] = 0 # absT
+ pose_enc[..., 3] = 1 # quaR
+ pose_enc[..., 7:] = 1 # logFL
+ return pose_enc
+
+def first_pose_enc_norm(pose_enc: torch.Tensor,
+ pose_encoding_type: str="absT_quaR_OneFL",
+ pose_mode: str = "W2C"):
+ """
+ make sure the poses in on window are normalized by the first frame, where the
+ first frame transformation is the Identity Matrix.
+ NOTE: Poses are all W2C
+ args:
+ pose_enc: [B S C]
+ return:
+ pose_enc_norm: [B S C]
+ """
+ B, S, C = pose_enc.shape
+ # Pose encoding to Cameras (Pytorch3D coordinate)
+ pred_cameras = pose_encoding_to_camera(
+ pose_enc, pose_encoding_type=pose_encoding_type,
+ to_OpenCV=False
+ ) #NOTE: the camera parameters are not in NDC
+
+ R = pred_cameras.R # [B*S, 3, 3]
+ T = pred_cameras.T # [B*S, 3]
+
+ Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B*S, 3, 4]
+ extra_ = torch.tensor([[[0, 0, 0, 1]]],
+ device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)
+ Tran_M = torch.cat([Tran_M, extra_
+ ], dim=1)
+ Tran_M = rearrange(Tran_M, '(b s) c d -> b s c d', b=B)
+
+ # Take the first frame as the base of world coordinate
+ if pose_mode == "C2W":
+ Tran_M_new = (Tran_M[:,:1,...].inverse())@Tran_M
+ elif pose_mode == "W2C":
+ Tran_M_new = Tran_M@(Tran_M[:,:1,...].inverse())
+
+ Tran_M_new = rearrange(Tran_M_new, 'b s c d -> (b s) c d')
+
+ R_ = Tran_M_new[:, :3, :3]
+ T_ = Tran_M_new[:, :3, 3]
+
+ # Cameras to Pose encoding
+ pred_cameras.R = R_
+ pred_cameras.T = T_
+ pose_enc_norm = camera_to_pose_encoding(pred_cameras,
+ pose_encoding_type=pose_encoding_type)
+ pose_enc_norm = rearrange(pose_enc_norm, '(b s) c -> b s c', b=B)
+ return pose_enc_norm
+
+def first_pose_enc_denorm(
+ pose_enc: torch.Tensor,
+ pose_enc_1st: torch.Tensor,
+ pose_encoding_type: str="absT_quaR_OneFL",
+ pose_mode: str = "W2C"):
+ """
+ make sure the poses in on window are de-normalized by the first frame, where the
+ first frame transformation is the Identity Matrix.
+ args:
+ pose_enc: [B S C]
+ pose_enc_1st: [B 1 C]
+ return:
+ pose_enc_denorm: [B S C]
+ """
+ B, S, C = pose_enc.shape
+ pose_enc_all = torch.cat([pose_enc_1st, pose_enc], dim=1)
+
+ # Pose encoding to Cameras (Pytorch3D coordinate)
+ pred_cameras = pose_encoding_to_camera(
+ pose_enc_all, pose_encoding_type=pose_encoding_type,
+ to_OpenCV=False
+ ) #NOTE: the camera parameters are not in NDC
+ R = pred_cameras.R # [B*(1+S), 3, 3]
+ T = pred_cameras.T # [B*(1+S), 3]
+
+ Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B*(1+S), 3, 4]
+ extra_ = torch.tensor([[[0, 0, 0, 1]]],
+ device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)
+ Tran_M = torch.cat([Tran_M, extra_
+ ], dim=1)
+ Tran_M_new = rearrange(Tran_M, '(b s) c d -> b s c d', b=B)[:, 1:]
+ Tran_M_1st = rearrange(Tran_M, '(b s) c d -> b s c d', b=B)[:,:1]
+
+ if pose_mode == "C2W":
+ Tran_M_new = Tran_M_1st@Tran_M_new
+ elif pose_mode == "W2C":
+ Tran_M_new = Tran_M_new@Tran_M_1st
+
+ Tran_M_new_ = torch.cat([Tran_M_1st, Tran_M_new], dim=1)
+ R_ = Tran_M_new_[..., :3, :3].view(-1, 3, 3)
+ T_ = Tran_M_new_[..., :3, 3].view(-1, 3)
+
+ # Cameras to Pose encoding
+ pred_cameras.R = R_
+ pred_cameras.T = T_
+
+ # Cameras to Pose encoding
+ pose_enc_denorm = camera_to_pose_encoding(pred_cameras,
+ pose_encoding_type=pose_encoding_type)
+ pose_enc_denorm = rearrange(pose_enc_denorm, '(b s) c -> b s c', b=B)
+ return pose_enc_denorm[:, 1:]
+
+def compute_scale_and_shift(prediction, target, mask):
+ # system matrix: A = [[a_00, a_01], [a_10, a_11]]
+ a_00 = torch.sum(mask * prediction * prediction, (1, 2))
+ a_01 = torch.sum(mask * prediction, (1, 2))
+ a_11 = torch.sum(mask, (1, 2))
+
+ # right hand side: b = [b_0, b_1]
+ b_0 = torch.sum(mask * prediction * target, (1, 2))
+ b_1 = torch.sum(mask * target, (1, 2))
+
+ # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
+ x_0 = torch.zeros_like(b_0)
+ x_1 = torch.zeros_like(b_1)
+
+ det = a_00 * a_11 - a_01 * a_01
+ # A needs to be a positive definite matrix.
+ valid = det > 0
+
+ x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
+ x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]
+
+ return x_0, x_1
+
+
+def normalize_prediction_robust(target, mask, Bs):
+ ssum = torch.sum(mask, (1, 2))
+ valid = ssum > 0
+
+ m = torch.zeros_like(ssum).to(target.dtype)
+ s = torch.ones_like(ssum).to(target.dtype)
+ m[valid] = torch.median(
+ (mask[valid] * target[valid]).view(valid.sum(), -1), dim=1
+ ).values
+ target = rearrange(target, '(b c) h w -> b c h w', b=Bs)
+ m_vid = rearrange(m, '(b c) -> b c 1 1', b=Bs) #.mean(dim=1, keepdim=True)
+ mask = rearrange(mask, '(b c) h w -> b c h w', b=Bs)
+
+ target = target - m_vid
+
+ sq = torch.sum(mask * target.abs(), (2, 3))
+ sq = rearrange(sq, 'b c -> (b c)')
+ s[valid] = torch.clamp((sq[valid] / ssum[valid]), min=1e-6)
+ s_vid = rearrange(s, '(b c) -> b c 1 1', b=Bs) #.mean(dim=1, keepdim=True)
+ target = target / s_vid
+ target = rearrange(target, 'b c h w -> (b c) h w', b=Bs)
+
+ return target, m_vid, s_vid
+
+def normalize_video_robust(target, mask, Bs):
+
+ vid_valid = target[mask]
+ # downsample to 1/20
+ with torch.no_grad():
+ vid_valid = vid_valid[torch.randperm(vid_valid.shape[0], device='cuda')[:vid_valid.shape[0]//5]]
+ t_2, t_98 = torch.quantile(vid_valid, 0.02), torch.quantile(vid_valid, 0.98)
+ # normalize
+ target = (target - t_2) / (t_98 - t_2)*2 - 1
+ return target, t_2, t_98
+
+def video_loss(prediction, target, mask, Bs):
+ # median norm
+ prediction_nm, a_norm, b_norm = normalize_video_robust(prediction, mask, Bs)
+ target_nm, a_norm_gt, b_norm_gt = normalize_video_robust(target.float(), mask, Bs)
+ depth_loss = nn.functional.l1_loss(prediction_nm[mask], target_nm[mask])
+ # rel depth 2 metric --> (pred - a')/(b'-a')*(b-a) + a
+ scale = (b_norm_gt - a_norm_gt) / (b_norm - a_norm)
+ shift = a_norm_gt - a_norm*scale
+ return depth_loss, scale, shift, prediction_nm, target_nm
+
+def median_loss(prediction, target, mask, Bs):
+ # median norm
+ prediction_nm, a_norm, b_norm = normalize_prediction_robust(prediction, mask, Bs)
+ target_nm, a_norm_gt, b_norm_gt = normalize_prediction_robust(target.float(), mask, Bs)
+ depth_loss = nn.functional.l1_loss(prediction_nm[mask], target_nm[mask])
+ scale = b_norm_gt/b_norm
+ shift = a_norm_gt - a_norm*scale
+ return depth_loss, scale, shift, prediction_nm, target_nm
+
+def reduction_batch_based(image_loss, M):
+ # average of all valid pixels of the batch
+
+ # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0)
+ divisor = torch.sum(M)
+
+ if divisor == 0:
+ return 0
+ else:
+ return torch.sum(image_loss) / divisor
+
+
+def reduction_image_based(image_loss, M):
+ # mean of average of valid pixels of an image
+
+ # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0)
+ valid = M.nonzero()
+
+ image_loss[valid] = image_loss[valid] / M[valid]
+
+ return torch.mean(image_loss)
+
+
+class ScaleAndShiftInvariantLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.name = "SSILoss"
+
+ def forward(self, prediction, target, mask, Bs,
+ interpolate=True, return_interpolated=False):
+
+ if prediction.shape[-1] != target.shape[-1] and interpolate:
+ prediction = nn.functional.interpolate(prediction, target.shape[-2:], mode='bilinear', align_corners=True)
+ intr_input = prediction
+ else:
+ intr_input = prediction
+
+ prediction, target, mask = prediction.squeeze(), target.squeeze(), mask.squeeze()
+ assert prediction.shape == target.shape, f"Shape mismatch: Expected same shape but got {prediction.shape} and {target.shape}."
+
+
+ scale, shift = compute_scale_and_shift(prediction, target, mask)
+ a_norm = scale.view(Bs, -1, 1, 1).mean(dim=1, keepdim=True)
+ b_norm = shift.view(Bs, -1, 1, 1).mean(dim=1, keepdim=True)
+ prediction = rearrange(prediction, '(b c) h w -> b c h w', b=Bs)
+ target = rearrange(target, '(b c) h w -> b c h w', b=Bs)
+ mask = rearrange(mask, '(b c) h w -> b c h w', b=Bs)
+ scaled_prediction = a_norm * prediction + b_norm
+ loss = nn.functional.l1_loss(scaled_prediction[mask], target[mask])
+ if not return_interpolated:
+ return loss, a_norm, b_norm
+ return loss, a_norm, b_norm
+
+ScaleAndShiftInvariantLoss_fn = ScaleAndShiftInvariantLoss()
+
+class GradientLoss(nn.Module):
+ def __init__(self, scales=4, reduction='batch-based'):
+ super().__init__()
+
+ if reduction == 'batch-based':
+ self.__reduction = reduction_batch_based
+ else:
+ self.__reduction = reduction_image_based
+
+ self.__scales = scales
+
+ def forward(self, prediction, target, mask):
+ total = 0
+
+ for scale in range(self.__scales):
+ step = pow(2, scale)
+ l1_ln, a_nm, b_nm = ScaleAndShiftInvariantLoss_fn(prediction[:, ::step, ::step],
+ target[:, ::step, ::step], mask[:, ::step, ::step], 1)
+ total += l1_ln
+ a_nm = a_nm.squeeze().detach() # [B, 1, 1]
+ b_nm = b_nm.squeeze().detach() # [B, 1, 1]
+ total += 2*gradient_loss(a_nm*prediction[:, ::step, ::step]+b_nm, target[:, ::step, ::step],
+ mask[:, ::step, ::step], reduction=self.__reduction)
+
+ return total
+
+Grad_fn = GradientLoss()
+
+def gradient_loss(prediction, target, mask, reduction=reduction_batch_based):
+
+ M = torch.sum(mask, (1, 2))
+
+ diff = prediction - target
+ diff = torch.mul(mask, diff)
+ grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
+ mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
+ grad_x = torch.mul(mask_x, grad_x)
+
+ grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
+ mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
+ grad_y = torch.mul(mask_y, grad_y)
+
+ image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
+
+ return reduction(image_loss, M)
+
+def loss_fn(
+ poses_preds: List[torch.Tensor],
+ poses_pred_all: List[torch.Tensor],
+ poses_gt: torch.Tensor,
+ inv_depth_preds: List[torch.Tensor],
+ inv_depth_raw: List[torch.Tensor],
+ depths_gt: torch.Tensor,
+ S: int = 16,
+ gamma: float = 0.8,
+ logger=None,
+ logger_tf=None,
+ global_step=0,
+ ):
+ """
+ Args:
+ poses_preds: list of predicted poses
+ poses_gt: ground truth poses
+ inv_depth_preds: list of predicted inverse depth maps
+ depths_gt: ground truth depth maps
+ S: length of sliding window
+ """
+ B, T, _, H, W = depths_gt.shape
+
+ loss_total = 0
+ for i in range(len(poses_preds)):
+ poses_preds_i = poses_preds[i][0]
+ poses_unc_i = poses_preds[i][1]
+ poses_gt_i = poses_gt[:, i*S//2:i*S//2+S,:]
+ poses_gt_i_norm = first_pose_enc_norm(poses_gt_i,
+ pose_encoding_type="absT_quaR_OneFL")
+ pose_loss = 0.0
+ for idx, poses_preds_ij in enumerate(poses_preds_i):
+ i_weight = gamma ** (len(poses_preds_i) - idx - 1)
+ if logger is not None:
+ if poses_preds_ij.max()>5e1:
+ logger.info(f"pose_pred_max_and_mean: {poses_preds_ij.max(), poses_preds_ij.mean()}")
+
+ trans_loss = (poses_preds_ij[...,:3] - poses_gt_i_norm[...,:3]).abs().sum(dim=-1).mean()
+ rot_loss = (poses_preds_ij[...,3:7] - poses_gt_i_norm[...,3:7]).abs().sum(dim=-1).mean()
+ focal_loss = (poses_preds_ij[...,7:] - poses_gt_i_norm[...,7:]).abs().sum(dim=-1).mean()
+ if torch.isnan((trans_loss + rot_loss + focal_loss)).any():
+ pose_loss += 0
+ else:
+ pose_loss += i_weight*(trans_loss + rot_loss + focal_loss)
+ if (logger_tf is not None)&(i==len(poses_preds)-1):
+ logger_tf.add_scalar(f"loss@pose/trans_iter{idx}",
+ trans_loss, global_step=global_step)
+ logger_tf.add_scalar(f"loss@pose/rot_iter{idx}",
+ rot_loss, global_step=global_step)
+ logger_tf.add_scalar(f"loss@pose/focal_iter{idx}",
+ focal_loss, global_step=global_step)
+ # compute the uncertainty loss
+ with torch.no_grad():
+ pose_loss_dist = (poses_preds_ij-poses_gt_i_norm).detach().abs()
+ pose_loss_std = 3*pose_loss_dist.view(-1,8).std(dim=0)[None,None,:]
+ gt_dist = F.relu(pose_loss_std - pose_loss_dist) / (pose_loss_std + 1e-3)
+ unc_loss = (gt_dist - poses_unc_i).abs().mean()
+ if (logger_tf is not None)&(i==len(poses_preds)-1):
+ logger_tf.add_scalar(f"loss@uncertainty/unc",
+ unc_loss,
+ global_step=global_step)
+ # if logger is not None:
+ # logger.info(f"pose_loss: {pose_loss}, unc_loss: {unc_loss}")
+ # total loss
+ loss_total += 0.1*unc_loss + 2*pose_loss
+
+ poses_gt_norm = poses_gt
+ pose_all_loss = 0.0
+ prev_loss = None
+ for idx, poses_preds_all_j in enumerate(poses_pred_all):
+ i_weight = gamma ** (len(poses_pred_all) - idx - 1)
+ trans_loss = (poses_preds_all_j[...,:3] - poses_gt_norm[...,:3]).abs().sum(dim=-1).mean()
+ rot_loss = (poses_preds_all_j[...,3:7] - poses_gt_norm[...,3:7]).abs().sum(dim=-1).mean()
+ focal_loss = (poses_preds_all_j[...,7:] - poses_gt_norm[...,7:]).abs().sum(dim=-1).mean()
+ if (logger_tf is not None):
+ if prev_loss is None:
+ prev_loss = (trans_loss + rot_loss + focal_loss)
+ else:
+ des_loss = (trans_loss + rot_loss + focal_loss) - prev_loss
+ prev_loss = trans_loss + rot_loss + focal_loss
+ logger_tf.add_scalar(f"loss@global_pose/des_iter{idx}",
+ des_loss, global_step=global_step)
+ logger_tf.add_scalar(f"loss@global_pose/trans_iter{idx}",
+ trans_loss, global_step=global_step)
+ logger_tf.add_scalar(f"loss@global_pose/rot_iter{idx}",
+ rot_loss, global_step=global_step)
+ logger_tf.add_scalar(f"loss@global_pose/focal_iter{idx}",
+ focal_loss, global_step=global_step)
+ if torch.isnan((trans_loss + rot_loss + focal_loss)).any():
+ pose_all_loss += 0
+ else:
+ pose_all_loss += i_weight*(trans_loss + rot_loss + focal_loss)
+
+ # if logger is not None:
+ # logger.info(f"global_pose_loss: {pose_all_loss}")
+
+ # compute the depth loss
+ if inv_depth_preds[0] is not None:
+ depths_gt = depths_gt[:,:,0]
+ msk = depths_gt > 5e-2
+ inv_gt = 1.0 / (depths_gt.clamp(1e-3, 1e16))
+ inv_gt_reshp = rearrange(inv_gt, 'b t h w -> (b t) h w')
+ inv_depth_preds_reshp = rearrange(inv_depth_preds[0], 'b t h w -> (b t) h w')
+ inv_raw_reshp = rearrange(inv_depth_raw[0], 'b t h w -> (b t) h w')
+ msk_reshp = rearrange(msk, 'b t h w -> (b t) h w')
+ huber_loss = ScaleAndShiftInvariantLoss_fn(inv_depth_preds_reshp, inv_gt_reshp, msk_reshp)
+ huber_loss_raw = ScaleAndShiftInvariantLoss_fn(inv_raw_reshp, inv_gt_reshp, msk_reshp)
+ # huber_loss = (inv_depth_preds[0][msk]-inv_gt[msk]).abs().mean()
+ # cal perason loss
+ perason_loss = 0
+ # for i in range(B):
+ # perason_loss += (1 - pearson_corrcoef(inv_depth_preds[0].view(B*T,-1), inv_gt.view(B*T,-1))).mean()
+ # perason_loss = perason_loss/B
+ if torch.isnan(huber_loss).any():
+ huber_loss = 0
+ depth_loss = huber_loss + perason_loss
+ if (logger_tf is not None)&(i==len(poses_preds)-1):
+ logger_tf.add_scalar(f"loss@depth/huber_iter{idx}",
+ depth_loss,
+ global_step=global_step)
+ # if logger is not None:
+ # logger.info(f"opt_depth: {huber_loss_raw - huber_loss}")
+ else:
+ depth_loss = 0.0
+
+
+ loss_total = loss_total/(len(poses_preds)) + 20*depth_loss + pose_all_loss
+
+ return loss_total, (huber_loss_raw - huber_loss)
+
+
+def vis_depth(x: torch.tensor,
+ logger_tf = None, title: str = "depth", step: int = 0):
+ """
+ args:
+ x: H W
+ """
+ assert len(x.shape) == 2
+
+ depth_map_normalized = cv2.normalize(x.cpu().numpy(),
+ None, 0, 255, cv2.NORM_MINMAX)
+ depth_map_colored = cv2.applyColorMap(depth_map_normalized.astype(np.uint8),
+ cv2.COLORMAP_JET)
+ depth_map_tensor = torch.from_numpy(depth_map_colored).permute(2, 0, 1).unsqueeze(0)
+ if logger_tf is not None:
+ logger_tf.add_image(title, depth_map_tensor[0], step)
+ else:
+ return depth_map_tensor
+
+def vis_pcd(
+ rgbs: torch.Tensor,
+ R: torch.Tensor,
+ T: torch.Tensor,
+ xy_depth: torch.Tensor,
+ focal_length: torch.Tensor,
+ pick_idx: List = [0]
+ ):
+ """
+ args:
+ rgbs: [S C H W]
+ R: [S 3 3]
+ T: [S 3]
+ xy_depth: [S H W 3]
+ focal_length: [S]
+ pick_idx: list of the index to pick
+ """
+ S, C, H, W = rgbs.shape
+
+ rgbs_pick = rgbs[pick_idx]
+ R_pick = R[pick_idx]
+ T_pick = T[pick_idx]
+ xy_depth_pick = xy_depth[pick_idx]
+ focal_length_pick = focal_length[pick_idx]
+ pcd_world = depth2pcd(xy_depth_pick.clone(),
+ focal_length_pick, R_pick.clone(), T_pick.clone(),
+ device=xy_depth.device, H=H, W=W)
+ pcd_world = pcd_world.permute(0, 2, 1) #[...,[1,0,2]]
+ mask = pcd_world.reshape(-1,3)[:,2] < 20
+ rgb_world = rgbs_pick.view(len(pick_idx), 3, -1).permute(0, 2, 1)
+ pcl = Pointclouds(points=[pcd_world.reshape(-1,3)[mask]],
+ features=[rgb_world.reshape(-1,3)[mask]/255])
+ return pcl
+
+def vis_result(rgbs, poses_pred, poses_gt,
+ depth_gt, depth_pred, iter_num=0,
+ vis=None, logger_tf=None, cfg=None):
+ """
+ Args:
+ rgbs: [S C H W]
+ depths_gt: [S C H W]
+ poses_gt: [S C]
+ poses_pred: [S C]
+ depth_pred: [S H W]
+ """
+ assert len(rgbs.shape) == 4, "only support one sequence, T 3 H W of rbg"
+
+ if vis is None:
+ return
+ S, _, H, W = depth_gt.shape
+ # get the xy
+ yx = torch.meshgrid(torch.arange(H).to(depth_pred.device),
+ torch.arange(W).to(depth_pred.device),indexing='ij')
+ xy = torch.stack(yx[::-1], dim=0).float().to(depth_pred.device)
+ xy_norm = (xy / torch.tensor([W, H],
+ device=depth_pred.device).view(2, 1, 1) - 0.5)*2
+ xy = xy[None].repeat(S, 1, 1, 1)
+ xy_depth = torch.cat([xy, depth_pred[:,None]], dim=1).permute(0, 2, 3, 1)
+ xy_depth_gt = torch.cat([xy, depth_gt], dim=1).permute(0, 2, 3, 1)
+ # get the focal length
+ focal_length = poses_gt[:,-1]*max(H, W)
+
+ # vis the camera poses
+ poses_gt_vis = pose_encoding_to_camera(poses_gt,
+ pose_encoding_type="absT_quaR_OneFL",to_OpenCV=False)
+ poses_pred_vis = pose_encoding_to_camera(poses_pred,
+ pose_encoding_type="absT_quaR_OneFL",to_OpenCV=False)
+
+ R_gt = poses_gt_vis.R.float()
+ R_pred = poses_pred_vis.R.float()
+ T_gt = poses_gt_vis.T.float()
+ T_pred = poses_pred_vis.T.float()
+ # C2W poses
+ R_gt_c2w = R_gt.permute(0,2,1)
+ T_gt_c2w = (-R_gt_c2w @ T_gt[:, :, None]).squeeze(-1)
+ R_pred_c2w = R_pred.permute(0,2,1)
+ T_pred_c2w = (-R_pred_c2w @ T_pred[:, :, None]).squeeze(-1)
+ with torch.cuda.amp.autocast(enabled=False):
+ pick_idx = torch.randperm(S)[:min(24, S)]
+ # pick_idx = [1]
+ #NOTE: very strange that the camera need C2W Rotation and W2C translation as input
+ poses_gt_vis = PerspectiveCamerasVisual(
+ R=R_gt_c2w[pick_idx], T=T_gt[pick_idx],
+ device=poses_gt_vis.device, image_size=((H, W),)
+ )
+ poses_pred_vis = PerspectiveCamerasVisual(
+ R=R_pred_c2w[pick_idx], T=T_pred[pick_idx],
+ device=poses_pred_vis.device
+ )
+ visual_dict = {"scenes": {"cameras": poses_pred_vis, "cameras_gt": poses_gt_vis}}
+ env_name = f"train_visualize_iter_{iter_num:05d}"
+ print(f"Visualizing the scene by visdom at env: {env_name}")
+ # visualize the depth map
+ vis_depth(depth_pred[0].detach(), logger_tf, title="vis/depth_pred",step=iter_num)
+ msk = depth_pred[0] > 1e-3
+ vis_depth(depth_gt[0,0].detach(), logger_tf, title="vis/depth_gt",step=iter_num)
+ depth_res = (depth_gt[0,0] - depth_pred[0]).abs()
+ vis_depth(depth_res.detach(), logger_tf, title="vis/depth_res",step=iter_num)
+ # visualize the point cloud
+ if cfg.debug.vis_pcd:
+ visual_dict["scenes"]["points_gt"] = vis_pcd(rgbs, R_gt, T_gt,
+ xy_depth_gt, focal_length, pick_idx)
+ else:
+ visual_dict["scenes"]["points_pred"] = vis_pcd(rgbs, R_pred, T_pred,
+ xy_depth, focal_length, pick_idx)
+ # visualize in visdom
+ fig = plot_scene(visual_dict, camera_scale=0.05)
+ vis.plotlyplot(fig, env=env_name, win="3D")
+ vis.save([env_name])
+
+ return
+
+def depth2pcd(
+ xy_depth: torch.Tensor,
+ focal_length: torch.Tensor,
+ R: torch.Tensor,
+ T: torch.Tensor,
+ device: torch.device = None,
+ H: int = 518,
+ W: int = 518
+ ):
+ """
+ args:
+ xy_depth: [S H W 3]
+ focal_length: [S]
+ R: [S 3 3] W2C
+ T: [S 3] W2C
+ return:
+ xyz: [S 3 (H W)]
+ """
+ S, H, W, _ = xy_depth.shape
+ # get the intrinsic
+ K = torch.eye(3, device=device)[None].repeat(len(focal_length), 1, 1).to(device)
+ K[:, 0, 0] = focal_length
+ K[:, 1, 1] = focal_length
+ K[:, 0, 2] = 0.5 * W
+ K[:, 1, 2] = 0.5 * H
+ K_inv = K.inverse()
+ # xyz
+ xyz = xy_depth.view(S, -1, 3).permute(0, 2, 1) # S 3 (H W)
+ depth = xyz[:, 2:].clone() # S (H W) 1
+ xyz[:, 2] = 1
+ xyz = K_inv @ xyz # S 3 (H W)
+ xyz = xyz * depth
+ # to world coordinate
+ xyz = R.permute(0,2,1) @ (xyz - T[:, :, None])
+
+ return xyz
+
+
+def pose_enc2mat(poses_pred,
+ H_resize, W_resize, resolution=336):
+ """
+ This function convert the pose encoding into `intrinsic` and `extrinsic`
+
+ Args:
+ poses_pred: B T 8
+ Return:
+ Intrinsic B T 3 3
+ Extrinsic B T 4 4
+ """
+ B, T, _ = poses_pred.shape
+ focal_pred = poses_pred[:, :, -1].clone()
+ pos_quat_preds = poses_pred[:, :, :7].clone()
+ pos_quat_preds = pos_quat_preds.view(B*T, -1)
+ # get extrinsic
+ c2w_rot = quaternion_to_matrix(pos_quat_preds[:, 3:])
+ c2w_tran = pos_quat_preds[:, :3]
+ c2w_traj = torch.eye(4)[None].repeat(B*T, 1, 1).to(poses_pred.device)
+ c2w_traj[:, :3, :3], c2w_traj[:, :3, 3] = c2w_rot, c2w_tran
+ c2w_traj = c2w_traj.view(B, T, 4, 4)
+ # get intrinsic
+ fxs, fys = focal_pred*resolution, focal_pred*resolution
+ intrs = torch.eye(3).to(c2w_traj.device).to(c2w_traj.dtype)[None, None].repeat(B, T, 1, 1)
+ intrs[:,:,0,0], intrs[:,:,1,1] = fxs, fys
+ intrs[:,:,0,2], intrs[:,:,1,2] = W_resize/2, H_resize/2
+
+ return intrs, c2w_traj
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ return ret
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
+
+def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
+ return standardize_quaternion(out)
+
+
+def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"):
+ # returns a meshgrid sized B x Y x X
+
+ grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
+ grid_y = torch.reshape(grid_y, [1, Y, 1])
+ grid_y = grid_y.repeat(B, 1, X)
+
+ grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))
+ grid_x = torch.reshape(grid_x, [1, 1, X])
+ grid_x = grid_x.repeat(B, Y, 1)
+
+ if stack:
+ # note we stack in xy order
+ # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
+ grid = torch.stack([grid_x, grid_y], dim=-1)
+ return grid
+ else:
+ return grid_y, grid_x
+
+def get_points_on_a_grid(grid_size, interp_shape,
+ grid_center=(0, 0), device="cuda"):
+ if grid_size == 1:
+ return torch.tensor([interp_shape[1] / 2,
+ interp_shape[0] / 2], device=device)[
+ None, None
+ ]
+
+ grid_y, grid_x = meshgrid2d(
+ 1, grid_size, grid_size, stack=False, norm=False, device=device
+ )
+ step = interp_shape[1] // 64
+ if grid_center[0] != 0 or grid_center[1] != 0:
+ grid_y = grid_y - grid_size / 2.0
+ grid_x = grid_x - grid_size / 2.0
+ grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * (
+ interp_shape[0] - step * 2
+ )
+ grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * (
+ interp_shape[1] - step * 2
+ )
+
+ grid_y = grid_y + grid_center[0]
+ grid_x = grid_x + grid_center[1]
+ xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
+ return xy
+
+def normalize_rgb(x,input_size=224,
+ resize_mode: Literal['resize', 'padding'] = 'resize',
+ if_da=False):
+ """
+ normalize the image for depth anything input
+
+ args:
+ x: the input images [B T C H W]
+ """
+ if isinstance(x, np.ndarray):
+ x = torch.from_numpy(x) / 255.0
+ elif isinstance(x, torch.Tensor):
+ x = x / 255.0
+ B, T, C, H, W = x.shape
+ x = x.view(B * T, C, H, W)
+ Resizer = Resize(
+ width=input_size,
+ height=input_size,
+ resize_target=False,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=14,
+ resize_method='lower_bound',
+ )
+ if resize_mode == 'padding':
+ # zero padding to make the input size to be multiple of 14
+ if H > W:
+ H_scale = input_size
+ W_scale = W * input_size // H
+ else:
+ W_scale = input_size
+ H_scale = H * input_size // W
+ # resize the image
+ x = F.interpolate(x, size=(H_scale, W_scale),
+ mode='bilinear', align_corners=False)
+ # central padding the image
+ padding_x = (input_size - W_scale) // 2
+ padding_y = (input_size - H_scale) // 2
+ extra_x = (input_size - W_scale) % 2
+ extra_y = (input_size - H_scale) % 2
+ x = F.pad(x, (padding_x, padding_x+extra_x,
+ padding_y, padding_y+extra_y), value=0.)
+ elif resize_mode == 'resize':
+ H_scale, W_scale = Resizer.get_size(H, W)
+ x = F.interpolate(x, size=(int(H_scale), int(W_scale)),
+ mode='bicubic', align_corners=True)
+ # get the mean and std
+ __mean__ = torch.tensor([0.485,
+ 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
+ __std__ = torch.tensor([0.229,
+ 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
+ # normalize the image
+ if if_da:
+ x = (x - __mean__) / __std__
+ else:
+ x = x
+ return x.view(B, T, C, x.shape[-2], x.shape[-1])
+
+def get_track_points(H, W, T, device, size=100, support_frame=0,
+ query_size=768, unc_metric=None, mode="mixed"):
+ """
+ This function is used to get the points on the grid
+ args:
+ H: the height of the grid.
+ W: the width of the grid.
+ T: the number of frames.
+ device: the device of the points.
+ size: the size of the grid.
+ """
+ grid_pts = get_points_on_a_grid(size, (H, W), device=device)
+ grid_pts = grid_pts.round()
+ if mode == "incremental":
+ queries = torch.cat(
+ [torch.randint_like(grid_pts[:, :, :1], T), grid_pts],
+ dim=2,
+ )
+ elif mode == "first":
+ queries_first = torch.cat(
+ [torch.zeros_like(grid_pts[:, :, :1]), grid_pts],
+ dim=2,
+ )
+ queries_support = torch.cat(
+ [torch.randint_like(grid_pts[:, :, :1], T), grid_pts],
+ dim=2,
+ )
+ queries = torch.cat([queries_first, queries_support, queries_support], dim=1)
+ elif mode == "mixed":
+ queries = torch.cat(
+ [torch.randint_like(grid_pts[:, :, :1], T), grid_pts],
+ dim=2,
+ )
+ queries_first = torch.cat(
+ [torch.ones_like(grid_pts[:, :, :1]) * support_frame, grid_pts],
+ dim=2,
+ )
+ queries = torch.cat([queries_first, queries, queries], dim=1)
+ if unc_metric is not None:
+ # filter the points with high uncertainty
+ sample_unc = sample_features5d(unc_metric[None], queries[:,None]).squeeze()
+ if ((sample_unc>0.5).sum() < 20):
+ queries = queries
+ else:
+ queries = queries[:,sample_unc>0.5,:]
+ idx_ = torch.randperm(queries.shape[1], device=device)[:query_size]
+ queries = queries[:, idx_]
+ return queries
\ No newline at end of file
diff --git a/models/SpaTrackV2/utils/embeddings.py b/models/SpaTrackV2/utils/embeddings.py
new file mode 100755
index 0000000000000000000000000000000000000000..60d68984dd881f8f6261ff43ae1fd3efa33be5dd
--- /dev/null
+++ b/models/SpaTrackV2/utils/embeddings.py
@@ -0,0 +1,247 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import numpy as np
+
+def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ if isinstance(grid_size, tuple):
+ grid_size_h, grid_size_w = grid_size
+ else:
+ grid_size_h = grid_size_w = grid_size
+ grid_h = np.arange(grid_size_h, dtype=np.float32)
+ grid_w = np.arange(grid_size_w, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
+ pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate(
+ [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
+ )
+ return pos_embed
+
+
+def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 3 == 0
+
+ # use half of dimensions to encode grid_h
+ B, S, N, _ = grid.shape
+ gridx = grid[..., 0].view(B*S*N).detach().cpu().numpy()
+ gridy = grid[..., 1].view(B*S*N).detach().cpu().numpy()
+ gridz = grid[..., 2].view(B*S*N).detach().cpu().numpy()
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridx) # (N, D/3)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridy) # (N, D/3)
+ emb_z = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridz) # (N, D/3)
+
+ emb = np.concatenate([emb_h, emb_w, emb_z], axis=1) # (N, D)
+ emb = torch.from_numpy(emb).to(grid.device)
+ return emb.view(B, S, N, embed_dim)
+
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ if isinstance(grid_size, tuple):
+ grid_size_h, grid_size_w = grid_size
+ else:
+ grid_size_h = grid_size_w = grid_size
+ grid_h = np.arange(grid_size_h, dtype=np.float32)
+ grid_w = np.arange(grid_size_w, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate(
+ [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
+ )
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000 ** omega # (D/2,)
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+def get_2d_embedding(xy, C, cat_coords=True):
+ B, N, D = xy.shape
+ assert D == 2
+
+ x = xy[:, :, 0:1]
+ y = xy[:, :, 1:2]
+ div_term = (
+ torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
+ ).reshape(1, 1, int(C / 2))
+
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
+
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
+
+ pe = torch.cat([pe_x, pe_y], dim=2) # B, N, C*3
+ if cat_coords:
+ pe = torch.cat([xy, pe], dim=2) # B, N, C*3+3
+ return pe
+
+
+def get_3d_embedding(xyz, C, cat_coords=True):
+ B, N, D = xyz.shape
+ assert D == 3
+
+ x = xyz[:, :, 0:1]
+ y = xyz[:, :, 1:2]
+ z = xyz[:, :, 2:3]
+ div_term = (
+ torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C)
+ ).reshape(1, 1, int(C / 2))
+
+ pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
+ pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
+ pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
+
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
+
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
+
+ pe_z[:, :, 0::2] = torch.sin(z * div_term)
+ pe_z[:, :, 1::2] = torch.cos(z * div_term)
+
+ pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3
+ if cat_coords:
+ pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3
+ return pe
+
+
+def get_4d_embedding(xyzw, C, cat_coords=True):
+ B, N, D = xyzw.shape
+ assert D == 4
+
+ x = xyzw[:, :, 0:1]
+ y = xyzw[:, :, 1:2]
+ z = xyzw[:, :, 2:3]
+ w = xyzw[:, :, 3:4]
+ div_term = (
+ torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C)
+ ).reshape(1, 1, int(C / 2))
+
+ pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
+ pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
+ pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
+ pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
+
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
+
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
+
+ pe_z[:, :, 0::2] = torch.sin(z * div_term)
+ pe_z[:, :, 1::2] = torch.cos(z * div_term)
+
+ pe_w[:, :, 0::2] = torch.sin(w * div_term)
+ pe_w[:, :, 1::2] = torch.cos(w * div_term)
+
+ pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2) # B, N, C*3
+ if cat_coords:
+ pe = torch.cat([pe, xyzw], dim=2) # B, N, C*3+3
+ return pe
+
+import torch.nn as nn
+class Embedder_Fourier(nn.Module):
+ def __init__(self, input_dim, max_freq_log2, N_freqs,
+ log_sampling=True, include_input=True,
+ periodic_fns=(torch.sin, torch.cos)):
+ '''
+ :param input_dim: dimension of input to be embedded
+ :param max_freq_log2: log2 of max freq; min freq is 1 by default
+ :param N_freqs: number of frequency bands
+ :param log_sampling: if True, frequency bands are linerly sampled in log-space
+ :param include_input: if True, raw input is included in the embedding
+ :param periodic_fns: periodic functions used to embed input
+ '''
+ super(Embedder_Fourier, self).__init__()
+
+ self.input_dim = input_dim
+ self.include_input = include_input
+ self.periodic_fns = periodic_fns
+
+ self.out_dim = 0
+ if self.include_input:
+ self.out_dim += self.input_dim
+
+ self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns)
+
+ if log_sampling:
+ self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
+ else:
+ self.freq_bands = torch.linspace(
+ 2. ** 0., 2. ** max_freq_log2, N_freqs)
+
+ self.freq_bands = self.freq_bands.numpy().tolist()
+
+ def forward(self,
+ input: torch.Tensor,
+ rescale: float = 1.0):
+ '''
+ :param input: tensor of shape [..., self.input_dim]
+ :return: tensor of shape [..., self.out_dim]
+ '''
+ assert (input.shape[-1] == self.input_dim)
+ out = []
+ if self.include_input:
+ out.append(input/rescale)
+
+ for i in range(len(self.freq_bands)):
+ freq = self.freq_bands[i]
+ for p_fn in self.periodic_fns:
+ out.append(p_fn(input * freq))
+ out = torch.cat(out, dim=-1)
+
+ assert (out.shape[-1] == self.out_dim)
+ return out
\ No newline at end of file
diff --git a/models/SpaTrackV2/utils/model_utils.py b/models/SpaTrackV2/utils/model_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..b1a706742a43aedba0a2001426a600180d4b00c4
--- /dev/null
+++ b/models/SpaTrackV2/utils/model_utils.py
@@ -0,0 +1,444 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+from easydict import EasyDict as edict
+from sklearn.decomposition import PCA
+import matplotlib.pyplot as plt
+
+EPS = 1e-6
+
+def nearest_sample2d(im, x, y, return_inbounds=False):
+ # x and y are each B, N
+ # output is B, C, N
+ if len(im.shape) == 5:
+ B, N, C, H, W = list(im.shape)
+ else:
+ B, C, H, W = list(im.shape)
+ N = list(x.shape)[1]
+
+ x = x.float()
+ y = y.float()
+ H_f = torch.tensor(H, dtype=torch.float32)
+ W_f = torch.tensor(W, dtype=torch.float32)
+
+ # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x -0.5).byte() & (x < float(W_f - 0.5)).byte()
+ y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
+ inbounds = (x_valid & y_valid).float()
+ inbounds = inbounds.reshape(
+ B, N
+ ) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
+ return output, inbounds
+
+ return output # B, C, N
+
+def smart_cat(tensor1, tensor2, dim):
+ if tensor1 is None:
+ return tensor2
+ return torch.cat([tensor1, tensor2], dim=dim)
+
+
+def normalize_single(d):
+ # d is a whatever shape torch tensor
+ dmin = torch.min(d)
+ dmax = torch.max(d)
+ d = (d - dmin) / (EPS + (dmax - dmin))
+ return d
+
+
+def normalize(d):
+ # d is B x whatever. normalize within each element of the batch
+ out = torch.zeros(d.size())
+ if d.is_cuda:
+ out = out.cuda()
+ B = list(d.size())[0]
+ for b in list(range(B)):
+ out[b] = normalize_single(d[b])
+ return out
+
+
+def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"):
+ # returns a meshgrid sized B x Y x X
+
+ grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
+ grid_y = torch.reshape(grid_y, [1, Y, 1])
+ grid_y = grid_y.repeat(B, 1, X)
+
+ grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))
+ grid_x = torch.reshape(grid_x, [1, 1, X])
+ grid_x = grid_x.repeat(B, Y, 1)
+
+ if stack:
+ # note we stack in xy order
+ # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
+ grid = torch.stack([grid_x, grid_y], dim=-1)
+ return grid
+ else:
+ return grid_y, grid_x
+
+
+def reduce_masked_mean(x, mask, dim=None, keepdim=False):
+ # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
+ # returns shape-1
+ # axis can be a list of axes
+ for (a, b) in zip(x.size(), mask.size()):
+ assert a == b # some shape mismatch!
+ prod = x * mask
+ if dim is None:
+ numer = torch.sum(prod)
+ denom = EPS + torch.sum(mask)
+ else:
+ numer = torch.sum(prod, dim=dim, keepdim=keepdim)
+ denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)
+
+ mean = numer / denom
+ return mean
+
+
+def bilinear_sample2d(im, x, y, return_inbounds=False):
+ # x and y are each B, N
+ # output is B, C, N
+ if len(im.shape) == 5:
+ B, N, C, H, W = list(im.shape)
+ else:
+ B, C, H, W = list(im.shape)
+ N = list(x.shape)[1]
+
+ x = x.float()
+ y = y.float()
+ H_f = torch.tensor(H, dtype=torch.float32)
+ W_f = torch.tensor(W, dtype=torch.float32)
+
+ # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x -0.5).byte() & (x < float(W_f - 0.5)).byte()
+ y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
+ inbounds = (x_valid & y_valid).float()
+ inbounds = inbounds.reshape(
+ B, N
+ ) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
+ return output, inbounds
+
+ return output # B, C, N
+
+
+def procrustes_analysis(X0,X1,Weight): # [B,N,3]
+ # translation
+ t0 = X0.mean(dim=1,keepdim=True)
+ t1 = X1.mean(dim=1,keepdim=True)
+ X0c = X0-t0
+ X1c = X1-t1
+ # scale
+ # s0 = (X0c**2).sum(dim=-1).mean().sqrt()
+ # s1 = (X1c**2).sum(dim=-1).mean().sqrt()
+ # X0cs = X0c/s0
+ # X1cs = X1c/s1
+ # rotation (use double for SVD, float loses precision)
+ U,_,V = (X0c.t()@X1c).double().svd(some=True)
+ R = (U@V.t()).float()
+ if R.det()<0: R[2] *= -1
+ # align X1 to X0: X1to0 = (X1-t1)/@R.t()+t0
+ se3 = edict(t0=t0[0],t1=t1[0],R=R)
+
+ return se3
+
+def bilinear_sampler(input, coords, align_corners=True, padding_mode="border", interp_mode="bilinear"):
+ r"""Sample a tensor using bilinear interpolation
+
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
+ convention.
+
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
+ :math:`B` is the batch size, :math:`C` is the number of channels,
+ :math:`H` is the height of the image, and :math:`W` is the width of the
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
+
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
+ that in this case the order of the components is slightly different
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
+
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
+ left-most image pixel :math:`W-1` to the center of the right-most
+ pixel.
+
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
+ the left-most pixel :math:`W` to the right edge of the right-most
+ pixel.
+
+ Similar conventions apply to the :math:`y` for the range
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
+ :math:`[0,T-1]` and :math:`[0,T]`.
+
+ Args:
+ input (Tensor): batch of input images.
+ coords (Tensor): batch of coordinates.
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
+
+ Returns:
+ Tensor: sampled points.
+ """
+
+ sizes = input.shape[2:]
+
+ assert len(sizes) in [2, 3]
+
+ if len(sizes) == 3:
+ # t x y -> x y t to match dimensions T H W in grid_sample
+ coords = coords[..., [1, 2, 0]]
+
+ if align_corners:
+ coords = coords * torch.tensor(
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
+ )
+ else:
+ coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
+
+ coords -= 1
+
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode, mode=interp_mode)
+
+
+def sample_features4d(input, coords, interp_mode="bilinear"):
+ r"""Sample spatial features
+
+ `sample_features4d(input, coords)` samples the spatial features
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
+
+ The field is sampled at coordinates :attr:`coords` using bilinear
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
+ 3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
+
+ The output tensor has one feature per point, and has shape :math:`(B,
+ R, C)`.
+
+ Args:
+ input (Tensor): spatial features.
+ coords (Tensor): points.
+
+ Returns:
+ Tensor: sampled features.
+ """
+
+ B, _, _, _ = input.shape
+
+ # B R 2 -> B R 1 2
+ coords = coords.unsqueeze(2)
+
+ # B C R 1
+ feats = bilinear_sampler(input, coords, interp_mode=interp_mode)
+
+ return feats.permute(0, 2, 1, 3).view(
+ B, -1, feats.shape[1] * feats.shape[3]
+ ) # B C R 1 -> B R C
+
+
+def sample_features5d(input, coords, interp_mode="bilinear"):
+ r"""Sample spatio-temporal features
+
+ `sample_features5d(input, coords)` works in the same way as
+ :func:`sample_features4d` but for spatio-temporal features and points:
+ :attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
+ a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
+ x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
+
+ Args:
+ input (Tensor): spatio-temporal features.
+ coords (Tensor): spatio-temporal points.
+
+ Returns:
+ Tensor: sampled features.
+ """
+
+ B, T, _, _, _ = input.shape
+
+ # B T C H W -> B C T H W
+ input = input.permute(0, 2, 1, 3, 4)
+
+ # B R1 R2 3 -> B R1 R2 1 3
+ coords = coords.unsqueeze(3)
+
+ # B C R1 R2 1
+ feats = bilinear_sampler(input, coords, interp_mode=interp_mode)
+
+ return feats.permute(0, 2, 3, 1, 4).view(
+ B, feats.shape[2], feats.shape[3], feats.shape[1]
+ ) # B C R1 R2 1 -> B R1 R2 C
+
+def vis_PCA(fmaps, save_dir):
+ """
+ visualize the PCA of the feature maps
+ args:
+ fmaps: feature maps 1 C H W
+ save_dir: the directory to save the PCA visualization
+ """
+
+ pca = PCA(n_components=3)
+ fmap_vis = fmaps[0,...]
+ fmap_vnorm = (
+ (fmap_vis-fmap_vis.min())/
+ (fmap_vis.max()-fmap_vis.min()))
+ H_vis, W_vis = fmap_vis.shape[1:]
+ fmap_vnorm = fmap_vnorm.reshape(fmap_vnorm.shape[0],
+ -1).permute(1,0)
+ fmap_pca = pca.fit_transform(fmap_vnorm.detach().cpu().numpy())
+ pca = fmap_pca.reshape(H_vis,W_vis,3)
+ plt.imsave(save_dir,
+ (
+ (pca-pca.min())/
+ (pca.max()-pca.min())
+ ))
\ No newline at end of file
diff --git a/models/SpaTrackV2/utils/visualizer.py b/models/SpaTrackV2/utils/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..14666a85ff33809717d7115046e5a13fa7f694b2
--- /dev/null
+++ b/models/SpaTrackV2/utils/visualizer.py
@@ -0,0 +1,352 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import numpy as np
+import cv2
+import torch
+import flow_vis
+
+from matplotlib import cm
+import torch.nn.functional as F
+import torchvision.transforms as transforms
+import moviepy
+from moviepy.editor import ImageSequenceClip
+import matplotlib.pyplot as plt
+
+
+def read_video_from_path(path):
+ cap = cv2.VideoCapture(path)
+ if not cap.isOpened():
+ print("Error opening video file")
+ else:
+ frames = []
+ while cap.isOpened():
+ ret, frame = cap.read()
+ if ret == True:
+ frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
+ else:
+ break
+ cap.release()
+ return np.stack(frames)
+
+
+class Visualizer:
+ def __init__(
+ self,
+ save_dir: str = "./results",
+ grayscale: bool = False,
+ pad_value: int = 0,
+ fps: int = 10,
+ mode: str = "rainbow", # 'cool', 'optical_flow'
+ linewidth: int = 2,
+ show_first_frame: int = 10,
+ tracks_leave_trace: int = 0, # -1 for infinite
+ ):
+ self.mode = mode
+ self.save_dir = save_dir
+ if mode == "rainbow":
+ self.color_map = cm.get_cmap("gist_rainbow")
+ elif mode == "cool":
+ self.color_map = cm.get_cmap(mode)
+ self.show_first_frame = show_first_frame
+ self.grayscale = grayscale
+ self.tracks_leave_trace = tracks_leave_trace
+ self.pad_value = pad_value
+ self.linewidth = linewidth
+ self.fps = fps
+
+ def visualize(
+ self,
+ video: torch.Tensor, # (B,T,C,H,W)
+ tracks: torch.Tensor, # (B,T,N,2)
+ visibility: torch.Tensor = None, # (B, T, N, 1) bool
+ gt_tracks: torch.Tensor = None, # (B,T,N,2)
+ segm_mask: torch.Tensor = None, # (B,1,H,W)
+ filename: str = "video",
+ writer=None, # tensorboard Summary Writer, used for visualization during training
+ step: int = 0,
+ query_frame: int = 0,
+ save_video: bool = True,
+ compensate_for_camera_motion: bool = False,
+ rigid_part = None,
+ video_depth = None # (B,T,C,H,W)
+ ):
+ if compensate_for_camera_motion:
+ assert segm_mask is not None
+ if segm_mask is not None:
+ coords = tracks[0, query_frame].round().long()
+ segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
+
+ video = F.pad(
+ video,
+ (self.pad_value, self.pad_value, self.pad_value, self.pad_value),
+ "constant",
+ 255,
+ )
+
+ if video_depth is not None:
+ video_depth = (video_depth*255).cpu().numpy().astype(np.uint8)
+ video_depth = ([cv2.applyColorMap(video_depth[0,i,0], cv2.COLORMAP_INFERNO)
+ for i in range(video_depth.shape[1])])
+ video_depth = np.stack(video_depth, axis=0)
+ video_depth = torch.from_numpy(video_depth).permute(0, 3, 1, 2)[None]
+
+ tracks = tracks + self.pad_value
+
+ if self.grayscale:
+ transform = transforms.Grayscale()
+ video = transform(video)
+ video = video.repeat(1, 1, 3, 1, 1)
+
+ res_video = self.draw_tracks_on_video(
+ video=video,
+ tracks=tracks,
+ visibility=visibility,
+ segm_mask=segm_mask,
+ gt_tracks=gt_tracks,
+ query_frame=query_frame,
+ compensate_for_camera_motion=compensate_for_camera_motion,
+ rigid_part=rigid_part
+ )
+
+ if save_video:
+ self.save_video(res_video, filename=filename,
+ writer=writer, step=step)
+ if video_depth is not None:
+ self.save_video(video_depth, filename=filename+"_depth",
+ writer=writer, step=step)
+ return res_video
+
+ def save_video(self, video, filename, writer=None, step=0):
+ if writer is not None:
+ writer.add_video(
+ f"{filename}_pred_track",
+ video.to(torch.uint8),
+ global_step=step,
+ fps=self.fps,
+ )
+ else:
+ os.makedirs(self.save_dir, exist_ok=True)
+ wide_list = list(video.unbind(1))
+ wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
+ clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps)
+
+ # Write the video file
+ save_path = os.path.join(self.save_dir, f"{filename}_pred_track.mp4")
+ clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None)
+
+ print(f"Video saved to {save_path}")
+
+ def draw_tracks_on_video(
+ self,
+ video: torch.Tensor,
+ tracks: torch.Tensor,
+ visibility: torch.Tensor = None,
+ segm_mask: torch.Tensor = None,
+ gt_tracks=None,
+ query_frame: int = 0,
+ compensate_for_camera_motion=False,
+ rigid_part=None,
+ ):
+ B, T, C, H, W = video.shape
+ _, _, N, D = tracks.shape
+
+ assert D == 2
+ assert C == 3
+ video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
+ tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
+ if gt_tracks is not None:
+ gt_tracks = gt_tracks.detach().cpu().numpy()
+
+ res_video = []
+
+ # process input video
+ for rgb in video:
+ res_video.append(rgb.copy())
+
+ vector_colors = np.zeros((T, N, 3))
+ if self.mode == "optical_flow":
+ vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
+ elif segm_mask is None:
+ if self.mode == "rainbow":
+ y_min, y_max = (
+ tracks[query_frame, :, 1].min(),
+ tracks[query_frame, :, 1].max(),
+ )
+ norm = plt.Normalize(y_min, y_max)
+ for n in range(N):
+ color = self.color_map(norm(tracks[query_frame, n, 1]))
+ color = np.array(color[:3])[None] * 255
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
+ else:
+ # color changes with time
+ for t in range(T):
+ color = np.array(self.color_map(t / T)[:3])[None] * 255
+ vector_colors[t] = np.repeat(color, N, axis=0)
+ else:
+ if self.mode == "rainbow":
+ vector_colors[:, segm_mask <= 0, :] = 255
+
+ y_min, y_max = (
+ tracks[0, segm_mask > 0, 1].min(),
+ tracks[0, segm_mask > 0, 1].max(),
+ )
+ norm = plt.Normalize(y_min, y_max)
+ for n in range(N):
+ if segm_mask[n] > 0:
+ color = self.color_map(norm(tracks[0, n, 1]))
+ color = np.array(color[:3])[None] * 255
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
+
+ else:
+ # color changes with segm class
+ segm_mask = segm_mask.cpu()
+ color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
+ color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
+ color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
+ vector_colors = np.repeat(color[None], T, axis=0)
+
+ # draw tracks
+ if self.tracks_leave_trace != 0:
+ for t in range(1, T):
+ first_ind = (
+ max(0, t - self.tracks_leave_trace)
+ if self.tracks_leave_trace >= 0
+ else 0
+ )
+ curr_tracks = tracks[first_ind : t + 1]
+ curr_colors = vector_colors[first_ind : t + 1]
+ if compensate_for_camera_motion:
+ diff = (
+ tracks[first_ind : t + 1, segm_mask <= 0]
+ - tracks[t : t + 1, segm_mask <= 0]
+ ).mean(1)[:, None]
+
+ curr_tracks = curr_tracks - diff
+ curr_tracks = curr_tracks[:, segm_mask > 0]
+ curr_colors = curr_colors[:, segm_mask > 0]
+
+ res_video[t] = self._draw_pred_tracks(
+ res_video[t],
+ curr_tracks,
+ curr_colors,
+ )
+ if gt_tracks is not None:
+ res_video[t] = self._draw_gt_tracks(
+ res_video[t], gt_tracks[first_ind : t + 1]
+ )
+
+ if rigid_part is not None:
+ cls_label = torch.unique(rigid_part)
+ cls_num = len(torch.unique(rigid_part))
+ # visualize the clustering results
+ cmap = plt.get_cmap('jet') # get the color mapping
+ colors = cmap(np.linspace(0, 1, cls_num))
+ colors = (colors[:, :3] * 255)
+ color_map = {lable.item(): color for lable, color in zip(cls_label, colors)}
+
+
+ # draw points
+ for t in range(T):
+ for i in range(N):
+ coord = (tracks[t, i, 0], tracks[t, i, 1])
+ visibile = True
+ if visibility is not None:
+ visibile = visibility[0, t, i] > 0.5
+ if coord[0] != 0 and coord[1] != 0:
+ if not compensate_for_camera_motion or (
+ compensate_for_camera_motion and segm_mask[i] > 0
+ ):
+ if rigid_part is not None:
+ color = color_map[rigid_part.squeeze()[i].item()]
+ cv2.circle(
+ res_video[t],
+ coord,
+ int(self.linewidth * 2),
+ color.tolist(),
+ thickness=-1 if visibile else 2
+ -1,
+ )
+ else:
+ cv2.circle(
+ res_video[t],
+ coord,
+ int(self.linewidth * 2),
+ vector_colors[t, i].tolist(),
+ thickness=-1 if visibile else 2
+ -1,
+ )
+
+ # construct the final rgb sequence
+ if self.show_first_frame > 0:
+ res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
+ return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
+
+ def _draw_pred_tracks(
+ self,
+ rgb: np.ndarray, # H x W x 3
+ tracks: np.ndarray, # T x 2
+ vector_colors: np.ndarray,
+ alpha: float = 0.5,
+ ):
+ T, N, _ = tracks.shape
+
+ for s in range(T - 1):
+ vector_color = vector_colors[s]
+ original = rgb.copy()
+ alpha = (s / T) ** 2
+ for i in range(N):
+ coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
+ coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
+ if coord_y[0] != 0 and coord_y[1] != 0:
+ cv2.line(
+ rgb,
+ coord_y,
+ coord_x,
+ vector_color[i].tolist(),
+ self.linewidth,
+ cv2.LINE_AA,
+ )
+ if self.tracks_leave_trace > 0:
+ rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)
+ return rgb
+
+ def _draw_gt_tracks(
+ self,
+ rgb: np.ndarray, # H x W x 3,
+ gt_tracks: np.ndarray, # T x 2
+ ):
+ T, N, _ = gt_tracks.shape
+ color = np.array((211.0, 0.0, 0.0))
+
+ for t in range(T):
+ for i in range(N):
+ gt_tracks_i = gt_tracks[t][i]
+ # draw a red cross
+ if gt_tracks_i[0] > 0 and gt_tracks_i[1] > 0:
+ length = self.linewidth * 3
+ coord_y = (int(gt_tracks_i[0]) + length, int(gt_tracks_i[1]) + length)
+ coord_x = (int(gt_tracks_i[0]) - length, int(gt_tracks_i[1]) - length)
+ cv2.line(
+ rgb,
+ coord_y,
+ coord_x,
+ color,
+ self.linewidth,
+ cv2.LINE_AA,
+ )
+ coord_y = (int(gt_tracks_i[0]) - length, int(gt_tracks_i[1]) + length)
+ coord_x = (int(gt_tracks_i[0]) + length, int(gt_tracks_i[1]) - length)
+ cv2.line(
+ rgb,
+ coord_y,
+ coord_x,
+ color,
+ self.linewidth,
+ cv2.LINE_AA,
+ )
+ return rgb
diff --git a/models/moge/__init__.py b/models/moge/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/moge/model/__init__.py b/models/moge/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c919e3be42c0005752e8c800129bd5f724b47ff9
--- /dev/null
+++ b/models/moge/model/__init__.py
@@ -0,0 +1,18 @@
+import importlib
+from typing import *
+
+if TYPE_CHECKING:
+ from .v1 import MoGeModel as MoGeModelV1
+ from .v2 import MoGeModel as MoGeModelV2
+
+
+def import_model_class_by_version(version: str) -> Type[Union['MoGeModelV1', 'MoGeModelV2']]:
+ assert version in ['v1', 'v2'], f'Unsupported model version: {version}'
+
+ try:
+ module = importlib.import_module(f'.{version}', __package__)
+ except ModuleNotFoundError:
+ raise ValueError(f'Model version "{version}" not found.')
+
+ cls = getattr(module, 'MoGeModel')
+ return cls
diff --git a/models/moge/model/dinov2/__init__.py b/models/moge/model/dinov2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae847e46898077fe3d8701b8a181d7b4e3d41cd9
--- /dev/null
+++ b/models/moge/model/dinov2/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+__version__ = "0.0.1"
diff --git a/models/moge/model/dinov2/hub/__init__.py b/models/moge/model/dinov2/hub/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/models/moge/model/dinov2/hub/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/models/moge/model/dinov2/hub/backbones.py b/models/moge/model/dinov2/hub/backbones.py
new file mode 100644
index 0000000000000000000000000000000000000000..53fe83719d5107eb77a8f25ef1814c3d73446002
--- /dev/null
+++ b/models/moge/model/dinov2/hub/backbones.py
@@ -0,0 +1,156 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from typing import Union
+
+import torch
+
+from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
+
+
+class Weights(Enum):
+ LVD142M = "LVD142M"
+
+
+def _make_dinov2_model(
+ *,
+ arch_name: str = "vit_large",
+ img_size: int = 518,
+ patch_size: int = 14,
+ init_values: float = 1.0,
+ ffn_layer: str = "mlp",
+ block_chunks: int = 0,
+ num_register_tokens: int = 0,
+ interpolate_antialias: bool = False,
+ interpolate_offset: float = 0.1,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.LVD142M,
+ **kwargs,
+):
+ from ..models import vision_transformer as vits
+
+ if isinstance(weights, str):
+ try:
+ weights = Weights[weights]
+ except KeyError:
+ raise AssertionError(f"Unsupported weights: {weights}")
+
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=patch_size,
+ init_values=init_values,
+ ffn_layer=ffn_layer,
+ block_chunks=block_chunks,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ )
+ vit_kwargs.update(**kwargs)
+ model = vits.__dict__[arch_name](**vit_kwargs)
+
+ if pretrained:
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ model.load_state_dict(state_dict, strict=True)
+
+ return model
+
+
+def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ **kwargs,
+ )
+
+
+def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_small",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_base",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_large",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
diff --git a/models/moge/model/dinov2/hub/utils.py b/models/moge/model/dinov2/hub/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64
--- /dev/null
+++ b/models/moge/model/dinov2/hub/utils.py
@@ -0,0 +1,39 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import itertools
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
+
+
+def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
+ compact_arch_name = arch_name.replace("_", "")[:4]
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
+
+
+class CenterPadding(nn.Module):
+ def __init__(self, multiple):
+ super().__init__()
+ self.multiple = multiple
+
+ def _get_pad(self, size):
+ new_size = math.ceil(size / self.multiple) * self.multiple
+ pad_size = new_size - size
+ pad_size_left = pad_size // 2
+ pad_size_right = pad_size - pad_size_left
+ return pad_size_left, pad_size_right
+
+ @torch.inference_mode()
+ def forward(self, x):
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
+ output = F.pad(x, pads)
+ return output
diff --git a/models/moge/model/dinov2/layers/__init__.py b/models/moge/model/dinov2/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..05a0b61868e43abb821ca05a813bab2b8b43629e
--- /dev/null
+++ b/models/moge/model/dinov2/layers/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .dino_head import DINOHead
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/models/moge/model/dinov2/layers/attention.py b/models/moge/model/dinov2/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fed573116d5c837be46a7525d8acf77422c2400
--- /dev/null
+++ b/models/moge/model/dinov2/layers/attention.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+import os
+import warnings
+
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import memory_efficient_attention, unbind
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (Attention)")
+ else:
+ # warnings.warn("xFormers is disabled (Attention)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ # warnings.warn("xFormers is not available (Attention)")
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/models/moge/model/dinov2/layers/block.py b/models/moge/model/dinov2/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd5b8a7bb8527b74186af7c1e060e37bdb52c73d
--- /dev/null
+++ b/models/moge/model/dinov2/layers/block.py
@@ -0,0 +1,259 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+import os
+from typing import Callable, List, Any, Tuple, Dict
+import warnings
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (Block)")
+ else:
+ # warnings.warn("xFormers is disabled (Block)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ # warnings.warn("xFormers is not available (Block)")
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/models/moge/model/dinov2/layers/dino_head.py b/models/moge/model/dinov2/layers/dino_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565
--- /dev/null
+++ b/models/moge/model/dinov2/layers/dino_head.py
@@ -0,0 +1,58 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils import weight_norm
+
+
+class DINOHead(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ use_bn=False,
+ nlayers=3,
+ hidden_dim=2048,
+ bottleneck_dim=256,
+ mlp_bias=True,
+ ):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
+ self.apply(self._init_weights)
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
+ x = self.last_layer(x)
+ return x
+
+
+def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
+ if nlayers == 1:
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
+ return nn.Sequential(*layers)
diff --git a/models/moge/model/dinov2/layers/drop_path.py b/models/moge/model/dinov2/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5
--- /dev/null
+++ b/models/moge/model/dinov2/layers/drop_path.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/models/moge/model/dinov2/layers/layer_scale.py b/models/moge/model/dinov2/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386
--- /dev/null
+++ b/models/moge/model/dinov2/layers/layer_scale.py
@@ -0,0 +1,27 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/models/moge/model/dinov2/layers/mlp.py b/models/moge/model/dinov2/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e
--- /dev/null
+++ b/models/moge/model/dinov2/layers/mlp.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/models/moge/model/dinov2/layers/patch_embed.py b/models/moge/model/dinov2/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339
--- /dev/null
+++ b/models/moge/model/dinov2/layers/patch_embed.py
@@ -0,0 +1,88 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/models/moge/model/dinov2/layers/swiglu_ffn.py b/models/moge/model/dinov2/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ce211515774d42e04c8b51003bae53b88f14b35
--- /dev/null
+++ b/models/moge/model/dinov2/layers/swiglu_ffn.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Callable, Optional
+import warnings
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (SwiGLU)")
+ else:
+ # warnings.warn("xFormers is disabled (SwiGLU)")
+ raise ImportError
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+ # warnings.warn("xFormers is not available (SwiGLU)")
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/models/moge/model/dinov2/models/__init__.py b/models/moge/model/dinov2/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fdff20badbd5244bf79f16bf18dd2cb73982265
--- /dev/null
+++ b/models/moge/model/dinov2/models/__init__.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+
+from . import vision_transformer as vits
+
+
+logger = logging.getLogger("dinov2")
+
+
+def build_model(args, only_teacher=False, img_size=224):
+ args.arch = args.arch.removesuffix("_memeff")
+ if "vit" in args.arch:
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=args.patch_size,
+ init_values=args.layerscale,
+ ffn_layer=args.ffn_layer,
+ block_chunks=args.block_chunks,
+ qkv_bias=args.qkv_bias,
+ proj_bias=args.proj_bias,
+ ffn_bias=args.ffn_bias,
+ num_register_tokens=args.num_register_tokens,
+ interpolate_offset=args.interpolate_offset,
+ interpolate_antialias=args.interpolate_antialias,
+ )
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
+ if only_teacher:
+ return teacher, teacher.embed_dim
+ student = vits.__dict__[args.arch](
+ **vit_kwargs,
+ drop_path_rate=args.drop_path_rate,
+ drop_path_uniform=args.drop_path_uniform,
+ )
+ embed_dim = student.embed_dim
+ return student, teacher, embed_dim
+
+
+def build_model_from_cfg(cfg, only_teacher=False):
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
diff --git a/models/moge/model/dinov2/models/vision_transformer.py b/models/moge/model/dinov2/models/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1007ba57ddb35109c91716f1f5bf203db346e7be
--- /dev/null
+++ b/models/moge/model/dinov2/models/vision_transformer.py
@@ -0,0 +1,396 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
+ assert N == M * M
+ kwargs = {}
+ if self.interpolate_offset:
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
+ sx = float(w0 + self.interpolate_offset) / M
+ sy = float(h0 + self.interpolate_offset) / M
+ kwargs["scale_factor"] = (sx, sy)
+ else:
+ # Simply specify an output size instead of a scale factor
+ kwargs["size"] = (w0, h0)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ **kwargs,
+ )
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
diff --git a/models/moge/model/dinov2/utils/__init__.py b/models/moge/model/dinov2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/models/moge/model/dinov2/utils/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/models/moge/model/dinov2/utils/cluster.py b/models/moge/model/dinov2/utils/cluster.py
new file mode 100644
index 0000000000000000000000000000000000000000..3df87dc3e1eb4f0f8a280dc3137cfef031886314
--- /dev/null
+++ b/models/moge/model/dinov2/utils/cluster.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+import os
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+
+class ClusterType(Enum):
+ AWS = "aws"
+ FAIR = "fair"
+ RSC = "rsc"
+
+
+def _guess_cluster_type() -> ClusterType:
+ uname = os.uname()
+ if uname.sysname == "Linux":
+ if uname.release.endswith("-aws"):
+ # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
+ return ClusterType.AWS
+ elif uname.nodename.startswith("rsc"):
+ # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
+ return ClusterType.RSC
+
+ return ClusterType.FAIR
+
+
+def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
+ if cluster_type is None:
+ return _guess_cluster_type()
+
+ return cluster_type
+
+
+def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type is None:
+ return None
+
+ CHECKPOINT_DIRNAMES = {
+ ClusterType.AWS: "checkpoints",
+ ClusterType.FAIR: "checkpoint",
+ ClusterType.RSC: "checkpoint/dino",
+ }
+ return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
+
+
+def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
+ checkpoint_path = get_checkpoint_path(cluster_type)
+ if checkpoint_path is None:
+ return None
+
+ username = os.environ.get("USER")
+ assert username is not None
+ return checkpoint_path / username
+
+
+def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type is None:
+ return None
+
+ SLURM_PARTITIONS = {
+ ClusterType.AWS: "learnlab",
+ ClusterType.FAIR: "learnlab",
+ ClusterType.RSC: "learn",
+ }
+ return SLURM_PARTITIONS[cluster_type]
+
+
+def get_slurm_executor_parameters(
+ nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
+) -> Dict[str, Any]:
+ # create default parameters
+ params = {
+ "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
+ "gpus_per_node": num_gpus_per_node,
+ "tasks_per_node": num_gpus_per_node, # one task per GPU
+ "cpus_per_task": 10,
+ "nodes": nodes,
+ "slurm_partition": get_slurm_partition(cluster_type),
+ }
+ # apply cluster-specific adjustments
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type == ClusterType.AWS:
+ params["cpus_per_task"] = 12
+ del params["mem_gb"]
+ elif cluster_type == ClusterType.RSC:
+ params["cpus_per_task"] = 12
+ # set additional parameters / apply overrides
+ params.update(kwargs)
+ return params
diff --git a/models/moge/model/dinov2/utils/config.py b/models/moge/model/dinov2/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9de578787bbcb376f8bd5a782206d0eb7ec1f52
--- /dev/null
+++ b/models/moge/model/dinov2/utils/config.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import math
+import logging
+import os
+
+from omegaconf import OmegaConf
+
+import dinov2.distributed as distributed
+from dinov2.logging import setup_logging
+from dinov2.utils import utils
+from dinov2.configs import dinov2_default_config
+
+
+logger = logging.getLogger("dinov2")
+
+
+def apply_scaling_rules_to_cfg(cfg): # to fix
+ if cfg.optim.scaling_rule == "sqrt_wrt_1024":
+ base_lr = cfg.optim.base_lr
+ cfg.optim.lr = base_lr
+ cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
+ logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
+ else:
+ raise NotImplementedError
+ return cfg
+
+
+def write_config(cfg, output_dir, name="config.yaml"):
+ logger.info(OmegaConf.to_yaml(cfg))
+ saved_cfg_path = os.path.join(output_dir, name)
+ with open(saved_cfg_path, "w") as f:
+ OmegaConf.save(config=cfg, f=f)
+ return saved_cfg_path
+
+
+def get_cfg_from_args(args):
+ args.output_dir = os.path.abspath(args.output_dir)
+ args.opts += [f"train.output_dir={args.output_dir}"]
+ default_cfg = OmegaConf.create(dinov2_default_config)
+ cfg = OmegaConf.load(args.config_file)
+ cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
+ return cfg
+
+
+def default_setup(args):
+ distributed.enable(overwrite=True)
+ seed = getattr(args, "seed", 0)
+ rank = distributed.get_global_rank()
+
+ global logger
+ setup_logging(output=args.output_dir, level=logging.INFO)
+ logger = logging.getLogger("dinov2")
+
+ utils.fix_random_seeds(seed + rank)
+ logger.info("git:\n {}\n".format(utils.get_sha()))
+ logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
+
+
+def setup(args):
+ """
+ Create configs and perform basic setups.
+ """
+ cfg = get_cfg_from_args(args)
+ os.makedirs(args.output_dir, exist_ok=True)
+ default_setup(args)
+ apply_scaling_rules_to_cfg(cfg)
+ write_config(cfg, args.output_dir)
+ return cfg
diff --git a/models/moge/model/dinov2/utils/dtype.py b/models/moge/model/dinov2/utils/dtype.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f4cd74d99faa2731dbe9f8d3a13d71b3f8e3a8
--- /dev/null
+++ b/models/moge/model/dinov2/utils/dtype.py
@@ -0,0 +1,37 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+
+from typing import Dict, Union
+
+import numpy as np
+import torch
+
+
+TypeSpec = Union[str, np.dtype, torch.dtype]
+
+
+_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
+ np.dtype("bool"): torch.bool,
+ np.dtype("uint8"): torch.uint8,
+ np.dtype("int8"): torch.int8,
+ np.dtype("int16"): torch.int16,
+ np.dtype("int32"): torch.int32,
+ np.dtype("int64"): torch.int64,
+ np.dtype("float16"): torch.float16,
+ np.dtype("float32"): torch.float32,
+ np.dtype("float64"): torch.float64,
+ np.dtype("complex64"): torch.complex64,
+ np.dtype("complex128"): torch.complex128,
+}
+
+
+def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
+ if isinstance(dtype, torch.dtype):
+ return dtype
+ if isinstance(dtype, str):
+ dtype = np.dtype(dtype)
+ assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
+ return _NUMPY_TO_TORCH_DTYPE[dtype]
diff --git a/models/moge/model/dinov2/utils/param_groups.py b/models/moge/model/dinov2/utils/param_groups.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a5d2ff627cddadc222e5f836864ee39c865208f
--- /dev/null
+++ b/models/moge/model/dinov2/utils/param_groups.py
@@ -0,0 +1,103 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict
+import logging
+
+
+logger = logging.getLogger("dinov2")
+
+
+def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
+ """
+ Calculate lr decay rate for different ViT blocks.
+ Args:
+ name (string): parameter name.
+ lr_decay_rate (float): base lr decay rate.
+ num_layers (int): number of ViT blocks.
+ Returns:
+ lr decay rate for the given parameter.
+ """
+ layer_id = num_layers + 1
+ if name.startswith("backbone") or force_is_backbone:
+ if (
+ ".pos_embed" in name
+ or ".patch_embed" in name
+ or ".mask_token" in name
+ or ".cls_token" in name
+ or ".register_tokens" in name
+ ):
+ layer_id = 0
+ elif force_is_backbone and (
+ "pos_embed" in name
+ or "patch_embed" in name
+ or "mask_token" in name
+ or "cls_token" in name
+ or "register_tokens" in name
+ ):
+ layer_id = 0
+ elif ".blocks." in name and ".residual." not in name:
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
+ elif chunked_blocks and "blocks." in name and "residual." not in name:
+ layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
+ elif "blocks." in name and "residual." not in name:
+ layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
+
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
+
+
+def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
+ chunked_blocks = False
+ if hasattr(model, "n_blocks"):
+ logger.info("chunked fsdp")
+ n_blocks = model.n_blocks
+ chunked_blocks = model.chunked_blocks
+ elif hasattr(model, "blocks"):
+ logger.info("first code branch")
+ n_blocks = len(model.blocks)
+ elif hasattr(model, "backbone"):
+ logger.info("second code branch")
+ n_blocks = len(model.backbone.blocks)
+ else:
+ logger.info("else code branch")
+ n_blocks = 0
+ all_param_groups = []
+
+ for name, param in model.named_parameters():
+ name = name.replace("_fsdp_wrapped_module.", "")
+ if not param.requires_grad:
+ continue
+ decay_rate = get_vit_lr_decay_rate(
+ name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
+ )
+ d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
+
+ if "last_layer" in name:
+ d.update({"is_last_layer": True})
+
+ if name.endswith(".bias") or "norm" in name or "gamma" in name:
+ d.update({"wd_multiplier": 0.0})
+
+ if "patch_embed" in name:
+ d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
+
+ all_param_groups.append(d)
+ logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
+
+ return all_param_groups
+
+
+def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
+ fused_params_groups = defaultdict(lambda: {"params": []})
+ for d in all_params_groups:
+ identifier = ""
+ for k in keys:
+ identifier += k + str(d[k]) + "_"
+
+ for k in keys:
+ fused_params_groups[identifier][k] = d[k]
+ fused_params_groups[identifier]["params"].append(d["params"])
+
+ return fused_params_groups.values()
diff --git a/models/moge/model/dinov2/utils/utils.py b/models/moge/model/dinov2/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..68f8e2c3be5f780bbb7e00359b5ac4fd0ba0785f
--- /dev/null
+++ b/models/moge/model/dinov2/utils/utils.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+import random
+import subprocess
+from urllib.parse import urlparse
+
+import numpy as np
+import torch
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
+ if urlparse(pretrained_weights).scheme: # If it looks like an URL
+ state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
+ else:
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
+ if checkpoint_key is not None and checkpoint_key in state_dict:
+ logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
+ state_dict = state_dict[checkpoint_key]
+ # remove `module.` prefix
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
+ # remove `backbone.` prefix induced by multicrop wrapper
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
+ msg = model.load_state_dict(state_dict, strict=False)
+ logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
+
+
+def fix_random_seeds(seed=31):
+ """
+ Fix random seeds.
+ """
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+
+def get_sha():
+ cwd = os.path.dirname(os.path.abspath(__file__))
+
+ def _run(command):
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
+
+ sha = "N/A"
+ diff = "clean"
+ branch = "N/A"
+ try:
+ sha = _run(["git", "rev-parse", "HEAD"])
+ subprocess.check_output(["git", "diff"], cwd=cwd)
+ diff = _run(["git", "diff-index", "HEAD"])
+ diff = "has uncommitted changes" if diff else "clean"
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
+ except Exception:
+ pass
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
+ return message
+
+
+class CosineScheduler(object):
+ def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
+ super().__init__()
+ self.final_value = final_value
+ self.total_iters = total_iters
+
+ freeze_schedule = np.zeros((freeze_iters))
+
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
+
+ iters = np.arange(total_iters - warmup_iters - freeze_iters)
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
+ self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
+
+ assert len(self.schedule) == self.total_iters
+
+ def __getitem__(self, it):
+ if it >= self.total_iters:
+ return self.final_value
+ else:
+ return self.schedule[it]
+
+
+def has_batchnorms(model):
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
+ for name, module in model.named_modules():
+ if isinstance(module, bn_types):
+ return True
+ return False
diff --git a/models/moge/model/modules.py b/models/moge/model/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b6731993a10bae0348a928b6018533cabcc1551
--- /dev/null
+++ b/models/moge/model/modules.py
@@ -0,0 +1,250 @@
+from typing import *
+from numbers import Number
+import importlib
+import itertools
+import functools
+import sys
+
+import torch
+from torch import Tensor
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .dinov2.models.vision_transformer import DinoVisionTransformer
+from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
+from ..utils.geometry_torch import normalized_view_plane_uv
+
+
+class ResidualConvBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int = None,
+ hidden_channels: int = None,
+ kernel_size: int = 3,
+ padding_mode: str = 'replicate',
+ activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
+ in_norm: Literal['group_norm', 'layer_norm', 'instance_norm', 'none'] = 'layer_norm',
+ hidden_norm: Literal['group_norm', 'layer_norm', 'instance_norm'] = 'group_norm',
+ ):
+ super(ResidualConvBlock, self).__init__()
+ if out_channels is None:
+ out_channels = in_channels
+ if hidden_channels is None:
+ hidden_channels = in_channels
+
+ if activation =='relu':
+ activation_cls = nn.ReLU
+ elif activation == 'leaky_relu':
+ activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2)
+ elif activation =='silu':
+ activation_cls = nn.SiLU
+ elif activation == 'elu':
+ activation_cls = nn.ELU
+ else:
+ raise ValueError(f'Unsupported activation function: {activation}')
+
+ self.layers = nn.Sequential(
+ nn.GroupNorm(in_channels // 32, in_channels) if in_norm == 'group_norm' else \
+ nn.GroupNorm(1, in_channels) if in_norm == 'layer_norm' else \
+ nn.InstanceNorm2d(in_channels) if in_norm == 'instance_norm' else \
+ nn.Identity(),
+ activation_cls(),
+ nn.Conv2d(in_channels, hidden_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode),
+ nn.GroupNorm(hidden_channels // 32, hidden_channels) if hidden_norm == 'group_norm' else \
+ nn.GroupNorm(1, hidden_channels) if hidden_norm == 'layer_norm' else \
+ nn.InstanceNorm2d(hidden_channels) if hidden_norm == 'instance_norm' else\
+ nn.Identity(),
+ activation_cls(),
+ nn.Conv2d(hidden_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode)
+ )
+
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
+
+ def forward(self, x):
+ skip = self.skip_connection(x)
+ x = self.layers(x)
+ x = x + skip
+ return x
+
+
+class DINOv2Encoder(nn.Module):
+ "Wrapped DINOv2 encoder supporting gradient checkpointing. Input is RGB image in range [0, 1]."
+ backbone: DinoVisionTransformer
+ image_mean: torch.Tensor
+ image_std: torch.Tensor
+ dim_features: int
+
+ def __init__(self, backbone: str, intermediate_layers: Union[int, List[int]], dim_out: int, **deprecated_kwargs):
+ super(DINOv2Encoder, self).__init__()
+
+ self.intermediate_layers = intermediate_layers
+
+ # Load the backbone
+ self.hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), backbone)
+ self.backbone_name = backbone
+ self.backbone = self.hub_loader(pretrained=False)
+
+ self.dim_features = self.backbone.blocks[0].attn.qkv.in_features
+ self.num_features = intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers)
+
+ self.output_projections = nn.ModuleList([
+ nn.Conv2d(in_channels=self.dim_features, out_channels=dim_out, kernel_size=1, stride=1, padding=0,)
+ for _ in range(self.num_features)
+ ])
+
+ self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def init_weights(self):
+ pretrained_backbone_state_dict = self.hub_loader(pretrained=True).state_dict()
+ self.backbone.load_state_dict(pretrained_backbone_state_dict)
+
+ def enable_gradient_checkpointing(self):
+ for i in range(len(self.backbone.blocks)):
+ wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
+
+ def enable_pytorch_native_sdpa(self):
+ for i in range(len(self.backbone.blocks)):
+ wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
+
+ def forward(self, image: torch.Tensor, token_rows: int, token_cols: int, return_class_token: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
+ image_14 = F.interpolate(image, (token_rows * 14, token_cols * 14), mode="bilinear", align_corners=False, antialias=True)
+ image_14 = (image_14 - self.image_mean) / self.image_std
+
+ # Get intermediate layers from the backbone
+ features = self.backbone.get_intermediate_layers(image_14, n=self.intermediate_layers, return_class_token=True)
+
+ # Project features to the desired dimensionality
+ x = torch.stack([
+ proj(feat.permute(0, 2, 1).unflatten(2, (token_rows, token_cols)).contiguous())
+ for proj, (feat, clstoken) in zip(self.output_projections, features)
+ ], dim=1).sum(dim=1)
+
+ if return_class_token:
+ return x, features[-1][1]
+ else:
+ return x
+
+
+class Resampler(nn.Sequential):
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ type_: Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'],
+ scale_factor: int = 2,
+ ):
+ if type_ == 'pixel_shuffle':
+ nn.Sequential.__init__(self,
+ nn.Conv2d(in_channels, out_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
+ nn.PixelShuffle(scale_factor),
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ for i in range(1, scale_factor ** 2):
+ self[0].weight.data[i::scale_factor ** 2] = self[0].weight.data[0::scale_factor ** 2]
+ self[0].bias.data[i::scale_factor ** 2] = self[0].bias.data[0::scale_factor ** 2]
+ elif type_ in ['nearest', 'bilinear']:
+ nn.Sequential.__init__(self,
+ nn.Upsample(scale_factor=scale_factor, mode=type_, align_corners=False if type_ == 'bilinear' else None),
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ elif type_ == 'conv_transpose':
+ nn.Sequential.__init__(self,
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=scale_factor, stride=scale_factor),
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1]
+ elif type_ == 'pixel_unshuffle':
+ nn.Sequential.__init__(self,
+ nn.PixelUnshuffle(scale_factor),
+ nn.Conv2d(in_channels * (scale_factor ** 2), out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ elif type_ == 'avg_pool':
+ nn.Sequential.__init__(self,
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
+ nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor),
+ )
+ elif type_ == 'max_pool':
+ nn.Sequential.__init__(self,
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
+ nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor),
+ )
+ else:
+ raise ValueError(f'Unsupported resampler type: {type_}')
+
+class MLP(nn.Sequential):
+ def __init__(self, dims: Sequence[int]):
+ nn.Sequential.__init__(self,
+ *itertools.chain(*[
+ (nn.Linear(dim_in, dim_out), nn.ReLU(inplace=True))
+ for dim_in, dim_out in zip(dims[:-2], dims[1:-1])
+ ]),
+ nn.Linear(dims[-2], dims[-1]),
+ )
+
+
+class ConvStack(nn.Module):
+ def __init__(self,
+ dim_in: List[Optional[int]],
+ dim_res_blocks: List[int],
+ dim_out: List[Optional[int]],
+ resamplers: Union[Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'], List],
+ dim_times_res_block_hidden: int = 1,
+ num_res_blocks: int = 1,
+ res_block_in_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'layer_norm',
+ res_block_hidden_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'group_norm',
+ activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
+ ):
+ super().__init__()
+ self.input_blocks = nn.ModuleList([
+ nn.Conv2d(dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0) if dim_in_ is not None else nn.Identity()
+ for dim_in_, dim_res_block_ in zip(dim_in if isinstance(dim_in, Sequence) else itertools.repeat(dim_in), dim_res_blocks)
+ ])
+ self.resamplers = nn.ModuleList([
+ Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler)
+ for i, (dim_prev, dim_succ, resampler) in enumerate(zip(
+ dim_res_blocks[:-1],
+ dim_res_blocks[1:],
+ resamplers if isinstance(resamplers, Sequence) else itertools.repeat(resamplers)
+ ))
+ ])
+ self.res_blocks = nn.ModuleList([
+ nn.Sequential(
+ *(
+ ResidualConvBlock(
+ dim_res_block_, dim_res_block_, dim_times_res_block_hidden * dim_res_block_,
+ activation=activation, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm
+ ) for _ in range(num_res_blocks[i] if isinstance(num_res_blocks, list) else num_res_blocks)
+ )
+ ) for i, dim_res_block_ in enumerate(dim_res_blocks)
+ ])
+ self.output_blocks = nn.ModuleList([
+ nn.Conv2d(dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0) if dim_out_ is not None else nn.Identity()
+ for dim_out_, dim_res_block_ in zip(dim_out if isinstance(dim_out, Sequence) else itertools.repeat(dim_out), dim_res_blocks)
+ ])
+
+ def enable_gradient_checkpointing(self):
+ for i in range(len(self.resamplers)):
+ self.resamplers[i] = wrap_module_with_gradient_checkpointing(self.resamplers[i])
+ for i in range(len(self.res_blocks)):
+ for j in range(len(self.res_blocks[i])):
+ self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing(self.res_blocks[i][j])
+
+ def forward(self, in_features: List[torch.Tensor]):
+ batch_shape = in_features[0].shape[:-3]
+ in_features = [x.reshape(-1, *x.shape[-3:]) for x in in_features]
+
+ out_features = []
+ for i in range(len(self.res_blocks)):
+ feature = self.input_blocks[i](in_features[i])
+ if i == 0:
+ x = feature
+ elif feature is not None:
+ x = x + feature
+ x = self.res_blocks[i](x)
+ out_features.append(self.output_blocks[i](x))
+ if i < len(self.res_blocks) - 1:
+ x = self.resamplers[i](x)
+
+ out_features = [x.unflatten(0, batch_shape) for x in out_features]
+ return out_features
diff --git a/models/moge/model/utils.py b/models/moge/model/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c50761d8740d9d0a0284e129503b8931c6fe08c4
--- /dev/null
+++ b/models/moge/model/utils.py
@@ -0,0 +1,49 @@
+from typing import *
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+def wrap_module_with_gradient_checkpointing(module: nn.Module):
+ from torch.utils.checkpoint import checkpoint
+ class _CheckpointingWrapper(module.__class__):
+ _restore_cls = module.__class__
+ def forward(self, *args, **kwargs):
+ return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
+
+ module.__class__ = _CheckpointingWrapper
+ return module
+
+
+def unwrap_module_with_gradient_checkpointing(module: nn.Module):
+ module.__class__ = module.__class__._restore_cls
+
+
+def wrap_dinov2_attention_with_sdpa(module: nn.Module):
+ assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later"
+ class _AttentionWrapper(module.__class__):
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
+
+ q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
+
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+ module.__class__ = _AttentionWrapper
+ return module
+
+
+def sync_ddp_hook(state, bucket: torch.distributed.GradBucket) -> torch.futures.Future[torch.Tensor]:
+ group_to_use = torch.distributed.group.WORLD
+ world_size = group_to_use.size()
+ grad = bucket.buffer()
+ grad.div_(world_size)
+ torch.distributed.all_reduce(grad, group=group_to_use)
+ fut = torch.futures.Future()
+ fut.set_result(grad)
+ return fut
diff --git a/models/moge/model/v1.py b/models/moge/model/v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c1d850507a54239df4a117d216a939df3939ce6
--- /dev/null
+++ b/models/moge/model/v1.py
@@ -0,0 +1,393 @@
+from typing import *
+from numbers import Number
+from functools import partial
+from pathlib import Path
+import importlib
+import warnings
+import json
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils
+import torch.utils.checkpoint
+import torch.version
+import utils3d
+from huggingface_hub import hf_hub_download
+
+
+from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, gaussian_blur_2d, dilate_with_mask
+from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
+from ..utils.tools import timeit
+
+
+class ResidualConvBlock(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int = None, hidden_channels: int = None, padding_mode: str = 'replicate', activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', norm: Literal['group_norm', 'layer_norm'] = 'group_norm'):
+ super(ResidualConvBlock, self).__init__()
+ if out_channels is None:
+ out_channels = in_channels
+ if hidden_channels is None:
+ hidden_channels = in_channels
+
+ if activation =='relu':
+ activation_cls = lambda: nn.ReLU(inplace=True)
+ elif activation == 'leaky_relu':
+ activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ elif activation =='silu':
+ activation_cls = lambda: nn.SiLU(inplace=True)
+ elif activation == 'elu':
+ activation_cls = lambda: nn.ELU(inplace=True)
+ else:
+ raise ValueError(f'Unsupported activation function: {activation}')
+
+ self.layers = nn.Sequential(
+ nn.GroupNorm(1, in_channels),
+ activation_cls(),
+ nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode),
+ nn.GroupNorm(hidden_channels // 32 if norm == 'group_norm' else 1, hidden_channels),
+ activation_cls(),
+ nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode)
+ )
+
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
+
+ def forward(self, x):
+ skip = self.skip_connection(x)
+ x = self.layers(x)
+ x = x + skip
+ return x
+
+
+class Head(nn.Module):
+ def __init__(
+ self,
+ num_features: int,
+ dim_in: int,
+ dim_out: List[int],
+ dim_proj: int = 512,
+ dim_upsample: List[int] = [256, 128, 128],
+ dim_times_res_block_hidden: int = 1,
+ num_res_blocks: int = 1,
+ res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
+ last_res_blocks: int = 0,
+ last_conv_channels: int = 32,
+ last_conv_size: int = 1
+ ):
+ super().__init__()
+
+ self.projects = nn.ModuleList([
+ nn.Conv2d(in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0,) for _ in range(num_features)
+ ])
+
+ self.upsample_blocks = nn.ModuleList([
+ nn.Sequential(
+ self._make_upsampler(in_ch + 2, out_ch),
+ *(ResidualConvBlock(out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm) for _ in range(num_res_blocks))
+ ) for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample)
+ ])
+
+ self.output_block = nn.ModuleList([
+ self._make_output_block(
+ dim_upsample[-1] + 2, dim_out_, dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm,
+ ) for dim_out_ in dim_out
+ ])
+
+ def _make_upsampler(self, in_channels: int, out_channels: int):
+ upsampler = nn.Sequential(
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1]
+ return upsampler
+
+ def _make_output_block(self, dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal['group_norm', 'layer_norm']):
+ return nn.Sequential(
+ nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
+ *(ResidualConvBlock(last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation='relu', norm=res_block_norm) for _ in range(last_res_blocks)),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode='replicate'),
+ )
+
+ def forward(self, hidden_states: torch.Tensor, image: torch.Tensor):
+ img_h, img_w = image.shape[-2:]
+ patch_h, patch_w = img_h // 14, img_w // 14
+
+ # Process the hidden states
+ x = torch.stack([
+ proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous())
+ for proj, (feat, clstoken) in zip(self.projects, hidden_states)
+ ], dim=1).sum(dim=1)
+
+ # Upsample stage
+ # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8)
+ for i, block in enumerate(self.upsample_blocks):
+ # UV coordinates is for awareness of image aspect ratio
+ uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
+ x = torch.cat([x, uv], dim=1)
+ for layer in block:
+ x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
+
+ # (patch_h * 8, patch_w * 8) -> (img_h, img_w)
+ x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
+ uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
+ x = torch.cat([x, uv], dim=1)
+
+ if isinstance(self.output_block, nn.ModuleList):
+ output = [torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) for block in self.output_block]
+ else:
+ output = torch.utils.checkpoint.checkpoint(self.output_block, x, use_reentrant=False)
+
+ return output
+
+
+class MoGeModel(nn.Module):
+ image_mean: torch.Tensor
+ image_std: torch.Tensor
+
+ def __init__(self,
+ encoder: str = 'dinov2_vitb14',
+ intermediate_layers: Union[int, List[int]] = 4,
+ dim_proj: int = 512,
+ dim_upsample: List[int] = [256, 128, 128],
+ dim_times_res_block_hidden: int = 1,
+ num_res_blocks: int = 1,
+ remap_output: Literal[False, True, 'linear', 'sinh', 'exp', 'sinh_exp'] = 'linear',
+ res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
+ num_tokens_range: Tuple[Number, Number] = [1200, 2500],
+ last_res_blocks: int = 0,
+ last_conv_channels: int = 32,
+ last_conv_size: int = 1,
+ mask_threshold: float = 0.5,
+ **deprecated_kwargs
+ ):
+ super(MoGeModel, self).__init__()
+
+ if deprecated_kwargs:
+ # Process legacy arguments
+ if 'trained_area_range' in deprecated_kwargs:
+ num_tokens_range = [deprecated_kwargs['trained_area_range'][0] // 14 ** 2, deprecated_kwargs['trained_area_range'][1] // 14 ** 2]
+ del deprecated_kwargs['trained_area_range']
+ warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
+
+ self.encoder = encoder
+ self.remap_output = remap_output
+ self.intermediate_layers = intermediate_layers
+ self.num_tokens_range = num_tokens_range
+ self.mask_threshold = mask_threshold
+
+ # NOTE: We have copied the DINOv2 code in torchhub to this repository.
+ # Minimal modifications have been made: removing irrelevant code, unnecessary warnings and fixing importing issues.
+ hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder)
+ self.backbone = hub_loader(pretrained=False)
+ dim_feature = self.backbone.blocks[0].attn.qkv.in_features
+
+ self.head = Head(
+ num_features=intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers),
+ dim_in=dim_feature,
+ dim_out=[3, 1],
+ dim_proj=dim_proj,
+ dim_upsample=dim_upsample,
+ dim_times_res_block_hidden=dim_times_res_block_hidden,
+ num_res_blocks=num_res_blocks,
+ res_block_norm=res_block_norm,
+ last_res_blocks=last_res_blocks,
+ last_conv_channels=last_conv_channels,
+ last_conv_size=last_conv_size
+ )
+
+ image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
+ image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
+
+ self.register_buffer("image_mean", image_mean)
+ self.register_buffer("image_std", image_std)
+
+ @property
+ def device(self) -> torch.device:
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return next(self.parameters()).dtype
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel':
+ """
+ Load a model from a checkpoint file.
+
+ ### Parameters:
+ - `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
+ - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
+ - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
+
+ ### Returns:
+ - A new instance of `MoGe` with the parameters loaded from the checkpoint.
+ """
+ if Path(pretrained_model_name_or_path).exists():
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location='cpu', weights_only=True)
+ else:
+ cached_checkpoint_path = hf_hub_download(
+ repo_id=pretrained_model_name_or_path,
+ repo_type="model",
+ filename="model.pt",
+ **hf_kwargs
+ )
+ checkpoint = torch.load(cached_checkpoint_path, map_location='cpu', weights_only=True)
+ model_config = checkpoint['model_config']
+ if model_kwargs is not None:
+ model_config.update(model_kwargs)
+ model = cls(**model_config)
+ model.load_state_dict(checkpoint['model'])
+ return model
+
+ def init_weights(self):
+ "Load the backbone with pretrained dinov2 weights from torch hub"
+ state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict()
+ self.backbone.load_state_dict(state_dict)
+
+ def enable_gradient_checkpointing(self):
+ for i in range(len(self.backbone.blocks)):
+ self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
+
+ def _remap_points(self, points: torch.Tensor) -> torch.Tensor:
+ if self.remap_output == 'linear':
+ pass
+ elif self.remap_output =='sinh':
+ points = torch.sinh(points)
+ elif self.remap_output == 'exp':
+ xy, z = points.split([2, 1], dim=-1)
+ z = torch.exp(z)
+ points = torch.cat([xy * z, z], dim=-1)
+ elif self.remap_output =='sinh_exp':
+ xy, z = points.split([2, 1], dim=-1)
+ points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
+ else:
+ raise ValueError(f"Invalid remap output type: {self.remap_output}")
+ return points
+
+ def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
+ original_height, original_width = image.shape[-2:]
+
+ # Resize to expected resolution defined by num_tokens
+ resize_factor = ((num_tokens * 14 ** 2) / (original_height * original_width)) ** 0.5
+ resized_width, resized_height = int(original_width * resize_factor), int(original_height * resize_factor)
+ image = F.interpolate(image, (resized_height, resized_width), mode="bicubic", align_corners=False, antialias=True)
+
+ # Apply image transformation for DINOv2
+ image = (image - self.image_mean) / self.image_std
+ image_14 = F.interpolate(image, (resized_height // 14 * 14, resized_width // 14 * 14), mode="bilinear", align_corners=False, antialias=True)
+
+ # Get intermediate layers from the backbone
+ features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True)
+
+ # Predict points (and mask)
+ output = self.head(features, image)
+ points, mask = output
+
+ # Make sure fp32 precision for output
+ with torch.autocast(device_type=image.device.type, dtype=torch.float32):
+ # Resize to original resolution
+ points = F.interpolate(points, (original_height, original_width), mode='bilinear', align_corners=False, antialias=False)
+ mask = F.interpolate(mask, (original_height, original_width), mode='bilinear', align_corners=False, antialias=False)
+
+ # Post-process points and mask
+ points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
+ points = self._remap_points(points) # slightly improves the performance in case of very large output values
+
+ return_dict = {'points': points, 'mask': mask}
+ return return_dict
+
+ @torch.inference_mode()
+ def infer(
+ self,
+ image: torch.Tensor,
+ fov_x: Union[Number, torch.Tensor] = None,
+ resolution_level: int = 9,
+ num_tokens: int = None,
+ apply_mask: bool = True,
+ force_projection: bool = True,
+ use_fp16: bool = True,
+ ) -> Dict[str, torch.Tensor]:
+ """
+ User-friendly inference function
+
+ ### Parameters
+ - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)\
+ - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
+ - `resolution_level`: An integer [0-9] for the resolution level for inference.
+ The higher, the finer details will be captured, but slower. Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size.
+ `resolution_level` actually controls `num_tokens`. See `num_tokens` for more details.
+ - `num_tokens`: number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`.
+ `resolution_level` will be ignored if `num_tokens` is provided. Default: None
+ - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
+ - `force_projection`: if True, the output point map will be recomputed to match the projection constraint. Default: True
+ - `use_fp16`: if True, use mixed precision to speed up inference. Default: True
+
+ ### Returns
+
+ A dictionary containing the following keys:
+ - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
+ - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
+ - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
+ """
+ if image.dim() == 3:
+ omit_batch_dim = True
+ image = image.unsqueeze(0)
+ else:
+ omit_batch_dim = False
+ image = image.to(dtype=self.dtype, device=self.device)
+
+ original_height, original_width = image.shape[-2:]
+ aspect_ratio = original_width / original_height
+
+ if num_tokens is None:
+ min_tokens, max_tokens = self.num_tokens_range
+ num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))
+
+ with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16):
+ output = self.forward(image, num_tokens)
+ points, mask = output['points'], output['mask']
+
+ # Always process the output in fp32 precision
+ with torch.autocast(device_type=self.device.type, dtype=torch.float32):
+ points, mask, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, mask, fov_x])
+
+ mask_binary = mask > self.mask_threshold
+
+ # Get camera-space point map. (Focal here is the focal length relative to half the image diagonal)
+ if fov_x is None:
+ focal, shift = recover_focal_shift(points, mask_binary)
+ else:
+ focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2))
+ if focal.ndim == 0:
+ focal = focal[None].expand(points.shape[0])
+ _, shift = recover_focal_shift(points, mask_binary, focal=focal)
+ fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio
+ fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
+ intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
+ depth = points[..., 2] + shift[..., None, None]
+
+ # If projection constraint is forced, recompute the point map using the actual depth map
+ if force_projection:
+ points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics)
+ else:
+ points = points + torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)[..., None, None, :]
+
+ # Apply mask if needed
+ if apply_mask:
+ points = torch.where(mask_binary[..., None], points, torch.inf)
+ depth = torch.where(mask_binary, depth, torch.inf)
+
+ return_dict = {
+ 'points': points,
+ 'intrinsics': intrinsics,
+ 'depth': depth,
+ 'mask': mask_binary,
+ "mask_prob": mask,
+ }
+
+ if omit_batch_dim:
+ return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
+
+ return return_dict
\ No newline at end of file
diff --git a/models/moge/model/v2.py b/models/moge/model/v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cb183584376f824992df003cbafab8ee0b94947
--- /dev/null
+++ b/models/moge/model/v2.py
@@ -0,0 +1,291 @@
+from typing import *
+from numbers import Number
+from functools import partial
+from pathlib import Path
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils
+import torch.utils.checkpoint
+import torch.amp
+import torch.version
+import utils3d
+from huggingface_hub import hf_hub_download
+
+from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, angle_diff_vec3
+from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
+from .modules import DINOv2Encoder, MLP, ConvStack
+
+
+class MoGeModel(nn.Module):
+ encoder: DINOv2Encoder
+ neck: ConvStack
+ points_head: ConvStack
+ mask_head: ConvStack
+ scale_head: MLP
+
+ def __init__(self,
+ encoder: Dict[str, Any],
+ neck: Dict[str, Any],
+ points_head: Dict[str, Any] = None,
+ mask_head: Dict[str, Any] = None,
+ normal_head: Dict[str, Any] = None,
+ scale_head: Dict[str, Any] = None,
+ remap_output: Literal['linear', 'sinh', 'exp', 'sinh_exp'] = 'linear',
+ num_tokens_range: List[int] = [1200, 3600],
+ **deprecated_kwargs
+ ):
+ super(MoGeModel, self).__init__()
+ if deprecated_kwargs:
+ warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
+
+ self.remap_output = remap_output
+ self.num_tokens_range = num_tokens_range
+
+ self.encoder = DINOv2Encoder(**encoder)
+ self.neck = ConvStack(**neck)
+ if points_head is not None:
+ self.points_head = ConvStack(**points_head)
+ if mask_head is not None:
+ self.mask_head = ConvStack(**mask_head)
+ if normal_head is not None:
+ self.normal_head = ConvStack(**normal_head)
+ if scale_head is not None:
+ self.scale_head = MLP(**scale_head)
+
+ @property
+ def device(self) -> torch.device:
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return next(self.parameters()).dtype
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel':
+ """
+ Load a model from a checkpoint file.
+
+ ### Parameters:
+ - `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
+ - `compiled`
+ - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
+ - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
+
+ ### Returns:
+ - A new instance of `MoGe` with the parameters loaded from the checkpoint.
+ """
+ if Path(pretrained_model_name_or_path).exists():
+ checkpoint_path = pretrained_model_name_or_path
+ else:
+ checkpoint_path = hf_hub_download(
+ repo_id=pretrained_model_name_or_path,
+ repo_type="model",
+ filename="model.pt",
+ **hf_kwargs
+ )
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
+
+ model_config = checkpoint['model_config']
+ if model_kwargs is not None:
+ model_config.update(model_kwargs)
+ model = cls(**model_config)
+ model.load_state_dict(checkpoint['model'], strict=False)
+
+ return model
+
+ def init_weights(self):
+ self.encoder.init_weights()
+
+ def enable_gradient_checkpointing(self):
+ self.encoder.enable_gradient_checkpointing()
+ self.neck.enable_gradient_checkpointing()
+ for head in ['points_head', 'normal_head', 'mask_head']:
+ if hasattr(self, head):
+ getattr(self, head).enable_gradient_checkpointing()
+
+ def enable_pytorch_native_sdpa(self):
+ self.encoder.enable_pytorch_native_sdpa()
+
+ def _remap_points(self, points: torch.Tensor) -> torch.Tensor:
+ if self.remap_output == 'linear':
+ pass
+ elif self.remap_output =='sinh':
+ points = torch.sinh(points)
+ elif self.remap_output == 'exp':
+ xy, z = points.split([2, 1], dim=-1)
+ z = torch.exp(z)
+ points = torch.cat([xy * z, z], dim=-1)
+ elif self.remap_output =='sinh_exp':
+ xy, z = points.split([2, 1], dim=-1)
+ points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
+ else:
+ raise ValueError(f"Invalid remap output type: {self.remap_output}")
+ return points
+
+ def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
+ batch_size, _, img_h, img_w = image.shape
+ device, dtype = image.device, image.dtype
+
+ aspect_ratio = img_w / img_h
+ base_h, base_w = int((num_tokens / aspect_ratio) ** 0.5), int((num_tokens * aspect_ratio) ** 0.5)
+ num_tokens = base_h * base_w
+
+ # Backbones encoding
+ features, cls_token = self.encoder(image, base_h, base_w, return_class_token=True)
+ features = [features, None, None, None, None]
+
+ # Concat UVs for aspect ratio input
+ for level in range(5):
+ uv = normalized_view_plane_uv(width=base_w * 2 ** level, height=base_h * 2 ** level, aspect_ratio=aspect_ratio, dtype=dtype, device=device)
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1)
+ if features[level] is None:
+ features[level] = uv
+ else:
+ features[level] = torch.concat([features[level], uv], dim=1)
+
+ # Shared neck
+ features = self.neck(features)
+
+ # Heads decoding
+ points, normal, mask = (getattr(self, head)(features)[-1] if hasattr(self, head) else None for head in ['points_head', 'normal_head', 'mask_head'])
+ metric_scale = self.scale_head(cls_token) if hasattr(self, 'scale_head') else None
+
+ # Resize
+ points, normal, mask = (F.interpolate(v, (img_h, img_w), mode='bilinear', align_corners=False, antialias=False) if v is not None else None for v in [points, normal, mask])
+
+ # Remap output
+ if points is not None:
+ points = points.permute(0, 2, 3, 1)
+ points = self._remap_points(points) # slightly improves the performance in case of very large output values
+ if normal is not None:
+ normal = normal.permute(0, 2, 3, 1)
+ normal = F.normalize(normal, dim=-1)
+ if mask is not None:
+ mask = mask.squeeze(1).sigmoid()
+ if metric_scale is not None:
+ metric_scale = metric_scale.squeeze(1).exp()
+
+ return_dict = {
+ 'points': points,
+ 'normal': normal,
+ 'mask': mask,
+ 'metric_scale': metric_scale
+ }
+ return_dict = {k: v for k, v in return_dict.items() if v is not None}
+
+ return return_dict
+
+ @torch.inference_mode()
+ def infer(
+ self,
+ image: torch.Tensor,
+ num_tokens: int = None,
+ resolution_level: int = 9,
+ force_projection: bool = True,
+ apply_mask: Literal[False, True, 'blend'] = True,
+ fov_x: Optional[Union[Number, torch.Tensor]] = None,
+ use_fp16: bool = True,
+ ) -> Dict[str, torch.Tensor]:
+ """
+ User-friendly inference function
+
+ ### Parameters
+ - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)
+ - `num_tokens`: the number of base ViT tokens to use for inference, `'least'` or `'most'` or an integer. Suggested range: 1200 ~ 2500.
+ More tokens will result in significantly higher accuracy and finer details, but slower inference time. Default: `'most'`.
+ - `force_projection`: if True, the output point map will be computed using the actual depth map. Default: True
+ - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
+ - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
+ - `use_fp16`: if True, use mixed precision to speed up inference. Default: True
+
+ ### Returns
+
+ A dictionary containing the following keys:
+ - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
+ - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
+ - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
+ """
+ if image.dim() == 3:
+ omit_batch_dim = True
+ image = image.unsqueeze(0)
+ else:
+ omit_batch_dim = False
+ image = image.to(dtype=self.dtype, device=self.device)
+
+ original_height, original_width = image.shape[-2:]
+ area = original_height * original_width
+ aspect_ratio = original_width / original_height
+
+ # Determine the number of base tokens to use
+ if num_tokens is None:
+ min_tokens, max_tokens = self.num_tokens_range
+ num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))
+
+ # Forward pass
+ with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16):
+ output = self.forward(image, num_tokens=num_tokens)
+ points, normal, mask, metric_scale = (output.get(k, None) for k in ['points', 'normal', 'mask', 'metric_scale'])
+
+ # Always process the output in fp32 precision
+ points, normal, mask, metric_scale, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, normal, mask, metric_scale, fov_x])
+ with torch.autocast(device_type=self.device.type, dtype=torch.float32):
+ if mask is not None:
+ mask_binary = mask > 0.5
+ else:
+ mask_binary = None
+
+ if points is not None:
+ # Convert affine point map to camera-space. Recover depth and intrinsics from point map.
+ # NOTE: Focal here is the focal length relative to half the image diagonal
+ if fov_x is None:
+ # Recover focal and shift from predicted point map
+ focal, shift = recover_focal_shift(points, mask_binary)
+ else:
+ # Focal is known, recover shift only
+ focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2))
+ if focal.ndim == 0:
+ focal = focal[None].expand(points.shape[0])
+ _, shift = recover_focal_shift(points, mask_binary, focal=focal)
+ fx, fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio, focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
+ intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
+ points[..., 2] += shift[..., None, None]
+ if mask_binary is not None:
+ mask_binary &= points[..., 2] > 0 # in case depth is contains negative values (which should never happen in practice)
+ depth = points[..., 2].clone()
+ else:
+ depth, intrinsics = None, None
+
+ # If projection constraint is forced, recompute the point map using the actual depth map & intrinsics
+ if force_projection and depth is not None:
+ points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics)
+
+ # Apply metric scale
+ if metric_scale is not None:
+ if points is not None:
+ points *= metric_scale[:, None, None, None]
+ if depth is not None:
+ depth *= metric_scale[:, None, None]
+
+ # Apply mask
+ if apply_mask and mask_binary is not None:
+ points = torch.where(mask_binary[..., None], points, torch.inf) if points is not None else None
+ depth = torch.where(mask_binary, depth, torch.inf) if depth is not None else None
+ normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal)) if normal is not None else None
+
+ return_dict = {
+ 'points': points,
+ 'intrinsics': intrinsics,
+ 'depth': depth,
+ 'mask': mask_binary,
+ 'normal': normal,
+ "mask_prob": mask,
+ }
+ return_dict = {k: v for k, v in return_dict.items() if v is not None}
+
+ if omit_batch_dim:
+ return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
+
+ return return_dict
diff --git a/models/moge/test/__init__.py b/models/moge/test/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/moge/test/baseline.py b/models/moge/test/baseline.py
new file mode 100644
index 0000000000000000000000000000000000000000..05980aaf96870304534fcec6532225e870351a66
--- /dev/null
+++ b/models/moge/test/baseline.py
@@ -0,0 +1,43 @@
+from typing import *
+
+import click
+import torch
+
+
+class MGEBaselineInterface:
+ """
+ Abstract class for model wrapper to uniformize the interface of loading and inference across different models.
+ """
+ device: torch.device
+
+ @click.command()
+ @staticmethod
+ def load(*args, **kwargs) -> "MGEBaselineInterface":
+ """
+ Customized static method to create an instance of the model wrapper from command line arguments. Decorated by `click.command()`
+ """
+ raise NotImplementedError(f"{type(self).__name__} has not implemented the load method.")
+
+ def infer(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
+ """
+ ### Parameters
+ `image`: [B, 3, H, W] or [3, H, W], RGB values in range [0, 1]
+ `intrinsics`: [B, 3, 3] or [3, 3], camera intrinsics. Optional.
+
+ ### Returns
+ A dictionary containing:
+ - `points_*`. point map output in OpenCV identity camera space.
+ Supported suffixes: `metric`, `scale_invariant`, `affine_invariant`.
+ - `depth_*`. depth map output
+ Supported suffixes: `metric` (in meters), `scale_invariant`, `affine_invariant`.
+ - `disparity_affine_invariant`. affine disparity map output
+ """
+ raise NotImplementedError(f"{type(self).__name__} has not implemented the infer method.")
+
+ def infer_for_evaluation(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
+ """
+ If the model has a special evaluation mode, override this method to provide the evaluation mode inference.
+
+ By default, this method simply calls `infer()`.
+ """
+ return self.infer(image, intrinsics)
\ No newline at end of file
diff --git a/models/moge/test/dataloader.py b/models/moge/test/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..76679829afdf385938b604fa8bb5ef07b2560e7b
--- /dev/null
+++ b/models/moge/test/dataloader.py
@@ -0,0 +1,221 @@
+import os
+from typing import *
+from pathlib import Path
+import math
+
+import numpy as np
+import torch
+from PIL import Image
+import cv2
+import utils3d
+
+from ..utils import pipeline
+from ..utils.geometry_numpy import focal_to_fov_numpy, mask_aware_nearest_resize_numpy, norm3d
+from ..utils.io import *
+from ..utils.tools import timeit
+
+
+class EvalDataLoaderPipeline:
+
+ def __init__(
+ self,
+ path: str,
+ width: int,
+ height: int,
+ split: int = '.index.txt',
+ drop_max_depth: float = 1000.,
+ num_load_workers: int = 4,
+ num_process_workers: int = 8,
+ include_segmentation: bool = False,
+ include_normal: bool = False,
+ depth_to_normal: bool = False,
+ max_segments: int = 100,
+ min_seg_area: int = 1000,
+ depth_unit: str = None,
+ has_sharp_boundary = False,
+ subset: int = None,
+ ):
+ filenames = Path(path).joinpath(split).read_text(encoding='utf-8').splitlines()
+ filenames = filenames[::subset]
+ self.width = width
+ self.height = height
+ self.drop_max_depth = drop_max_depth
+ self.path = Path(path)
+ self.filenames = filenames
+ self.include_segmentation = include_segmentation
+ self.include_normal = include_normal
+ self.max_segments = max_segments
+ self.min_seg_area = min_seg_area
+ self.depth_to_normal = depth_to_normal
+ self.depth_unit = depth_unit
+ self.has_sharp_boundary = has_sharp_boundary
+
+ self.rng = np.random.default_rng(seed=0)
+
+ self.pipeline = pipeline.Sequential([
+ self._generator,
+ pipeline.Parallel([self._load_instance] * num_load_workers),
+ pipeline.Parallel([self._process_instance] * num_process_workers),
+ pipeline.Buffer(4)
+ ])
+
+ def __len__(self):
+ return math.ceil(len(self.filenames))
+
+ def _generator(self):
+ for idx in range(len(self)):
+ yield idx
+
+ def _load_instance(self, idx):
+ if idx >= len(self.filenames):
+ return None
+
+ path = self.path.joinpath(self.filenames[idx])
+
+ instance = {
+ 'filename': self.filenames[idx],
+ 'width': self.width,
+ 'height': self.height,
+ }
+ instance['image'] = read_image(Path(path, 'image.jpg'))
+
+ depth, _ = read_depth(Path(path, 'depth.png')) # ignore depth unit from depth file, use config instead
+ instance.update({
+ 'depth': np.nan_to_num(depth, nan=1, posinf=1, neginf=1),
+ 'depth_mask': np.isfinite(depth),
+ 'depth_mask_inf': np.isinf(depth),
+ })
+
+ if self.include_segmentation:
+ segmentation_mask, segmentation_labels = read_segmentation(Path(path,'segmentation.png'))
+ instance.update({
+ 'segmentation_mask': segmentation_mask,
+ 'segmentation_labels': segmentation_labels,
+ })
+
+ meta = read_meta(Path(path, 'meta.json'))
+ instance['intrinsics'] = np.array(meta['intrinsics'], dtype=np.float32)
+
+ return instance
+
+ def _process_instance(self, instance: dict):
+ if instance is None:
+ return None
+
+ image, depth, depth_mask, intrinsics = instance['image'], instance['depth'], instance['depth_mask'], instance['intrinsics']
+ segmentation_mask, segmentation_labels = instance.get('segmentation_mask', None), instance.get('segmentation_labels', None)
+
+ raw_height, raw_width = image.shape[:2]
+ raw_horizontal, raw_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1])
+ raw_pixel_w, raw_pixel_h = raw_horizontal / raw_width, raw_vertical / raw_height
+ tgt_width, tgt_height = instance['width'], instance['height']
+ tgt_aspect = tgt_width / tgt_height
+
+ # set expected target view field
+ tgt_horizontal = min(raw_horizontal, raw_vertical * tgt_aspect)
+ tgt_vertical = tgt_horizontal / tgt_aspect
+
+ # set target view direction
+ cu, cv = 0.5, 0.5
+ direction = utils3d.numpy.unproject_cv(np.array([[cu, cv]], dtype=np.float32), np.array([1.0], dtype=np.float32), intrinsics=intrinsics)[0]
+ R = utils3d.numpy.rotation_matrix_from_vectors(direction, np.array([0, 0, 1], dtype=np.float32))
+
+ # restrict target view field within the raw view
+ corners = np.array([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32)
+ corners = np.concatenate([corners, np.ones((4, 1), dtype=np.float32)], axis=1) @ (np.linalg.inv(intrinsics).T @ R.T) # corners in viewport's camera plane
+ corners = corners[:, :2] / corners[:, 2:3]
+
+ warp_horizontal, warp_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1])
+ for i in range(4):
+ intersection, _ = utils3d.numpy.ray_intersection(
+ np.array([0., 0.]), np.array([[tgt_aspect, 1.0], [tgt_aspect, -1.0]]),
+ corners[i - 1], corners[i] - corners[i - 1],
+ )
+ warp_horizontal, warp_vertical = min(warp_horizontal, 2 * np.abs(intersection[:, 0]).min()), min(warp_vertical, 2 * np.abs(intersection[:, 1]).min())
+ tgt_horizontal, tgt_vertical = min(tgt_horizontal, warp_horizontal), min(tgt_vertical, warp_vertical)
+
+ # get target view intrinsics
+ fx, fy = 1.0 / tgt_horizontal, 1.0 / tgt_vertical
+ tgt_intrinsics = utils3d.numpy.intrinsics_from_focal_center(fx, fy, 0.5, 0.5).astype(np.float32)
+
+ # do homogeneous transformation with the rotation and intrinsics
+ # 4.1 The image and depth is resized first to approximately the same pixel size as the target image with PIL's antialiasing resampling
+ tgt_pixel_w, tgt_pixel_h = tgt_horizontal / tgt_width, tgt_vertical / tgt_height # (should be exactly the same for x and y axes)
+ rescaled_w, rescaled_h = int(raw_width * raw_pixel_w / tgt_pixel_w), int(raw_height * raw_pixel_h / tgt_pixel_h)
+ image = np.array(Image.fromarray(image).resize((rescaled_w, rescaled_h), Image.Resampling.LANCZOS))
+
+ depth, depth_mask = mask_aware_nearest_resize_numpy(depth, depth_mask, (rescaled_w, rescaled_h))
+ distance = norm3d(utils3d.numpy.depth_to_points(depth, intrinsics=intrinsics))
+ segmentation_mask = cv2.resize(segmentation_mask, (rescaled_w, rescaled_h), interpolation=cv2.INTER_NEAREST) if segmentation_mask is not None else None
+
+ # 4.2 calculate homography warping
+ transform = intrinsics @ np.linalg.inv(R) @ np.linalg.inv(tgt_intrinsics)
+ uv_tgt = utils3d.numpy.image_uv(width=tgt_width, height=tgt_height)
+ pts = np.concatenate([uv_tgt, np.ones((tgt_height, tgt_width, 1), dtype=np.float32)], axis=-1) @ transform.T
+ uv_remap = pts[:, :, :2] / (pts[:, :, 2:3] + 1e-12)
+ pixel_remap = utils3d.numpy.uv_to_pixel(uv_remap, width=rescaled_w, height=rescaled_h).astype(np.float32)
+
+ tgt_image = cv2.remap(image, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR)
+ tgt_distance = cv2.remap(distance, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST)
+ tgt_ray_length = utils3d.numpy.unproject_cv(uv_tgt, np.ones_like(uv_tgt[:, :, 0]), intrinsics=tgt_intrinsics)
+ tgt_ray_length = (tgt_ray_length[:, :, 0] ** 2 + tgt_ray_length[:, :, 1] ** 2 + tgt_ray_length[:, :, 2] ** 2) ** 0.5
+ tgt_depth = tgt_distance / (tgt_ray_length + 1e-12)
+ tgt_depth_mask = cv2.remap(depth_mask.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0
+ tgt_segmentation_mask = cv2.remap(segmentation_mask, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) if segmentation_mask is not None else None
+
+ # drop depth greater than drop_max_depth
+ max_depth = np.nanquantile(np.where(tgt_depth_mask, tgt_depth, np.nan), 0.01) * self.drop_max_depth
+ tgt_depth_mask &= tgt_depth <= max_depth
+ tgt_depth = np.nan_to_num(tgt_depth, nan=0.0)
+
+ if self.depth_unit is not None:
+ tgt_depth *= self.depth_unit
+
+ if not np.any(tgt_depth_mask):
+ # always make sure that mask is not empty, otherwise the loss calculation will crash
+ tgt_depth_mask = np.ones_like(tgt_depth_mask)
+ tgt_depth = np.ones_like(tgt_depth)
+ instance['label_type'] = 'invalid'
+
+ tgt_pts = utils3d.numpy.unproject_cv(uv_tgt, tgt_depth, intrinsics=tgt_intrinsics)
+
+ # Process segmentation labels
+ if self.include_segmentation and segmentation_mask is not None:
+ for k in ['undefined', 'unannotated', 'background', 'sky']:
+ if k in segmentation_labels:
+ del segmentation_labels[k]
+ seg_id2count = dict(zip(*np.unique(tgt_segmentation_mask, return_counts=True)))
+ sorted_labels = sorted(segmentation_labels.keys(), key=lambda x: seg_id2count.get(segmentation_labels[x], 0), reverse=True)
+ segmentation_labels = {k: segmentation_labels[k] for k in sorted_labels[:self.max_segments] if seg_id2count.get(segmentation_labels[k], 0) >= self.min_seg_area}
+
+ instance.update({
+ 'image': torch.from_numpy(tgt_image.astype(np.float32) / 255.0).permute(2, 0, 1),
+ 'depth': torch.from_numpy(tgt_depth).float(),
+ 'depth_mask': torch.from_numpy(tgt_depth_mask).bool(),
+ 'intrinsics': torch.from_numpy(tgt_intrinsics).float(),
+ 'points': torch.from_numpy(tgt_pts).float(),
+ 'segmentation_mask': torch.from_numpy(tgt_segmentation_mask).long() if tgt_segmentation_mask is not None else None,
+ 'segmentation_labels': segmentation_labels,
+ 'is_metric': self.depth_unit is not None,
+ 'has_sharp_boundary': self.has_sharp_boundary,
+ })
+
+ instance = {k: v for k, v in instance.items() if v is not None}
+
+ return instance
+
+ def start(self):
+ self.pipeline.start()
+
+ def stop(self):
+ self.pipeline.stop()
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.stop()
+
+ def get(self):
+ return self.pipeline.get()
\ No newline at end of file
diff --git a/models/moge/test/metrics.py b/models/moge/test/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..904064f2a30d05dca3a53db7ecc076a0c2aaa0ad
--- /dev/null
+++ b/models/moge/test/metrics.py
@@ -0,0 +1,343 @@
+from typing import *
+from numbers import Number
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+import utils3d
+
+from ..utils.geometry_torch import (
+ weighted_mean,
+ mask_aware_nearest_resize,
+ intrinsics_to_fov
+)
+from ..utils.alignment import (
+ align_points_scale_z_shift,
+ align_points_scale_xyz_shift,
+ align_points_xyz_shift,
+ align_affine_lstsq,
+ align_depth_scale,
+ align_depth_affine,
+ align_points_scale,
+)
+from ..utils.tools import key_average, timeit
+
+
+def rel_depth(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6):
+ rel = (torch.abs(pred - gt) / (gt + eps)).mean()
+ return rel.item()
+
+
+def delta1_depth(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6):
+ delta1 = (torch.maximum(gt / pred, pred / gt) < 1.25).float().mean()
+ return delta1.item()
+
+
+def rel_point(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6):
+ dist_gt = torch.norm(gt, dim=-1)
+ dist_err = torch.norm(pred - gt, dim=-1)
+ rel = (dist_err / (dist_gt + eps)).mean()
+ return rel.item()
+
+
+def delta1_point(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6):
+ dist_pred = torch.norm(pred, dim=-1)
+ dist_gt = torch.norm(gt, dim=-1)
+ dist_err = torch.norm(pred - gt, dim=-1)
+
+ delta1 = (dist_err < 0.25 * torch.minimum(dist_gt, dist_pred)).float().mean()
+ return delta1.item()
+
+
+def rel_point_local(pred: torch.Tensor, gt: torch.Tensor, diameter: torch.Tensor):
+ dist_err = torch.norm(pred - gt, dim=-1)
+ rel = (dist_err / diameter).mean()
+ return rel.item()
+
+
+def delta1_point_local(pred: torch.Tensor, gt: torch.Tensor, diameter: torch.Tensor):
+ dist_err = torch.norm(pred - gt, dim=-1)
+ delta1 = (dist_err < 0.25 * diameter).float().mean()
+ return delta1.item()
+
+
+def boundary_f1(pred: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor, radius: int = 1):
+ neighbor_x, neight_y = torch.meshgrid(
+ torch.linspace(-radius, radius, 2 * radius + 1, device=pred.device),
+ torch.linspace(-radius, radius, 2 * radius + 1, device=pred.device),
+ indexing='xy'
+ )
+ neighbor_mask = (neighbor_x ** 2 + neight_y ** 2) <= radius ** 2 + 1e-5
+
+ pred_window = utils3d.torch.sliding_window_2d(pred, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1]
+ gt_window = utils3d.torch.sliding_window_2d(gt, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1]
+ mask_window = neighbor_mask & utils3d.torch.sliding_window_2d(mask, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1]
+
+ pred_rel = pred_window / pred[radius:-radius, radius:-radius, None, None]
+ gt_rel = gt_window / gt[radius:-radius, radius:-radius, None, None]
+ valid = mask[radius:-radius, radius:-radius, None, None] & mask_window
+
+ f1_list = []
+ w_list = t_list = torch.linspace(0.05, 0.25, 10).tolist()
+
+ for t in t_list:
+ pred_label = pred_rel > 1 + t
+ gt_label = gt_rel > 1 + t
+ TP = (pred_label & gt_label & valid).float().sum()
+ precision = TP / (gt_label & valid).float().sum().clamp_min(1e-12)
+ recall = TP / (pred_label & valid).float().sum().clamp_min(1e-12)
+ f1 = 2 * precision * recall / (precision + recall).clamp_min(1e-12)
+ f1_list.append(f1.item())
+
+ f1_avg = sum(w * f1 for w, f1 in zip(w_list, f1_list)) / sum(w_list)
+ return f1_avg
+
+
+def compute_metrics(
+ pred: Dict[str, torch.Tensor],
+ gt: Dict[str, torch.Tensor],
+ vis: bool = False
+) -> Tuple[Dict[str, Dict[str, Number]], Dict[str, torch.Tensor]]:
+ """
+ A unified function to compute metrics for different types of predictions and ground truths.
+
+ #### Supported keys in pred:
+ - `disparity_affine_invariant`: disparity map predicted by a depth estimator with scale and shift invariant.
+ - `depth_scale_invariant`: depth map predicted by a depth estimator with scale invariant.
+ - `depth_affine_invariant`: depth map predicted by a depth estimator with scale and shift invariant.
+ - `depth_metric`: depth map predicted by a depth estimator with no scale or shift.
+ - `points_scale_invariant`: point map predicted by a point estimator with scale invariant.
+ - `points_affine_invariant`: point map predicted by a point estimator with scale and xyz shift invariant.
+ - `points_metric`: point map predicted by a point estimator with no scale or shift.
+ - `intrinsics`: normalized camera intrinsics matrix.
+
+ #### Required keys in gt:
+ - `depth`: depth map ground truth (in metric units if `depth_metric` is used)
+ - `points`: point map ground truth in camera coordinates.
+ - `mask`: mask indicating valid pixels in the ground truth.
+ - `intrinsics`: normalized ground-truth camera intrinsics matrix.
+ - `is_metric`: whether the depth is in metric units.
+ """
+ metrics = {}
+ misc = {}
+
+ mask = gt['depth_mask']
+ gt_depth = gt['depth']
+ gt_points = gt['points']
+
+ height, width = mask.shape[-2:]
+ _, lr_mask, lr_index = mask_aware_nearest_resize(None, mask, (64, 64), return_index=True)
+
+ only_depth = not any('point' in k for k in pred)
+ pred_depth_aligned, pred_points_aligned = None, None
+
+ # Metric depth
+ if 'depth_metric' in pred and gt['is_metric']:
+ pred_depth, gt_depth = pred['depth_metric'], gt['depth']
+ metrics['depth_metric'] = {
+ 'rel': rel_depth(pred_depth[mask], gt_depth[mask]),
+ 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask])
+ }
+
+ if pred_depth_aligned is None:
+ pred_depth_aligned = pred_depth
+
+ # Scale-invariant depth
+ if 'depth_scale_invariant' in pred:
+ pred_depth_scale_invariant = pred['depth_scale_invariant']
+ elif 'depth_metric' in pred:
+ pred_depth_scale_invariant = pred['depth_metric']
+ else:
+ pred_depth_scale_invariant = None
+
+ if pred_depth_scale_invariant is not None:
+ pred_depth = pred_depth_scale_invariant
+
+ pred_depth_lr_masked, gt_depth_lr_masked = pred_depth[lr_index][lr_mask], gt_depth[lr_index][lr_mask]
+ scale = align_depth_scale(pred_depth_lr_masked, gt_depth_lr_masked, 1 / gt_depth_lr_masked)
+ pred_depth = pred_depth * scale
+
+ metrics['depth_scale_invariant'] = {
+ 'rel': rel_depth(pred_depth[mask], gt_depth[mask]),
+ 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask])
+ }
+
+ if pred_depth_aligned is None:
+ pred_depth_aligned = pred_depth
+
+ # Affine-invariant depth
+ if 'depth_affine_invariant' in pred:
+ pred_depth_affine_invariant = pred['depth_affine_invariant']
+ elif 'depth_scale_invariant' in pred:
+ pred_depth_affine_invariant = pred['depth_scale_invariant']
+ elif 'depth_metric' in pred:
+ pred_depth_affine_invariant = pred['depth_metric']
+ else:
+ pred_depth_affine_invariant = None
+
+ if pred_depth_affine_invariant is not None:
+ pred_depth = pred_depth_affine_invariant
+
+ pred_depth_lr_masked, gt_depth_lr_masked = pred_depth[lr_index][lr_mask], gt_depth[lr_index][lr_mask]
+ scale, shift = align_depth_affine(pred_depth_lr_masked, gt_depth_lr_masked, 1 / gt_depth_lr_masked)
+ pred_depth = pred_depth * scale + shift
+
+ metrics['depth_affine_invariant'] = {
+ 'rel': rel_depth(pred_depth[mask], gt_depth[mask]),
+ 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask])
+ }
+
+ if pred_depth_aligned is None:
+ pred_depth_aligned = pred_depth
+
+ # Affine-invariant disparity
+ if 'disparity_affine_invariant' in pred:
+ pred_disparity_affine_invariant = pred['disparity_affine_invariant']
+ elif 'depth_scale_invariant' in pred:
+ pred_disparity_affine_invariant = 1 / pred['depth_scale_invariant']
+ elif 'depth_metric' in pred:
+ pred_disparity_affine_invariant = 1 / pred['depth_metric']
+ else:
+ pred_disparity_affine_invariant = None
+
+ if pred_disparity_affine_invariant is not None:
+ pred_disp = pred_disparity_affine_invariant
+
+ scale, shift = align_affine_lstsq(pred_disp[mask], 1 / gt_depth[mask])
+ pred_disp = pred_disp * scale + shift
+
+ # NOTE: The alignment is done on the disparity map could introduce extreme outliers at disparities close to 0.
+ # Therefore we clamp the disparities by minimum ground truth disparity.
+ pred_depth = 1 / pred_disp.clamp_min(1 / gt_depth[mask].max().item())
+
+ metrics['disparity_affine_invariant'] = {
+ 'rel': rel_depth(pred_depth[mask], gt_depth[mask]),
+ 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask])
+ }
+
+ if pred_depth_aligned is None:
+ pred_depth_aligned = 1 / pred_disp.clamp_min(1e-6)
+
+ # Metric points
+ if 'points_metric' in pred and gt['is_metric']:
+ pred_points = pred['points_metric']
+
+ pred_points_lr_masked, gt_points_lr_masked = pred_points[lr_index][lr_mask], gt_points[lr_index][lr_mask]
+ shift = align_points_xyz_shift(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1))
+ pred_points = pred_points + shift
+
+ metrics['points_metric'] = {
+ 'rel': rel_point(pred_points[mask], gt_points[mask]),
+ 'delta1': delta1_point(pred_points[mask], gt_points[mask])
+ }
+
+ if pred_points_aligned is None:
+ pred_points_aligned = pred['points_metric']
+
+ # Scale-invariant points (in camera space)
+ if 'points_scale_invariant' in pred:
+ pred_points_scale_invariant = pred['points_scale_invariant']
+ elif 'points_metric' in pred:
+ pred_points_scale_invariant = pred['points_metric']
+ else:
+ pred_points_scale_invariant = None
+
+ if pred_points_scale_invariant is not None:
+ pred_points = pred_points_scale_invariant
+
+ pred_points_lr_masked, gt_points_lr_masked = pred_points_scale_invariant[lr_index][lr_mask], gt_points[lr_index][lr_mask]
+ scale = align_points_scale(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1))
+ pred_points = pred_points * scale
+
+ metrics['points_scale_invariant'] = {
+ 'rel': rel_point(pred_points[mask], gt_points[mask]),
+ 'delta1': delta1_point(pred_points[mask], gt_points[mask])
+ }
+
+ if vis and pred_points_aligned is None:
+ pred_points_aligned = pred['points_scale_invariant'] * scale
+
+ # Affine-invariant points
+ if 'points_affine_invariant' in pred:
+ pred_points_affine_invariant = pred['points_affine_invariant']
+ elif 'points_scale_invariant' in pred:
+ pred_points_affine_invariant = pred['points_scale_invariant']
+ elif 'points_metric' in pred:
+ pred_points_affine_invariant = pred['points_metric']
+ else:
+ pred_points_affine_invariant = None
+
+ if pred_points_affine_invariant is not None:
+ pred_points = pred_points_affine_invariant
+
+ pred_points_lr_masked, gt_points_lr_masked = pred_points[lr_index][lr_mask], gt_points[lr_index][lr_mask]
+ scale, shift = align_points_scale_xyz_shift(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1))
+ pred_points = pred_points * scale + shift
+
+ metrics['points_affine_invariant'] = {
+ 'rel': rel_point(pred_points[mask], gt_points[mask]),
+ 'delta1': delta1_point(pred_points[mask], gt_points[mask])
+ }
+
+ if vis and pred_points_aligned is None:
+ pred_points_aligned = pred['points_affine_invariant'] * scale + shift
+
+ # Local points
+ if 'segmentation_mask' in gt and 'points' in gt and any('points' in k for k in pred.keys()):
+ pred_points = next(pred[k] for k in pred.keys() if 'points' in k)
+ gt_points = gt['points']
+ segmentation_mask = gt['segmentation_mask']
+ segmentation_labels = gt['segmentation_labels']
+ segmentation_mask_lr = segmentation_mask[lr_index]
+ local_points_metrics = []
+ for _, seg_id in segmentation_labels.items():
+ valid_mask = (segmentation_mask == seg_id) & mask
+
+ pred_points_masked = pred_points[valid_mask]
+ gt_points_masked = gt_points[valid_mask]
+
+ valid_mask_lr = (segmentation_mask_lr == seg_id) & lr_mask
+ if valid_mask_lr.sum().item() < 10:
+ continue
+ pred_points_masked_lr = pred_points[lr_index][valid_mask_lr]
+ gt_points_masked_lr = gt_points[lr_index][valid_mask_lr]
+ diameter = (gt_points_masked.max(dim=0).values - gt_points_masked.min(dim=0).values).max()
+ scale, shift = align_points_scale_xyz_shift(pred_points_masked_lr, gt_points_masked_lr, 1 / diameter.expand(gt_points_masked_lr.shape[0]))
+ pred_points_masked = pred_points_masked * scale + shift
+
+ local_points_metrics.append({
+ 'rel': rel_point_local(pred_points_masked, gt_points_masked, diameter),
+ 'delta1': delta1_point_local(pred_points_masked, gt_points_masked, diameter),
+ })
+
+ metrics['local_points'] = key_average(local_points_metrics)
+
+ # FOV. NOTE: If there is no random augmentation applied to the input images, all GT FOV are generallly the same.
+ # Fair evaluation of FOV requires random augmentation.
+ if 'intrinsics' in pred and 'intrinsics' in gt:
+ pred_intrinsics = pred['intrinsics']
+ gt_intrinsics = gt['intrinsics']
+ pred_fov_x, pred_fov_y = intrinsics_to_fov(pred_intrinsics)
+ gt_fov_x, gt_fov_y = intrinsics_to_fov(gt_intrinsics)
+ metrics['fov_x'] = {
+ 'mae': torch.rad2deg(pred_fov_x - gt_fov_x).abs().mean().item(),
+ 'deviation': torch.rad2deg(pred_fov_x - gt_fov_x).item(),
+ }
+
+ # Boundary F1
+ if pred_depth_aligned is not None and gt['has_sharp_boundary']:
+ metrics['boundary'] = {
+ 'radius1_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=1),
+ 'radius2_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=2),
+ 'radius3_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=3),
+ }
+
+ if vis:
+ if pred_points_aligned is not None:
+ misc['pred_points'] = pred_points_aligned
+ if only_depth:
+ misc['pred_points'] = utils3d.torch.depth_to_points(pred_depth_aligned, intrinsics=gt['intrinsics'])
+ if pred_depth_aligned is not None:
+ misc['pred_depth'] = pred_depth_aligned
+
+ return metrics, misc
\ No newline at end of file
diff --git a/models/moge/train/__init__.py b/models/moge/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/moge/train/dataloader.py b/models/moge/train/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3bfc280844dac602e89bee747e247946dbc6f67
--- /dev/null
+++ b/models/moge/train/dataloader.py
@@ -0,0 +1,338 @@
+import os
+from pathlib import Path
+import json
+import time
+import random
+from typing import *
+import traceback
+import itertools
+from numbers import Number
+import io
+
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+import torchvision.transforms.v2.functional as TF
+import utils3d
+from tqdm import tqdm
+
+from ..utils import pipeline
+from ..utils.io import *
+from ..utils.geometry_numpy import mask_aware_nearest_resize_numpy, harmonic_mean_numpy, norm3d, depth_occlusion_edge_numpy, depth_of_field
+
+
+class TrainDataLoaderPipeline:
+ def __init__(self, config: dict, batch_size: int, num_load_workers: int = 4, num_process_workers: int = 8, buffer_size: int = 8):
+ self.config = config
+
+ self.batch_size = batch_size
+ self.clamp_max_depth = config['clamp_max_depth']
+ self.fov_range_absolute = config.get('fov_range_absolute', 0.0)
+ self.fov_range_relative = config.get('fov_range_relative', 0.0)
+ self.center_augmentation = config.get('center_augmentation', 0.0)
+ self.image_augmentation = config.get('image_augmentation', [])
+ self.depth_interpolation = config.get('depth_interpolation', 'bilinear')
+
+ if 'image_sizes' in config:
+ self.image_size_strategy = 'fixed'
+ self.image_sizes = config['image_sizes']
+ elif 'aspect_ratio_range' in config and 'area_range' in config:
+ self.image_size_strategy = 'aspect_area'
+ self.aspect_ratio_range = config['aspect_ratio_range']
+ self.area_range = config['area_range']
+ else:
+ raise ValueError('Invalid image size configuration')
+
+ # Load datasets
+ self.datasets = {}
+ for dataset in tqdm(config['datasets'], desc='Loading datasets'):
+ name = dataset['name']
+ content = Path(dataset['path'], dataset.get('index', '.index.txt')).joinpath().read_text()
+ filenames = content.splitlines()
+ self.datasets[name] = {
+ **dataset,
+ 'path': dataset['path'],
+ 'filenames': filenames,
+ }
+ self.dataset_names = [dataset['name'] for dataset in config['datasets']]
+ self.dataset_weights = [dataset['weight'] for dataset in config['datasets']]
+
+ # Build pipeline
+ self.pipeline = pipeline.Sequential([
+ self._sample_batch,
+ pipeline.Unbatch(),
+ pipeline.Parallel([self._load_instance] * num_load_workers),
+ pipeline.Parallel([self._process_instance] * num_process_workers),
+ pipeline.Batch(self.batch_size),
+ self._collate_batch,
+ pipeline.Buffer(buffer_size),
+ ])
+
+ self.invalid_instance = {
+ 'intrinsics': np.array([[1.0, 0.0, 0.5], [0.0, 1.0, 0.5], [0.0, 0.0, 1.0]], dtype=np.float32),
+ 'image': np.zeros((256, 256, 3), dtype=np.uint8),
+ 'depth': np.ones((256, 256), dtype=np.float32),
+ 'depth_mask': np.ones((256, 256), dtype=bool),
+ 'depth_mask_inf': np.zeros((256, 256), dtype=bool),
+ 'label_type': 'invalid',
+ }
+
+ def _sample_batch(self):
+ batch_id = 0
+ last_area = None
+ while True:
+ # Depending on the sample strategy, choose a dataset and a filename
+ batch_id += 1
+ batch = []
+
+ # Sample instances
+ for _ in range(self.batch_size):
+ dataset_name = random.choices(self.dataset_names, weights=self.dataset_weights)[0]
+ filename = random.choice(self.datasets[dataset_name]['filenames'])
+
+ path = Path(self.datasets[dataset_name]['path'], filename)
+
+ instance = {
+ 'batch_id': batch_id,
+ 'seed': random.randint(0, 2 ** 32 - 1),
+ 'dataset': dataset_name,
+ 'filename': filename,
+ 'path': path,
+ 'label_type': self.datasets[dataset_name]['label_type'],
+ }
+ batch.append(instance)
+
+ # Decide the image size for this batch
+ if self.image_size_strategy == 'fixed':
+ width, height = random.choice(self.config['image_sizes'])
+ elif self.image_size_strategy == 'aspect_area':
+ area = random.uniform(*self.area_range)
+ aspect_ratio_ranges = [self.datasets[instance['dataset']].get('aspect_ratio_range', self.aspect_ratio_range) for instance in batch]
+ aspect_ratio_range = (min(r[0] for r in aspect_ratio_ranges), max(r[1] for r in aspect_ratio_ranges))
+ aspect_ratio = random.uniform(*aspect_ratio_range)
+ width, height = int((area * aspect_ratio) ** 0.5), int((area / aspect_ratio) ** 0.5)
+ else:
+ raise ValueError('Invalid image size strategy')
+
+ for instance in batch:
+ instance['width'], instance['height'] = width, height
+
+ yield batch
+
+ def _load_instance(self, instance: dict):
+ try:
+ image = read_image(Path(instance['path'], 'image.jpg'))
+ depth, _ = read_depth(Path(instance['path'], self.datasets[instance['dataset']].get('depth', 'depth.png')))
+
+ meta = read_meta(Path(instance['path'], 'meta.json'))
+ intrinsics = np.array(meta['intrinsics'], dtype=np.float32)
+ depth_mask = np.isfinite(depth)
+ depth_mask_inf = np.isinf(depth)
+ depth = np.nan_to_num(depth, nan=1, posinf=1, neginf=1)
+ data = {
+ 'image': image,
+ 'depth': depth,
+ 'depth_mask': depth_mask,
+ 'depth_mask_inf': depth_mask_inf,
+ 'intrinsics': intrinsics
+ }
+ instance.update({
+ **data,
+ })
+ except Exception as e:
+ print(f"Failed to load instance {instance['dataset']}/{instance['filename']} because of exception:", e)
+ instance.update(self.invalid_instance)
+ return instance
+
+ def _process_instance(self, instance: Dict[str, Union[np.ndarray, str, float, bool]]):
+ image, depth, depth_mask, depth_mask_inf, intrinsics, label_type = instance['image'], instance['depth'], instance['depth_mask'], instance['depth_mask_inf'], instance['intrinsics'], instance['label_type']
+ depth_unit = self.datasets[instance['dataset']].get('depth_unit', None)
+
+ raw_height, raw_width = image.shape[:2]
+ raw_horizontal, raw_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1])
+ raw_fov_x, raw_fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics)
+ raw_pixel_w, raw_pixel_h = raw_horizontal / raw_width, raw_vertical / raw_height
+ tgt_width, tgt_height = instance['width'], instance['height']
+ tgt_aspect = tgt_width / tgt_height
+
+ rng = np.random.default_rng(instance['seed'])
+
+ # 1. set target fov
+ center_augmentation = self.datasets[instance['dataset']].get('center_augmentation', self.center_augmentation)
+ fov_range_absolute_min, fov_range_absolute_max = self.datasets[instance['dataset']].get('fov_range_absolute', self.fov_range_absolute)
+ fov_range_relative_min, fov_range_relative_max = self.datasets[instance['dataset']].get('fov_range_relative', self.fov_range_relative)
+ tgt_fov_x_min = min(fov_range_relative_min * raw_fov_x, fov_range_relative_min * utils3d.focal_to_fov(utils3d.fov_to_focal(raw_fov_y) / tgt_aspect))
+ tgt_fov_x_max = min(fov_range_relative_max * raw_fov_x, fov_range_relative_max * utils3d.focal_to_fov(utils3d.fov_to_focal(raw_fov_y) / tgt_aspect))
+ tgt_fov_x_min, tgt_fov_x_max = max(np.deg2rad(fov_range_absolute_min), tgt_fov_x_min), min(np.deg2rad(fov_range_absolute_max), tgt_fov_x_max)
+ tgt_fov_x = rng.uniform(min(tgt_fov_x_min, tgt_fov_x_max), tgt_fov_x_max)
+ tgt_fov_y = utils3d.focal_to_fov(utils3d.numpy.fov_to_focal(tgt_fov_x) * tgt_aspect)
+
+ # 2. set target image center (principal point) and the corresponding z-direction in raw camera space
+ center_dtheta = center_augmentation * rng.uniform(-0.5, 0.5) * (raw_fov_x - tgt_fov_x)
+ center_dphi = center_augmentation * rng.uniform(-0.5, 0.5) * (raw_fov_y - tgt_fov_y)
+ cu, cv = 0.5 + 0.5 * np.tan(center_dtheta) / np.tan(raw_fov_x / 2), 0.5 + 0.5 * np.tan(center_dphi) / np.tan(raw_fov_y / 2)
+ direction = utils3d.unproject_cv(np.array([[cu, cv]], dtype=np.float32), np.array([1.0], dtype=np.float32), intrinsics=intrinsics)[0]
+
+ # 3. obtain the rotation matrix for homography warping
+ R = utils3d.rotation_matrix_from_vectors(direction, np.array([0, 0, 1], dtype=np.float32))
+
+ # 4. shrink the target view to fit into the warped image
+ corners = np.array([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32)
+ corners = np.concatenate([corners, np.ones((4, 1), dtype=np.float32)], axis=1) @ (np.linalg.inv(intrinsics).T @ R.T) # corners in viewport's camera plane
+ corners = corners[:, :2] / corners[:, 2:3]
+ tgt_horizontal, tgt_vertical = np.tan(tgt_fov_x / 2) * 2, np.tan(tgt_fov_y / 2) * 2
+ warp_horizontal, warp_vertical = float('inf'), float('inf')
+ for i in range(4):
+ intersection, _ = utils3d.numpy.ray_intersection(
+ np.array([0., 0.]), np.array([[tgt_aspect, 1.0], [tgt_aspect, -1.0]]),
+ corners[i - 1], corners[i] - corners[i - 1],
+ )
+ warp_horizontal, warp_vertical = min(warp_horizontal, 2 * np.abs(intersection[:, 0]).min()), min(warp_vertical, 2 * np.abs(intersection[:, 1]).min())
+ tgt_horizontal, tgt_vertical = min(tgt_horizontal, warp_horizontal), min(tgt_vertical, warp_vertical)
+
+ # 5. obtain the target intrinsics
+ fx, fy = 1 / tgt_horizontal, 1 / tgt_vertical
+ tgt_intrinsics = utils3d.numpy.intrinsics_from_focal_center(fx, fy, 0.5, 0.5).astype(np.float32)
+
+ # 6. do homogeneous transformation
+ # 6.1 The image and depth are resized first to approximately the same pixel size as the target image with PIL's antialiasing resampling
+ tgt_pixel_w, tgt_pixel_h = tgt_horizontal / tgt_width, tgt_vertical / tgt_height # (should be exactly the same for x and y axes)
+ rescaled_w, rescaled_h = int(raw_width * raw_pixel_w / tgt_pixel_w), int(raw_height * raw_pixel_h / tgt_pixel_h)
+ image = np.array(Image.fromarray(image).resize((rescaled_w, rescaled_h), Image.Resampling.LANCZOS))
+
+ edge_mask = depth_occlusion_edge_numpy(depth, mask=depth_mask, thickness=2, tol=0.01)
+ _, depth_mask_nearest, resize_index = mask_aware_nearest_resize_numpy(None, depth_mask, (rescaled_w, rescaled_h), return_index=True)
+ depth_nearest = depth[resize_index]
+ distance_nearest = norm3d(utils3d.numpy.depth_to_points(depth_nearest, intrinsics=intrinsics))
+ edge_mask = edge_mask[resize_index]
+
+ if self.depth_interpolation == 'bilinear':
+ depth_mask_bilinear = cv2.resize(depth_mask.astype(np.float32), (rescaled_w, rescaled_h), interpolation=cv2.INTER_LINEAR)
+ depth_bilinear = 1 / cv2.resize(1 / depth, (rescaled_w, rescaled_h), interpolation=cv2.INTER_LINEAR)
+ distance_bilinear = norm3d(utils3d.numpy.depth_to_points(depth_bilinear, intrinsics=intrinsics))
+
+ depth_mask_inf = cv2.resize(depth_mask_inf.astype(np.uint8), (rescaled_w, rescaled_h), interpolation=cv2.INTER_NEAREST) > 0
+
+ # 6.2 calculate homography warping
+ transform = intrinsics @ np.linalg.inv(R) @ np.linalg.inv(tgt_intrinsics)
+ uv_tgt = utils3d.numpy.image_uv(width=tgt_width, height=tgt_height)
+ pts = np.concatenate([uv_tgt, np.ones((tgt_height, tgt_width, 1), dtype=np.float32)], axis=-1) @ transform.T
+ uv_remap = pts[:, :, :2] / (pts[:, :, 2:3] + 1e-12)
+ pixel_remap = utils3d.numpy.uv_to_pixel(uv_remap, width=rescaled_w, height=rescaled_h).astype(np.float32)
+
+ tgt_image = cv2.remap(image, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LANCZOS4)
+ tgt_ray_length = norm3d(utils3d.numpy.unproject_cv(uv_tgt, np.ones_like(uv_tgt[:, :, 0]), intrinsics=tgt_intrinsics))
+ tgt_depth_mask_nearest = cv2.remap(depth_mask_nearest.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0
+ tgt_depth_nearest = cv2.remap(distance_nearest, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) / tgt_ray_length
+ tgt_edge_mask = cv2.remap(edge_mask.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0
+ if self.depth_interpolation == 'bilinear':
+ tgt_depth_mask_bilinear = cv2.remap(depth_mask_bilinear, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR)
+ tgt_depth_bilinear = cv2.remap(distance_bilinear, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR) / tgt_ray_length
+ tgt_depth = np.where((tgt_depth_mask_bilinear == 1) & ~tgt_edge_mask, tgt_depth_bilinear, tgt_depth_nearest)
+ else:
+ tgt_depth = tgt_depth_nearest
+ tgt_depth_mask = tgt_depth_mask_nearest
+
+ tgt_depth_mask_inf = cv2.remap(depth_mask_inf.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0
+
+ # always make sure that mask is not empty
+ if tgt_depth_mask.sum() / tgt_depth_mask.size < 0.001:
+ tgt_depth_mask = np.ones_like(tgt_depth_mask)
+ tgt_depth = np.ones_like(tgt_depth)
+ instance['label_type'] = 'invalid'
+
+ # Flip augmentation
+ if rng.choice([True, False]):
+ tgt_image = np.flip(tgt_image, axis=1).copy()
+ tgt_depth = np.flip(tgt_depth, axis=1).copy()
+ tgt_depth_mask = np.flip(tgt_depth_mask, axis=1).copy()
+ tgt_depth_mask_inf = np.flip(tgt_depth_mask_inf, axis=1).copy()
+
+ # Color augmentation
+ image_augmentation = self.datasets[instance['dataset']].get('image_augmentation', self.image_augmentation)
+ if 'jittering' in image_augmentation:
+ tgt_image = torch.from_numpy(tgt_image).permute(2, 0, 1)
+ tgt_image = TF.adjust_brightness(tgt_image, rng.uniform(0.7, 1.3))
+ tgt_image = TF.adjust_contrast(tgt_image, rng.uniform(0.7, 1.3))
+ tgt_image = TF.adjust_saturation(tgt_image, rng.uniform(0.7, 1.3))
+ tgt_image = TF.adjust_hue(tgt_image, rng.uniform(-0.1, 0.1))
+ tgt_image = TF.adjust_gamma(tgt_image, rng.uniform(0.7, 1.3))
+ tgt_image = tgt_image.permute(1, 2, 0).numpy()
+ if 'dof' in image_augmentation:
+ if rng.uniform() < 0.5:
+ dof_strength = rng.integers(12)
+ tgt_disp = np.where(tgt_depth_mask_inf, 0, 1 / tgt_depth)
+ disp_min, disp_max = tgt_disp[tgt_depth_mask].min(), tgt_disp[tgt_depth_mask].max()
+ tgt_disp = cv2.inpaint(tgt_disp, (~tgt_depth_mask & ~tgt_depth_mask_inf).astype(np.uint8), 3, cv2.INPAINT_TELEA).clip(disp_min, disp_max)
+ dof_focus = rng.uniform(disp_min, disp_max)
+ tgt_image = depth_of_field(tgt_image, tgt_disp, dof_focus, dof_strength)
+ if 'shot_noise' in image_augmentation:
+ if rng.uniform() < 0.5:
+ k = np.exp(rng.uniform(np.log(100), np.log(10000))) / 255
+ tgt_image = (rng.poisson(tgt_image * k) / k).clip(0, 255).astype(np.uint8)
+ if 'jpeg_loss' in image_augmentation:
+ if rng.uniform() < 0.5:
+ tgt_image = cv2.imdecode(cv2.imencode('.jpg', tgt_image, [cv2.IMWRITE_JPEG_QUALITY, rng.integers(20, 100)])[1], cv2.IMREAD_COLOR)
+ if 'blurring' in image_augmentation:
+ if rng.uniform() < 0.5:
+ ratio = rng.uniform(0.25, 1)
+ tgt_image = cv2.resize(cv2.resize(tgt_image, (int(tgt_width * ratio), int(tgt_height * ratio)), interpolation=cv2.INTER_AREA), (tgt_width, tgt_height), interpolation=rng.choice([cv2.INTER_LINEAR_EXACT, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]))
+
+ # convert depth to metric if necessary
+ if depth_unit is not None:
+ tgt_depth *= depth_unit
+ instance['is_metric'] = True
+ else:
+ instance['is_metric'] = False
+
+ # clamp depth maximum values
+ max_depth = np.nanquantile(np.where(tgt_depth_mask, tgt_depth, np.nan), 0.01) * self.clamp_max_depth
+ tgt_depth = np.clip(tgt_depth, 0, max_depth)
+ tgt_depth = np.nan_to_num(tgt_depth, nan=1.0)
+
+ if self.datasets[instance['dataset']].get('finite_depth_mask', None) == "only_known":
+ tgt_depth_mask_fin = tgt_depth_mask
+ else:
+ tgt_depth_mask_fin = ~tgt_depth_mask_inf
+
+ instance.update({
+ 'image': torch.from_numpy(tgt_image.astype(np.float32) / 255.0).permute(2, 0, 1),
+ 'depth': torch.from_numpy(tgt_depth).float(),
+ 'depth_mask': torch.from_numpy(tgt_depth_mask).bool(),
+ 'depth_mask_fin': torch.from_numpy(tgt_depth_mask_fin).bool(),
+ 'depth_mask_inf': torch.from_numpy(tgt_depth_mask_inf).bool(),
+ 'intrinsics': torch.from_numpy(tgt_intrinsics).float(),
+ })
+
+ return instance
+
+ def _collate_batch(self, instances: List[Dict[str, Any]]):
+ batch = {k: torch.stack([instance[k] for instance in instances], dim=0) for k in ['image', 'depth', 'depth_mask', 'depth_mask_fin', 'depth_mask_inf', 'intrinsics']}
+ batch = {
+ 'label_type': [instance['label_type'] for instance in instances],
+ 'is_metric': [instance['is_metric'] for instance in instances],
+ 'info': [{'dataset': instance['dataset'], 'filename': instance['filename']} for instance in instances],
+ **batch,
+ }
+ return batch
+
+ def get(self) -> Dict[str, Union[torch.Tensor, str]]:
+ return self.pipeline.get()
+
+ def start(self):
+ self.pipeline.start()
+
+ def stop(self):
+ self.pipeline.stop()
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.pipeline.terminate()
+ self.pipeline.join()
+ return False
+
+
diff --git a/models/moge/train/losses.py b/models/moge/train/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b251b230f4cc86d8358f613acf483badfb49e14
--- /dev/null
+++ b/models/moge/train/losses.py
@@ -0,0 +1,270 @@
+from typing import *
+import math
+
+import torch
+import torch.nn.functional as F
+import utils3d
+
+from ..utils.geometry_torch import (
+ weighted_mean,
+ harmonic_mean,
+ geometric_mean,
+ mask_aware_nearest_resize,
+ normalized_view_plane_uv,
+ angle_diff_vec3
+)
+from ..utils.alignment import (
+ align_points_scale_z_shift,
+ align_points_scale,
+ align_points_scale_xyz_shift,
+ align_points_z_shift,
+)
+
+
+def _smooth(err: torch.FloatTensor, beta: float = 0.0) -> torch.FloatTensor:
+ if beta == 0:
+ return err
+ else:
+ return torch.where(err < beta, 0.5 * err.square() / beta, err - 0.5 * beta)
+
+
+def affine_invariant_global_loss(
+ pred_points: torch.Tensor,
+ gt_points: torch.Tensor,
+ mask: torch.Tensor,
+ align_resolution: int = 64,
+ beta: float = 0.0,
+ trunc: float = 1.0,
+ sparsity_aware: bool = False
+):
+ device = pred_points.device
+
+ # Align
+ (pred_points_lr, gt_points_lr), lr_mask = mask_aware_nearest_resize((pred_points, gt_points), mask=mask, size=(align_resolution, align_resolution))
+ scale, shift = align_points_scale_z_shift(pred_points_lr.flatten(-3, -2), gt_points_lr.flatten(-3, -2), lr_mask.flatten(-2, -1) / gt_points_lr[..., 2].flatten(-2, -1).clamp_min(1e-2), trunc=trunc)
+ valid = scale > 0
+ scale, shift = torch.where(valid, scale, 0), torch.where(valid[..., None], shift, 0)
+
+ pred_points = scale[..., None, None, None] * pred_points + shift[..., None, None, :]
+
+ # Compute loss
+ weight = (valid[..., None, None] & mask).float() / gt_points[..., 2].clamp_min(1e-5)
+ weight = weight.clamp_max(10.0 * weighted_mean(weight, mask, dim=(-2, -1), keepdim=True)) # In case your data contains extremely small depth values
+ loss = _smooth((pred_points - gt_points).abs() * weight[..., None], beta=beta).mean(dim=(-3, -2, -1))
+
+ if sparsity_aware:
+ # Reweighting improves performance on sparse depth data. NOTE: this is not used in MoGe-1.
+ sparsity = mask.float().mean(dim=(-2, -1)) / lr_mask.float().mean(dim=(-2, -1))
+ loss = loss / (sparsity + 1e-7)
+
+ err = (pred_points.detach() - gt_points).norm(dim=-1) / gt_points[..., 2]
+
+ # Record any scalar metric
+ misc = {
+ 'truncated_error': weighted_mean(err.clamp_max(1.0), mask).item(),
+ 'delta': weighted_mean((err < 1).float(), mask).item()
+ }
+
+ return loss, misc, scale.detach()
+
+
+def monitoring(points: torch.Tensor):
+ return {
+ 'std': points.std().item(),
+ }
+
+
+def compute_anchor_sampling_weight(
+ points: torch.Tensor,
+ mask: torch.Tensor,
+ radius_2d: torch.Tensor,
+ radius_3d: torch.Tensor,
+ num_test: int = 64
+) -> torch.Tensor:
+ # Importance sampling to balance the sampled probability of fine strutures.
+ # NOTE: MoGe-1 uses uniform random sampling instead of importance sampling.
+ # This is an incremental trick introduced later than the publication of MoGe-1 paper.
+
+ height, width = points.shape[-3:-1]
+
+ pixel_i, pixel_j = torch.meshgrid(
+ torch.arange(height, device=points.device),
+ torch.arange(width, device=points.device),
+ indexing='ij'
+ )
+
+ test_delta_i = torch.randint(-radius_2d, radius_2d + 1, (height, width, num_test,), device=points.device) # [num_test]
+ test_delta_j = torch.randint(-radius_2d, radius_2d + 1, (height, width, num_test,), device=points.device) # [num_test]
+ test_i, test_j = pixel_i[..., None] + test_delta_i, pixel_j[..., None] + test_delta_j # [height, width, num_test]
+ test_mask = (test_i >= 0) & (test_i < height) & (test_j >= 0) & (test_j < width) # [height, width, num_test]
+ test_i, test_j = test_i.clamp(0, height - 1), test_j.clamp(0, width - 1) # [height, width, num_test]
+ test_mask = test_mask & mask[..., test_i, test_j] # [..., height, width, num_test]
+ test_points = points[..., test_i, test_j, :] # [..., height, width, num_test, 3]
+ test_dist = (test_points - points[..., None, :]).norm(dim=-1) # [..., height, width, num_test]
+
+ weight = 1 / ((test_dist <= radius_3d[..., None]) & test_mask).float().sum(dim=-1).clamp_min(1)
+ weight = torch.where(mask, weight, 0)
+ weight = weight / weight.sum(dim=(-2, -1), keepdim=True).add(1e-7) # [..., height, width]
+ return weight
+
+
+def affine_invariant_local_loss(
+ pred_points: torch.Tensor,
+ gt_points: torch.Tensor,
+ gt_mask: torch.Tensor,
+ focal: torch.Tensor,
+ global_scale: torch.Tensor,
+ level: Literal[4, 16, 64],
+ align_resolution: int = 32,
+ num_patches: int = 16,
+ beta: float = 0.0,
+ trunc: float = 1.0,
+ sparsity_aware: bool = False
+):
+ device, dtype = pred_points.device, pred_points.dtype
+ *batch_shape, height, width, _ = pred_points.shape
+ batch_size = math.prod(batch_shape)
+ pred_points, gt_points, gt_mask, focal, global_scale = pred_points.reshape(-1, height, width, 3), gt_points.reshape(-1, height, width, 3), gt_mask.reshape(-1, height, width), focal.reshape(-1), global_scale.reshape(-1) if global_scale is not None else None
+
+ # Sample patch anchor points indices [num_total_patches]
+ radius_2d = math.ceil(0.5 / level * (height ** 2 + width ** 2) ** 0.5)
+ radius_3d = 0.5 / level / focal * gt_points[..., 2]
+ anchor_sampling_weights = compute_anchor_sampling_weight(gt_points, gt_mask, radius_2d, radius_3d, num_test=64)
+ where_mask = torch.where(gt_mask)
+ random_selection = torch.multinomial(anchor_sampling_weights[where_mask], num_patches * batch_size, replacement=True)
+ patch_batch_idx, patch_anchor_i, patch_anchor_j = [indices[random_selection] for indices in where_mask] # [num_total_patches]
+
+ # Get patch indices [num_total_patches, patch_h, patch_w]
+ patch_i, patch_j = torch.meshgrid(
+ torch.arange(-radius_2d, radius_2d + 1, device=device),
+ torch.arange(-radius_2d, radius_2d + 1, device=device),
+ indexing='ij'
+ )
+ patch_i, patch_j = patch_i + patch_anchor_i[:, None, None], patch_j + patch_anchor_j[:, None, None]
+ patch_mask = (patch_i >= 0) & (patch_i < height) & (patch_j >= 0) & (patch_j < width)
+ patch_i, patch_j = patch_i.clamp(0, height - 1), patch_j.clamp(0, width - 1)
+
+ # Get patch mask and gt patch points
+ gt_patch_anchor_points = gt_points[patch_batch_idx, patch_anchor_i, patch_anchor_j]
+ gt_patch_radius_3d = 0.5 / level / focal[patch_batch_idx] * gt_patch_anchor_points[:, 2]
+ gt_patch_points = gt_points[patch_batch_idx[:, None, None], patch_i, patch_j]
+ gt_patch_dist = (gt_patch_points - gt_patch_anchor_points[:, None, None, :]).norm(dim=-1)
+ patch_mask &= gt_mask[patch_batch_idx[:, None, None], patch_i, patch_j]
+ patch_mask &= gt_patch_dist <= gt_patch_radius_3d[:, None, None]
+
+ # Pick only non-empty patches
+ MINIMUM_POINTS_PER_PATCH = 32
+ nonempty = torch.where(patch_mask.sum(dim=(-2, -1)) >= MINIMUM_POINTS_PER_PATCH)
+ num_nonempty_patches = nonempty[0].shape[0]
+ if num_nonempty_patches == 0:
+ return torch.tensor(0.0, dtype=dtype, device=device), {}
+
+ # Finalize all patch variables
+ patch_batch_idx, patch_i, patch_j = patch_batch_idx[nonempty], patch_i[nonempty], patch_j[nonempty]
+ patch_mask = patch_mask[nonempty] # [num_nonempty_patches, patch_h, patch_w]
+ gt_patch_points = gt_patch_points[nonempty] # [num_nonempty_patches, patch_h, patch_w, 3]
+ gt_patch_radius_3d = gt_patch_radius_3d[nonempty] # [num_nonempty_patches]
+ gt_patch_anchor_points = gt_patch_anchor_points[nonempty] # [num_nonempty_patches, 3]
+ pred_patch_points = pred_points[patch_batch_idx[:, None, None], patch_i, patch_j]
+
+ # Align patch points
+ (pred_patch_points_lr, gt_patch_points_lr), patch_lr_mask = mask_aware_nearest_resize((pred_patch_points, gt_patch_points), mask=patch_mask, size=(align_resolution, align_resolution))
+ local_scale, local_shift = align_points_scale_xyz_shift(pred_patch_points_lr.flatten(-3, -2), gt_patch_points_lr.flatten(-3, -2), patch_lr_mask.flatten(-2) / gt_patch_radius_3d[:, None].add(1e-7), trunc=trunc)
+ if global_scale is not None:
+ scale_differ = local_scale / global_scale[patch_batch_idx]
+ patch_valid = (scale_differ > 0.1) & (scale_differ < 10.0) & (global_scale > 0)
+ else:
+ patch_valid = local_scale > 0
+ local_scale, local_shift = torch.where(patch_valid, local_scale, 0), torch.where(patch_valid[:, None], local_shift, 0)
+ patch_mask &= patch_valid[:, None, None]
+
+ pred_patch_points = local_scale[:, None, None, None] * pred_patch_points + local_shift[:, None, None, :] # [num_patches_nonempty, patch_h, patch_w, 3]
+
+ # Compute loss
+ gt_mean = harmonic_mean(gt_points[..., 2], gt_mask, dim=(-2, -1))
+ patch_weight = patch_mask.float() / gt_patch_points[..., 2].clamp_min(0.1 * gt_mean[patch_batch_idx, None, None]) # [num_patches_nonempty, patch_h, patch_w]
+ loss = _smooth((pred_patch_points - gt_patch_points).abs() * patch_weight[..., None], beta=beta).mean(dim=(-3, -2, -1)) # [num_patches_nonempty]
+
+ if sparsity_aware:
+ # Reweighting improves performance on sparse depth data. NOTE: this is not used in MoGe-1.
+ sparsity = patch_mask.float().mean(dim=(-2, -1)) / patch_lr_mask.float().mean(dim=(-2, -1))
+ loss = loss / (sparsity + 1e-7)
+ loss = torch.scatter_reduce(torch.zeros(batch_size, dtype=dtype, device=device), dim=0, index=patch_batch_idx, src=loss, reduce='sum') / num_patches
+ loss = loss.reshape(batch_shape)
+
+ err = (pred_patch_points.detach() - gt_patch_points).norm(dim=-1) / gt_patch_radius_3d[..., None, None]
+
+ # Record any scalar metric
+ misc = {
+ 'truncated_error': weighted_mean(err.clamp_max(1), patch_mask).item(),
+ 'delta': weighted_mean((err < 1).float(), patch_mask).item()
+ }
+
+ return loss, misc
+
+def normal_loss(points: torch.Tensor, gt_points: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ device, dtype = points.device, points.dtype
+ height, width = points.shape[-3:-1]
+
+ leftup, rightup, leftdown, rightdown = points[..., :-1, :-1, :], points[..., :-1, 1:, :], points[..., 1:, :-1, :], points[..., 1:, 1:, :]
+ upxleft = torch.cross(rightup - rightdown, leftdown - rightdown, dim=-1)
+ leftxdown = torch.cross(leftup - rightup, rightdown - rightup, dim=-1)
+ downxright = torch.cross(leftdown - leftup, rightup - leftup, dim=-1)
+ rightxup = torch.cross(rightdown - leftdown, leftup - leftdown, dim=-1)
+
+ gt_leftup, gt_rightup, gt_leftdown, gt_rightdown = gt_points[..., :-1, :-1, :], gt_points[..., :-1, 1:, :], gt_points[..., 1:, :-1, :], gt_points[..., 1:, 1:, :]
+ gt_upxleft = torch.cross(gt_rightup - gt_rightdown, gt_leftdown - gt_rightdown, dim=-1)
+ gt_leftxdown = torch.cross(gt_leftup - gt_rightup, gt_rightdown - gt_rightup, dim=-1)
+ gt_downxright = torch.cross(gt_leftdown - gt_leftup, gt_rightup - gt_leftup, dim=-1)
+ gt_rightxup = torch.cross(gt_rightdown - gt_leftdown, gt_leftup - gt_leftdown, dim=-1)
+
+ mask_leftup, mask_rightup, mask_leftdown, mask_rightdown = mask[..., :-1, :-1], mask[..., :-1, 1:], mask[..., 1:, :-1], mask[..., 1:, 1:]
+ mask_upxleft = mask_rightup & mask_leftdown & mask_rightdown
+ mask_leftxdown = mask_leftup & mask_rightdown & mask_rightup
+ mask_downxright = mask_leftdown & mask_rightup & mask_leftup
+ mask_rightxup = mask_rightdown & mask_leftup & mask_leftdown
+
+ MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(1), math.radians(90), math.radians(3)
+
+ loss = mask_upxleft * _smooth(angle_diff_vec3(upxleft, gt_upxleft).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \
+ + mask_leftxdown * _smooth(angle_diff_vec3(leftxdown, gt_leftxdown).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \
+ + mask_downxright * _smooth(angle_diff_vec3(downxright, gt_downxright).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \
+ + mask_rightxup * _smooth(angle_diff_vec3(rightxup, gt_rightxup).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD)
+
+ loss = loss.mean() / (4 * max(points.shape[-3:-1]))
+
+ return loss, {}
+
+
+def edge_loss(points: torch.Tensor, gt_points: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ device, dtype = points.device, points.dtype
+ height, width = points.shape[-3:-1]
+
+ dx = points[..., :-1, :, :] - points[..., 1:, :, :]
+ dy = points[..., :, :-1, :] - points[..., :, 1:, :]
+
+ gt_dx = gt_points[..., :-1, :, :] - gt_points[..., 1:, :, :]
+ gt_dy = gt_points[..., :, :-1, :] - gt_points[..., :, 1:, :]
+
+ mask_dx = mask[..., :-1, :] & mask[..., 1:, :]
+ mask_dy = mask[..., :, :-1] & mask[..., :, 1:]
+
+ MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(0.1), math.radians(90), math.radians(3)
+
+ loss_dx = mask_dx * _smooth(angle_diff_vec3(dx, gt_dx).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD)
+ loss_dy = mask_dy * _smooth(angle_diff_vec3(dy, gt_dy).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD)
+ loss = (loss_dx.mean(dim=(-2, -1)) + loss_dy.mean(dim=(-2, -1))) / (2 * max(points.shape[-3:-1]))
+
+ return loss, {}
+
+
+def mask_l2_loss(pred_mask: torch.Tensor, gt_mask_pos: torch.Tensor, gt_mask_neg: torch.Tensor) -> torch.Tensor:
+ loss = gt_mask_neg.float() * pred_mask.square() + gt_mask_pos.float() * (1 - pred_mask).square()
+ loss = loss.mean(dim=(-2, -1))
+ return loss, {}
+
+
+def mask_bce_loss(pred_mask_prob: torch.Tensor, gt_mask_pos: torch.Tensor, gt_mask_neg: torch.Tensor) -> torch.Tensor:
+ loss = (gt_mask_pos | gt_mask_neg) * F.binary_cross_entropy(pred_mask_prob, gt_mask_pos.float(), reduction='none')
+ loss = loss.mean(dim=(-2, -1))
+ return loss, {}
diff --git a/models/moge/train/utils.py b/models/moge/train/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f21e00876b927991381bf2f777a68b02c5b38cc
--- /dev/null
+++ b/models/moge/train/utils.py
@@ -0,0 +1,57 @@
+from typing import *
+import fnmatch
+
+import sympy
+import torch
+import torch.nn as nn
+
+
+def any_match(s: str, patterns: List[str]) -> bool:
+ return any(fnmatch.fnmatch(s, pat) for pat in patterns)
+
+
+def build_optimizer(model: nn.Module, optimizer_config: Dict[str, Any]) -> torch.optim.Optimizer:
+ named_param_groups = [
+ {
+ k: p for k, p in model.named_parameters() if any_match(k, param_group_config['params']['include']) and not any_match(k, param_group_config['params'].get('exclude', []))
+ } for param_group_config in optimizer_config['params']
+ ]
+ excluded_params = [k for k, p in model.named_parameters() if p.requires_grad and not any(k in named_params for named_params in named_param_groups)]
+ assert len(excluded_params) == 0, f'The following parameters require grad but are excluded from the optimizer: {excluded_params}'
+ optimizer_cls = getattr(torch.optim, optimizer_config['type'])
+ optimizer = optimizer_cls([
+ {
+ **param_group_config,
+ 'params': list(params.values()),
+ } for param_group_config, params in zip(optimizer_config['params'], named_param_groups)
+ ])
+ return optimizer
+
+
+def parse_lr_lambda(s: str) -> Callable[[int], float]:
+ epoch = sympy.symbols('epoch')
+ lr_lambda = sympy.sympify(s)
+ return sympy.lambdify(epoch, lr_lambda, 'math')
+
+
+def build_lr_scheduler(optimizer: torch.optim.Optimizer, scheduler_config: Dict[str, Any]) -> torch.optim.lr_scheduler._LRScheduler:
+ if scheduler_config['type'] == "SequentialLR":
+ child_schedulers = [
+ build_lr_scheduler(optimizer, child_scheduler_config)
+ for child_scheduler_config in scheduler_config['params']['schedulers']
+ ]
+ return torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=child_schedulers, milestones=scheduler_config['params']['milestones'])
+ elif scheduler_config['type'] == "LambdaLR":
+ lr_lambda = scheduler_config['params']['lr_lambda']
+ if isinstance(lr_lambda, str):
+ lr_lambda = parse_lr_lambda(lr_lambda)
+ elif isinstance(lr_lambda, list):
+ lr_lambda = [parse_lr_lambda(l) for l in lr_lambda]
+ return torch.optim.lr_scheduler.LambdaLR(
+ optimizer,
+ lr_lambda=lr_lambda,
+ )
+ else:
+ scheduler_cls = getattr(torch.optim.lr_scheduler, scheduler_config['type'])
+ scheduler = scheduler_cls(optimizer, **scheduler_config.get('params', {}))
+ return scheduler
\ No newline at end of file
diff --git a/models/moge/utils/__init__.py b/models/moge/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/moge/utils/alignment.py b/models/moge/utils/alignment.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d6bb78766ec1a43a89a4fc931b64f70c5201e2d
--- /dev/null
+++ b/models/moge/utils/alignment.py
@@ -0,0 +1,416 @@
+from typing import *
+import math
+from collections import namedtuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.types
+import utils3d
+
+
+def scatter_min(size: int, dim: int, index: torch.LongTensor, src: torch.Tensor) -> torch.return_types.min:
+ "Scatter the minimum value along the given dimension of `input` into `src` at the indices specified in `index`."
+ shape = src.shape[:dim] + (size,) + src.shape[dim + 1:]
+ minimum = torch.full(shape, float('inf'), dtype=src.dtype, device=src.device).scatter_reduce(dim=dim, index=index, src=src, reduce='amin', include_self=False)
+ minimum_where = torch.where(src == torch.gather(minimum, dim=dim, index=index))
+ indices = torch.full(shape, -1, dtype=torch.long, device=src.device)
+ indices[(*minimum_where[:dim], index[minimum_where], *minimum_where[dim + 1:])] = minimum_where[dim]
+ return torch.return_types.min((minimum, indices))
+
+
+def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs):
+ batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0]
+ n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0)
+ splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args)
+ splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()}
+ results = []
+ for i in range(n_chunks):
+ chunk_args = tuple(arg[i] for arg in splited_args)
+ chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()}
+ results.append(fn(*chunk_args, **chunk_kwargs))
+
+ if isinstance(results[0], tuple):
+ return tuple(torch.cat(r, dim=0) for r in zip(*results))
+ else:
+ return torch.cat(results, dim=0)
+
+
+def _pad_inf(x_: torch.Tensor):
+ return torch.cat([torch.full_like(x_[..., :1], -torch.inf), x_, torch.full_like(x_[..., :1], torch.inf)], dim=-1)
+
+
+def _pad_cumsum(cumsum: torch.Tensor):
+ return torch.cat([torch.zeros_like(cumsum[..., :1]), cumsum, cumsum[..., -1:]], dim=-1)
+
+
+def _compute_residual(a: torch.Tensor, xyw: torch.Tensor, trunc: float):
+ return a.mul(xyw[..., 0]).sub_(xyw[..., 1]).abs_().mul_(xyw[..., 2]).clamp_max_(trunc).sum(dim=-1)
+
+
+def align(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, trunc: Optional[Union[float, torch.Tensor]] = None, eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
+ """
+ If trunc is None, solve `min sum_i w_i * |a * x_i - y_i|`, otherwise solve `min sum_i min(trunc, w_i * |a * x_i - y_i|)`.
+
+ w_i must be >= 0.
+
+ ### Parameters:
+ - `x`: tensor of shape (..., n)
+ - `y`: tensor of shape (..., n)
+ - `w`: tensor of shape (..., n)
+ - `trunc`: optional, float or tensor of shape (..., n) or None
+
+ ### Returns:
+ - `a`: tensor of shape (...), differentiable
+ - `loss`: tensor of shape (...), value of loss function at `a`, detached
+ - `index`: tensor of shape (...), where a = y[idx] / x[idx]
+ """
+ if trunc is None:
+ x, y, w = torch.broadcast_tensors(x, y, w)
+ sign = torch.sign(x)
+ x, y = x * sign, y * sign
+ y_div_x = y / x.clamp_min(eps)
+ y_div_x, argsort = y_div_x.sort(dim=-1)
+
+ wx = torch.gather(x * w, dim=-1, index=argsort)
+ derivatives = 2 * wx.cumsum(dim=-1) - wx.sum(dim=-1, keepdim=True)
+ search = torch.searchsorted(derivatives, torch.zeros_like(derivatives[..., :1]), side='left').clamp_max(derivatives.shape[-1] - 1)
+
+ a = y_div_x.gather(dim=-1, index=search).squeeze(-1)
+ index = argsort.gather(dim=-1, index=search).squeeze(-1)
+ loss = (w * (a[..., None] * x - y).abs()).sum(dim=-1)
+
+ else:
+ # Reshape to (batch_size, n) for simplicity
+ x, y, w = torch.broadcast_tensors(x, y, w)
+ batch_shape = x.shape[:-1]
+ batch_size = math.prod(batch_shape)
+ x, y, w = x.reshape(-1, x.shape[-1]), y.reshape(-1, y.shape[-1]), w.reshape(-1, w.shape[-1])
+
+ sign = torch.sign(x)
+ x, y = x * sign, y * sign
+ wx, wy = w * x, w * y
+ xyw = torch.stack([x, y, w], dim=-1) # Stacked for convenient gathering
+
+ y_div_x = A = y / x.clamp_min(eps)
+ B = (wy - trunc) / wx.clamp_min(eps)
+ C = (wy + trunc) / wx.clamp_min(eps)
+ with torch.no_grad():
+ # Caculate prefix sum by orders of A, B, C
+ A, A_argsort = A.sort(dim=-1)
+ Q_A = torch.cumsum(torch.gather(wx, dim=-1, index=A_argsort), dim=-1)
+ A, Q_A = _pad_inf(A), _pad_cumsum(Q_A) # Pad [-inf, A1, ..., An, inf] and [0, Q1, ..., Qn, Qn] to handle edge cases.
+
+ B, B_argsort = B.sort(dim=-1)
+ Q_B = torch.cumsum(torch.gather(wx, dim=-1, index=B_argsort), dim=-1)
+ B, Q_B = _pad_inf(B), _pad_cumsum(Q_B)
+
+ C, C_argsort = C.sort(dim=-1)
+ Q_C = torch.cumsum(torch.gather(wx, dim=-1, index=C_argsort), dim=-1)
+ C, Q_C = _pad_inf(C), _pad_cumsum(Q_C)
+
+ # Caculate left and right derivative of A
+ j_A = torch.searchsorted(A, y_div_x, side='left').sub_(1)
+ j_B = torch.searchsorted(B, y_div_x, side='left').sub_(1)
+ j_C = torch.searchsorted(C, y_div_x, side='left').sub_(1)
+ left_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
+ j_A = torch.searchsorted(A, y_div_x, side='right').sub_(1)
+ j_B = torch.searchsorted(B, y_div_x, side='right').sub_(1)
+ j_C = torch.searchsorted(C, y_div_x, side='right').sub_(1)
+ right_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
+
+ # Find extrema
+ is_extrema = (left_derivative < 0) & (right_derivative >= 0)
+ is_extrema[..., 0] |= ~is_extrema.any(dim=-1) # In case all derivatives are zero, take the first one as extrema.
+ where_extrema_batch, where_extrema_index = torch.where(is_extrema)
+
+ # Calculate objective value at extrema
+ extrema_a = y_div_x[where_extrema_batch, where_extrema_index] # (num_extrema,)
+ MAX_ELEMENTS = 4096 ** 2 # Split into small batches to avoid OOM in case there are too many extrema.(~1G)
+ SPLIT_SIZE = MAX_ELEMENTS // x.shape[-1]
+ extrema_value = torch.cat([
+ _compute_residual(extrema_a_split[:, None], xyw[extrema_i_split, :, :], trunc)
+ for extrema_a_split, extrema_i_split in zip(extrema_a.split(SPLIT_SIZE), where_extrema_batch.split(SPLIT_SIZE))
+ ]) # (num_extrema,)
+
+ # Find minima among corresponding extrema
+ minima, indices = scatter_min(size=batch_size, dim=0, index=where_extrema_batch, src=extrema_value) # (batch_size,)
+ index = where_extrema_index[indices]
+
+ a = torch.gather(y, dim=-1, index=index[..., None]) / torch.gather(x, dim=-1, index=index[..., None]).clamp_min(eps)
+ a = a.reshape(batch_shape)
+ loss = minima.reshape(batch_shape)
+ index = index.reshape(batch_shape)
+
+ return a, loss, index
+
+
+def align_depth_scale(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
+ """
+ Align `depth_src` to `depth_tgt` with given constant weights.
+
+ ### Parameters:
+ - `depth_src: torch.Tensor` of shape (..., N)
+ - `depth_tgt: torch.Tensor` of shape (..., N)
+
+ """
+ scale, _, _ = align(depth_src, depth_tgt, weight, trunc)
+
+ return scale
+
+
+def align_depth_affine(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
+ """
+ Align `depth_src` to `depth_tgt` with given constant weights.
+
+ ### Parameters:
+ - `depth_src: torch.Tensor` of shape (..., N)
+ - `depth_tgt: torch.Tensor` of shape (..., N)
+ - `weight: torch.Tensor` of shape (..., N)
+ - `trunc: float` or tensor of shape (..., N) or None
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (...).
+ """
+ dtype, device = depth_src.dtype, depth_src.device
+
+ # Flatten batch dimensions for simplicity
+ batch_shape, n = depth_src.shape[:-1], depth_src.shape[-1]
+ batch_size = math.prod(batch_shape)
+ depth_src, depth_tgt, weight = depth_src.reshape(batch_size, n), depth_tgt.reshape(batch_size, n), weight.reshape(batch_size, n)
+
+ # Here, we take anchors only for non-zero weights.
+ # Although the results will be still correct even anchor points have zero weight,
+ # it is wasting computation and may cause instability in some cases, e.g. too many extrema.
+ anchors_where_batch, anchors_where_n = torch.where(weight > 0)
+
+ # Stop gradient when solving optimal anchors
+ with torch.no_grad():
+ depth_src_anchor = depth_src[anchors_where_batch, anchors_where_n] # (anchors)
+ depth_tgt_anchor = depth_tgt[anchors_where_batch, anchors_where_n] # (anchors)
+
+ depth_src_anchored = depth_src[anchors_where_batch, :] - depth_src_anchor[..., None] # (anchors, n)
+ depth_tgt_anchored = depth_tgt[anchors_where_batch, :] - depth_tgt_anchor[..., None] # (anchors, n)
+ weight_anchored = weight[anchors_where_batch, :] # (anchors, n)
+
+ scale, loss, index = align(depth_src_anchored, depth_tgt_anchored, weight_anchored, trunc) # (anchors)
+
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchors_where_batch, src=loss) # (batch_size,)
+
+ # Reproduce by indexing for shorter compute graph
+ index_1 = anchors_where_n[index_anchor] # (batch_size,)
+ index_2 = index[index_anchor] # (batch_size,)
+
+ tgt_1, src_1 = torch.gather(depth_tgt, dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_1[..., None]).squeeze(-1)
+ tgt_2, src_2 = torch.gather(depth_tgt, dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_2[..., None]).squeeze(-1)
+
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1e-7)
+ shift = tgt_1 - scale * src_1
+
+ scale, shift = scale.reshape(batch_shape), shift.reshape(batch_shape)
+
+ return scale, shift
+
+def align_depth_affine_irls(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], max_iter: int = 100, eps: float = 1e-12):
+ """
+ Align `depth_src` to `depth_tgt` with given constant weights using IRLS.
+ """
+ dtype, device = depth_src.dtype, depth_src.device
+
+ w = weight
+ x = torch.stack([depth_src, torch.ones_like(depth_src)], dim=-1)
+ y = depth_tgt
+
+ for i in range(max_iter):
+ beta = (x.transpose(-1, -2) @ (w * y)) @ (x.transpose(-1, -2) @ (w[..., None] * x)).inverse().transpose(-2, -1)
+ w = 1 / (y - (x @ beta[..., None])[..., 0]).abs().clamp_min(eps)
+
+ return beta[..., 0], beta[..., 1]
+
+
+def align_points_scale(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
+ """
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weight: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `a: torch.Tensor` of shape (...). Only positive solutions are garunteed. You should filter out negative scales before using it.
+ - `b: torch.Tensor` of shape (...)
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ scale, _, _ = align(points_src.flatten(-2), points_tgt.flatten(-2), weight[..., None].expand_as(points_src).flatten(-2), trunc)
+
+ return scale
+
+
+def align_points_scale_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
+ """
+ Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
+ It is similar to `align_affine` but scale and shift are applied to different dimensions.
+
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weights: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (..., 3). x and y shifts are zeros.
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ # Flatten batch dimensions for simplicity
+ batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
+ batch_size = math.prod(batch_shape)
+ points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
+
+ # Take anchors
+ anchor_where_batch, anchor_where_n = torch.where(weight > 0)
+ with torch.no_grad():
+ zeros = torch.zeros(anchor_where_batch.shape[0], device=device, dtype=dtype)
+ points_src_anchor = torch.stack([zeros, zeros, points_src[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
+ points_tgt_anchor = torch.stack([zeros, zeros, points_tgt[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
+
+ points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
+ points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
+ weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
+
+ # Solve optimal scale and shift for each anchor
+ MAX_ELEMENTS = 2 ** 20
+ scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // n, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
+
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
+
+ # Reproduce by indexing for shorter compute graph
+ index_2 = index[index_anchor] # (batch_size,) [0, 3n)
+ index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
+
+ zeros = torch.zeros((batch_size, n), device=device, dtype=dtype)
+ points_tgt_00z, points_src_00z = torch.stack([zeros, zeros, points_tgt[..., 2]], dim=-1), torch.stack([zeros, zeros, points_src[..., 2]], dim=-1)
+ tgt_1, src_1 = torch.gather(points_tgt_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_src_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
+ tgt_2, src_2 = torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
+
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
+ shift = torch.gather(points_tgt_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
+ scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
+
+ return scale, shift
+
+
+def align_points_scale_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
+ """
+ Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
+ It is similar to `align_affine` but scale and shift are applied to different dimensions.
+
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weights: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (..., 3)
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ # Flatten batch dimensions for simplicity
+ batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
+ batch_size = math.prod(batch_shape)
+ points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
+
+ # Take anchors
+ anchor_where_batch, anchor_where_n = torch.where(weight > 0)
+
+ with torch.no_grad():
+ points_src_anchor = points_src[anchor_where_batch, anchor_where_n] # (anchors, 3)
+ points_tgt_anchor = points_tgt[anchor_where_batch, anchor_where_n] # (anchors, 3)
+
+ points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
+ points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
+ weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
+
+ # Solve optimal scale and shift for each anchor
+ MAX_ELEMENTS = 2 ** 20
+ scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // 2, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
+
+ # Get optimal scale and shift for each batch element
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
+
+ index_2 = index[index_anchor] # (batch_size,) [0, 3n)
+ index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
+
+ src_1, tgt_1 = torch.gather(points_src.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
+ src_2, tgt_2 = torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
+
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
+ shift = torch.gather(points_tgt, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
+
+ scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
+
+ return scale, shift
+
+
+def align_points_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
+ """
+ Align `points_src` to `points_tgt` with respect to a Z-axis shift.
+
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weights: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (..., 3)
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ shift, _, _ = align(torch.ones_like(points_src[..., 2]), points_tgt[..., 2] - points_src[..., 2], weight, trunc)
+ shift = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)
+
+ return shift
+
+
+def align_points_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
+ """
+ Align `points_src` to `points_tgt` with respect to a Z-axis shift.
+
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weights: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (..., 3)
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ shift, _, _ = align(torch.ones_like(points_src).swapaxes(-2, -1), (points_tgt - points_src).swapaxes(-2, -1), weight[..., None, :], trunc)
+
+ return shift
+
+
+def align_affine_lstsq(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Solve `min sum_i w_i * (a * x_i + b - y_i ) ^ 2`, where `a` and `b` are scalars, with respect to `a` and `b` using least squares.
+
+ ### Parameters:
+ - `x: torch.Tensor` of shape (..., N)
+ - `y: torch.Tensor` of shape (..., N)
+ - `w: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `a: torch.Tensor` of shape (...,)
+ - `b: torch.Tensor` of shape (...,)
+ """
+ w_sqrt = torch.ones_like(x) if w is None else w.sqrt()
+ A = torch.stack([w_sqrt * x, torch.ones_like(x)], dim=-1)
+ B = (w_sqrt * y)[..., None]
+ a, b = torch.linalg.lstsq(A, B)[0].squeeze(-1).unbind(-1)
+ return a, b
\ No newline at end of file
diff --git a/models/moge/utils/download.py b/models/moge/utils/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..886edbccc81cc0c3daed4d858f641097bdfceee2
--- /dev/null
+++ b/models/moge/utils/download.py
@@ -0,0 +1,55 @@
+from pathlib import Path
+from typing import *
+import requests
+
+from tqdm import tqdm
+
+
+__all__ = ["download_file", "download_bytes"]
+
+
+def download_file(url: str, filepath: Union[str, Path], headers: dict = None, resume: bool = True) -> None:
+ # Ensure headers is a dict if not provided
+ headers = headers or {}
+
+ # Initialize local variables
+ file_path = Path(filepath)
+ downloaded_bytes = 0
+
+ # Check if we should resume the download
+ if resume and file_path.exists():
+ downloaded_bytes = file_path.stat().st_size
+ headers['Range'] = f"bytes={downloaded_bytes}-"
+
+ # Make a GET request to fetch the file
+ with requests.get(url, stream=True, headers=headers) as response:
+ response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx
+
+ # Calculate the total size to download
+ total_size = downloaded_bytes + int(response.headers.get('content-length', 0))
+
+ # Display a progress bar while downloading
+ with (
+ tqdm(desc=f"Downloading {file_path.name}", total=total_size, unit='B', unit_scale=True, leave=False) as pbar,
+ open(file_path, 'ab') as file,
+ ):
+ # Set the initial position of the progress bar
+ pbar.update(downloaded_bytes)
+
+ # Write the content to the file in chunks
+ for chunk in response.iter_content(chunk_size=4096):
+ file.write(chunk)
+ pbar.update(len(chunk))
+
+
+def download_bytes(url: str, headers: dict = None) -> bytes:
+ # Ensure headers is a dict if not provided
+ headers = headers or {}
+
+ # Make a GET request to fetch the file
+ with requests.get(url, stream=True, headers=headers) as response:
+ response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx
+
+ # Read the content of the response
+ return response.content
+
\ No newline at end of file
diff --git a/models/moge/utils/geometry_numpy.py b/models/moge/utils/geometry_numpy.py
new file mode 100644
index 0000000000000000000000000000000000000000..6975471e9fb7443d5a615a47de94d49841c789e1
--- /dev/null
+++ b/models/moge/utils/geometry_numpy.py
@@ -0,0 +1,406 @@
+from typing import *
+from functools import partial
+import math
+
+import cv2
+import numpy as np
+from scipy.signal import fftconvolve
+import numpy as np
+import utils3d
+
+from .tools import timeit
+
+
+def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
+ if w is None:
+ return np.mean(x, axis=axis)
+ else:
+ w = w.astype(x.dtype)
+ return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None)
+
+
+def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
+ if w is None:
+ return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis)
+ else:
+ w = w.astype(x.dtype)
+ return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps)
+
+
+def normalized_view_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray:
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
+ if aspect_ratio is None:
+ aspect_ratio = width / height
+
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
+
+ u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype)
+ v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype)
+ u, v = np.meshgrid(u, v, indexing='xy')
+ uv = np.stack([u, v], axis=-1)
+ return uv
+
+
+def focal_to_fov_numpy(focal: np.ndarray):
+ return 2 * np.arctan(0.5 / focal)
+
+
+def fov_to_focal_numpy(fov: np.ndarray):
+ return 0.5 / np.tan(fov / 2)
+
+
+def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0])
+ fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1])
+ return fov_x, fov_y
+
+
+def point_map_to_depth_legacy_numpy(points: np.ndarray):
+ height, width = points.shape[-3:-1]
+ diagonal = (height ** 2 + width ** 2) ** 0.5
+ uv = normalized_view_plane_uv_numpy(width, height, dtype=points.dtype) # (H, W, 2)
+ _, uv = np.broadcast_arrays(points[..., :2], uv)
+
+ # Solve least squares problem
+ b = (uv * points[..., 2:]).reshape(*points.shape[:-3], -1) # (..., H * W * 2)
+ A = np.stack([points[..., :2], -uv], axis=-1).reshape(*points.shape[:-3], -1, 2) # (..., H * W * 2, 2)
+
+ M = A.swapaxes(-2, -1) @ A
+ solution = (np.linalg.inv(M + 1e-6 * np.eye(2)) @ (A.swapaxes(-2, -1) @ b[..., None])).squeeze(-1)
+ focal, shift = solution
+
+ depth = points[..., 2] + shift[..., None, None]
+ fov_x = np.arctan(width / diagonal / focal) * 2
+ fov_y = np.arctan(height / diagonal / focal) * 2
+ return depth, fov_x, fov_y, shift
+
+
+def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray):
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
+ from scipy.optimize import least_squares
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
+
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
+ xy_proj = xy / (z + shift)[: , None]
+ f = (xy_proj * uv).sum() / np.square(xy_proj).sum()
+ err = (f * xy_proj - uv).ravel()
+ return err
+
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
+ optim_shift = solution['x'].squeeze().astype(np.float32)
+
+ xy_proj = xy / (z + optim_shift)[: , None]
+ optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum()
+
+ return optim_shift, optim_focal
+
+
+def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float):
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift"
+ from scipy.optimize import least_squares
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
+
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
+ xy_proj = xy / (z + shift)[: , None]
+ err = (focal * xy_proj - uv).ravel()
+ return err
+
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
+ optim_shift = solution['x'].squeeze().astype(np.float32)
+
+ return optim_shift
+
+
+def recover_focal_shift_numpy(points: np.ndarray, mask: np.ndarray = None, focal: float = None, downsample_size: Tuple[int, int] = (64, 64)):
+ import cv2
+ assert points.shape[-1] == 3, "Points should (H, W, 3)"
+
+ height, width = points.shape[-3], points.shape[-2]
+ diagonal = (height ** 2 + width ** 2) ** 0.5
+
+ uv = normalized_view_plane_uv_numpy(width=width, height=height)
+
+ if mask is None:
+ points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3)
+ uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2)
+ else:
+ (points_lr, uv_lr), mask_lr = mask_aware_nearest_resize_numpy((points, uv), mask, downsample_size)
+
+ if points_lr.size < 2:
+ return 1., 0.
+
+ if focal is None:
+ focal, shift = solve_optimal_focal_shift(uv_lr, points_lr)
+ else:
+ shift = solve_optimal_shift(uv_lr, points_lr, focal)
+
+ return focal, shift
+
+
+def mask_aware_nearest_resize_numpy(
+ inputs: Union[np.ndarray, Tuple[np.ndarray, ...], None],
+ mask: np.ndarray,
+ size: Tuple[int, int],
+ return_index: bool = False
+) -> Tuple[Union[np.ndarray, Tuple[np.ndarray, ...], None], np.ndarray, Tuple[np.ndarray, ...]]:
+ """
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
+
+ ### Parameters
+ - `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...).
+ - `mask`: input 2D mask of shape (..., H, W)
+ - `size`: target size (width, height)
+
+ ### Returns
+ - `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...).
+ - `resized_mask`: mask of the resized map of shape (..., target_height, target_width)
+ - `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension.
+ """
+ height, width = mask.shape[-2:]
+ target_width, target_height = size
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
+ filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
+ filter_size = filter_h_i * filter_w_i
+ padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
+
+ # Window the original mask and uv
+ uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
+ indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
+ padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
+ padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
+ padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
+ windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
+ windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
+ windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
+
+ # Gather the target pixels's local window
+ target_centers = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
+ target_lefttop = target_centers - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
+ target_window = np.round(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
+
+ target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
+ target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
+ target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(*([-1] * (mask.ndim - 2)), target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
+
+ # Compute nearest neighbor in the local window for each pixel
+ dist = np.square(target_window_centers - target_centers[..., None])
+ dist = dist[..., 0, :] + dist[..., 1, :]
+ dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size)
+ nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1)
+ nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width)
+ nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
+ target_mask = np.any(target_window_mask, axis=-1)
+ batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])]
+
+ index = (*batch_indices, nearest_i, nearest_j)
+
+ if inputs is None:
+ outputs = None
+ elif isinstance(inputs, np.ndarray):
+ outputs = inputs[index]
+ elif isinstance(inputs, Sequence):
+ outputs = tuple(x[index] for x in inputs)
+ else:
+ raise ValueError(f'Invalid input type: {type(inputs)}')
+
+ if return_index:
+ return outputs, target_mask, index
+ else:
+ return outputs, target_mask
+
+
+def mask_aware_area_resize_numpy(image: np.ndarray, mask: np.ndarray, target_width: int, target_height: int) -> Tuple[Tuple[np.ndarray, ...], np.ndarray]:
+ """
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
+
+ ### Parameters
+ - `image`: Input 2D image of shape (..., H, W, C)
+ - `mask`: Input 2D mask of shape (..., H, W)
+ - `target_width`: target width of the resized map
+ - `target_height`: target height of the resized map
+
+ ### Returns
+ - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width).
+ - `target_mask`: Mask of the resized map of shape (..., target_height, target_width)
+ """
+ height, width = mask.shape[-2:]
+
+ if image.shape[-2:] == (height, width):
+ omit_channel_dim = True
+ else:
+ omit_channel_dim = False
+ if omit_channel_dim:
+ image = image[..., None]
+
+ image = np.where(mask[..., None], image, 0)
+
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
+ filter_h_i, filter_w_i = math.ceil(filter_h_f) + 1, math.ceil(filter_w_f) + 1
+ filter_size = filter_h_i * filter_w_i
+ padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
+
+ # Window the original mask and uv (non-copy)
+ uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
+ indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
+ padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
+ padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
+ padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
+ windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
+ windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
+ windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
+
+ # Gather the target pixels's local window
+ target_center = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
+ target_lefttop = target_center - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
+ target_bottomright = target_center + np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
+ target_window = np.floor(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
+
+ target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
+ target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
+ target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
+
+ # Compute pixel area in the local windows
+ target_window_lefttop = np.maximum(target_window_centers - 0.5, target_lefttop[..., None])
+ target_window_bottomright = np.minimum(target_window_centers + 0.5, target_bottomright[..., None])
+ target_window_area = (target_window_bottomright - target_window_lefttop).clip(0, None)
+ target_window_area = np.where(target_window_mask, target_window_area[..., 0, :] * target_window_area[..., 1, :], 0)
+
+ # Weighted sum by area
+ target_window_image = image.reshape(*image.shape[:-3], height * width, -1)[..., target_window_indices, :].swapaxes(-2, -1)
+ target_mask = np.sum(target_window_area, axis=-1) >= 0.25
+ target_image = weighted_mean_numpy(target_window_image, target_window_area[..., None, :], axis=-1)
+
+ if omit_channel_dim:
+ target_image = target_image[..., 0]
+
+ return target_image, target_mask
+
+
+def norm3d(x: np.ndarray) -> np.ndarray:
+ "Faster `np.linalg.norm(x, axis=-1)` for 3D vectors"
+ return np.sqrt(np.square(x[..., 0]) + np.square(x[..., 1]) + np.square(x[..., 2]))
+
+
+def depth_occlusion_edge_numpy(depth: np.ndarray, mask: np.ndarray, thickness: int = 1, tol: float = 0.1):
+ disp = np.where(mask, 1 / depth, 0)
+ disp_pad = np.pad(disp, (thickness, thickness), constant_values=0)
+ mask_pad = np.pad(mask, (thickness, thickness), constant_values=False)
+ kernel_size = 2 * thickness + 1
+ disp_window = utils3d.numpy.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2]
+ mask_window = utils3d.numpy.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2]
+
+ disp_mean = weighted_mean_numpy(disp_window, mask_window, axis=(-2, -1))
+ fg_edge_mask = mask & (disp > (1 + tol) * disp_mean)
+ bg_edge_mask = mask & (disp_mean > (1 + tol) * disp)
+
+ edge_mask = (cv2.dilate(fg_edge_mask.astype(np.uint8), np.ones((3, 3), dtype=np.uint8), iterations=thickness) > 0) \
+ & (cv2.dilate(bg_edge_mask.astype(np.uint8), np.ones((3, 3), dtype=np.uint8), iterations=thickness) > 0)
+
+ return edge_mask
+
+
+def disk_kernel(radius: int) -> np.ndarray:
+ """
+ Generate disk kernel with given radius.
+
+ Args:
+ radius (int): Radius of the disk (in pixels).
+
+ Returns:
+ np.ndarray: (2*radius+1, 2*radius+1) normalized convolution kernel.
+ """
+ # Create coordinate grid centered at (0,0)
+ L = np.arange(-radius, radius + 1)
+ X, Y = np.meshgrid(L, L)
+ # Generate disk: region inside circle with radius R is 1
+ kernel = ((X**2 + Y**2) <= radius**2).astype(np.float32)
+ # Normalize the kernel
+ kernel /= np.sum(kernel)
+ return kernel
+
+
+def disk_blur(image: np.ndarray, radius: int) -> np.ndarray:
+ """
+ Apply disk blur to an image using FFT convolution.
+
+ Args:
+ image (np.ndarray): Input image, can be grayscale or color.
+ radius (int): Blur radius (in pixels).
+
+ Returns:
+ np.ndarray: Blurred image.
+ """
+ if radius == 0:
+ return image
+ kernel = disk_kernel(radius)
+ if image.ndim == 2:
+ blurred = fftconvolve(image, kernel, mode='same')
+ elif image.ndim == 3:
+ channels = []
+ for i in range(image.shape[2]):
+ blurred_channel = fftconvolve(image[..., i], kernel, mode='same')
+ channels.append(blurred_channel)
+ blurred = np.stack(channels, axis=-1)
+ else:
+ raise ValueError("Image must be 2D or 3D.")
+ return blurred
+
+
+def depth_of_field(
+ img: np.ndarray,
+ disp: np.ndarray,
+ focus_disp : float,
+ max_blur_radius : int = 10,
+) -> np.ndarray:
+ """
+ Apply depth of field effect to an image.
+
+ Args:
+ img (numpy.ndarray): (H, W, 3) input image.
+ depth (numpy.ndarray): (H, W) depth map of the scene.
+ focus_depth (float): Focus depth of the lens.
+ strength (float): Strength of the depth of field effect.
+ max_blur_radius (int): Maximum blur radius (in pixels).
+
+ Returns:
+ numpy.ndarray: (H, W, 3) output image with depth of field effect applied.
+ """
+ # Precalculate dialated depth map for each blur radius
+ max_disp = np.max(disp)
+ disp = disp / max_disp
+ focus_disp = focus_disp / max_disp
+ dilated_disp = []
+ for radius in range(max_blur_radius + 1):
+ dilated_disp.append(cv2.dilate(disp, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*radius+1, 2*radius+1)), iterations=1))
+
+ # Determine the blur radius for each pixel based on the depth map
+ blur_radii = np.clip(abs(disp - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
+ for radius in range(max_blur_radius + 1):
+ dialted_blur_radii = np.clip(abs(dilated_disp[radius] - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
+ mask = (dialted_blur_radii >= radius) & (dialted_blur_radii >= blur_radii) & (dilated_disp[radius] > disp)
+ blur_radii[mask] = dialted_blur_radii[mask]
+ blur_radii = np.clip(blur_radii, 0, max_blur_radius)
+ blur_radii = cv2.blur(blur_radii, (5, 5))
+
+ # Precalculate the blured image for each blur radius
+ unique_radii = np.unique(blur_radii)
+ precomputed = {}
+ for radius in range(max_blur_radius + 1):
+ if radius not in unique_radii:
+ continue
+ precomputed[radius] = disk_blur(img, radius)
+
+ # Composit the blured image for each pixel
+ output = np.zeros_like(img)
+ for r in unique_radii:
+ mask = blur_radii == r
+ output[mask] = precomputed[r][mask]
+
+ return output
diff --git a/models/moge/utils/geometry_torch.py b/models/moge/utils/geometry_torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab5dbe965a42d0e0b3cbe53eb213bdcb829f8243
--- /dev/null
+++ b/models/moge/utils/geometry_torch.py
@@ -0,0 +1,354 @@
+from typing import *
+import math
+from collections import namedtuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.types
+import utils3d
+
+from .tools import timeit
+from .geometry_numpy import solve_optimal_focal_shift, solve_optimal_shift
+
+
+def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
+ if w is None:
+ return x.mean(dim=dim, keepdim=keepdim)
+ else:
+ w = w.to(x.dtype)
+ return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps)
+
+
+def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
+ if w is None:
+ return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal()
+ else:
+ w = w.to(x.dtype)
+ return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal()
+
+
+def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
+ if w is None:
+ return x.add(eps).log().mean(dim=dim).exp()
+ else:
+ w = w.to(x.dtype)
+ return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp()
+
+
+def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
+ if aspect_ratio is None:
+ aspect_ratio = width / height
+
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
+
+ u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
+ v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
+ u, v = torch.meshgrid(u, v, indexing='xy')
+ uv = torch.stack([u, v], dim=-1)
+ return uv
+
+
+def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
+ kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2))
+ kernel = kernel / kernel.sum()
+ kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size)
+ input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate')
+ input = F.conv2d(input, kernel, groups=input.shape[1])
+ return input
+
+
+def focal_to_fov(focal: torch.Tensor):
+ return 2 * torch.atan(0.5 / focal)
+
+
+def fov_to_focal(fov: torch.Tensor):
+ return 0.5 / torch.tan(fov / 2)
+
+
+def angle_diff_vec3(v1: torch.Tensor, v2: torch.Tensor, eps: float = 1e-12):
+ return torch.atan2(torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps, (v1 * v2).sum(dim=-1))
+
+def intrinsics_to_fov(intrinsics: torch.Tensor):
+ """
+ Returns field of view in radians from normalized intrinsics matrix.
+ ### Parameters:
+ - intrinsics: torch.Tensor of shape (..., 3, 3)
+
+ ### Returns:
+ - fov_x: torch.Tensor of shape (...)
+ - fov_y: torch.Tensor of shape (...)
+ """
+ focal_x = intrinsics[..., 0, 0]
+ focal_y = intrinsics[..., 1, 1]
+ return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y)
+
+
+def point_map_to_depth_legacy(points: torch.Tensor):
+ height, width = points.shape[-3:-1]
+ diagonal = (height ** 2 + width ** 2) ** 0.5
+ uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
+
+ # Solve least squares problem
+ b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2)
+ A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2)
+
+ M = A.transpose(-2, -1) @ A
+ solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1)
+ focal, shift = solution.unbind(-1)
+
+ depth = points[..., 2] + shift[..., None, None]
+ fov_x = torch.atan(width / diagonal / focal) * 2
+ fov_y = torch.atan(height / diagonal / focal) * 2
+ return depth, fov_x, fov_y, shift
+
+
+def view_plane_uv_to_focal(uv: torch.Tensor):
+ normed_uv = normalized_view_plane_uv(width=uv.shape[-2], height=uv.shape[-3], device=uv.device, dtype=uv.dtype)
+ focal = (uv * normed_uv).sum() / uv.square().sum().add(1e-12)
+ return focal
+
+
+def recover_focal_shift(points: torch.Tensor, mask: torch.Tensor = None, focal: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)):
+ """
+ Recover the depth map and FoV from a point map with unknown z shift and focal.
+
+ Note that it assumes:
+ - the optical center is at the center of the map
+ - the map is undistorted
+ - the map is isometric in the x and y directions
+
+ ### Parameters:
+ - `points: torch.Tensor` of shape (..., H, W, 3)
+ - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
+
+ ### Returns:
+ - `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map
+ - `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space
+ """
+ shape = points.shape
+ height, width = points.shape[-3], points.shape[-2]
+ diagonal = (height ** 2 + width ** 2) ** 0.5
+
+ points = points.reshape(-1, *shape[-3:])
+ mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
+ focal = focal.reshape(-1) if focal is not None else None
+ uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
+
+ points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1)
+ uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0)
+ mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0
+
+ uv_lr_np = uv_lr.cpu().numpy()
+ points_lr_np = points_lr.detach().cpu().numpy()
+ focal_np = focal.cpu().numpy() if focal is not None else None
+ mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
+ optim_shift, optim_focal = [], []
+ for i in range(points.shape[0]):
+ points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
+ uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
+ if uv_lr_i_np.shape[0] < 2:
+ optim_focal.append(1)
+ optim_shift.append(0)
+ continue
+ if focal is None:
+ optim_shift_i, optim_focal_i = solve_optimal_focal_shift(uv_lr_i_np, points_lr_i_np)
+ optim_focal.append(float(optim_focal_i))
+ else:
+ optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i])
+ optim_shift.append(float(optim_shift_i))
+ optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])
+
+ if focal is None:
+ optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3])
+ else:
+ optim_focal = focal.reshape(shape[:-3])
+
+ return optim_focal, optim_shift
+
+
+def mask_aware_nearest_resize(
+ inputs: Union[torch.Tensor, Sequence[torch.Tensor], None],
+ mask: torch.BoolTensor,
+ size: Tuple[int, int],
+ return_index: bool = False
+) -> Tuple[Union[torch.Tensor, Sequence[torch.Tensor], None], torch.BoolTensor, Tuple[torch.LongTensor, ...]]:
+ """
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
+
+ ### Parameters
+ - `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...).
+ - `mask`: input 2D mask of shape (..., H, W)
+ - `size`: target size (target_width, target_height)
+
+ ### Returns
+ - `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...).
+ - `resized_mask`: mask of the resized map of shape (..., target_height, target_width)
+ - `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension, .
+ """
+ height, width = mask.shape[-2:]
+ target_width, target_height = size
+ device = mask.device
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
+ filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
+ filter_size = filter_h_i * filter_w_i
+ padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
+
+ # Window the original mask and uv
+ uv = utils3d.torch.image_pixel_center(width=width, height=height, dtype=torch.float32, device=device)
+ indices = torch.arange(height * width, dtype=torch.long, device=device).reshape(height, width)
+ padded_uv = torch.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=torch.float32, device=device)
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
+ padded_mask = torch.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=torch.bool, device=device)
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
+ padded_indices = torch.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=torch.long, device=device)
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
+ windowed_uv = utils3d.torch.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, dim=(0, 1))
+ windowed_mask = utils3d.torch.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, dim=(-2, -1))
+ windowed_indices = utils3d.torch.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, dim=(0, 1))
+
+ # Gather the target pixels's local window
+ target_uv = utils3d.torch.image_uv(width=target_width, height=target_height, dtype=torch.float32, device=device) * torch.tensor([width, height], dtype=torch.float32, device=device)
+ target_lefttop = target_uv - torch.tensor((filter_w_f / 2, filter_h_f / 2), dtype=torch.float32, device=device)
+ target_window = torch.round(target_lefttop).long() + torch.tensor((padding_w, padding_h), dtype=torch.long, device=device)
+
+ target_window_uv = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
+ target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
+ target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
+ target_window_indices = target_window_indices.expand_as(target_window_mask)
+
+ # Compute nearest neighbor in the local window for each pixel
+ dist = torch.where(target_window_mask, torch.norm(target_window_uv - target_uv[..., None], dim=-2), torch.inf) # (..., target_height, tgt_width, filter_size)
+ nearest = torch.argmin(dist, dim=-1, keepdim=True) # (..., target_height, tgt_width, 1)
+ nearest_idx = torch.gather(target_window_indices, index=nearest, dim=-1).squeeze(-1) # (..., target_height, tgt_width)
+ target_mask = torch.any(target_window_mask, dim=-1)
+ nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
+ batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])]
+
+ index = (*batch_indices, nearest_i, nearest_j)
+
+ if inputs is None:
+ outputs = None
+ elif isinstance(inputs, torch.Tensor):
+ outputs = inputs[index]
+ elif isinstance(inputs, Sequence):
+ outputs = tuple(x[index] for x in inputs)
+ else:
+ raise ValueError(f'Invalid input type: {type(inputs)}')
+
+ if return_index:
+ return outputs, target_mask, index
+ else:
+ return outputs, target_mask
+
+
+def theshold_depth_change(depth: torch.Tensor, mask: torch.Tensor, pooler: Literal['min', 'max'], rtol: float = 0.2, kernel_size: int = 3):
+ *batch_shape, height, width = depth.shape
+ depth = depth.reshape(-1, 1, height, width)
+ mask = mask.reshape(-1, 1, height, width)
+ if pooler =='max':
+ pooled_depth = F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2)
+ output_mask = pooled_depth > depth * (1 + rtol)
+ elif pooler =='min':
+ pooled_depth = -F.max_pool2d(-torch.where(mask, depth, torch.inf), kernel_size, stride=1, padding=kernel_size // 2)
+ output_mask = pooled_depth < depth * (1 - rtol)
+ else:
+ raise ValueError(f'Unsupported pooler: {pooler}')
+ output_mask = output_mask.reshape(*batch_shape, height, width)
+ return output_mask
+
+
+def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1):
+ device, dtype = depth.device, depth.dtype
+
+ disp = torch.where(mask, 1 / depth, 0)
+ disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0)
+ mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False)
+ disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2]
+ mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2]
+
+ x = torch.linspace(-kernel_size // 2, kernel_size // 2, kernel_size, device=device, dtype=dtype)
+ A = torch.stack([*torch.meshgrid(x, x, indexing='xy'), torch.ones((kernel_size, kernel_size), device=device, dtype=dtype)], dim=-1).reshape(kernel_size ** 2, 3) # [kernel_size ** 2, 3]
+ A = mask_window[..., None] * A
+ I = torch.eye(3, device=device, dtype=dtype)
+
+ affine_disp_window = (disp_window[..., None, :] @ A @ torch.inverse(A.mT @ A + 1e-5 * I) @ A.mT).clamp_min(1e-12)[..., 0, :] # [..., H, W, kernel_size ** 2]
+ diff = torch.where(mask_window, torch.maximum(affine_disp_window, disp_window) / torch.minimum(affine_disp_window, disp_window) - 1, 0)
+
+ edge_mask = mask & (diff > tol).any(dim=-1)
+
+ disp_mean = weighted_mean(disp_window, mask_window, dim=-1)
+ fg_edge_mask = edge_mask & (disp > disp_mean)
+ # fg_edge_mask = edge_mask & theshold_depth_change(depth, mask, pooler='max', rtol=tol, kernel_size=kernel_size)
+ bg_edge_mask = edge_mask & ~fg_edge_mask
+ return fg_edge_mask, bg_edge_mask
+
+
+def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1):
+ device, dtype = depth.device, depth.dtype
+
+ disp = torch.where(mask, 1 / depth, 0)
+ disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0)
+ mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False)
+ disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2]
+ mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2]
+
+ disp_mean = weighted_mean(disp_window, mask_window, dim=(-2, -1))
+ fg_edge_mask = mask & (disp / disp_mean > 1 + tol)
+ bg_edge_mask = mask & (disp_mean / disp > 1 + tol)
+
+ fg_edge_mask = fg_edge_mask & F.max_pool2d(bg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool()
+ bg_edge_mask = bg_edge_mask & F.max_pool2d(fg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool()
+
+ return fg_edge_mask, bg_edge_mask
+
+
+def dilate_with_mask(input: torch.Tensor, mask: torch.BoolTensor, filter: Literal['min', 'max', 'mean', 'median'] = 'mean', iterations: int = 1) -> torch.Tensor:
+ kernel = torch.tensor([[False, True, False], [True, True, True], [False, True, False]], device=input.device, dtype=torch.bool)
+ for _ in range(iterations):
+ input_window = utils3d.torch.sliding_window_2d(F.pad(input, (1, 1, 1, 1), mode='constant', value=0), window_size=3, stride=1, dim=(-2, -1))
+ mask_window = kernel & utils3d.torch.sliding_window_2d(F.pad(mask, (1, 1, 1, 1), mode='constant', value=False), window_size=3, stride=1, dim=(-2, -1))
+ if filter =='min':
+ input = torch.where(mask, input, torch.where(mask_window, input_window, torch.inf).min(dim=(-2, -1)).values)
+ elif filter =='max':
+ input = torch.where(mask, input, torch.where(mask_window, input_window, -torch.inf).max(dim=(-2, -1)).values)
+ elif filter == 'mean':
+ input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).nanmean(dim=(-2, -1)))
+ elif filter =='median':
+ input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).flatten(-2).nanmedian(dim=-1).values)
+ mask = mask_window.any(dim=(-2, -1))
+ return input, mask
+
+
+def refine_depth_with_normal(depth: torch.Tensor, normal: torch.Tensor, intrinsics: torch.Tensor, iterations: int = 10, damp: float = 1e-3, eps: float = 1e-12, kernel_size: int = 5) -> torch.Tensor:
+ device, dtype = depth.device, depth.dtype
+ height, width = depth.shape[-2:]
+ radius = kernel_size // 2
+
+ duv = torch.stack(torch.meshgrid(torch.linspace(-radius / width, radius / width, kernel_size, device=device, dtype=dtype), torch.linspace(-radius / height, radius / height, kernel_size, device=device, dtype=dtype), indexing='xy'), dim=-1).to(dtype=dtype, device=device)
+
+ log_depth = depth.clamp_min_(eps).log()
+ log_depth_diff = utils3d.torch.sliding_window_2d(log_depth, window_size=kernel_size, stride=1, dim=(-2, -1)) - log_depth[..., radius:-radius, radius:-radius, None, None]
+
+ weight = torch.exp(-(log_depth_diff / duv.norm(dim=-1).clamp_min_(eps) / 10).square())
+ tot_weight = weight.sum(dim=(-2, -1)).clamp_min_(eps)
+
+ uv = utils3d.torch.image_uv(height=height, width=width, device=device, dtype=dtype)
+ K_inv = torch.inverse(intrinsics)
+
+ grad = -(normal[..., None, :2] @ K_inv[..., None, None, :2, :2]).squeeze(-2) \
+ / (normal[..., None, 2:] + normal[..., None, :2] @ (K_inv[..., None, None, :2, :2] @ uv[..., :, None] + K_inv[..., None, None, :2, 2:])).squeeze(-2)
+ laplacian = (weight * ((utils3d.torch.sliding_window_2d(grad, window_size=kernel_size, stride=1, dim=(-3, -2)) + grad[..., radius:-radius, radius:-radius, :, None, None]) * (duv.permute(2, 0, 1) / 2)).sum(dim=-3)).sum(dim=(-2, -1))
+
+ laplacian = laplacian.clamp(-0.1, 0.1)
+ log_depth_refine = log_depth.clone()
+
+ for _ in range(iterations):
+ log_depth_refine[..., radius:-radius, radius:-radius] = 0.1 * log_depth_refine[..., radius:-radius, radius:-radius] + 0.9 * (damp * log_depth[..., radius:-radius, radius:-radius] - laplacian + (weight * utils3d.torch.sliding_window_2d(log_depth_refine, window_size=kernel_size, stride=1, dim=(-2, -1))).sum(dim=(-2, -1))) / (tot_weight + damp)
+
+ depth_refine = log_depth_refine.exp()
+
+ return depth_refine
\ No newline at end of file
diff --git a/models/moge/utils/io.py b/models/moge/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..108548caaa34dfcbf394ed4b021874c5ac12edf8
--- /dev/null
+++ b/models/moge/utils/io.py
@@ -0,0 +1,236 @@
+import os
+os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
+from typing import IO
+import zipfile
+import json
+import io
+from typing import *
+from pathlib import Path
+import re
+from PIL import Image, PngImagePlugin
+
+import numpy as np
+import cv2
+
+from .tools import timeit
+
+
+def save_glb(
+ save_path: Union[str, os.PathLike],
+ vertices: np.ndarray,
+ faces: np.ndarray,
+ vertex_uvs: np.ndarray,
+ texture: np.ndarray,
+ vertex_normals: Optional[np.ndarray] = None,
+):
+ import trimesh
+ import trimesh.visual
+ from PIL import Image
+
+ trimesh.Trimesh(
+ vertices=vertices,
+ vertex_normals=vertex_normals,
+ faces=faces,
+ visual = trimesh.visual.texture.TextureVisuals(
+ uv=vertex_uvs,
+ material=trimesh.visual.material.PBRMaterial(
+ baseColorTexture=Image.fromarray(texture),
+ metallicFactor=0.5,
+ roughnessFactor=1.0
+ )
+ ),
+ process=False
+ ).export(save_path)
+
+
+def save_ply(
+ save_path: Union[str, os.PathLike],
+ vertices: np.ndarray,
+ faces: np.ndarray,
+ vertex_colors: np.ndarray,
+ vertex_normals: Optional[np.ndarray] = None,
+):
+ import trimesh
+ import trimesh.visual
+ from PIL import Image
+
+ trimesh.Trimesh(
+ vertices=vertices,
+ faces=faces,
+ vertex_colors=vertex_colors,
+ vertex_normals=vertex_normals,
+ process=False
+ ).export(save_path)
+
+
+def read_image(path: Union[str, os.PathLike, IO]) -> np.ndarray:
+ """
+ Read a image, return uint8 RGB array of shape (H, W, 3).
+ """
+ if isinstance(path, (str, os.PathLike)):
+ data = Path(path).read_bytes()
+ else:
+ data = path.read()
+ image = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
+ return image
+
+
+def write_image(path: Union[str, os.PathLike, IO], image: np.ndarray, quality: int = 95):
+ """
+ Write a image, input uint8 RGB array of shape (H, W, 3).
+ """
+ data = cv2.imencode('.jpg', cv2.cvtColor(image, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_JPEG_QUALITY, quality])[1].tobytes()
+ if isinstance(path, (str, os.PathLike)):
+ Path(path).write_bytes(data)
+ else:
+ path.write(data)
+
+
+def read_depth(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, float]:
+ """
+ Read a depth image, return float32 depth array of shape (H, W).
+ """
+ if isinstance(path, (str, os.PathLike)):
+ data = Path(path).read_bytes()
+ else:
+ data = path.read()
+ pil_image = Image.open(io.BytesIO(data))
+ near = float(pil_image.info.get('near'))
+ far = float(pil_image.info.get('far'))
+ unit = float(pil_image.info.get('unit')) if 'unit' in pil_image.info else None
+ depth = np.array(pil_image)
+ mask_nan, mask_inf = depth == 0, depth == 65535
+ depth = (depth.astype(np.float32) - 1) / 65533
+ depth = near ** (1 - depth) * far ** depth
+ depth[mask_nan] = np.nan
+ depth[mask_inf] = np.inf
+ return depth, unit
+
+
+def write_depth(
+ path: Union[str, os.PathLike, IO],
+ depth: np.ndarray,
+ unit: float = None,
+ max_range: float = 1e5,
+ compression_level: int = 7,
+):
+ """
+ Encode and write a depth image as 16-bit PNG format.
+ ### Parameters:
+ - `path: Union[str, os.PathLike, IO]`
+ The file path or file object to write to.
+ - `depth: np.ndarray`
+ The depth array, float32 array of shape (H, W).
+ May contain `NaN` for invalid values and `Inf` for infinite values.
+ - `unit: float = None`
+ The unit of the depth values.
+
+ Depth values are encoded as follows:
+ - 0: unknown
+ - 1 ~ 65534: depth values in logarithmic
+ - 65535: infinity
+
+ metadata is stored in the PNG file as text fields:
+ - `near`: the minimum depth value
+ - `far`: the maximum depth value
+ - `unit`: the unit of the depth values (optional)
+ """
+ mask_values, mask_nan, mask_inf = np.isfinite(depth), np.isnan(depth),np.isinf(depth)
+
+ depth = depth.astype(np.float32)
+ mask_finite = depth
+ near = max(depth[mask_values].min(), 1e-5)
+ far = max(near * 1.1, min(depth[mask_values].max(), near * max_range))
+ depth = 1 + np.round((np.log(np.nan_to_num(depth, nan=0).clip(near, far) / near) / np.log(far / near)).clip(0, 1) * 65533).astype(np.uint16) # 1~65534
+ depth[mask_nan] = 0
+ depth[mask_inf] = 65535
+
+ pil_image = Image.fromarray(depth)
+ pnginfo = PngImagePlugin.PngInfo()
+ pnginfo.add_text('near', str(near))
+ pnginfo.add_text('far', str(far))
+ if unit is not None:
+ pnginfo.add_text('unit', str(unit))
+ pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
+
+
+def read_segmentation(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, Dict[str, int]]:
+ """
+ Read a segmentation mask
+ ### Parameters:
+ - `path: Union[str, os.PathLike, IO]`
+ The file path or file object to read from.
+ ### Returns:
+ - `Tuple[np.ndarray, Dict[str, int]]`
+ A tuple containing:
+ - `mask`: uint8 or uint16 numpy.ndarray of shape (H, W).
+ - `labels`: Dict[str, int]. The label mapping, a dictionary of {label_name: label_id}.
+ """
+ if isinstance(path, (str, os.PathLike)):
+ data = Path(path).read_bytes()
+ else:
+ data = path.read()
+ pil_image = Image.open(io.BytesIO(data))
+ labels = json.loads(pil_image.info['labels']) if 'labels' in pil_image.info else None
+ mask = np.array(pil_image)
+ return mask, labels
+
+
+def write_segmentation(path: Union[str, os.PathLike, IO], mask: np.ndarray, labels: Dict[str, int] = None, compression_level: int = 7):
+ """
+ Write a segmentation mask and label mapping, as PNG format.
+ ### Parameters:
+ - `path: Union[str, os.PathLike, IO]`
+ The file path or file object to write to.
+ - `mask: np.ndarray`
+ The segmentation mask, uint8 or uint16 array of shape (H, W).
+ - `labels: Dict[str, int] = None`
+ The label mapping, a dictionary of {label_name: label_id}.
+ - `compression_level: int = 7`
+ The compression level for PNG compression.
+ """
+ assert mask.dtype == np.uint8 or mask.dtype == np.uint16, f"Unsupported dtype {mask.dtype}"
+ pil_image = Image.fromarray(mask)
+ pnginfo = PngImagePlugin.PngInfo()
+ if labels is not None:
+ labels_json = json.dumps(labels, ensure_ascii=True, separators=(',', ':'))
+ pnginfo.add_text('labels', labels_json)
+ pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
+
+
+
+def read_normal(path: Union[str, os.PathLike, IO]) -> np.ndarray:
+ """
+ Read a normal image, return float32 normal array of shape (H, W, 3).
+ """
+ if isinstance(path, (str, os.PathLike)):
+ data = Path(path).read_bytes()
+ else:
+ data = path.read()
+ normal = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB)
+ mask_nan = np.all(normal == 0, axis=-1)
+ normal = (normal.astype(np.float32) / 65535 - 0.5) * [2.0, -2.0, -2.0]
+ normal = normal / (np.sqrt(np.square(normal[..., 0]) + np.square(normal[..., 1]) + np.square(normal[..., 2])) + 1e-12)
+ normal[mask_nan] = np.nan
+ return normal
+
+
+def write_normal(path: Union[str, os.PathLike, IO], normal: np.ndarray, compression_level: int = 7) -> np.ndarray:
+ """
+ Write a normal image, input float32 normal array of shape (H, W, 3).
+ """
+ mask_nan = np.isnan(normal).any(axis=-1)
+ normal = ((normal * [0.5, -0.5, -0.5] + 0.5).clip(0, 1) * 65535).astype(np.uint16)
+ normal[mask_nan] = 0
+ data = cv2.imencode('.png', cv2.cvtColor(normal, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes()
+ if isinstance(path, (str, os.PathLike)):
+ Path(path).write_bytes(data)
+ else:
+ path.write(data)
+
+
+def read_meta(path: Union[str, os.PathLike, IO]) -> Dict[str, Any]:
+ return json.loads(Path(path).read_text())
+
+def write_meta(path: Union[str, os.PathLike, IO], meta: Dict[str, Any]):
+ Path(path).write_text(json.dumps(meta))
\ No newline at end of file
diff --git a/models/moge/utils/panorama.py b/models/moge/utils/panorama.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f9d121c3c189770a7fd9f88be66f74f1ba5cfd3
--- /dev/null
+++ b/models/moge/utils/panorama.py
@@ -0,0 +1,191 @@
+import os
+os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
+from pathlib import Path
+from typing import *
+import itertools
+import json
+import warnings
+
+import cv2
+import numpy as np
+from numpy import ndarray
+from tqdm import tqdm, trange
+from scipy.sparse import csr_array, hstack, vstack
+from scipy.ndimage import convolve
+from scipy.sparse.linalg import lsmr
+
+import utils3d
+
+
+def get_panorama_cameras():
+ vertices, _ = utils3d.numpy.icosahedron()
+ intrinsics = utils3d.numpy.intrinsics_from_fov(fov_x=np.deg2rad(90), fov_y=np.deg2rad(90))
+ extrinsics = utils3d.numpy.extrinsics_look_at([0, 0, 0], vertices, [0, 0, 1]).astype(np.float32)
+ return extrinsics, [intrinsics] * len(vertices)
+
+
+def spherical_uv_to_directions(uv: np.ndarray):
+ theta, phi = (1 - uv[..., 0]) * (2 * np.pi), uv[..., 1] * np.pi
+ directions = np.stack([np.sin(phi) * np.cos(theta), np.sin(phi) * np.sin(theta), np.cos(phi)], axis=-1)
+ return directions
+
+
+def directions_to_spherical_uv(directions: np.ndarray):
+ directions = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
+ u = 1 - np.arctan2(directions[..., 1], directions[..., 0]) / (2 * np.pi) % 1.0
+ v = np.arccos(directions[..., 2]) / np.pi
+ return np.stack([u, v], axis=-1)
+
+
+def split_panorama_image(image: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray, resolution: int):
+ height, width = image.shape[:2]
+ uv = utils3d.numpy.image_uv(width=resolution, height=resolution)
+ splitted_images = []
+ for i in range(len(extrinsics)):
+ spherical_uv = directions_to_spherical_uv(utils3d.numpy.unproject_cv(uv, extrinsics=extrinsics[i], intrinsics=intrinsics[i]))
+ pixels = utils3d.numpy.uv_to_pixel(spherical_uv, width=width, height=height).astype(np.float32)
+
+ splitted_image = cv2.remap(image, pixels[..., 0], pixels[..., 1], interpolation=cv2.INTER_LINEAR)
+ splitted_images.append(splitted_image)
+ return splitted_images
+
+
+def poisson_equation(width: int, height: int, wrap_x: bool = False, wrap_y: bool = False) -> Tuple[csr_array, ndarray]:
+ grid_index = np.arange(height * width).reshape(height, width)
+ grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode='wrap' if wrap_x else 'edge')
+ grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode='wrap' if wrap_y else 'edge')
+
+ data = np.array([[-4, 1, 1, 1, 1]], dtype=np.float32).repeat(height * width, axis=0).reshape(-1)
+ indices = np.stack([
+ grid_index[1:-1, 1:-1],
+ grid_index[:-2, 1:-1], # up
+ grid_index[2:, 1:-1], # down
+ grid_index[1:-1, :-2], # left
+ grid_index[1:-1, 2:] # right
+ ], axis=-1).reshape(-1)
+ indptr = np.arange(0, height * width * 5 + 1, 5)
+ A = csr_array((data, indices, indptr), shape=(height * width, height * width))
+
+ return A
+
+
+def grad_equation(width: int, height: int, wrap_x: bool = False, wrap_y: bool = False) -> Tuple[csr_array, np.ndarray]:
+ grid_index = np.arange(width * height).reshape(height, width)
+ if wrap_x:
+ grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode='wrap')
+ if wrap_y:
+ grid_index = np.pad(grid_index, ((0, 1), (0, 0)), mode='wrap')
+
+ data = np.concatenate([
+ np.concatenate([
+ np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), # x[i,j]
+ -np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), # x[i,j-1]
+ ], axis=1).reshape(-1),
+ np.concatenate([
+ np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), # x[i,j]
+ -np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), # x[i-1,j]
+ ], axis=1).reshape(-1),
+ ])
+ indices = np.concatenate([
+ np.concatenate([
+ grid_index[:, :-1].reshape(-1, 1),
+ grid_index[:, 1:].reshape(-1, 1),
+ ], axis=1).reshape(-1),
+ np.concatenate([
+ grid_index[:-1, :].reshape(-1, 1),
+ grid_index[1:, :].reshape(-1, 1),
+ ], axis=1).reshape(-1),
+ ])
+ indptr = np.arange(0, grid_index.shape[0] * (grid_index.shape[1] - 1) * 2 + (grid_index.shape[0] - 1) * grid_index.shape[1] * 2 + 1, 2)
+ A = csr_array((data, indices, indptr), shape=(grid_index.shape[0] * (grid_index.shape[1] - 1) + (grid_index.shape[0] - 1) * grid_index.shape[1], height * width))
+
+ return A
+
+
+def merge_panorama_depth(width: int, height: int, distance_maps: List[np.ndarray], pred_masks: List[np.ndarray], extrinsics: List[np.ndarray], intrinsics: List[np.ndarray]):
+ if max(width, height) > 256:
+ panorama_depth_init, _ = merge_panorama_depth(width // 2, height // 2, distance_maps, pred_masks, extrinsics, intrinsics)
+ panorama_depth_init = cv2.resize(panorama_depth_init, (width, height), cv2.INTER_LINEAR)
+ else:
+ panorama_depth_init = None
+
+ uv = utils3d.numpy.image_uv(width=width, height=height)
+ spherical_directions = spherical_uv_to_directions(uv)
+
+ # Warp each view to the panorama
+ panorama_log_distance_grad_maps, panorama_grad_masks = [], []
+ panorama_log_distance_laplacian_maps, panorama_laplacian_masks = [], []
+ panorama_pred_masks = []
+ for i in range(len(distance_maps)):
+ projected_uv, projected_depth = utils3d.numpy.project_cv(spherical_directions, extrinsics=extrinsics[i], intrinsics=intrinsics[i])
+ projection_valid_mask = (projected_depth > 0) & (projected_uv > 0).all(axis=-1) & (projected_uv < 1).all(axis=-1)
+
+ projected_pixels = utils3d.numpy.uv_to_pixel(np.clip(projected_uv, 0, 1), width=distance_maps[i].shape[1], height=distance_maps[i].shape[0]).astype(np.float32)
+
+ log_splitted_distance = np.log(distance_maps[i])
+ panorama_log_distance_map = np.where(projection_valid_mask, cv2.remap(log_splitted_distance, projected_pixels[..., 0], projected_pixels[..., 1], cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE), 0)
+ panorama_pred_mask = projection_valid_mask & (cv2.remap(pred_masks[i].astype(np.uint8), projected_pixels[..., 0], projected_pixels[..., 1], cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE) > 0)
+
+ # calculate gradient map
+ padded = np.pad(panorama_log_distance_map, ((0, 0), (0, 1)), mode='wrap')
+ grad_x, grad_y = padded[:, :-1] - padded[:, 1:], padded[:-1, :] - padded[1:, :]
+
+ padded = np.pad(panorama_pred_mask, ((0, 0), (0, 1)), mode='wrap')
+ mask_x, mask_y = padded[:, :-1] & padded[:, 1:], padded[:-1, :] & padded[1:, :]
+
+ panorama_log_distance_grad_maps.append((grad_x, grad_y))
+ panorama_grad_masks.append((mask_x, mask_y))
+
+ # calculate laplacian map
+ padded = np.pad(panorama_log_distance_map, ((1, 1), (0, 0)), mode='edge')
+ padded = np.pad(padded, ((0, 0), (1, 1)), mode='wrap')
+ laplacian = convolve(padded, np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32))[1:-1, 1:-1]
+
+ padded = np.pad(panorama_pred_mask, ((1, 1), (0, 0)), mode='edge')
+ padded = np.pad(padded, ((0, 0), (1, 1)), mode='wrap')
+ mask = convolve(padded.astype(np.uint8), np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.uint8))[1:-1, 1:-1] == 5
+
+ panorama_log_distance_laplacian_maps.append(laplacian)
+ panorama_laplacian_masks.append(mask)
+
+ panorama_pred_masks.append(panorama_pred_mask)
+
+ panorama_log_distance_grad_x = np.stack([grad_map[0] for grad_map in panorama_log_distance_grad_maps], axis=0)
+ panorama_log_distance_grad_y = np.stack([grad_map[1] for grad_map in panorama_log_distance_grad_maps], axis=0)
+ panorama_grad_mask_x = np.stack([mask_map[0] for mask_map in panorama_grad_masks], axis=0)
+ panorama_grad_mask_y = np.stack([mask_map[1] for mask_map in panorama_grad_masks], axis=0)
+
+ panorama_log_distance_grad_x = np.sum(panorama_log_distance_grad_x * panorama_grad_mask_x, axis=0) / np.sum(panorama_grad_mask_x, axis=0).clip(1e-3)
+ panorama_log_distance_grad_y = np.sum(panorama_log_distance_grad_y * panorama_grad_mask_y, axis=0) / np.sum(panorama_grad_mask_y, axis=0).clip(1e-3)
+
+ panorama_laplacian_maps = np.stack(panorama_log_distance_laplacian_maps, axis=0)
+ panorama_laplacian_masks = np.stack(panorama_laplacian_masks, axis=0)
+ panorama_laplacian_map = np.sum(panorama_laplacian_maps * panorama_laplacian_masks, axis=0) / np.sum(panorama_laplacian_masks, axis=0).clip(1e-3)
+
+ grad_x_mask = np.any(panorama_grad_mask_x, axis=0).reshape(-1)
+ grad_y_mask = np.any(panorama_grad_mask_y, axis=0).reshape(-1)
+ grad_mask = np.concatenate([grad_x_mask, grad_y_mask])
+ laplacian_mask = np.any(panorama_laplacian_masks, axis=0).reshape(-1)
+
+ # Solve overdetermined system
+ A = vstack([
+ grad_equation(width, height, wrap_x=True, wrap_y=False)[grad_mask],
+ poisson_equation(width, height, wrap_x=True, wrap_y=False)[laplacian_mask],
+ ])
+ b = np.concatenate([
+ panorama_log_distance_grad_x.reshape(-1)[grad_x_mask],
+ panorama_log_distance_grad_y.reshape(-1)[grad_y_mask],
+ panorama_laplacian_map.reshape(-1)[laplacian_mask]
+ ])
+ x, *_ = lsmr(
+ A, b,
+ atol=1e-5, btol=1e-5,
+ x0=np.log(panorama_depth_init).reshape(-1) if panorama_depth_init is not None else None,
+ show=False,
+ )
+
+ panorama_depth = np.exp(x).reshape(height, width).astype(np.float32)
+ panorama_mask = np.any(panorama_pred_masks, axis=0)
+
+ return panorama_depth, panorama_mask
+
diff --git a/models/moge/utils/pipeline.py b/models/moge/utils/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..daa522e987317e949899d4159e61d7a7066e1fba
--- /dev/null
+++ b/models/moge/utils/pipeline.py
@@ -0,0 +1,503 @@
+from typing import *
+from abc import abstractmethod
+from queue import Empty, Full
+from threading import Thread
+from queue import Queue
+from multiprocessing import Process
+from threading import Thread, Event
+import multiprocessing
+import threading
+import inspect
+import time
+import uuid
+from copy import deepcopy
+import itertools
+import functools
+
+__all__ = [
+ 'Node',
+ 'Link',
+ 'ConcurrentNode',
+ 'Worker',
+ 'WorkerFunction',
+ 'Provider',
+ 'ProviderFunction',
+ 'Sequential',
+ 'Batch',
+ 'Unbatch',
+ 'Parallel',
+ 'Graph',
+ 'Buffer',
+]
+
+TERMINATE_CHECK_INTERVAL = 0.5
+
+
+class _ItemWrapper:
+ def __init__(self, data: Any, id: Union[int, List[int]] = None):
+ self.data = data
+ self.id = id
+
+
+class Terminate(Exception):
+ pass
+
+
+def _get_queue_item(queue: Queue, terminate_flag: Event, timeout: float = None) -> _ItemWrapper:
+ while True:
+ try:
+ item: _ItemWrapper = queue.get(block=True, timeout=TERMINATE_CHECK_INTERVAL if timeout is None else min(timeout, TERMINATE_CHECK_INTERVAL))
+ if terminate_flag.is_set():
+ raise Terminate()
+ return item
+ except Empty:
+ if terminate_flag.is_set():
+ raise Terminate()
+
+ if timeout is not None:
+ timeout -= TERMINATE_CHECK_INTERVAL
+ if timeout <= 0:
+ raise Empty()
+
+
+def _put_queue_item(queue: Queue, item: _ItemWrapper, terminate_flag: Event):
+ while True:
+ try:
+ queue.put(item, block=True, timeout=TERMINATE_CHECK_INTERVAL)
+ if terminate_flag.is_set():
+ raise Terminate()
+ return
+ except Full:
+ if terminate_flag.is_set():
+ raise Terminate()
+
+class Node:
+ def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
+ self.input: Queue = Queue(maxsize=in_buffer_size)
+ self.output: Queue = Queue(maxsize=out_buffer_size)
+ self.in_buffer_size = in_buffer_size
+ self.out_buffer_size = out_buffer_size
+
+ @abstractmethod
+ def start(self):
+ pass
+
+ @abstractmethod
+ def terminate(self):
+ pass
+
+ def stop(self):
+ self.terminate()
+ self.join()
+
+ @abstractmethod
+ def join(self):
+ pass
+
+ def put(self, data: Any, key: str = None, block: bool = True) -> None:
+ item = _ItemWrapper(data)
+ self.input.put(item, block=block)
+
+ def get(self, key: str = None, block: bool = True) -> Any:
+ item: _ItemWrapper = self.output.get(block=block)
+ return item.data
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.terminate()
+ self.join()
+
+
+class ConcurrentNode(Node):
+ job: Union[Thread, Process]
+
+ def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
+ super().__init__(in_buffer_size, out_buffer_size)
+ self.running_as = running_as
+
+ @abstractmethod
+ def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
+ pass
+
+ def start(self):
+ if self.running_as == 'thread':
+ terminate_flag = threading.Event()
+ job = Thread(target=self._loop_fn, args=(self.input, self.output, terminate_flag))
+ elif self.running_as == 'process':
+ terminate_flag = multiprocessing.Event()
+ job = Process(target=self._loop_fn, args=(self.input, self.output, terminate_flag))
+ job.start()
+ self.job = job
+ self.terminate_flag = terminate_flag
+
+ def terminate(self):
+ self.terminate_flag.set()
+
+ def join(self):
+ self.job.join()
+
+
+class Worker(ConcurrentNode):
+ def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 0, out_buffer_size: int = 0) -> None:
+ super().__init__(running_as, in_buffer_size, out_buffer_size)
+
+ def init(self) -> None:
+ """
+ This method is called the the thread is started, to initialize any resources that is only held in the thread.
+ """
+ pass
+
+ @abstractmethod
+ def work(self, *args, **kwargs) -> Union[Any, Dict[str, Any]]:
+ """
+ This method defines the job that the node should do for each input item.
+ A item obtained from the input queue is passed as arguments to this method, and the result is placed in the output queue.
+ The method is executed concurrently with other nodes.
+ """
+ pass
+
+ def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
+ self.init()
+ try:
+ while True:
+ item = _get_queue_item(input, terminate_flag)
+ result = self.work(item.data)
+ _put_queue_item(output, _ItemWrapper(result, item.id), terminate_flag)
+
+ except Terminate:
+ return
+
+
+class Provider(ConcurrentNode):
+ """
+ A node that provides data to successive nodes. It takes no input and provides data to the output queue.
+ """
+ def __init__(self, running_as: Literal['thread', 'process'], out_buffer_size: int = 1) -> None:
+ super().__init__(running_as, 0, out_buffer_size)
+
+ def init(self) -> None:
+ """
+ This method is called the the thread or process is started, to initialize any resources that is only held in the thread or process.
+ """
+ pass
+
+ @abstractmethod
+ def provide(self) -> Generator[Any, None, None]:
+ pass
+
+ def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
+ self.init()
+ try:
+ for data in self.provide():
+ _put_queue_item(output, _ItemWrapper(data), terminate_flag)
+ except Terminate:
+ return
+
+
+class WorkerFunction(Worker):
+ def __init__(self, fn: Callable, running_as: 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
+ super().__init__(running_as, in_buffer_size, out_buffer_size)
+ self.fn = fn
+
+ def work(self, *args, **kwargs):
+ return self.fn(*args, **kwargs)
+
+
+class ProviderFunction(Provider):
+ def __init__(self, fn: Callable, running_as: 'thread', out_buffer_size: int = 1) -> None:
+ super().__init__(running_as, out_buffer_size)
+ self.fn = fn
+
+ def provide(self):
+ for item in self.fn():
+ yield item
+
+
+class Link:
+ def __init__(self, src: Queue, dst: Queue):
+ self.src = src
+ self.dst = dst
+
+ def _thread_fn(self):
+ try:
+ while True:
+ item = _get_queue_item(self.src, self.terminate_flag)
+ _put_queue_item(self.dst, item, self.terminate_flag)
+ except Terminate:
+ return
+
+ def start(self):
+ self.terminate_flag = threading.Event()
+ self.thread = Thread(target=self._thread_fn)
+ self.thread.start()
+
+ def terminate(self):
+ self.terminate_flag.set()
+
+ def join(self):
+ self.thread.join()
+
+
+class Graph(Node):
+ """
+ Graph pipeline of nodes and links
+ """
+ nodes: List[Node]
+ links: List[Link]
+
+ def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1):
+ super().__init__(in_buffer_size, out_buffer_size)
+ self.nodes = []
+ self.links = []
+
+ def add(self, node: Node):
+ self.nodes.append(node)
+
+ def link(self, src: Union[Node, Tuple[Node, str]], dst: Union[Node, Tuple[Node, str]]):
+ """
+ Links the output of the source node to the input of the destination node.
+ If the source or destination node is None, the pipeline's input or output is used.
+ """
+ src_queue = self.input if src is None else src.output
+ dst_queue = self.output if dst is None else dst.input
+ self.links.append(Link(src_queue, dst_queue))
+
+ def chain(self, nodes: Iterable[Node]):
+ """
+ Link the output of each node to the input of the next node.
+ """
+ nodes = list(nodes)
+ for i in range(len(nodes) - 1):
+ self.link(nodes[i], nodes[i + 1])
+
+ def start(self):
+ for node in self.nodes:
+ node.start()
+ for link in self.links:
+ link.start()
+
+ def terminate(self):
+ for node in self.nodes:
+ node.terminate()
+ for link in self.links:
+ link.terminate()
+
+ def join(self):
+ for node in self.nodes:
+ node.join()
+ for link in self.links:
+ link.join()
+
+ def __iter__(self):
+ providers = [node for node in self.nodes if isinstance(node, Provider)]
+ if len(providers) == 0:
+ raise ValueError("No provider node found in the pipeline. If you want to iterate over the pipeline, the pipeline must be driven by a provider node.")
+ with self:
+ # while all(provider.job.is_alive() for provider in providers):
+ while True:
+ yield self.get()
+
+ def __call__(self, data: Any) -> Any:
+ """
+ Submit data to the pipeline's input queue, and return the output data asynchronously.
+ NOTE: The pipeline must be streamed (i.e., every output item is uniquely associated with an input item) for this to work.
+ """
+ # TODO
+
+
+class Sequential(Graph):
+ """
+ Pipeline of nodes in sequential order, where each node takes the output of the previous node as input.
+ The order of input and output items is preserved (FIFO)
+ """
+ def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1):
+ """
+ Initialize the pipeline with a list of nodes to execute sequentially.
+ ### Parameters:
+ - nodes: List of nodes or functions to execute sequentially. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes.
+ - function_running_as: Whether to wrap the function as a thread or process worker. Defaults to 'thread'.
+ - in_buffer_size: Maximum size of the input queue of the pipeline. Defaults to 0 (unlimited).
+ - out_buffer_size: Maximum size of the output queue of the pipeline. Defaults to 0 (unlimited).
+ """
+ super().__init__(in_buffer_size, out_buffer_size)
+ for node in nodes:
+ if isinstance(node, Node):
+ pass
+ elif isinstance(node, Callable):
+ if inspect.isgeneratorfunction(node):
+ node = ProviderFunction(node, function_running_as)
+ else:
+ node = WorkerFunction(node, function_running_as)
+ else:
+ raise ValueError(f"Invalid node type: {type(node)}")
+ self.add(node)
+ self.chain([None, *self.nodes, None])
+
+
+class Parallel(Node):
+ """
+ A FIFO node that runs multiple nodes in parallel to process the input items. Each input item is handed to one of the nodes whoever is available.
+ NOTE: It is FIFO if and only if all the nested nodes are FIFO.
+ """
+ nodes: List[Node]
+
+ def __init__(self, nodes: Iterable[Node], in_buffer_size: int = 1, out_buffer_size: int = 1, function_running_as: Literal['thread', 'process'] = 'thread'):
+ super().__init__(in_buffer_size, out_buffer_size)
+ self.nodes = []
+ for node in nodes:
+ if isinstance(node, Node):
+ pass
+ elif isinstance(node, Callable):
+ if inspect.isgeneratorfunction(node):
+ node = ProviderFunction(node, function_running_as)
+ else:
+ node = WorkerFunction(node, function_running_as)
+ else:
+ raise ValueError(f"Invalid node type: {type(node)}")
+ self.nodes.append(node)
+ self.output_order = Queue()
+ self.lock = threading.Lock()
+
+ def _in_thread_fn(self, node: Node):
+ try:
+ while True:
+ with self.lock:
+ # A better idea: first make sure its node is vacant, then get it a new item.
+ # Currently we will not be able to know which node is busy util there is at least one item already waiting in the queue of the node.
+ # This could lead to suboptimal scheduling.
+ item = _get_queue_item(self.input, self.terminate_flag)
+ self.output_order.put(node.output)
+ _put_queue_item(node.input, item, self.terminate_flag)
+ except Terminate:
+ return
+
+ def _out_thread_fn(self):
+ try:
+ while True:
+ queue = _get_queue_item(self.output_order, self.terminate_flag)
+ item = _get_queue_item(queue, self.terminate_flag)
+ _put_queue_item(self.output, item, self.terminate_flag)
+ except Terminate:
+ return
+
+ def start(self):
+ self.terminate_flag = threading.Event()
+ self.in_threads = []
+ for node in self.nodes:
+ thread = Thread(target=self._in_thread_fn, args=(node,))
+ thread.start()
+ self.in_threads.append(thread)
+ thread = Thread(target=self._out_thread_fn)
+ thread.start()
+ self.out_thread = thread
+ for node in self.nodes:
+ node.start()
+
+ def terminate(self):
+ self.terminate_flag.set()
+ for node in self.nodes:
+ node.terminate()
+
+ def join(self):
+ for thread in self.in_threads:
+ thread.join()
+ self.out_thread.join()
+
+
+class UnorderedParallel(Graph):
+ """
+ Pipeline of nodes in parallel, where each input item is handed to one of the nodes whoever is available.
+ NOTE: The order of the output items is NOT guaranteed to be the same as the input items, depending on how fast the nodes handle their input.
+ """
+ def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1):
+ """
+ Initialize the pipeline with a list of nodes to execute in parallel. If a function is given, it is wrapped in a worker node.
+ ### Parameters:
+ - nodes: List of nodes or functions to execute in parallel. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes.
+ - function_running_as: Whether to wrap the function as a thread or process worker. Defaults to 'thread'.
+ - in_buffer_size: Maximum size of the input queue of the pipeline. Defaults to 0 (unlimited).
+ - out_buffer_size: Maximum size of the output queue of the pipeline. Defaults to 0 (unlimited).
+ """
+ super().__init__(in_buffer_size, out_buffer_size)
+ for node in nodes:
+ if isinstance(node, Node):
+ pass
+ elif isinstance(node, Callable):
+ if inspect.isgeneratorfunction(node):
+ node = ProviderFunction(node, function_running_as)
+ else:
+ node = WorkerFunction(node, function_running_as)
+ else:
+ raise ValueError(f"Invalid node type: {type(node)}")
+ self.add(node)
+ for i in range(len(nodes)):
+ self.chain([None, self.nodes[i], None])
+
+
+class Batch(ConcurrentNode):
+ """
+ Groups every `batch_size` items into a batch (a list of items) and passes the batch to successive nodes.
+ The `patience` parameter specifies the maximum time to wait for a batch to be filled before sending it to the next node,
+ i.e., when the earliest item in the batch is out of `patience` seconds, the batch is sent regardless of its size.
+ """
+ def __init__(self, batch_size: int, patience: float = None, in_buffer_size: int = 1, out_buffer_size: int = 1):
+ assert batch_size > 0, "Batch size must be greater than 0."
+ super().__init__('thread', in_buffer_size, out_buffer_size)
+ self.batch_size = batch_size
+ self.patience = patience
+
+ def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
+ try:
+ while True:
+ batch_id, batch_data = [], []
+ # Try to fill the batch
+ for i in range(self.batch_size):
+ if i == 0 or self.patience is None:
+ timeout = None
+ else:
+ timeout = self.patience - (time.time() - earliest_time)
+ if timeout < 0:
+ break
+ try:
+ item = _get_queue_item(input, terminate_flag, timeout)
+ except Empty:
+ break
+
+ if i == 0:
+ earliest_time = time.time()
+ batch_data.append(item.data)
+ batch_id.append(item.id)
+
+ batch = _ItemWrapper(batch_data, batch_id)
+ _put_queue_item(output, batch, terminate_flag)
+ except Terminate:
+ return
+
+
+class Unbatch(ConcurrentNode):
+ """
+ Ungroups every batch (a list of items) into individual items and passes them to successive nodes.
+ """
+ def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1):
+ super().__init__('thread', in_buffer_size, out_buffer_size)
+
+ def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
+ try:
+ while True:
+ batch = _get_queue_item(input, terminate_flag)
+ for id, data in zip(batch.id or itertools.repeat(None), batch.data):
+ item = _ItemWrapper(data, id)
+ _put_queue_item(output, item, terminate_flag)
+ except Terminate:
+ return
+
+
+class Buffer(Node):
+ "A FIFO node that buffers items in a queue. Usefull achieve better temporal balance when its successor node has a variable processing time."
+ def __init__(self, size: int):
+ super().__init__(size, size)
+ self.size = size
+ self.input = self.output = Queue(maxsize=size)
\ No newline at end of file
diff --git a/models/moge/utils/tools.py b/models/moge/utils/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..3687f6938fe34433d149a1a8405be7eed5f23c37
--- /dev/null
+++ b/models/moge/utils/tools.py
@@ -0,0 +1,289 @@
+from typing import *
+import time
+from pathlib import Path
+from numbers import Number
+from functools import wraps
+import warnings
+import math
+import json
+import os
+import importlib
+import importlib.util
+
+
+def catch_exception(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ try:
+ return fn(*args, **kwargs)
+ except Exception as e:
+ import traceback
+ print(f"Exception in {fn.__name__}", end='r')
+ # print({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())})
+ traceback.print_exc(chain=False)
+ time.sleep(0.1)
+ return None
+ return wrapper
+
+
+class CallbackOnException:
+ def __init__(self, callback: Callable, exception: type):
+ self.exception = exception
+ self.callback = callback
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if isinstance(exc_val, self.exception):
+ self.callback()
+ return True
+ return False
+
+def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]:
+ for k, v in d.items():
+ if isinstance(v, dict):
+ for sub_key in traverse_nested_dict_keys(v):
+ yield (k, ) + sub_key
+ else:
+ yield (k, )
+
+
+def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None):
+ for k in keys:
+ d = d.get(k, default)
+ if d is None:
+ break
+ return d
+
+def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any):
+ for k in keys[:-1]:
+ d = d.setdefault(k, {})
+ d[keys[-1]] = value
+
+
+def key_average(list_of_dicts: list) -> Dict[str, Any]:
+ """
+ Returns a dictionary with the average value of each key in the input list of dictionaries.
+ """
+ _nested_dict_keys = set()
+ for d in list_of_dicts:
+ _nested_dict_keys.update(traverse_nested_dict_keys(d))
+ _nested_dict_keys = sorted(_nested_dict_keys)
+ result = {}
+ for k in _nested_dict_keys:
+ values = []
+ for d in list_of_dicts:
+ v = get_nested_dict(d, k)
+ if v is not None and not math.isnan(v):
+ values.append(v)
+ avg = sum(values) / len(values) if values else float('nan')
+ set_nested_dict(result, k, avg)
+ return result
+
+
+def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]:
+ """
+ Flattens a nested dictionary into a single-level dictionary, with keys as tuples.
+ """
+ items = []
+ if parent_key is None:
+ parent_key = ()
+ for k, v in d.items():
+ new_key = parent_key + (k, )
+ if isinstance(v, MutableMapping):
+ items.extend(flatten_nested_dict(v, new_key).items())
+ else:
+ items.append((new_key, v))
+ return dict(items)
+
+
+def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Unflattens a single-level dictionary into a nested dictionary, with keys as tuples.
+ """
+ result = {}
+ for k, v in d.items():
+ sub_dict = result
+ for k_ in k[:-1]:
+ if k_ not in sub_dict:
+ sub_dict[k_] = {}
+ sub_dict = sub_dict[k_]
+ sub_dict[k[-1]] = v
+ return result
+
+
+def read_jsonl(file):
+ import json
+ with open(file, 'r') as f:
+ data = f.readlines()
+ return [json.loads(line) for line in data]
+
+
+def write_jsonl(data: List[dict], file):
+ import json
+ with open(file, 'w') as f:
+ for item in data:
+ f.write(json.dumps(item) + '\n')
+
+
+def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]):
+ import pandas as pd
+ data = [flatten_nested_dict(d) for d in data]
+ df = pd.DataFrame(data)
+ df = df.sort_index(axis=1)
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
+ return df
+
+
+def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]):
+ if isinstance(d, str):
+ for old, new in mapping.items():
+ d = d.replace(old, new)
+ elif isinstance(d, list):
+ for i, item in enumerate(d):
+ d[i] = recursive_replace(item, mapping)
+ elif isinstance(d, dict):
+ for k, v in d.items():
+ d[k] = recursive_replace(v, mapping)
+ return d
+
+
+class timeit:
+ _history: Dict[str, List['timeit']] = {}
+
+ def __init__(self, name: str = None, verbose: bool = True, average: bool = False):
+ self.name = name
+ self.verbose = verbose
+ self.start = None
+ self.end = None
+ self.average = average
+ if average and name not in timeit._history:
+ timeit._history[name] = []
+
+ def __call__(self, func: Callable):
+ import inspect
+ if inspect.iscoroutinefunction(func):
+ async def wrapper(*args, **kwargs):
+ with timeit(self.name or func.__qualname__):
+ ret = await func(*args, **kwargs)
+ return ret
+ return wrapper
+ else:
+ def wrapper(*args, **kwargs):
+ with timeit(self.name or func.__qualname__):
+ ret = func(*args, **kwargs)
+ return ret
+ return wrapper
+
+ def __enter__(self):
+ self.start = time.time()
+ return self
+
+ @property
+ def time(self) -> float:
+ assert self.start is not None, "Time not yet started."
+ assert self.end is not None, "Time not yet ended."
+ return self.end - self.start
+
+ @property
+ def average_time(self) -> float:
+ assert self.average, "Average time not available."
+ return sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name])
+
+ @property
+ def history(self) -> List['timeit']:
+ return timeit._history.get(self.name, [])
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.end = time.time()
+ if self.average:
+ timeit._history[self.name].append(self)
+ if self.verbose:
+ if self.average:
+ avg = self.average_time
+ print(f"{self.name or 'It'} took {avg:.6f} seconds in average.")
+ else:
+ print(f"{self.name or 'It'} took {self.time:.6f} seconds.")
+
+
+def strip_common_prefix_suffix(strings: List[str]) -> List[str]:
+ first = strings[0]
+
+ for start in range(len(first)):
+ if any(s[start] != strings[0][start] for s in strings):
+ break
+
+ for end in range(1, min(len(s) for s in strings)):
+ if any(s[-end] != first[-end] for s in strings):
+ break
+
+ return [s[start:len(s) - end + 1] for s in strings]
+
+
+def multithead_execute(inputs: List[Any], num_workers: int, pbar = None):
+ from concurrent.futures import ThreadPoolExecutor
+ from contextlib import nullcontext
+ from tqdm import tqdm
+
+ if pbar is not None:
+ pbar.total = len(inputs) if hasattr(inputs, '__len__') else None
+ else:
+ pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None)
+
+ def decorator(fn: Callable):
+ with (
+ ThreadPoolExecutor(max_workers=num_workers) as executor,
+ pbar
+ ):
+ pbar.refresh()
+ @catch_exception
+ @suppress_traceback
+ def _fn(input):
+ ret = fn(input)
+ pbar.update()
+ return ret
+ executor.map(_fn, inputs)
+ executor.shutdown(wait=True)
+
+ return decorator
+
+
+def suppress_traceback(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ try:
+ return fn(*args, **kwargs)
+ except Exception as e:
+ e.__traceback__ = e.__traceback__.tb_next.tb_next
+ raise
+ return wrapper
+
+
+class no_warnings:
+ def __init__(self, action: str = 'ignore', **kwargs):
+ self.action = action
+ self.filter_kwargs = kwargs
+
+ def __call__(self, fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ with warnings.catch_warnings():
+ warnings.simplefilter(self.action, **self.filter_kwargs)
+ return fn(*args, **kwargs)
+ return wrapper
+
+ def __enter__(self):
+ self.warnings_manager = warnings.catch_warnings()
+ self.warnings_manager.__enter__()
+ warnings.simplefilter(self.action, **self.filter_kwargs)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
+
+
+def import_file_as_module(file_path: Union[str, os.PathLike], module_name: str):
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ return module
\ No newline at end of file
diff --git a/models/moge/utils/vis.py b/models/moge/utils/vis.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb9c2378b58ec26ac5067b7ffcbd749a8ad968ce
--- /dev/null
+++ b/models/moge/utils/vis.py
@@ -0,0 +1,65 @@
+from typing import *
+
+import numpy as np
+import matplotlib
+
+
+def colorize_depth(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
+ if mask is None:
+ depth = np.where(depth > 0, depth, np.nan)
+ else:
+ depth = np.where((depth > 0) & mask, depth, np.nan)
+ disp = 1 / depth
+ if normalize:
+ min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.99)
+ disp = (disp - min_disp) / (max_disp - min_disp)
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp)[..., :3], 0)
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
+ return colored
+
+
+def colorize_depth_affine(depth: np.ndarray, mask: np.ndarray = None, cmap: str = 'Spectral') -> np.ndarray:
+ if mask is not None:
+ depth = np.where(mask, depth, np.nan)
+
+ min_depth, max_depth = np.nanquantile(depth, 0.001), np.nanquantile(depth, 0.999)
+ depth = (depth - min_depth) / (max_depth - min_depth)
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](depth)[..., :3], 0)
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
+ return colored
+
+
+def colorize_disparity(disparity: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
+ if mask is not None:
+ disparity = np.where(mask, disparity, np.nan)
+
+ if normalize:
+ min_disp, max_disp = np.nanquantile(disparity, 0.001), np.nanquantile(disparity, 0.999)
+ disparity = (disparity - min_disp) / (max_disp - min_disp)
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disparity)[..., :3], 0)
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
+ return colored
+
+
+def colorize_segmentation(segmentation: np.ndarray, cmap: str = 'Set1') -> np.ndarray:
+ colored = matplotlib.colormaps[cmap]((segmentation % 20) / 20)[..., :3]
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
+ return colored
+
+
+def colorize_normal(normal: np.ndarray, mask: np.ndarray = None) -> np.ndarray:
+ if mask is not None:
+ normal = np.where(mask[..., None], normal, 0)
+ normal = normal * [0.5, -0.5, -0.5] + 0.5
+ normal = (normal.clip(0, 1) * 255).astype(np.uint8)
+ return normal
+
+
+def colorize_error_map(error_map: np.ndarray, mask: np.ndarray = None, cmap: str = 'plasma', value_range: Tuple[float, float] = None):
+ vmin, vmax = value_range if value_range is not None else (np.nanmin(error_map), np.nanmax(error_map))
+ cmap = matplotlib.colormaps[cmap]
+ colorized_error_map = cmap(((error_map - vmin) / (vmax - vmin)).clip(0, 1))[..., :3]
+ if mask is not None:
+ colorized_error_map = np.where(mask[..., None], colorized_error_map, 0)
+ colorized_error_map = np.ascontiguousarray((colorized_error_map.clip(0, 1) * 255).astype(np.uint8))
+ return colorized_error_map
diff --git a/models/moge/utils/webfile.py b/models/moge/utils/webfile.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e98abf8413e1c9f408849b74f4d2025d25511b6
--- /dev/null
+++ b/models/moge/utils/webfile.py
@@ -0,0 +1,73 @@
+import requests
+from typing import *
+
+__all__ = ["WebFile"]
+
+
+class WebFile:
+ def __init__(self, url: str, session: Optional[requests.Session] = None, headers: Optional[Dict[str, str]] = None, size: Optional[int] = None):
+ self.url = url
+ self.session = session or requests.Session()
+ self.session.headers.update(headers or {})
+ self._offset = 0
+ self.size = size if size is not None else self._fetch_size()
+
+ def _fetch_size(self):
+ with self.session.get(self.url, stream=True) as response:
+ response.raise_for_status()
+ content_length = response.headers.get("Content-Length")
+ if content_length is None:
+ raise ValueError("Missing Content-Length in header")
+ return int(content_length)
+
+ def _fetch_data(self, offset: int, n: int) -> bytes:
+ headers = {"Range": f"bytes={offset}-{min(offset + n - 1, self.size)}"}
+ response = self.session.get(self.url, headers=headers)
+ response.raise_for_status()
+ return response.content
+
+ def seekable(self) -> bool:
+ return True
+
+ def tell(self) -> int:
+ return self._offset
+
+ def available(self) -> int:
+ return self.size - self._offset
+
+ def seek(self, offset: int, whence: int = 0) -> None:
+ if whence == 0:
+ new_offset = offset
+ elif whence == 1:
+ new_offset = self._offset + offset
+ elif whence == 2:
+ new_offset = self.size + offset
+ else:
+ raise ValueError("Invalid value for whence")
+
+ self._offset = max(0, min(new_offset, self.size))
+
+ def read(self, n: Optional[int] = None) -> bytes:
+ if n is None or n < 0:
+ n = self.available()
+ else:
+ n = min(n, self.available())
+
+ if n == 0:
+ return b''
+
+ data = self._fetch_data(self._offset, n)
+ self._offset += len(data)
+
+ return data
+
+ def close(self) -> None:
+ pass
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ pass
+
+
\ No newline at end of file
diff --git a/models/moge/utils/webzipfile.py b/models/moge/utils/webzipfile.py
new file mode 100644
index 0000000000000000000000000000000000000000..25ed1d3cd34720335eb001d77a278539ffef569b
--- /dev/null
+++ b/models/moge/utils/webzipfile.py
@@ -0,0 +1,128 @@
+from typing import *
+import io
+import os
+from zipfile import (
+ ZipInfo, BadZipFile, ZipFile, ZipExtFile,
+ sizeFileHeader, structFileHeader, stringFileHeader,
+ _FH_SIGNATURE, _FH_FILENAME_LENGTH, _FH_EXTRA_FIELD_LENGTH, _FH_GENERAL_PURPOSE_FLAG_BITS,
+ _MASK_COMPRESSED_PATCH, _MASK_STRONG_ENCRYPTION, _MASK_UTF_FILENAME, _MASK_ENCRYPTED
+)
+import struct
+from requests import Session
+
+from .webfile import WebFile
+
+
+class _SharedWebFile(WebFile):
+ def __init__(self, webfile: WebFile, pos: int):
+ super().__init__(webfile.url, webfile.session, size=webfile.size)
+ self.seek(pos)
+
+
+class WebZipFile(ZipFile):
+ "Lock-free version of ZipFile that reads from a WebFile, allowing for concurrent reads."
+ def __init__(self, url: str, session: Optional[Session] = None, headers: Optional[Dict[str, str]] = None):
+ """Open the ZIP file with mode read 'r', write 'w', exclusive create 'x',
+ or append 'a'."""
+ webf = WebFile(url, session=session, headers=headers)
+ super().__init__(webf, mode='r')
+
+ def open(self, name, mode="r", pwd=None, *, force_zip64=False):
+ """Return file-like object for 'name'.
+
+ name is a string for the file name within the ZIP file, or a ZipInfo
+ object.
+
+ mode should be 'r' to read a file already in the ZIP file, or 'w' to
+ write to a file newly added to the archive.
+
+ pwd is the password to decrypt files (only used for reading).
+
+ When writing, if the file size is not known in advance but may exceed
+ 2 GiB, pass force_zip64 to use the ZIP64 format, which can handle large
+ files. If the size is known in advance, it is best to pass a ZipInfo
+ instance for name, with zinfo.file_size set.
+ """
+ if mode not in {"r", "w"}:
+ raise ValueError('open() requires mode "r" or "w"')
+ if pwd and (mode == "w"):
+ raise ValueError("pwd is only supported for reading files")
+ if not self.fp:
+ raise ValueError(
+ "Attempt to use ZIP archive that was already closed")
+
+ assert mode == "r", "Only read mode is supported for now"
+
+ # Make sure we have an info object
+ if isinstance(name, ZipInfo):
+ # 'name' is already an info object
+ zinfo = name
+ elif mode == 'w':
+ zinfo = ZipInfo(name)
+ zinfo.compress_type = self.compression
+ zinfo._compresslevel = self.compresslevel
+ else:
+ # Get info object for name
+ zinfo = self.getinfo(name)
+
+ if mode == 'w':
+ return self._open_to_write(zinfo, force_zip64=force_zip64)
+
+ if self._writing:
+ raise ValueError("Can't read from the ZIP file while there "
+ "is an open writing handle on it. "
+ "Close the writing handle before trying to read.")
+
+ # Open for reading:
+ self._fileRefCnt += 1
+ zef_file = _SharedWebFile(self.fp, zinfo.header_offset)
+
+ try:
+ # Skip the file header:
+ fheader = zef_file.read(sizeFileHeader)
+ if len(fheader) != sizeFileHeader:
+ raise BadZipFile("Truncated file header")
+ fheader = struct.unpack(structFileHeader, fheader)
+ if fheader[_FH_SIGNATURE] != stringFileHeader:
+ raise BadZipFile("Bad magic number for file header")
+
+ fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
+ if fheader[_FH_EXTRA_FIELD_LENGTH]:
+ zef_file.seek(fheader[_FH_EXTRA_FIELD_LENGTH], whence=1)
+
+ if zinfo.flag_bits & _MASK_COMPRESSED_PATCH:
+ # Zip 2.7: compressed patched data
+ raise NotImplementedError("compressed patched data (flag bit 5)")
+
+ if zinfo.flag_bits & _MASK_STRONG_ENCRYPTION:
+ # strong encryption
+ raise NotImplementedError("strong encryption (flag bit 6)")
+
+ if fheader[_FH_GENERAL_PURPOSE_FLAG_BITS] & _MASK_UTF_FILENAME:
+ # UTF-8 filename
+ fname_str = fname.decode("utf-8")
+ else:
+ fname_str = fname.decode(self.metadata_encoding or "cp437")
+
+ if fname_str != zinfo.orig_filename:
+ raise BadZipFile(
+ 'File name in directory %r and header %r differ.'
+ % (zinfo.orig_filename, fname))
+
+ # check for encrypted flag & handle password
+ is_encrypted = zinfo.flag_bits & _MASK_ENCRYPTED
+ if is_encrypted:
+ if not pwd:
+ pwd = self.pwd
+ if pwd and not isinstance(pwd, bytes):
+ raise TypeError("pwd: expected bytes, got %s" % type(pwd).__name__)
+ if not pwd:
+ raise RuntimeError("File %r is encrypted, password "
+ "required for extraction" % name)
+ else:
+ pwd = None
+
+ return ZipExtFile(zef_file, mode, zinfo, pwd, True)
+ except:
+ zef_file.close()
+ raise
\ No newline at end of file
diff --git a/models/monoD/depth_anything/__init__.py b/models/monoD/depth_anything/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/models/monoD/depth_anything/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/models/monoD/depth_anything/blocks.py b/models/monoD/depth_anything/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..38dbcfeffc0c38ef51bcb20dfd347e50b2a60616
--- /dev/null
+++ b/models/monoD/depth_anything/blocks.py
@@ -0,0 +1,153 @@
+import torch.nn as nn
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape*2
+ out_shape3 = out_shape*4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape*8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+
+ return scratch
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ if self.bn==True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn==True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn==True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand==True:
+ out_features = features//2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ self.size=size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = nn.functional.interpolate(
+ output, **modifier, mode="bilinear", align_corners=self.align_corners
+ )
+
+ output = self.out_conv(output)
+
+ return output
diff --git a/models/monoD/depth_anything/build.py b/models/monoD/depth_anything/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..76173f3359420fe415ea0fb236beae7b8d1442e4
--- /dev/null
+++ b/models/monoD/depth_anything/build.py
@@ -0,0 +1,100 @@
+import argparse
+import cv2
+import numpy as np
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision.transforms import Compose
+
+from models.monoD.depth_anything.dpt import DPT_DINOv2
+from models.monoD.depth_anything.util.transform import (
+ Resize, NormalizeImage, PrepareForNet
+)
+
+
+def build(config):
+ """
+ Build the model from the config
+ NOTE: the config should contain the following
+ - encoder: the encoder type of the model
+ - load_from: the path to the pretrained model
+ """
+ args = config
+ assert args.encoder in ['vits', 'vitb', 'vitl']
+ if args.encoder == 'vits':
+ depth_anything = DPT_DINOv2(encoder='vits', features=64,
+ out_channels=[48, 96, 192, 384],
+ localhub=args.localhub).cuda()
+ elif args.encoder == 'vitb':
+ depth_anything = DPT_DINOv2(encoder='vitb', features=128,
+ out_channels=[96, 192, 384, 768],
+ localhub=args.localhub).cuda()
+ else:
+ depth_anything = DPT_DINOv2(encoder='vitl', features=256,
+ out_channels=[256, 512, 1024, 1024],
+ localhub=args.localhub).cuda()
+ depth_anything.load_state_dict(torch.load(args.load_from,
+ map_location='cpu'), strict=True)
+ total_params = sum(param.numel() for param in depth_anything.parameters())
+ print('Total parameters: {:.2f}M'.format(total_params / 1e6))
+ depth_anything.eval()
+
+ return depth_anything
+
+class DepthAnything(nn.Module):
+ def __init__(self, args):
+ super(DepthAnything, self).__init__()
+
+ # build the chosen model
+ self.dpAny = build(args)
+
+ def infer(self, rgbs):
+ """
+ Infer the depth map from the input RGB image
+
+ Args:
+ rgbs: the input RGB image B x 3 x H x W (Cuda Tensor)
+
+ Asserts:
+ the input should be a cuda tensor
+ """
+ assert (rgbs.is_cuda)&(len(rgbs.shape) == 4)
+ T, C, H, W = rgbs.shape
+ # prepare the input
+ Resizer = Resize(
+ width=518,
+ height=518,
+ resize_target=False,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=14,
+ resize_method='lower_bound',
+ image_interpolation_method=cv2.INTER_CUBIC,
+ )
+ #NOTE: step 1 Resize
+ width, height = Resizer.get_size(
+ rgbs.shape[2], rgbs.shape[3]
+ )
+ rgbs = F.interpolate(
+ rgbs, (int(height), int(width)), mode='bicubic', align_corners=False
+ )
+ #NOTE: step 2 NormalizeImage
+ mean_ = torch.tensor([0.485, 0.456, 0.406],
+ device=rgbs.device).view(1, 3, 1, 1)
+ std_ = torch.tensor([0.229, 0.224, 0.225],
+ device=rgbs.device).view(1, 3, 1, 1)
+ rgbs = (rgbs - mean_)/std_
+ #NOTE: step 3 PrepareForNet
+
+ # get the depth map
+
+ disp = self.dpAny(rgbs)
+ disp = F.interpolate(
+ disp[:,None], (H, W),
+ mode='bilinear', align_corners=False
+ )
+ # clamping the farthest depth to 100x of the nearest
+ depth_map = disp
+
+ return depth_map
+
diff --git a/models/monoD/depth_anything/dpt.py b/models/monoD/depth_anything/dpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..c21f64e59e83417f732544e1234f6cfd62112beb
--- /dev/null
+++ b/models/monoD/depth_anything/dpt.py
@@ -0,0 +1,283 @@
+import torch
+import torch.nn as nn
+
+from .blocks import FeatureFusionBlock, _make_scratch
+import torch.nn.functional as F
+
+
+def _make_fusion_block(features, use_bn, size = None):
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ size=size,
+ )
+
+
+class DPTHead(nn.Module):
+ def __init__(self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False):
+ super(DPTHead, self).__init__()
+
+ self.nclass = nclass
+ self.use_clstoken = use_clstoken
+
+ self.projects = nn.ModuleList([
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ) for out_channel in out_channels
+ ])
+
+ self.resize_layers = nn.ModuleList([
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0],
+ out_channels=out_channels[0],
+ kernel_size=4,
+ stride=4,
+ padding=0),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1],
+ out_channels=out_channels[1],
+ kernel_size=2,
+ stride=2,
+ padding=0),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3],
+ out_channels=out_channels[3],
+ kernel_size=3,
+ stride=2,
+ padding=1)
+ ])
+
+ if use_clstoken:
+ self.readout_projects = nn.ModuleList()
+ for _ in range(len(self.projects)):
+ self.readout_projects.append(
+ nn.Sequential(
+ nn.Linear(2 * in_channels, in_channels),
+ nn.GELU()))
+
+ self.scratch = _make_scratch(
+ out_channels,
+ features,
+ groups=1,
+ expand=False,
+ )
+
+ self.scratch.stem_transpose = None
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ head_features_1 = features
+ head_features_2 = 32
+
+ if nclass > 1:
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
+ )
+ else:
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
+
+ self.scratch.output_conv2 = nn.Sequential(
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
+ # nn.ReLU(True),
+ # nn.Identity(),
+ )
+
+ def forward(self, out_features, patch_h, patch_w):
+ out = []
+ for i, x in enumerate(out_features):
+ if self.use_clstoken:
+ x, cls_token = x[0], x[1]
+ readout = cls_token.unsqueeze(1).expand_as(x)
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
+ else:
+ x = x[0]
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[i](x)
+ x = self.resize_layers[i](x)
+
+ out.append(x)
+
+ layer_1, layer_2, layer_3, layer_4 = out
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv1(path_1)
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
+ out = self.scratch.output_conv2(out)
+
+ return out
+
+class DPTHeadEnc(nn.Module):
+ def __init__(self, nclass, in_channels,
+ features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False, out_c = 128,):
+ super(DPTHeadEnc, self).__init__()
+
+ self.nclass = nclass
+ self.use_clstoken = use_clstoken
+
+ self.projects = nn.ModuleList([
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ) for out_channel in out_channels
+ ])
+
+ self.resize_layers = nn.ModuleList([
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0],
+ out_channels=out_channels[0],
+ kernel_size=4,
+ stride=4,
+ padding=0),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1],
+ out_channels=out_channels[1],
+ kernel_size=2,
+ stride=2,
+ padding=0),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3],
+ out_channels=out_channels[3],
+ kernel_size=3,
+ stride=2,
+ padding=1)
+ ])
+
+ if use_clstoken:
+ self.readout_projects = nn.ModuleList()
+ for _ in range(len(self.projects)):
+ self.readout_projects.append(
+ nn.Sequential(
+ nn.Linear(2 * in_channels, in_channels),
+ nn.GELU()))
+
+ self.scratch = _make_scratch(
+ out_channels,
+ features,
+ groups=1,
+ expand=False,
+ )
+
+ self.scratch.stem_transpose = None
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ head_features_1 = features
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(head_features_1, out_c, kernel_size=3, stride=1, padding=1),
+
+ )
+
+
+ def forward(self, out_features, patch_h, patch_w, enc_only=True):
+ out = []
+ for i, x in enumerate(out_features):
+ if self.use_clstoken:
+ x, cls_token = x[0], x[1]
+ readout = cls_token.unsqueeze(1).expand_as(x)
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
+ else:
+ x = x[0]
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+ x = self.projects[i](x)
+ x = self.resize_layers[i](x)
+
+ out.append(x)
+ layer_1, layer_2, layer_3, layer_4 = out
+
+ if enc_only==True:
+ layer_1_rs = F.interpolate(layer_1,
+ (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
+ layer_2_rs = F.interpolate(layer_2,
+ (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
+ layer_3_rs = F.interpolate(layer_3,
+ (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
+ layer_4_rs = F.interpolate(layer_4,
+ (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
+
+ return layer_4_rs+layer_3_rs+layer_2_rs+layer_1_rs
+ else:
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
+
+ return out
+
+class DPT_DINOv2(nn.Module):
+ def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False, localhub=True):
+ super(DPT_DINOv2, self).__init__()
+
+ assert encoder in ['vits', 'vitb', 'vitl']
+
+ # in case the Internet connection is not stable, please load the DINOv2 locally
+ if localhub:
+ self.pretrained = torch.hub.load('models/torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False)
+ else:
+ self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder))
+
+ dim = self.pretrained.blocks[0].attn.qkv.in_features
+
+ self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
+
+ def forward(self, x):
+ h, w = x.shape[-2:]
+
+ features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
+
+ patch_h, patch_w = h // 14, w // 14
+
+ depth = self.depth_head(features, patch_h, patch_w)
+ depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
+ depth = F.relu(depth)
+
+ return depth.squeeze(1)
+
+
+if __name__ == '__main__':
+ depth_anything = DPT_DINOv2()
+ depth_anything.load_state_dict(torch.load('checkpoints/depth_anything_dinov2_vitl14.pth'))
+
\ No newline at end of file
diff --git a/models/monoD/depth_anything/util/transform.py b/models/monoD/depth_anything/util/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..7beab14a9a1e18a8e46e7666fe5bdec223074155
--- /dev/null
+++ b/models/monoD/depth_anything/util/transform.py
@@ -0,0 +1,248 @@
+import random
+from PIL import Image, ImageOps, ImageFilter
+import torch
+from torchvision import transforms
+import torch.nn.functional as F
+
+import numpy as np
+import cv2
+import math
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+ Args:
+ sample (dict): sample
+ size (tuple): image size
+
+ Returns:
+ tuple: new size
+ """
+ shape = list(sample["disparity"].shape)
+
+ if shape[0] >= size[0] and shape[1] >= size[1]:
+ return sample
+
+ scale = [0, 0]
+ scale[0] = size[0] / shape[0]
+ scale[1] = size[1] / shape[1]
+
+ scale = max(scale)
+
+ shape[0] = math.ceil(scale * shape[0])
+ shape[1] = math.ceil(scale * shape[1])
+
+ # resize
+ sample["image"] = cv2.resize(
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+ )
+
+ sample["disparity"] = cv2.resize(
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ tuple(shape[::-1]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return tuple(shape)
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented"
+ )
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, min_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, min_val=self.__width
+ )
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, max_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, max_val=self.__width
+ )
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(
+ sample["image"].shape[1], sample["image"].shape[0]
+ )
+
+ # resize sample
+ sample["image"] = cv2.resize(
+ sample["image"],
+ (width, height),
+ interpolation=self.__image_interpolation_method,
+ )
+
+ if self.__resize_target:
+ if "disparity" in sample:
+ sample["disparity"] = cv2.resize(
+ sample["disparity"],
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+ )
+
+ if "semseg_mask" in sample:
+ # sample["semseg_mask"] = cv2.resize(
+ # sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
+ # )
+ sample["semseg_mask"] = F.interpolate(torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode='nearest').numpy()[0, 0]
+
+ if "mask" in sample:
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ # sample["mask"] = sample["mask"].astype(bool)
+
+ # print(sample['image'].shape, sample['depth'].shape)
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ if "semseg_mask" in sample:
+ sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
+ sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
+
+ return sample
diff --git a/models/monoD/depth_anything_v2/dinov2.py b/models/monoD/depth_anything_v2/dinov2.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e5a2dd97387b0be8108f0d4665bdf8d7b3b8ae4
--- /dev/null
+++ b/models/monoD/depth_anything_v2/dinov2.py
@@ -0,0 +1,438 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+import inspect
+
+logger = logging.getLogger("dinov2")
+
+def check_function_params(func, *params):
+ sig = inspect.signature(func)
+ return all(param in sig.parameters for param in params)
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
+ # w0, h0 = w0 + 0.1, h0 + 0.1
+
+ sqrt_N = math.sqrt(N)
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
+ scale_factor=(sx, sy),
+ # (int(w0), int(h0)), # to solve the upsampling shape issue
+ mode="bicubic",
+ antialias=self.interpolate_antialias
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2]
+ assert int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ b, nc, w, h = x.shape
+ w_p, h_p = w // self.patch_size, h // self.patch_size
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ if getattr(blk, "TloraMLP", None) is not None:
+ x = blk(x, w_p=w_p, h_p=h_p)
+ else:
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ b, nc, w, h = x.shape
+ w_p, h_p = w // self.patch_size, h // self.patch_size
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ if getattr(blk, "TloraMLP", None) is not None:
+ x = blk(x, w_p=w_p, h_p=h_p)
+ else:
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ return_register_tokens: bool = False,
+ norm=True
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+
+ if self.num_register_tokens > 0:
+ # get register tokens
+ register_tokens = [out[:, 1 : 1 + self.num_register_tokens] for out in outputs]
+
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ elif return_register_tokens:
+ return tuple(zip(outputs, register_tokens))
+
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def DINOv2(model_name):
+ model_zoo = {
+ "vits": vit_small,
+ "vitb": vit_base,
+ "vitl": vit_large,
+ "vitg": vit_giant2
+ }
+
+ return model_zoo[model_name](
+ img_size=518,
+ patch_size=14,
+ init_values=1.0,
+ ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
+ block_chunks=0,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1
+ )
diff --git a/models/monoD/depth_anything_v2/dinov2_layers/__init__.py b/models/monoD/depth_anything_v2/dinov2_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1
--- /dev/null
+++ b/models/monoD/depth_anything_v2/dinov2_layers/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/models/monoD/depth_anything_v2/dinov2_layers/attention.py b/models/monoD/depth_anything_v2/dinov2_layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..815a2bf53dbec496f6a184ed7d03bcecb7124262
--- /dev/null
+++ b/models/monoD/depth_anything_v2/dinov2_layers/attention.py
@@ -0,0 +1,83 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import memory_efficient_attention, unbind, fmha
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
\ No newline at end of file
diff --git a/models/monoD/depth_anything_v2/dinov2_layers/block.py b/models/monoD/depth_anything_v2/dinov2_layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1
--- /dev/null
+++ b/models/monoD/depth_anything_v2/dinov2_layers/block.py
@@ -0,0 +1,252 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+from typing import Callable, List, Any, Tuple, Dict
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import fmha
+ from xformers.ops import scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/models/monoD/depth_anything_v2/dinov2_layers/drop_path.py b/models/monoD/depth_anything_v2/dinov2_layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1
--- /dev/null
+++ b/models/monoD/depth_anything_v2/dinov2_layers/drop_path.py
@@ -0,0 +1,35 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/models/monoD/depth_anything_v2/dinov2_layers/layer_scale.py b/models/monoD/depth_anything_v2/dinov2_layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1
--- /dev/null
+++ b/models/monoD/depth_anything_v2/dinov2_layers/layer_scale.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/models/monoD/depth_anything_v2/dinov2_layers/mlp.py b/models/monoD/depth_anything_v2/dinov2_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018
--- /dev/null
+++ b/models/monoD/depth_anything_v2/dinov2_layers/mlp.py
@@ -0,0 +1,41 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/models/monoD/depth_anything_v2/dinov2_layers/patch_embed.py b/models/monoD/depth_anything_v2/dinov2_layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5b333999406babab7a7786d4f8148db00313b17
--- /dev/null
+++ b/models/monoD/depth_anything_v2/dinov2_layers/patch_embed.py
@@ -0,0 +1,90 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/models/monoD/depth_anything_v2/dinov2_layers/swiglu_ffn.py b/models/monoD/depth_anything_v2/dinov2_layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e
--- /dev/null
+++ b/models/monoD/depth_anything_v2/dinov2_layers/swiglu_ffn.py
@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+try:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/models/monoD/depth_anything_v2/dpt.py b/models/monoD/depth_anything_v2/dpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..87decfe7b3398904a4754a77636cb3080b85198e
--- /dev/null
+++ b/models/monoD/depth_anything_v2/dpt.py
@@ -0,0 +1,249 @@
+import cv2
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision.transforms import Compose
+
+from .dinov2 import DINOv2
+from .util.blocks import FeatureFusionBlock, _make_scratch
+from .util.transform import Resize, NormalizeImage, PrepareForNet
+
+
+def _make_fusion_block(features, use_bn, size=None):
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ size=size,
+ )
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_feature, out_feature):
+ super().__init__()
+
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(out_feature),
+ nn.ReLU(True)
+ )
+
+ def forward(self, x):
+ return self.conv_block(x)
+
+class DPTHead(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ features=256,
+ use_bn=False,
+ out_channels=[256, 512, 1024, 1024],
+ use_clstoken=False
+ ):
+ super(DPTHead, self).__init__()
+
+ self.use_clstoken = use_clstoken
+
+ self.projects = nn.ModuleList([
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ) for out_channel in out_channels
+ ])
+
+ self.resize_layers = nn.ModuleList([
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0],
+ out_channels=out_channels[0],
+ kernel_size=4,
+ stride=4,
+ padding=0),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1],
+ out_channels=out_channels[1],
+ kernel_size=2,
+ stride=2,
+ padding=0),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3],
+ out_channels=out_channels[3],
+ kernel_size=3,
+ stride=2,
+ padding=1)
+ ])
+
+ if use_clstoken:
+ self.readout_projects = nn.ModuleList()
+ for _ in range(len(self.projects)):
+ self.readout_projects.append(
+ nn.Sequential(
+ nn.Linear(2 * in_channels, in_channels),
+ nn.GELU()))
+
+ self.scratch = _make_scratch(
+ out_channels,
+ features,
+ groups=1,
+ expand=False,
+ )
+
+ self.scratch.stem_transpose = None
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ head_features_1 = features
+ head_features_2 = 32
+
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
+ self.scratch.output_conv2 = nn.Sequential(
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True),
+ nn.Identity(),
+ )
+ self.scratch.output_conv3 = nn.Sequential(
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
+ nn.Sigmoid(),
+ )
+ self.scratch.output_conv4 = nn.Sequential(
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
+ nn.Sigmoid(),
+ )
+
+ def _initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ torch.nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ torch.nn.init.xavier_uniform_(m.weight)
+ torch.nn.init.constant_(m.bias, 0)
+
+ def forward(self, out_features, patch_h,
+ patch_w, last_feat=False, Bs=1, ego_cond=None, track_cond=None):
+ out = []
+ for i, x in enumerate(out_features):
+ if self.use_clstoken:
+ x, cls_token = x[0], x[1]
+ readout = cls_token.unsqueeze(1).expand_as(x)
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
+ else:
+ x = x[0]
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[i](x)
+ x = self.resize_layers[i](x)
+
+ out.append(x)
+
+ layer_1, layer_2, layer_3, layer_4 = out
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+ out = self.scratch.output_conv1(path_1)
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
+
+ if last_feat: # only adapt on the last output layer
+ out_extra = out.clone()
+ unc_metric = self.scratch.output_conv3(out)
+ rel_out = self.scratch.output_conv2(out)
+ dyn_prob = self.scratch.output_conv4(out)
+
+ if last_feat:
+ return unc_metric, rel_out, out_extra, dyn_prob
+ else:
+ return unc_metric, rel_out, dyn_prob
+
+
+class DepthAnythingV2(nn.Module):
+
+ def __init__(
+ self,
+ encoder='vitl',
+ features=256,
+ out_channels=[256, 512, 1024, 1024],
+ use_bn=False,
+ use_clstoken=False,
+ max_depth=65.0
+ ):
+ super(DepthAnythingV2, self).__init__()
+
+ self.intermediate_layer_idx = {
+ 'vits': [2, 5, 8, 11],
+ 'vitb': [2, 5, 8, 11],
+ 'vitl': [4, 11, 17, 23],
+ 'vitg': [9, 19, 29, 39]
+ }
+ self.max_depth = max_depth
+ self.encoder = encoder
+ self.pretrained = DINOv2(model_name=encoder)
+
+ self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
+
+ def forward(self, x):
+ patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
+ features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
+ depth = self.depth_head(features, patch_h, patch_w) * self.max_depth
+ # depth = F.relu(depth)
+
+ return depth.squeeze(1)
+
+ @torch.no_grad()
+ def infer_image(self, raw_image, input_size=518):
+ image, (h, w) = self.image2tensor(raw_image, input_size)
+ depth = self.forward(image)
+
+ depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
+
+ return depth.cpu().numpy()
+
+ def image2tensor(self, raw_image, input_size=518):
+ transform = Compose([
+ Resize(
+ width=input_size,
+ height=input_size,
+ resize_target=False,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=14,
+ resize_method='lower_bound',
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ PrepareForNet(),
+ ])
+
+ h, w = raw_image.shape[:2]
+
+ image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
+
+ image = transform({'image': image})['image']
+ image = torch.from_numpy(image).unsqueeze(0)
+
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
+ image = image.to(DEVICE)
+
+ return image, (h, w)
diff --git a/models/monoD/depth_anything_v2/util/blocks.py b/models/monoD/depth_anything_v2/util/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..382ea183a40264056142afffc201c992a2b01d37
--- /dev/null
+++ b/models/monoD/depth_anything_v2/util/blocks.py
@@ -0,0 +1,148 @@
+import torch.nn as nn
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+
+ return scratch
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ if self.bn == True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn == True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn == True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=None
+ ):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand == True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ self.size=size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
+
+ output = self.out_conv(output)
+
+ return output
diff --git a/models/monoD/depth_anything_v2/util/transform.py b/models/monoD/depth_anything_v2/util/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..d670b20cba7bcaf76f1f10c495b9cfa724da7b74
--- /dev/null
+++ b/models/monoD/depth_anything_v2/util/transform.py
@@ -0,0 +1,156 @@
+import numpy as np
+import cv2
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
+ new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
+ new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
+
+ # resize sample
+ sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
+ if self.__resize_target:
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
+
+ if "mask" in sample:
+ sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ return sample
\ No newline at end of file
diff --git a/models/monoD/depth_pro/__init__.py b/models/monoD/depth_pro/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..52080b686a64851c4bb62003884fbdeb55dced9a
--- /dev/null
+++ b/models/monoD/depth_pro/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (C) 2024 Apple Inc. All Rights Reserved.
+"""Depth Pro package."""
+
+from .depth_pro import create_model_and_transforms # noqa
+from .utils import load_rgb # noqa
diff --git a/models/monoD/depth_pro/cli/__init__.py b/models/monoD/depth_pro/cli/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..54ac5722c5db5e9a6846f12fea9efc00f3e385e5
--- /dev/null
+++ b/models/monoD/depth_pro/cli/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (C) 2024 Apple Inc. All Rights Reserved.
+"""Depth Pro CLI and tools."""
+
+from .run import main as run_main # noqa
diff --git a/models/monoD/depth_pro/cli/run.py b/models/monoD/depth_pro/cli/run.py
new file mode 100755
index 0000000000000000000000000000000000000000..3545a99993810b7d602f63c057640645b215f2b2
--- /dev/null
+++ b/models/monoD/depth_pro/cli/run.py
@@ -0,0 +1,154 @@
+#!/usr/bin/env python3
+"""Sample script to run DepthPro.
+
+Copyright (C) 2024 Apple Inc. All Rights Reserved.
+"""
+
+
+import argparse
+import logging
+from pathlib import Path
+
+import numpy as np
+import PIL.Image
+import torch
+from matplotlib import pyplot as plt
+from tqdm import tqdm
+
+from depth_pro import create_model_and_transforms, load_rgb
+
+LOGGER = logging.getLogger(__name__)
+
+
+def get_torch_device() -> torch.device:
+ """Get the Torch device."""
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda:0")
+ elif torch.backends.mps.is_available():
+ device = torch.device("mps")
+ return device
+
+
+def run(args):
+ """Run Depth Pro on a sample image."""
+ if args.verbose:
+ logging.basicConfig(level=logging.INFO)
+
+ # Load model.
+ model, transform = create_model_and_transforms(
+ device=get_torch_device(),
+ precision=torch.half,
+ )
+ model.eval()
+
+ image_paths = [args.image_path]
+ if args.image_path.is_dir():
+ image_paths = args.image_path.glob("**/*")
+ relative_path = args.image_path
+ else:
+ relative_path = args.image_path.parent
+
+ if not args.skip_display:
+ plt.ion()
+ fig = plt.figure()
+ ax_rgb = fig.add_subplot(121)
+ ax_disp = fig.add_subplot(122)
+
+ for image_path in tqdm(image_paths):
+ # Load image and focal length from exif info (if found.).
+ try:
+ LOGGER.info(f"Loading image {image_path} ...")
+ image, _, f_px = load_rgb(image_path)
+ except Exception as e:
+ LOGGER.error(str(e))
+ continue
+ # Run prediction. If `f_px` is provided, it is used to estimate the final metric depth,
+ # otherwise the model estimates `f_px` to compute the depth metricness.
+ prediction = model.infer(transform(image), f_px=f_px)
+
+ # Extract the depth and focal length.
+ depth = prediction["depth"].detach().cpu().numpy().squeeze()
+ if f_px is not None:
+ LOGGER.debug(f"Focal length (from exif): {f_px:0.2f}")
+ elif prediction["focallength_px"] is not None:
+ focallength_px = prediction["focallength_px"].detach().cpu().item()
+ LOGGER.info(f"Estimated focal length: {focallength_px}")
+
+ inverse_depth = 1 / depth
+ # Visualize inverse depth instead of depth, clipped to [0.1m;250m] range for better visualization.
+ max_invdepth_vizu = min(inverse_depth.max(), 1 / 0.1)
+ min_invdepth_vizu = max(1 / 250, inverse_depth.min())
+ inverse_depth_normalized = (inverse_depth - min_invdepth_vizu) / (
+ max_invdepth_vizu - min_invdepth_vizu
+ )
+
+ # Save Depth as npz file.
+ if args.output_path is not None:
+ output_file = (
+ args.output_path
+ / image_path.relative_to(relative_path).parent
+ / image_path.stem
+ )
+ LOGGER.info(f"Saving depth map to: {str(output_file)}")
+ output_file.parent.mkdir(parents=True, exist_ok=True)
+ np.savez_compressed(output_file, depth=depth)
+
+ # Save as color-mapped "turbo" jpg image.
+ cmap = plt.get_cmap("turbo")
+ color_depth = (cmap(inverse_depth_normalized)[..., :3] * 255).astype(
+ np.uint8
+ )
+ color_map_output_file = str(output_file) + ".jpg"
+ LOGGER.info(f"Saving color-mapped depth to: : {color_map_output_file}")
+ PIL.Image.fromarray(color_depth).save(
+ color_map_output_file, format="JPEG", quality=90
+ )
+
+ # Display the image and estimated depth map.
+ if not args.skip_display:
+ ax_rgb.imshow(image)
+ ax_disp.imshow(inverse_depth_normalized, cmap="turbo")
+ fig.canvas.draw()
+ fig.canvas.flush_events()
+
+ LOGGER.info("Done predicting depth!")
+ if not args.skip_display:
+ plt.show(block=True)
+
+
+def main():
+ """Run DepthPro inference example."""
+ parser = argparse.ArgumentParser(
+ description="Inference scripts of DepthPro with PyTorch models."
+ )
+ parser.add_argument(
+ "-i",
+ "--image-path",
+ type=Path,
+ default="./data/example.jpg",
+ help="Path to input image.",
+ )
+ parser.add_argument(
+ "-o",
+ "--output-path",
+ type=Path,
+ help="Path to store output files.",
+ )
+ parser.add_argument(
+ "--skip-display",
+ action="store_true",
+ help="Skip matplotlib display.",
+ )
+ parser.add_argument(
+ "-v",
+ "--verbose",
+ action="store_true",
+ help="Show verbose output."
+ )
+
+ run(parser.parse_args())
+
+
+if __name__ == "__main__":
+ main()
diff --git a/models/monoD/depth_pro/depth_pro.py b/models/monoD/depth_pro/depth_pro.py
new file mode 100644
index 0000000000000000000000000000000000000000..f31b4e16178c5e29a3ffcd1a2366fe585bc9c370
--- /dev/null
+++ b/models/monoD/depth_pro/depth_pro.py
@@ -0,0 +1,298 @@
+# Copyright (C) 2024 Apple Inc. All Rights Reserved.
+# Depth Pro: Sharp Monocular Metric Depth in Less Than a Second
+
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Mapping, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torchvision.transforms import (
+ Compose,
+ ConvertImageDtype,
+ Lambda,
+ Normalize,
+ ToTensor,
+)
+
+from .network.decoder import MultiresConvDecoder
+from .network.encoder import DepthProEncoder
+from .network.fov import FOVNetwork
+from .network.vit_factory import VIT_CONFIG_DICT, ViTPreset, create_vit
+
+
+@dataclass
+class DepthProConfig:
+ """Configuration for DepthPro."""
+
+ patch_encoder_preset: ViTPreset
+ image_encoder_preset: ViTPreset
+ decoder_features: int
+
+ checkpoint_uri: Optional[str] = None
+ fov_encoder_preset: Optional[ViTPreset] = None
+ use_fov_head: bool = True
+
+
+DEFAULT_MONODEPTH_CONFIG_DICT = DepthProConfig(
+ patch_encoder_preset="dinov2l16_384",
+ image_encoder_preset="dinov2l16_384",
+ checkpoint_uri="./checkpoints/depth_pro.pt",
+ decoder_features=256,
+ use_fov_head=True,
+ fov_encoder_preset="dinov2l16_384",
+)
+
+
+def create_backbone_model(
+ preset: ViTPreset
+) -> Tuple[nn.Module, ViTPreset]:
+ """Create and load a backbone model given a config.
+
+ Args:
+ ----
+ preset: A backbone preset to load pre-defind configs.
+
+ Returns:
+ -------
+ A Torch module and the associated config.
+
+ """
+ if preset in VIT_CONFIG_DICT:
+ config = VIT_CONFIG_DICT[preset]
+ model = create_vit(preset=preset, use_pretrained=False)
+ else:
+ raise KeyError(f"Preset {preset} not found.")
+
+ return model, config
+
+
+def create_model_and_transforms(
+ config: DepthProConfig = DEFAULT_MONODEPTH_CONFIG_DICT,
+ device: torch.device = torch.device("cpu"),
+ precision: torch.dtype = torch.float32,
+) -> Tuple[DepthPro, Compose]:
+ """Create a DepthPro model and load weights from `config.checkpoint_uri`.
+
+ Args:
+ ----
+ config: The configuration for the DPT model architecture.
+ device: The optional Torch device to load the model onto, default runs on "cpu".
+ precision: The optional precision used for the model, default is FP32.
+
+ Returns:
+ -------
+ The Torch DepthPro model and associated Transform.
+
+ """
+ patch_encoder, patch_encoder_config = create_backbone_model(
+ preset=config.patch_encoder_preset
+ )
+ image_encoder, _ = create_backbone_model(
+ preset=config.image_encoder_preset
+ )
+
+ fov_encoder = None
+ if config.use_fov_head and config.fov_encoder_preset is not None:
+ fov_encoder, _ = create_backbone_model(preset=config.fov_encoder_preset)
+
+ dims_encoder = patch_encoder_config.encoder_feature_dims
+ hook_block_ids = patch_encoder_config.encoder_feature_layer_ids
+ encoder = DepthProEncoder(
+ dims_encoder=dims_encoder,
+ patch_encoder=patch_encoder,
+ image_encoder=image_encoder,
+ hook_block_ids=hook_block_ids,
+ decoder_features=config.decoder_features,
+ )
+ decoder = MultiresConvDecoder(
+ dims_encoder=[config.decoder_features] + list(encoder.dims_encoder),
+ dim_decoder=config.decoder_features,
+ )
+ model = DepthPro(
+ encoder=encoder,
+ decoder=decoder,
+ last_dims=(32, 1),
+ use_fov_head=config.use_fov_head,
+ fov_encoder=fov_encoder,
+ ).to(device)
+
+ if precision == torch.half:
+ model.half()
+
+ transform = Compose(
+ [
+ ToTensor(),
+ Lambda(lambda x: x.to(device)),
+ Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
+ ConvertImageDtype(precision),
+ ]
+ )
+
+ if config.checkpoint_uri is not None:
+ state_dict = torch.load(config.checkpoint_uri, map_location="cpu")
+ missing_keys, unexpected_keys = model.load_state_dict(
+ state_dict=state_dict, strict=True
+ )
+
+ if len(unexpected_keys) != 0:
+ raise KeyError(
+ f"Found unexpected keys when loading monodepth: {unexpected_keys}"
+ )
+
+ # fc_norm is only for the classification head,
+ # which we would not use. We only use the encoding.
+ missing_keys = [key for key in missing_keys if "fc_norm" not in key]
+ if len(missing_keys) != 0:
+ raise KeyError(f"Keys are missing when loading monodepth: {missing_keys}")
+
+ return model, transform
+
+
+class DepthPro(nn.Module):
+ """DepthPro network."""
+
+ def __init__(
+ self,
+ encoder: DepthProEncoder,
+ decoder: MultiresConvDecoder,
+ last_dims: tuple[int, int],
+ use_fov_head: bool = True,
+ fov_encoder: Optional[nn.Module] = None,
+ ):
+ """Initialize DepthPro.
+
+ Args:
+ ----
+ encoder: The DepthProEncoder backbone.
+ decoder: The MultiresConvDecoder decoder.
+ last_dims: The dimension for the last convolution layers.
+ use_fov_head: Whether to use the field-of-view head.
+ fov_encoder: A separate encoder for the field of view.
+
+ """
+ super().__init__()
+
+ self.encoder = encoder
+ self.decoder = decoder
+
+ dim_decoder = decoder.dim_decoder
+ self.head = nn.Sequential(
+ nn.Conv2d(
+ dim_decoder, dim_decoder // 2, kernel_size=3, stride=1, padding=1
+ ),
+ nn.ConvTranspose2d(
+ in_channels=dim_decoder // 2,
+ out_channels=dim_decoder // 2,
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ ),
+ nn.Conv2d(
+ dim_decoder // 2,
+ last_dims[0],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ nn.ReLU(True),
+ nn.Conv2d(last_dims[0], last_dims[1], kernel_size=1, stride=1, padding=0),
+ nn.ReLU(),
+ )
+
+ # Set the final convolution layer's bias to be 0.
+ self.head[4].bias.data.fill_(0)
+
+ # Set the FOV estimation head.
+ if use_fov_head:
+ self.fov = FOVNetwork(num_features=dim_decoder, fov_encoder=fov_encoder)
+
+ @property
+ def img_size(self) -> int:
+ """Return the internal image size of the network."""
+ return self.encoder.img_size
+
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Decode by projection and fusion of multi-resolution encodings.
+
+ Args:
+ ----
+ x (torch.Tensor): Input image.
+
+ Returns:
+ -------
+ The canonical inverse depth map [m] and the optional estimated field of view [deg].
+
+ """
+ _, _, H, W = x.shape
+ assert H == self.img_size and W == self.img_size
+
+ encodings = self.encoder(x)
+ features, features_0 = self.decoder(encodings)
+ canonical_inverse_depth = self.head(features)
+
+ fov_deg = None
+ if hasattr(self, "fov"):
+ fov_deg = self.fov.forward(x, features_0.detach())
+
+ return canonical_inverse_depth, fov_deg
+
+ @torch.no_grad()
+ def infer(
+ self,
+ x: torch.Tensor,
+ f_px: Optional[Union[float, torch.Tensor]] = None,
+ interpolation_mode="bilinear",
+ ) -> Mapping[str, torch.Tensor]:
+ """Infer depth and fov for a given image.
+
+ If the image is not at network resolution, it is resized to 1536x1536 and
+ the estimated depth is resized to the original image resolution.
+ Note: if the focal length is given, the estimated value is ignored and the provided
+ focal length is use to generate the metric depth values.
+
+ Args:
+ ----
+ x (torch.Tensor): Input image
+ f_px (torch.Tensor): Optional focal length in pixels corresponding to `x`.
+ interpolation_mode (str): Interpolation function for downsampling/upsampling.
+
+ Returns:
+ -------
+ Tensor dictionary (torch.Tensor): depth [m], focallength [pixels].
+
+ """
+ if len(x.shape) == 3:
+ x = x.unsqueeze(0)
+ _, _, H, W = x.shape
+ resize = H != self.img_size or W != self.img_size
+
+ if resize:
+ x = nn.functional.interpolate(
+ x,
+ size=(self.img_size, self.img_size),
+ mode=interpolation_mode,
+ align_corners=False,
+ )
+
+ canonical_inverse_depth, fov_deg = self.forward(x)
+ if f_px is None:
+ f_px = 0.5 * W / torch.tan(0.5 * torch.deg2rad(fov_deg.to(torch.float)))
+
+ inverse_depth = canonical_inverse_depth * (W / f_px)
+ f_px = f_px.squeeze()
+
+ if resize:
+ inverse_depth = nn.functional.interpolate(
+ inverse_depth, size=(H, W), mode=interpolation_mode, align_corners=False
+ )
+
+ depth = 1.0 / torch.clamp(inverse_depth, min=1e-4, max=1e4)
+
+ return {
+ "depth": depth.squeeze(),
+ "focallength_px": f_px,
+ }
diff --git a/models/monoD/depth_pro/eval/boundary_metrics.py b/models/monoD/depth_pro/eval/boundary_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7650dbb60b990ed66b4444a1bbb7f7eaaed1390
--- /dev/null
+++ b/models/monoD/depth_pro/eval/boundary_metrics.py
@@ -0,0 +1,332 @@
+from typing import List, Tuple
+
+import numpy as np
+
+
+def connected_component(r: np.ndarray, c: np.ndarray) -> List[List[int]]:
+ """Find connected components in the given row and column indices.
+
+ Args:
+ ----
+ r (np.ndarray): Row indices.
+ c (np.ndarray): Column indices.
+
+ Yields:
+ ------
+ List[int]: Indices of connected components.
+
+ """
+ indices = [0]
+ for i in range(1, r.size):
+ if r[i] == r[indices[-1]] and c[i] == c[indices[-1]] + 1:
+ indices.append(i)
+ else:
+ yield indices
+ indices = [i]
+ yield indices
+
+
+def nms_horizontal(ratio: np.ndarray, threshold: float) -> np.ndarray:
+ """Apply Non-Maximum Suppression (NMS) horizontally on the given ratio matrix.
+
+ Args:
+ ----
+ ratio (np.ndarray): Input ratio matrix.
+ threshold (float): Threshold for NMS.
+
+ Returns:
+ -------
+ np.ndarray: Binary mask after applying NMS.
+
+ """
+ mask = np.zeros_like(ratio, dtype=bool)
+ r, c = np.nonzero(ratio > threshold)
+ if len(r) == 0:
+ return mask
+ for ids in connected_component(r, c):
+ values = [ratio[r[i], c[i]] for i in ids]
+ mi = np.argmax(values)
+ mask[r[ids[mi]], c[ids[mi]]] = True
+ return mask
+
+
+def nms_vertical(ratio: np.ndarray, threshold: float) -> np.ndarray:
+ """Apply Non-Maximum Suppression (NMS) vertically on the given ratio matrix.
+
+ Args:
+ ----
+ ratio (np.ndarray): Input ratio matrix.
+ threshold (float): Threshold for NMS.
+
+ Returns:
+ -------
+ np.ndarray: Binary mask after applying NMS.
+
+ """
+ return np.transpose(nms_horizontal(np.transpose(ratio), threshold))
+
+
+def fgbg_depth(
+ d: np.ndarray, t: float
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ """Find foreground-background relations between neighboring pixels.
+
+ Args:
+ ----
+ d (np.ndarray): Depth matrix.
+ t (float): Threshold for comparison.
+
+ Returns:
+ -------
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
+ left, top, right, and bottom foreground-background relations.
+
+ """
+ right_is_big_enough = (d[..., :, 1:] / d[..., :, :-1]) > t
+ left_is_big_enough = (d[..., :, :-1] / d[..., :, 1:]) > t
+ bottom_is_big_enough = (d[..., 1:, :] / d[..., :-1, :]) > t
+ top_is_big_enough = (d[..., :-1, :] / d[..., 1:, :]) > t
+ return (
+ left_is_big_enough,
+ top_is_big_enough,
+ right_is_big_enough,
+ bottom_is_big_enough,
+ )
+
+
+def fgbg_depth_thinned(
+ d: np.ndarray, t: float
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ """Find foreground-background relations between neighboring pixels with Non-Maximum Suppression.
+
+ Args:
+ ----
+ d (np.ndarray): Depth matrix.
+ t (float): Threshold for NMS.
+
+ Returns:
+ -------
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
+ left, top, right, and bottom foreground-background relations with NMS applied.
+
+ """
+ right_is_big_enough = nms_horizontal(d[..., :, 1:] / d[..., :, :-1], t)
+ left_is_big_enough = nms_horizontal(d[..., :, :-1] / d[..., :, 1:], t)
+ bottom_is_big_enough = nms_vertical(d[..., 1:, :] / d[..., :-1, :], t)
+ top_is_big_enough = nms_vertical(d[..., :-1, :] / d[..., 1:, :], t)
+ return (
+ left_is_big_enough,
+ top_is_big_enough,
+ right_is_big_enough,
+ bottom_is_big_enough,
+ )
+
+
+def fgbg_binary_mask(
+ d: np.ndarray,
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ """Find foreground-background relations between neighboring pixels in binary masks.
+
+ Args:
+ ----
+ d (np.ndarray): Binary depth matrix.
+
+ Returns:
+ -------
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
+ left, top, right, and bottom foreground-background relations in binary masks.
+
+ """
+ assert d.dtype == bool
+ right_is_big_enough = d[..., :, 1:] & ~d[..., :, :-1]
+ left_is_big_enough = d[..., :, :-1] & ~d[..., :, 1:]
+ bottom_is_big_enough = d[..., 1:, :] & ~d[..., :-1, :]
+ top_is_big_enough = d[..., :-1, :] & ~d[..., 1:, :]
+ return (
+ left_is_big_enough,
+ top_is_big_enough,
+ right_is_big_enough,
+ bottom_is_big_enough,
+ )
+
+
+def edge_recall_matting(pr: np.ndarray, gt: np.ndarray, t: float) -> float:
+ """Calculate edge recall for image matting.
+
+ Args:
+ ----
+ pr (np.ndarray): Predicted depth matrix.
+ gt (np.ndarray): Ground truth binary mask.
+ t (float): Threshold for NMS.
+
+ Returns:
+ -------
+ float: Edge recall value.
+
+ """
+ assert gt.dtype == bool
+ ap, bp, cp, dp = fgbg_depth_thinned(pr, t)
+ ag, bg, cg, dg = fgbg_binary_mask(gt)
+ return 0.25 * (
+ np.count_nonzero(ap & ag) / max(np.count_nonzero(ag), 1)
+ + np.count_nonzero(bp & bg) / max(np.count_nonzero(bg), 1)
+ + np.count_nonzero(cp & cg) / max(np.count_nonzero(cg), 1)
+ + np.count_nonzero(dp & dg) / max(np.count_nonzero(dg), 1)
+ )
+
+
+def boundary_f1(
+ pr: np.ndarray,
+ gt: np.ndarray,
+ t: float,
+ return_p: bool = False,
+ return_r: bool = False,
+) -> float:
+ """Calculate Boundary F1 score.
+
+ Args:
+ ----
+ pr (np.ndarray): Predicted depth matrix.
+ gt (np.ndarray): Ground truth depth matrix.
+ t (float): Threshold for comparison.
+ return_p (bool, optional): If True, return precision. Defaults to False.
+ return_r (bool, optional): If True, return recall. Defaults to False.
+
+ Returns:
+ -------
+ float: Boundary F1 score, or precision, or recall depending on the flags.
+
+ """
+ ap, bp, cp, dp = fgbg_depth(pr, t)
+ ag, bg, cg, dg = fgbg_depth(gt, t)
+
+ r = 0.25 * (
+ np.count_nonzero(ap & ag) / max(np.count_nonzero(ag), 1)
+ + np.count_nonzero(bp & bg) / max(np.count_nonzero(bg), 1)
+ + np.count_nonzero(cp & cg) / max(np.count_nonzero(cg), 1)
+ + np.count_nonzero(dp & dg) / max(np.count_nonzero(dg), 1)
+ )
+ p = 0.25 * (
+ np.count_nonzero(ap & ag) / max(np.count_nonzero(ap), 1)
+ + np.count_nonzero(bp & bg) / max(np.count_nonzero(bp), 1)
+ + np.count_nonzero(cp & cg) / max(np.count_nonzero(cp), 1)
+ + np.count_nonzero(dp & dg) / max(np.count_nonzero(dp), 1)
+ )
+ if r + p == 0:
+ return 0.0
+ if return_p:
+ return p
+ if return_r:
+ return r
+ return 2 * (r * p) / (r + p)
+
+
+def get_thresholds_and_weights(
+ t_min: float, t_max: float, N: int
+) -> Tuple[np.ndarray, np.ndarray]:
+ """Generate thresholds and weights for the given range.
+
+ Args:
+ ----
+ t_min (float): Minimum threshold.
+ t_max (float): Maximum threshold.
+ N (int): Number of thresholds.
+
+ Returns:
+ -------
+ Tuple[np.ndarray, np.ndarray]: Array of thresholds and corresponding weights.
+
+ """
+ thresholds = np.linspace(t_min, t_max, N)
+ weights = thresholds / thresholds.sum()
+ return thresholds, weights
+
+
+def invert_depth(depth: np.ndarray, eps: float = 1e-6) -> np.ndarray:
+ """Inverts a depth map with numerical stability.
+
+ Args:
+ ----
+ depth (np.ndarray): Depth map to be inverted.
+ eps (float): Minimum value to avoid division by zero (default is 1e-6).
+
+ Returns:
+ -------
+ np.ndarray: Inverted depth map.
+
+ """
+ inverse_depth = 1.0 / depth.clip(min=eps)
+ return inverse_depth
+
+
+def SI_boundary_F1(
+ predicted_depth: np.ndarray,
+ target_depth: np.ndarray,
+ t_min: float = 1.05,
+ t_max: float = 1.25,
+ N: int = 10,
+) -> float:
+ """Calculate Scale-Invariant Boundary F1 Score for depth-based ground-truth.
+
+ Args:
+ ----
+ predicted_depth (np.ndarray): Predicted depth matrix.
+ target_depth (np.ndarray): Ground truth depth matrix.
+ t_min (float, optional): Minimum threshold. Defaults to 1.05.
+ t_max (float, optional): Maximum threshold. Defaults to 1.25.
+ N (int, optional): Number of thresholds. Defaults to 10.
+
+ Returns:
+ -------
+ float: Scale-Invariant Boundary F1 Score.
+
+ """
+ assert predicted_depth.ndim == target_depth.ndim == 2
+ thresholds, weights = get_thresholds_and_weights(t_min, t_max, N)
+ f1_scores = np.array(
+ [
+ boundary_f1(invert_depth(predicted_depth), invert_depth(target_depth), t)
+ for t in thresholds
+ ]
+ )
+ return np.sum(f1_scores * weights)
+
+
+def SI_boundary_Recall(
+ predicted_depth: np.ndarray,
+ target_mask: np.ndarray,
+ t_min: float = 1.05,
+ t_max: float = 1.25,
+ N: int = 10,
+ alpha_threshold: float = 0.1,
+) -> float:
+ """Calculate Scale-Invariant Boundary Recall Score for mask-based ground-truth.
+
+ Args:
+ ----
+ predicted_depth (np.ndarray): Predicted depth matrix.
+ target_mask (np.ndarray): Ground truth binary mask.
+ t_min (float, optional): Minimum threshold. Defaults to 1.05.
+ t_max (float, optional): Maximum threshold. Defaults to 1.25.
+ N (int, optional): Number of thresholds. Defaults to 10.
+ alpha_threshold (float, optional): Threshold for alpha masking. Defaults to 0.1.
+
+ Returns:
+ -------
+ float: Scale-Invariant Boundary Recall Score.
+
+ """
+ assert predicted_depth.ndim == target_mask.ndim == 2
+ thresholds, weights = get_thresholds_and_weights(t_min, t_max, N)
+ thresholded_target = target_mask > alpha_threshold
+
+ recall_scores = np.array(
+ [
+ edge_recall_matting(
+ invert_depth(predicted_depth), thresholded_target, t=float(t)
+ )
+ for t in thresholds
+ ]
+ )
+ weighted_recall = np.sum(recall_scores * weights)
+ return weighted_recall
diff --git a/models/monoD/depth_pro/eval/dis5k_sample_list.txt b/models/monoD/depth_pro/eval/dis5k_sample_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..81da1dcdb786da7bbec861604f50d4f039f695ef
--- /dev/null
+++ b/models/monoD/depth_pro/eval/dis5k_sample_list.txt
@@ -0,0 +1,200 @@
+DIS5K/DIS-TE1/im/12#Graphics#4#TrafficSign#8245751856_821be14f86_o.jpg
+DIS5K/DIS-TE1/im/13#Insect#4#Butterfly#16023994688_7ff8cdccb1_o.jpg
+DIS5K/DIS-TE1/im/14#Kitchenware#4#Kitchenware#IMG_20210520_205538.jpg
+DIS5K/DIS-TE1/im/14#Kitchenware#8#SweetStand#4848284981_fc90f54b50_o.jpg
+DIS5K/DIS-TE1/im/17#Non-motor Vehicle#4#Cart#15012855035_d10b57014f_o.jpg
+DIS5K/DIS-TE1/im/2#Aircraft#5#Kite#13104545564_5afceec9bd_o.jpg
+DIS5K/DIS-TE1/im/20#Sports#10#Skateboarding#8472763540_bb2390e928_o.jpg
+DIS5K/DIS-TE1/im/21#Tool#14#Sword#32473146960_dcc6b77848_o.jpg
+DIS5K/DIS-TE1/im/21#Tool#15#Tapeline#9680492386_2d2020f282_o.jpg
+DIS5K/DIS-TE1/im/21#Tool#4#Flag#507752845_ef852100f0_o.jpg
+DIS5K/DIS-TE1/im/21#Tool#6#Key#11966089533_3becd78b44_o.jpg
+DIS5K/DIS-TE1/im/21#Tool#8#Scale#31946428472_d28def471b_o.jpg
+DIS5K/DIS-TE1/im/22#Weapon#4#Rifle#8472656430_3eb908b211_o.jpg
+DIS5K/DIS-TE1/im/8#Electronics#3#Earphone#1177468301_641df8c267_o.jpg
+DIS5K/DIS-TE1/im/8#Electronics#9#MusicPlayer#2235782872_7d47847bb4_o.jpg
+DIS5K/DIS-TE2/im/11#Furniture#13#Ladder#3878434417_2ed740586e_o.jpg
+DIS5K/DIS-TE2/im/13#Insect#1#Ant#27047700955_3b3a1271f8_o.jpg
+DIS5K/DIS-TE2/im/13#Insect#11#Spider#5567179191_38d1f65589_o.jpg
+DIS5K/DIS-TE2/im/13#Insect#8#Locust#5237933769_e6687c05e4_o.jpg
+DIS5K/DIS-TE2/im/14#Kitchenware#2#DishRack#70838854_40cf689da7_o.jpg
+DIS5K/DIS-TE2/im/14#Kitchenware#8#SweetStand#8467929412_fef7f4275d_o.jpg
+DIS5K/DIS-TE2/im/16#Music Instrument#2#Harp#28058219806_28e05ff24a_o.jpg
+DIS5K/DIS-TE2/im/17#Non-motor Vehicle#1#BabyCarriage#29794777180_2e1695a0cf_o.jpg
+DIS5K/DIS-TE2/im/19#Ship#3#Sailboat#22442908623_5977e3becf_o.jpg
+DIS5K/DIS-TE2/im/2#Aircraft#5#Kite#44654358051_1400e71cc4_o.jpg
+DIS5K/DIS-TE2/im/21#Tool#11#Stand#IMG_20210520_205442.jpg
+DIS5K/DIS-TE2/im/21#Tool#17#Tripod#9318977876_34615ec9a0_o.jpg
+DIS5K/DIS-TE2/im/5#Artifact#3#Handcraft#50860882577_8482143b1b_o.jpg
+DIS5K/DIS-TE2/im/8#Electronics#10#Robot#3093360210_fee54dc5c5_o.jpg
+DIS5K/DIS-TE2/im/8#Electronics#6#Microphone#47411477652_6da66cbc10_o.jpg
+DIS5K/DIS-TE3/im/14#Kitchenware#4#Kitchenware#2451122898_ef883175dd_o.jpg
+DIS5K/DIS-TE3/im/15#Machine#4#SewingMachine#9311164128_97ba1d3947_o.jpg
+DIS5K/DIS-TE3/im/16#Music Instrument#2#Harp#7670920550_59e992fd7b_o.jpg
+DIS5K/DIS-TE3/im/17#Non-motor Vehicle#1#BabyCarriage#8389984877_1fddf8715c_o.jpg
+DIS5K/DIS-TE3/im/17#Non-motor Vehicle#3#Carriage#5947122724_98e0fc3d1f_o.jpg
+DIS5K/DIS-TE3/im/2#Aircraft#2#Balloon#2487168092_641505883f_o.jpg
+DIS5K/DIS-TE3/im/2#Aircraft#4#Helicopter#8401177591_06c71c8df2_o.jpg
+DIS5K/DIS-TE3/im/20#Sports#1#Archery#12520003103_faa43ea3e0_o.jpg
+DIS5K/DIS-TE3/im/21#Tool#11#Stand#IMG_20210709_221507.jpg
+DIS5K/DIS-TE3/im/21#Tool#2#Clip#5656649687_63d0c6696d_o.jpg
+DIS5K/DIS-TE3/im/21#Tool#6#Key#12878459244_6387a140ea_o.jpg
+DIS5K/DIS-TE3/im/3#Aquatic#1#Lobster#109214461_f52b4b6093_o.jpg
+DIS5K/DIS-TE3/im/4#Architecture#19#Windmill#20195851863_2627117e0e_o.jpg
+DIS5K/DIS-TE3/im/5#Artifact#2#Cage#5821476369_ea23927487_o.jpg
+DIS5K/DIS-TE3/im/8#Electronics#7#MobileHolder#49732997896_7f53c290b5_o.jpg
+DIS5K/DIS-TE4/im/13#Insect#6#Centipede#15302179708_a267850881_o.jpg
+DIS5K/DIS-TE4/im/17#Non-motor Vehicle#11#Tricycle#5771069105_a3aef6f665_o.jpg
+DIS5K/DIS-TE4/im/17#Non-motor Vehicle#2#Bicycle#4245936196_fdf812dcb7_o.jpg
+DIS5K/DIS-TE4/im/17#Non-motor Vehicle#9#ShoppingCart#4674052920_a5b7a2b236_o.jpg
+DIS5K/DIS-TE4/im/18#Plant#1#Bonsai#3539420884_ca8973e2c0_o.jpg
+DIS5K/DIS-TE4/im/2#Aircraft#6#Parachute#33590416634_9d6f2325e7_o.jpg
+DIS5K/DIS-TE4/im/20#Sports#1#Archery#46924476515_0be1caa684_o.jpg
+DIS5K/DIS-TE4/im/20#Sports#8#Racket#19337607166_dd1985fb59_o.jpg
+DIS5K/DIS-TE4/im/21#Tool#6#Key#3193329588_839b0c74ce_o.jpg
+DIS5K/DIS-TE4/im/5#Artifact#2#Cage#5821886526_0573ba2d0d_o.jpg
+DIS5K/DIS-TE4/im/5#Artifact#3#Handcraft#50105138282_3c1d02c968_o.jpg
+DIS5K/DIS-TE4/im/8#Electronics#1#Antenna#4305034305_874f21a701_o.jpg
+DIS5K/DIS-TR/im/1#Accessories#1#Bag#15554964549_3105e51b6f_o.jpg
+DIS5K/DIS-TR/im/1#Accessories#1#Bag#41104261980_098a6c4a56_o.jpg
+DIS5K/DIS-TR/im/1#Accessories#2#Clothes#2284764037_871b2e8ca4_o.jpg
+DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#1824643784_70d0134156_o.jpg
+DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#3590020230_37b09a29b3_o.jpg
+DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#4809652879_4da8a69f3b_o.jpg
+DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#792204934_f9b28f99b4_o.jpg
+DIS5K/DIS-TR/im/1#Accessories#5#Jewelry#13909132974_c4750c5fb7_o.jpg
+DIS5K/DIS-TR/im/1#Accessories#7#Shoe#2483391615_9199ece8d6_o.jpg
+DIS5K/DIS-TR/im/1#Accessories#8#Watch#4343266960_f6633b029b_o.jpg
+DIS5K/DIS-TR/im/10#Frame#2#BicycleFrame#17897573_42964dd104_o.jpg
+DIS5K/DIS-TR/im/10#Frame#5#Rack#15898634812_64807069ff_o.jpg
+DIS5K/DIS-TR/im/10#Frame#5#Rack#23928546819_c184cb0b60_o.jpg
+DIS5K/DIS-TR/im/11#Furniture#19#Shower#6189119596_77bcfe80ee_o.jpg
+DIS5K/DIS-TR/im/11#Furniture#2#Bench#3263647075_9306e280b5_o.jpg
+DIS5K/DIS-TR/im/11#Furniture#5#CoatHanger#12774091054_cd5ff520ef_o.jpg
+DIS5K/DIS-TR/im/11#Furniture#6#DentalChair#13878156865_d0439dcb32_o.jpg
+DIS5K/DIS-TR/im/11#Furniture#9#Easel#5861024714_2070cd480c_o.jpg
+DIS5K/DIS-TR/im/12#Graphics#4#TrafficSign#40621867334_f3c32ec189_o.jpg
+DIS5K/DIS-TR/im/13#Insect#1#Ant#3295038190_db5dd0d4f4_o.jpg
+DIS5K/DIS-TR/im/13#Insect#10#Mosquito#24341339_a88a1dad4c_o.jpg
+DIS5K/DIS-TR/im/13#Insect#11#Spider#27171518270_63b78069ff_o.jpg
+DIS5K/DIS-TR/im/13#Insect#11#Spider#49925050281_fa727c154e_o.jpg
+DIS5K/DIS-TR/im/13#Insect#2#Beatle#279616486_2f1e64f591_o.jpg
+DIS5K/DIS-TR/im/13#Insect#3#Bee#43892067695_82cf3e536b_o.jpg
+DIS5K/DIS-TR/im/13#Insect#6#Centipede#20874281788_3e15c90a1c_o.jpg
+DIS5K/DIS-TR/im/13#Insect#7#Dragonfly#14106671120_1b824d77e4_o.jpg
+DIS5K/DIS-TR/im/13#Insect#8#Locust#21637491048_676ef7c9f7_o.jpg
+DIS5K/DIS-TR/im/13#Insect#9#Mantis#1381120202_9dff6987b2_o.jpg
+DIS5K/DIS-TR/im/14#Kitchenware#1#Cup#12812517473_327d6474b8_o.jpg
+DIS5K/DIS-TR/im/14#Kitchenware#10#WineGlass#6402491641_389275d4d1_o.jpg
+DIS5K/DIS-TR/im/14#Kitchenware#3#Hydrovalve#3129932040_8c05825004_o.jpg
+DIS5K/DIS-TR/im/14#Kitchenware#4#Kitchenware#2881934780_87d5218ebb_o.jpg
+DIS5K/DIS-TR/im/14#Kitchenware#4#Kitchenware#IMG_20210520_205527.jpg
+DIS5K/DIS-TR/im/14#Kitchenware#6#Spoon#32989113501_b69eccf0df_o.jpg
+DIS5K/DIS-TR/im/14#Kitchenware#8#SweetStand#2867322189_c56d1e0b87_o.jpg
+DIS5K/DIS-TR/im/15#Machine#1#Gear#19217846720_f5f2807475_o.jpg
+DIS5K/DIS-TR/im/15#Machine#2#Machine#1620160659_9571b7a7ab_o.jpg
+DIS5K/DIS-TR/im/16#Music Instrument#2#Harp#6012801603_1a6e2c16a6_o.jpg
+DIS5K/DIS-TR/im/16#Music Instrument#5#Trombone#8683292118_d223c17ccb_o.jpg
+DIS5K/DIS-TR/im/16#Music Instrument#6#Trumpet#8393262740_b8c216142c_o.jpg
+DIS5K/DIS-TR/im/16#Music Instrument#8#Violin#1511267391_40e4949d68_o.jpg
+DIS5K/DIS-TR/im/17#Non-motor Vehicle#1#BabyCarriage#6989512997_38b3dbc88b_o.jpg
+DIS5K/DIS-TR/im/17#Non-motor Vehicle#12#Wheel#14627183228_b2d68cf501_o.jpg
+DIS5K/DIS-TR/im/17#Non-motor Vehicle#12#Wheel#2932226475_1b2403e549_o.jpg
+DIS5K/DIS-TR/im/17#Non-motor Vehicle#12#Wheel#5420155648_86459905b8_o.jpg
+DIS5K/DIS-TR/im/17#Non-motor Vehicle#2#Bicycle#IMG_20210513_134904.jpg
+DIS5K/DIS-TR/im/17#Non-motor Vehicle#3#Carriage#3311962551_6f211b7bd6_o.jpg
+DIS5K/DIS-TR/im/17#Non-motor Vehicle#4#Cart#2609732026_baf7fff3a1_o.jpg
+DIS5K/DIS-TR/im/17#Non-motor Vehicle#5#Handcart#5821282211_201cefeaf2_o.jpg
+DIS5K/DIS-TR/im/17#Non-motor Vehicle#7#Mower#5779003232_3bb3ae531a_o.jpg
+DIS5K/DIS-TR/im/17#Non-motor Vehicle#9#ShoppingCart#10051622843_ace07e32b8_o.jpg
+DIS5K/DIS-TR/im/17#Non-motor Vehicle#9#ShoppingCart#8075259294_f23e243849_o.jpg
+DIS5K/DIS-TR/im/18#Plant#2#Tree#44800999741_e377e16dbb_o.jpg
+DIS5K/DIS-TR/im/2#Aircraft#1#Airplane#2631761913_3ac67d0223_o.jpg
+DIS5K/DIS-TR/im/2#Aircraft#1#Airplane#37707911566_e908a261b6_o.jpg
+DIS5K/DIS-TR/im/2#Aircraft#3#HangGlider#2557220131_b8506920c5_o.jpg
+DIS5K/DIS-TR/im/2#Aircraft#4#Helicopter#6215659280_5dbd9b4546_o.jpg
+DIS5K/DIS-TR/im/2#Aircraft#6#Parachute#20185790493_e56fcaf8c6_o.jpg
+DIS5K/DIS-TR/im/20#Sports#1#Archery#3871269982_ae4c59a7eb_o.jpg
+DIS5K/DIS-TR/im/20#Sports#9#RockClimbing#9662433268_51299bc50e_o.jpg
+DIS5K/DIS-TR/im/21#Tool#14#Sword#26258479365_2950d7fa37_o.jpg
+DIS5K/DIS-TR/im/21#Tool#15#Tapeline#15505703447_e0fdeaa5a6_o.jpg
+DIS5K/DIS-TR/im/21#Tool#4#Flag#26678602024_9b665742de_o.jpg
+DIS5K/DIS-TR/im/21#Tool#4#Flag#5774823110_d603ce3cc8_o.jpg
+DIS5K/DIS-TR/im/21#Tool#5#Hook#6867989814_dba18d673c_o.jpg
+DIS5K/DIS-TR/im/22#Weapon#4#Rifle#4451713125_cd91719189_o.jpg
+DIS5K/DIS-TR/im/3#Aquatic#2#Seadragon#4910944581_913139b238_o.jpg
+DIS5K/DIS-TR/im/4#Architecture#12#Scaffold#3661448960_8aff24cc4d_o.jpg
+DIS5K/DIS-TR/im/4#Architecture#13#Sculpture#6385318715_9a88d4eba7_o.jpg
+DIS5K/DIS-TR/im/4#Architecture#17#Well#5011603479_75cf42808a_o.jpg
+DIS5K/DIS-TR/im/5#Artifact#2#Cage#4892828841_7f1bc05682_o.jpg
+DIS5K/DIS-TR/im/5#Artifact#3#Handcraft#15404211628_9e9ff2ce2e_o.jpg
+DIS5K/DIS-TR/im/5#Artifact#3#Handcraft#3200169865_7c84cfcccf_o.jpg
+DIS5K/DIS-TR/im/5#Artifact#3#Handcraft#5859295071_c217e7c22f_o.jpg
+DIS5K/DIS-TR/im/6#Automobile#10#SteeringWheel#17200338026_f1e2122d8e_o.jpg
+DIS5K/DIS-TR/im/6#Automobile#3#Car#3780893425_1a7d275e09_o.jpg
+DIS5K/DIS-TR/im/6#Automobile#5#Crane#15282506502_1b1132a7c3_o.jpg
+DIS5K/DIS-TR/im/7#Electrical#1#Cable#16767791875_8e6df41752_o.jpg
+DIS5K/DIS-TR/im/7#Electrical#1#Cable#3291433361_38747324c4_o.jpg
+DIS5K/DIS-TR/im/7#Electrical#1#Cable#4195104238_12a754c61a_o.jpg
+DIS5K/DIS-TR/im/7#Electrical#1#Cable#49645415132_61e5664ecf_o.jpg
+DIS5K/DIS-TR/im/7#Electrical#1#Cable#IMG_20210521_232406.jpg
+DIS5K/DIS-TR/im/7#Electrical#10#UtilityPole#3298312021_92f431e3e9_o.jpg
+DIS5K/DIS-TR/im/7#Electrical#10#UtilityPole#47950134773_fbfff63f4e_o.jpg
+DIS5K/DIS-TR/im/7#Electrical#11#VacuumCleaner#5448403677_6a29e21881_o.jpg
+DIS5K/DIS-TR/im/7#Electrical#2#CeilingLamp#611568868_680ed5d39f_o.jpg
+DIS5K/DIS-TR/im/7#Electrical#3#Fan#3391683115_990525a693_o.jpg
+DIS5K/DIS-TR/im/7#Electrical#6#StreetLamp#150049122_0692266618_o.jpg
+DIS5K/DIS-TR/im/7#Electrical#9#TransmissionTower#31433908671_7e7e277dfe_o.jpg
+DIS5K/DIS-TR/im/8#Electronics#1#Antenna#8727884873_e0622ee5c4_o.jpg
+DIS5K/DIS-TR/im/8#Electronics#2#Camcorder#4172690390_7e5f280ace_o.jpg
+DIS5K/DIS-TR/im/8#Electronics#3#Earphone#413984555_f290febdf5_o.jpg
+DIS5K/DIS-TR/im/8#Electronics#5#Headset#30574225373_3717ed9fa4_o.jpg
+DIS5K/DIS-TR/im/8#Electronics#6#Microphone#538006482_4aae4f5bd6_o.jpg
+DIS5K/DIS-TR/im/8#Electronics#9#MusicPlayer#1306012480_2ea80d2afd_o.jpg
+DIS5K/DIS-TR/im/9#Entertainment#1#GymEquipment#33071754135_8f3195cbd1_o.jpg
+DIS5K/DIS-TR/im/9#Entertainment#2#KidsPlayground#2305807849_be53d724ea_o.jpg
+DIS5K/DIS-TR/im/9#Entertainment#2#KidsPlayground#3862040422_5bbf903204_o.jpg
+DIS5K/DIS-TR/im/9#Entertainment#3#OutdoorFitnessEquipment#10814507005_3dacaa28b3_o.jpg
+DIS5K/DIS-TR/im/9#Entertainment#4#FerrisWheel#81640293_4b0ee62040_o.jpg
+DIS5K/DIS-TR/im/9#Entertainment#5#Swing#49867339188_08073f4b76_o.jpg
+DIS5K/DIS-VD/im/1#Accessories#1#Bag#6815402415_e01c1a41e6_o.jpg
+DIS5K/DIS-VD/im/1#Accessories#5#Jewelry#2744070193_1486582e8d_o.jpg
+DIS5K/DIS-VD/im/10#Frame#1#BasketballHoop#IMG_20210521_232650.jpg
+DIS5K/DIS-VD/im/10#Frame#5#Rack#6156611713_49ebf12b1e_o.jpg
+DIS5K/DIS-VD/im/11#Furniture#11#Handrail#3276641240_1b84b5af85_o.jpg
+DIS5K/DIS-VD/im/11#Furniture#13#Ladder#33423266_5391cf47e9_o.jpg
+DIS5K/DIS-VD/im/11#Furniture#17#Table#3725111755_4fc101e7ab_o.jpg
+DIS5K/DIS-VD/im/11#Furniture#2#Bench#35556410400_7235b58070_o.jpg
+DIS5K/DIS-VD/im/11#Furniture#4#Chair#3301769985_e49de6739f_o.jpg
+DIS5K/DIS-VD/im/11#Furniture#6#DentalChair#23811071619_2a95c3a688_o.jpg
+DIS5K/DIS-VD/im/11#Furniture#9#Easel#8322807354_df6d56542e_o.jpg
+DIS5K/DIS-VD/im/13#Insect#10#Mosquito#12391674863_0cdf430d3f_o.jpg
+DIS5K/DIS-VD/im/13#Insect#7#Dragonfly#14693028899_344ea118f2_o.jpg
+DIS5K/DIS-VD/im/14#Kitchenware#10#WineGlass#4450148455_8f460f541a_o.jpg
+DIS5K/DIS-VD/im/14#Kitchenware#3#Hydrovalve#IMG_20210520_203410.jpg
+DIS5K/DIS-VD/im/15#Machine#3#PlowHarrow#34521712846_df4babb024_o.jpg
+DIS5K/DIS-VD/im/16#Music Instrument#5#Trombone#6222242743_e7189405cd_o.jpg
+DIS5K/DIS-VD/im/17#Non-motor Vehicle#12#Wheel#25677578797_ea47e1d9e8_o.jpg
+DIS5K/DIS-VD/im/17#Non-motor Vehicle#2#Bicycle#5153474856_21560b081b_o.jpg
+DIS5K/DIS-VD/im/17#Non-motor Vehicle#7#Mower#16992510572_8a6ff27398_o.jpg
+DIS5K/DIS-VD/im/19#Ship#2#Canoe#40571458163_7faf8b73d9_o.jpg
+DIS5K/DIS-VD/im/2#Aircraft#1#Airplane#4270588164_66a619e834_o.jpg
+DIS5K/DIS-VD/im/2#Aircraft#4#Helicopter#86789665_650b94b2ee_o.jpg
+DIS5K/DIS-VD/im/20#Sports#14#Wakesurfing#5589577652_5061c168d2_o.jpg
+DIS5K/DIS-VD/im/21#Tool#10#Spade#37018312543_63b21b0784_o.jpg
+DIS5K/DIS-VD/im/21#Tool#14#Sword#24789047250_42df9bf422_o.jpg
+DIS5K/DIS-VD/im/21#Tool#18#Umbrella#IMG_20210513_140445.jpg
+DIS5K/DIS-VD/im/21#Tool#6#Key#43939732715_5a6e28b518_o.jpg
+DIS5K/DIS-VD/im/22#Weapon#1#Cannon#12758066705_90b54295e7_o.jpg
+DIS5K/DIS-VD/im/22#Weapon#4#Rifle#8019368790_fb6dc469a7_o.jpg
+DIS5K/DIS-VD/im/3#Aquatic#5#Shrimp#2582833427_7a99e7356e_o.jpg
+DIS5K/DIS-VD/im/4#Architecture#12#Scaffold#1013402687_590750354e_o.jpg
+DIS5K/DIS-VD/im/4#Architecture#13#Sculpture#17176841759_272a3ed6e3_o.jpg
+DIS5K/DIS-VD/im/4#Architecture#14#Stair#15079108505_0d11281624_o.jpg
+DIS5K/DIS-VD/im/4#Architecture#19#Windmill#2928111082_ceb3051c04_o.jpg
+DIS5K/DIS-VD/im/4#Architecture#3#Crack#3551574032_17dd106d31_o.jpg
+DIS5K/DIS-VD/im/4#Architecture#5#GasStation#4564307581_c3069bdc62_o.jpg
+DIS5K/DIS-VD/im/4#Architecture#8#ObservationTower#2704526950_d4f0ddc807_o.jpg
+DIS5K/DIS-VD/im/5#Artifact#3#Handcraft#10873642323_1bafce3aa5_o.jpg
+DIS5K/DIS-VD/im/6#Automobile#11#Tractor#8594504006_0c2c557d85_o.jpg
+DIS5K/DIS-VD/im/8#Electronics#3#Earphone#8106454803_1178d867cc_o.jpg
\ No newline at end of file
diff --git a/models/monoD/depth_pro/network/__init__.py b/models/monoD/depth_pro/network/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..74882c0eacac7e9bde0e13008fab31037eae671d
--- /dev/null
+++ b/models/monoD/depth_pro/network/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (C) 2024 Apple Inc. All Rights Reserved.
+"""Depth Pro network blocks."""
diff --git a/models/monoD/depth_pro/network/decoder.py b/models/monoD/depth_pro/network/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..770665fcd3e47948388d5da43487d9e75dc0f3fc
--- /dev/null
+++ b/models/monoD/depth_pro/network/decoder.py
@@ -0,0 +1,206 @@
+"""Copyright (C) 2024 Apple Inc. All Rights Reserved.
+
+Dense Prediction Transformer Decoder architecture.
+
+Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
+"""
+
+from __future__ import annotations
+
+from typing import Iterable
+
+import torch
+from torch import nn
+
+
+class MultiresConvDecoder(nn.Module):
+ """Decoder for multi-resolution encodings."""
+
+ def __init__(
+ self,
+ dims_encoder: Iterable[int],
+ dim_decoder: int,
+ ):
+ """Initialize multiresolution convolutional decoder.
+
+ Args:
+ ----
+ dims_encoder: Expected dims at each level from the encoder.
+ dim_decoder: Dim of decoder features.
+
+ """
+ super().__init__()
+ self.dims_encoder = list(dims_encoder)
+ self.dim_decoder = dim_decoder
+ self.dim_out = dim_decoder
+
+ num_encoders = len(self.dims_encoder)
+
+ # At the highest resolution, i.e. level 0, we apply projection w/ 1x1 convolution
+ # when the dimensions mismatch. Otherwise we do not do anything, which is
+ # the default behavior of monodepth.
+ conv0 = (
+ nn.Conv2d(self.dims_encoder[0], dim_decoder, kernel_size=1, bias=False)
+ if self.dims_encoder[0] != dim_decoder
+ else nn.Identity()
+ )
+
+ convs = [conv0]
+ for i in range(1, num_encoders):
+ convs.append(
+ nn.Conv2d(
+ self.dims_encoder[i],
+ dim_decoder,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ )
+ )
+
+ self.convs = nn.ModuleList(convs)
+
+ fusions = []
+ for i in range(num_encoders):
+ fusions.append(
+ FeatureFusionBlock2d(
+ num_features=dim_decoder,
+ deconv=(i != 0),
+ batch_norm=False,
+ )
+ )
+ self.fusions = nn.ModuleList(fusions)
+
+ def forward(self, encodings: torch.Tensor) -> torch.Tensor:
+ """Decode the multi-resolution encodings."""
+ num_levels = len(encodings)
+ num_encoders = len(self.dims_encoder)
+
+ if num_levels != num_encoders:
+ raise ValueError(
+ f"Got encoder output levels={num_levels}, expected levels={num_encoders+1}."
+ )
+
+ # Project features of different encoder dims to the same decoder dim.
+ # Fuse features from the lowest resolution (num_levels-1)
+ # to the highest (0).
+ features = self.convs[-1](encodings[-1])
+ lowres_features = features
+ features = self.fusions[-1](features)
+ for i in range(num_levels - 2, -1, -1):
+ features_i = self.convs[i](encodings[i])
+ features = self.fusions[i](features, features_i)
+ return features, lowres_features
+
+
+class ResidualBlock(nn.Module):
+ """Generic implementation of residual blocks.
+
+ This implements a generic residual block from
+ He et al. - Identity Mappings in Deep Residual Networks (2016),
+ https://arxiv.org/abs/1603.05027
+ which can be further customized via factory functions.
+ """
+
+ def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None:
+ """Initialize ResidualBlock."""
+ super().__init__()
+ self.residual = residual
+ self.shortcut = shortcut
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Apply residual block."""
+ delta_x = self.residual(x)
+
+ if self.shortcut is not None:
+ x = self.shortcut(x)
+
+ return x + delta_x
+
+
+class FeatureFusionBlock2d(nn.Module):
+ """Feature fusion for DPT."""
+
+ def __init__(
+ self,
+ num_features: int,
+ deconv: bool = False,
+ batch_norm: bool = False,
+ ):
+ """Initialize feature fusion block.
+
+ Args:
+ ----
+ num_features: Input and output dimensions.
+ deconv: Whether to use deconv before the final output conv.
+ batch_norm: Whether to use batch normalization in resnet blocks.
+
+ """
+ super().__init__()
+
+ self.resnet1 = self._residual_block(num_features, batch_norm)
+ self.resnet2 = self._residual_block(num_features, batch_norm)
+
+ self.use_deconv = deconv
+ if deconv:
+ self.deconv = nn.ConvTranspose2d(
+ in_channels=num_features,
+ out_channels=num_features,
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=False,
+ )
+
+ self.out_conv = nn.Conv2d(
+ num_features,
+ num_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ )
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Tensor:
+ """Process and fuse input features."""
+ x = x0
+
+ if x1 is not None:
+ res = self.resnet1(x1)
+ x = self.skip_add.add(x, res)
+
+ x = self.resnet2(x)
+
+ if self.use_deconv:
+ x = self.deconv(x)
+ x = self.out_conv(x)
+
+ return x
+
+ @staticmethod
+ def _residual_block(num_features: int, batch_norm: bool):
+ """Create a residual block."""
+
+ def _create_block(dim: int, batch_norm: bool) -> list[nn.Module]:
+ layers = [
+ nn.ReLU(False),
+ nn.Conv2d(
+ num_features,
+ num_features,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=not batch_norm,
+ ),
+ ]
+ if batch_norm:
+ layers.append(nn.BatchNorm2d(dim))
+ return layers
+
+ residual = nn.Sequential(
+ *_create_block(dim=num_features, batch_norm=batch_norm),
+ *_create_block(dim=num_features, batch_norm=batch_norm),
+ )
+ return ResidualBlock(residual)
diff --git a/models/monoD/depth_pro/network/encoder.py b/models/monoD/depth_pro/network/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3a3da17d47bf91662463520afaf413f08676c3b
--- /dev/null
+++ b/models/monoD/depth_pro/network/encoder.py
@@ -0,0 +1,332 @@
+# Copyright (C) 2024 Apple Inc. All Rights Reserved.
+# DepthProEncoder combining patch and image encoders.
+
+from __future__ import annotations
+
+import math
+from typing import Iterable, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class DepthProEncoder(nn.Module):
+ """DepthPro Encoder.
+
+ An encoder aimed at creating multi-resolution encodings from Vision Transformers.
+ """
+
+ def __init__(
+ self,
+ dims_encoder: Iterable[int],
+ patch_encoder: nn.Module,
+ image_encoder: nn.Module,
+ hook_block_ids: Iterable[int],
+ decoder_features: int,
+ ):
+ """Initialize DepthProEncoder.
+
+ The framework
+ 1. creates an image pyramid,
+ 2. generates overlapping patches with a sliding window at each pyramid level,
+ 3. creates batched encodings via vision transformer backbones,
+ 4. produces multi-resolution encodings.
+
+ Args:
+ ----
+ img_size: Backbone image resolution.
+ dims_encoder: Dimensions of the encoder at different layers.
+ patch_encoder: Backbone used for patches.
+ image_encoder: Backbone used for global image encoder.
+ hook_block_ids: Hooks to obtain intermediate features for the patch encoder model.
+ decoder_features: Number of feature output in the decoder.
+
+ """
+ super().__init__()
+
+ self.dims_encoder = list(dims_encoder)
+ self.patch_encoder = patch_encoder
+ self.image_encoder = image_encoder
+ self.hook_block_ids = list(hook_block_ids)
+
+ patch_encoder_embed_dim = patch_encoder.embed_dim
+ image_encoder_embed_dim = image_encoder.embed_dim
+
+ self.out_size = int(
+ patch_encoder.patch_embed.img_size[0] // patch_encoder.patch_embed.patch_size[0]
+ )
+
+ def _create_project_upsample_block(
+ dim_in: int,
+ dim_out: int,
+ upsample_layers: int,
+ dim_int: Optional[int] = None,
+ ) -> nn.Module:
+ if dim_int is None:
+ dim_int = dim_out
+ # Projection.
+ blocks = [
+ nn.Conv2d(
+ in_channels=dim_in,
+ out_channels=dim_int,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ )
+ ]
+
+ # Upsampling.
+ blocks += [
+ nn.ConvTranspose2d(
+ in_channels=dim_int if i == 0 else dim_out,
+ out_channels=dim_out,
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=False,
+ )
+ for i in range(upsample_layers)
+ ]
+
+ return nn.Sequential(*blocks)
+
+ self.upsample_latent0 = _create_project_upsample_block(
+ dim_in=patch_encoder_embed_dim,
+ dim_int=self.dims_encoder[0],
+ dim_out=decoder_features,
+ upsample_layers=3,
+ )
+ self.upsample_latent1 = _create_project_upsample_block(
+ dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[0], upsample_layers=2
+ )
+
+ self.upsample0 = _create_project_upsample_block(
+ dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[1], upsample_layers=1
+ )
+ self.upsample1 = _create_project_upsample_block(
+ dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[2], upsample_layers=1
+ )
+ self.upsample2 = _create_project_upsample_block(
+ dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[3], upsample_layers=1
+ )
+
+ self.upsample_lowres = nn.ConvTranspose2d(
+ in_channels=image_encoder_embed_dim,
+ out_channels=self.dims_encoder[3],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ )
+ self.fuse_lowres = nn.Conv2d(
+ in_channels=(self.dims_encoder[3] + self.dims_encoder[3]),
+ out_channels=self.dims_encoder[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ )
+
+ # Obtain intermediate outputs of the blocks.
+ self.patch_encoder.blocks[self.hook_block_ids[0]].register_forward_hook(
+ self._hook0
+ )
+ self.patch_encoder.blocks[self.hook_block_ids[1]].register_forward_hook(
+ self._hook1
+ )
+
+ def _hook0(self, model, input, output):
+ self.backbone_highres_hook0 = output
+
+ def _hook1(self, model, input, output):
+ self.backbone_highres_hook1 = output
+
+ @property
+ def img_size(self) -> int:
+ """Return the full image size of the SPN network."""
+ return self.patch_encoder.patch_embed.img_size[0] * 4
+
+ def _create_pyramid(
+ self, x: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Create a 3-level image pyramid."""
+ # Original resolution: 1536 by default.
+ x0 = x
+
+ # Middle resolution: 768 by default.
+ x1 = F.interpolate(
+ x, size=None, scale_factor=0.5, mode="bilinear", align_corners=False
+ )
+
+ # Low resolution: 384 by default, corresponding to the backbone resolution.
+ x2 = F.interpolate(
+ x, size=None, scale_factor=0.25, mode="bilinear", align_corners=False
+ )
+
+ return x0, x1, x2
+
+ def split(self, x: torch.Tensor, overlap_ratio: float = 0.25) -> torch.Tensor:
+ """Split the input into small patches with sliding window."""
+ patch_size = 384
+ patch_stride = int(patch_size * (1 - overlap_ratio))
+
+ image_size = x.shape[-1]
+ steps = int(math.ceil((image_size - patch_size) / patch_stride)) + 1
+
+ x_patch_list = []
+ for j in range(steps):
+ j0 = j * patch_stride
+ j1 = j0 + patch_size
+
+ for i in range(steps):
+ i0 = i * patch_stride
+ i1 = i0 + patch_size
+ x_patch_list.append(x[..., j0:j1, i0:i1])
+
+ return torch.cat(x_patch_list, dim=0)
+
+ def merge(self, x: torch.Tensor, batch_size: int, padding: int = 3) -> torch.Tensor:
+ """Merge the patched input into a image with sliding window."""
+ steps = int(math.sqrt(x.shape[0] // batch_size))
+
+ idx = 0
+
+ output_list = []
+ for j in range(steps):
+ output_row_list = []
+ for i in range(steps):
+ output = x[batch_size * idx : batch_size * (idx + 1)]
+
+ if j != 0:
+ output = output[..., padding:, :]
+ if i != 0:
+ output = output[..., :, padding:]
+ if j != steps - 1:
+ output = output[..., :-padding, :]
+ if i != steps - 1:
+ output = output[..., :, :-padding]
+
+ output_row_list.append(output)
+ idx += 1
+
+ output_row = torch.cat(output_row_list, dim=-1)
+ output_list.append(output_row)
+ output = torch.cat(output_list, dim=-2)
+ return output
+
+ def reshape_feature(
+ self, embeddings: torch.Tensor, width, height, cls_token_offset=1
+ ):
+ """Discard class token and reshape 1D feature map to a 2D grid."""
+ b, hw, c = embeddings.shape
+
+ # Remove class token.
+ if cls_token_offset > 0:
+ embeddings = embeddings[:, cls_token_offset:, :]
+
+ # Shape: (batch, height, width, dim) -> (batch, dim, height, width)
+ embeddings = embeddings.reshape(b, height, width, c).permute(0, 3, 1, 2)
+ return embeddings
+
+ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
+ """Encode input at multiple resolutions.
+
+ Args:
+ ----
+ x (torch.Tensor): Input image.
+
+ Returns:
+ -------
+ Multi resolution encoded features.
+
+ """
+ batch_size = x.shape[0]
+
+ # Step 0: create a 3-level image pyramid.
+ x0, x1, x2 = self._create_pyramid(x)
+
+ # Step 1: split to create batched overlapped mini-images at the backbone (BeiT/ViT/Dino)
+ # resolution.
+ # 5x5 @ 384x384 at the highest resolution (1536x1536).
+ x0_patches = self.split(x0, overlap_ratio=0.25)
+ # 3x3 @ 384x384 at the middle resolution (768x768).
+ x1_patches = self.split(x1, overlap_ratio=0.5)
+ # 1x1 # 384x384 at the lowest resolution (384x384).
+ x2_patches = x2
+
+ # Concatenate all the sliding window patches and form a batch of size (35=5x5+3x3+1x1).
+ x_pyramid_patches = torch.cat(
+ (x0_patches, x1_patches, x2_patches),
+ dim=0,
+ )
+
+ # Step 2: Run the backbone (BeiT) model and get the result of large batch size.
+ x_pyramid_encodings = self.patch_encoder(x_pyramid_patches)
+ x_pyramid_encodings = self.reshape_feature(
+ x_pyramid_encodings, self.out_size, self.out_size
+ )
+
+ # Step 3: merging.
+ # Merge highres latent encoding.
+ x_latent0_encodings = self.reshape_feature(
+ self.backbone_highres_hook0,
+ self.out_size,
+ self.out_size,
+ )
+ x_latent0_features = self.merge(
+ x_latent0_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3
+ )
+
+ x_latent1_encodings = self.reshape_feature(
+ self.backbone_highres_hook1,
+ self.out_size,
+ self.out_size,
+ )
+ x_latent1_features = self.merge(
+ x_latent1_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3
+ )
+
+ # Split the 35 batch size from pyramid encoding back into 5x5+3x3+1x1.
+ x0_encodings, x1_encodings, x2_encodings = torch.split(
+ x_pyramid_encodings,
+ [len(x0_patches), len(x1_patches), len(x2_patches)],
+ dim=0,
+ )
+
+ # 96x96 feature maps by merging 5x5 @ 24x24 patches with overlaps.
+ x0_features = self.merge(x0_encodings, batch_size=batch_size, padding=3)
+
+ # 48x84 feature maps by merging 3x3 @ 24x24 patches with overlaps.
+ x1_features = self.merge(x1_encodings, batch_size=batch_size, padding=6)
+
+ # 24x24 feature maps.
+ x2_features = x2_encodings
+
+ # Apply the image encoder model.
+ x_global_features = self.image_encoder(x2_patches)
+ x_global_features = self.reshape_feature(
+ x_global_features, self.out_size, self.out_size
+ )
+
+ # Upsample feature maps.
+ x_latent0_features = self.upsample_latent0(x_latent0_features)
+ x_latent1_features = self.upsample_latent1(x_latent1_features)
+
+ x0_features = self.upsample0(x0_features)
+ x1_features = self.upsample1(x1_features)
+ x2_features = self.upsample2(x2_features)
+
+ x_global_features = self.upsample_lowres(x_global_features)
+ x_global_features = self.fuse_lowres(
+ torch.cat((x2_features, x_global_features), dim=1)
+ )
+
+ return [
+ x_latent0_features,
+ x_latent1_features,
+ x0_features,
+ x1_features,
+ x_global_features,
+ ]
diff --git a/models/monoD/depth_pro/network/fov.py b/models/monoD/depth_pro/network/fov.py
new file mode 100644
index 0000000000000000000000000000000000000000..5900286509ca9535d4d29679b88055b5b6aed938
--- /dev/null
+++ b/models/monoD/depth_pro/network/fov.py
@@ -0,0 +1,82 @@
+# Copyright (C) 2024 Apple Inc. All Rights Reserved.
+# Field of View network architecture.
+
+from typing import Optional
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class FOVNetwork(nn.Module):
+ """Field of View estimation network."""
+
+ def __init__(
+ self,
+ num_features: int,
+ fov_encoder: Optional[nn.Module] = None,
+ ):
+ """Initialize the Field of View estimation block.
+
+ Args:
+ ----
+ num_features: Number of features used.
+ fov_encoder: Optional encoder to bring additional network capacity.
+
+ """
+ super().__init__()
+
+ # Create FOV head.
+ fov_head0 = [
+ nn.Conv2d(
+ num_features, num_features // 2, kernel_size=3, stride=2, padding=1
+ ), # 128 x 24 x 24
+ nn.ReLU(True),
+ ]
+ fov_head = [
+ nn.Conv2d(
+ num_features // 2, num_features // 4, kernel_size=3, stride=2, padding=1
+ ), # 64 x 12 x 12
+ nn.ReLU(True),
+ nn.Conv2d(
+ num_features // 4, num_features // 8, kernel_size=3, stride=2, padding=1
+ ), # 32 x 6 x 6
+ nn.ReLU(True),
+ nn.Conv2d(num_features // 8, 1, kernel_size=6, stride=1, padding=0),
+ ]
+ if fov_encoder is not None:
+ self.encoder = nn.Sequential(
+ fov_encoder, nn.Linear(fov_encoder.embed_dim, num_features // 2)
+ )
+ self.downsample = nn.Sequential(*fov_head0)
+ else:
+ fov_head = fov_head0 + fov_head
+ self.head = nn.Sequential(*fov_head)
+
+ def forward(self, x: torch.Tensor, lowres_feature: torch.Tensor) -> torch.Tensor:
+ """Forward the fov network.
+
+ Args:
+ ----
+ x (torch.Tensor): Input image.
+ lowres_feature (torch.Tensor): Low resolution feature.
+
+ Returns:
+ -------
+ The field of view tensor.
+
+ """
+ if hasattr(self, "encoder"):
+ x = F.interpolate(
+ x,
+ size=None,
+ scale_factor=0.25,
+ mode="bilinear",
+ align_corners=False,
+ )
+ x = self.encoder(x)[:, 1:].permute(0, 2, 1)
+ lowres_feature = self.downsample(lowres_feature)
+ x = x.reshape_as(lowres_feature) + lowres_feature
+ else:
+ x = lowres_feature
+ return self.head(x)
diff --git a/models/monoD/depth_pro/network/vit.py b/models/monoD/depth_pro/network/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6c3768a1dcedccd99a58f9507f4edac3cde9da0
--- /dev/null
+++ b/models/monoD/depth_pro/network/vit.py
@@ -0,0 +1,123 @@
+# Copyright (C) 2024 Apple Inc. All Rights Reserved.
+
+
+try:
+ from timm.layers import resample_abs_pos_embed
+except ImportError as err:
+ print("ImportError: {0}".format(err))
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+
+def make_vit_b16_backbone(
+ model,
+ encoder_feature_dims,
+ encoder_feature_layer_ids,
+ vit_features,
+ start_index=1,
+ use_grad_checkpointing=False,
+) -> nn.Module:
+ """Make a ViTb16 backbone for the DPT model."""
+ if use_grad_checkpointing:
+ model.set_grad_checkpointing()
+
+ vit_model = nn.Module()
+ vit_model.hooks = encoder_feature_layer_ids
+ vit_model.model = model
+ vit_model.features = encoder_feature_dims
+ vit_model.vit_features = vit_features
+ vit_model.model.start_index = start_index
+ vit_model.model.patch_size = vit_model.model.patch_embed.patch_size
+ vit_model.model.is_vit = True
+ vit_model.model.forward = vit_model.model.forward_features
+
+ return vit_model
+
+
+def forward_features_eva_fixed(self, x):
+ """Encode features."""
+ x = self.patch_embed(x)
+ x, rot_pos_embed = self._pos_embed(x)
+ for blk in self.blocks:
+ if self.grad_checkpointing:
+ x = checkpoint(blk, x, rot_pos_embed)
+ else:
+ x = blk(x, rot_pos_embed)
+ x = self.norm(x)
+ return x
+
+
+def resize_vit(model: nn.Module, img_size) -> nn.Module:
+ """Resample the ViT module to the given size."""
+ patch_size = model.patch_embed.patch_size
+ model.patch_embed.img_size = img_size
+ grid_size = tuple([s // p for s, p in zip(img_size, patch_size)])
+ model.patch_embed.grid_size = grid_size
+
+ pos_embed = resample_abs_pos_embed(
+ model.pos_embed,
+ grid_size, # img_size
+ num_prefix_tokens=(
+ 0 if getattr(model, "no_embed_class", False) else model.num_prefix_tokens
+ ),
+ )
+ model.pos_embed = torch.nn.Parameter(pos_embed)
+
+ return model
+
+
+def resize_patch_embed(model: nn.Module, new_patch_size=(16, 16)) -> nn.Module:
+ """Resample the ViT patch size to the given one."""
+ # interpolate patch embedding
+ if hasattr(model, "patch_embed"):
+ old_patch_size = model.patch_embed.patch_size
+
+ if (
+ new_patch_size[0] != old_patch_size[0]
+ or new_patch_size[1] != old_patch_size[1]
+ ):
+ patch_embed_proj = model.patch_embed.proj.weight
+ patch_embed_proj_bias = model.patch_embed.proj.bias
+ use_bias = True if patch_embed_proj_bias is not None else False
+ _, _, h, w = patch_embed_proj.shape
+
+ new_patch_embed_proj = torch.nn.functional.interpolate(
+ patch_embed_proj,
+ size=[new_patch_size[0], new_patch_size[1]],
+ mode="bicubic",
+ align_corners=False,
+ )
+ new_patch_embed_proj = (
+ new_patch_embed_proj * (h / new_patch_size[0]) * (w / new_patch_size[1])
+ )
+
+ model.patch_embed.proj = nn.Conv2d(
+ in_channels=model.patch_embed.proj.in_channels,
+ out_channels=model.patch_embed.proj.out_channels,
+ kernel_size=new_patch_size,
+ stride=new_patch_size,
+ bias=use_bias,
+ )
+
+ if use_bias:
+ model.patch_embed.proj.bias = patch_embed_proj_bias
+
+ model.patch_embed.proj.weight = torch.nn.Parameter(new_patch_embed_proj)
+
+ model.patch_size = new_patch_size
+ model.patch_embed.patch_size = new_patch_size
+ model.patch_embed.img_size = (
+ int(
+ model.patch_embed.img_size[0]
+ * new_patch_size[0]
+ / old_patch_size[0]
+ ),
+ int(
+ model.patch_embed.img_size[1]
+ * new_patch_size[1]
+ / old_patch_size[1]
+ ),
+ )
+
+ return model
diff --git a/models/monoD/depth_pro/network/vit_factory.py b/models/monoD/depth_pro/network/vit_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cd899f650978043c2c83348670beaf597e9ca30
--- /dev/null
+++ b/models/monoD/depth_pro/network/vit_factory.py
@@ -0,0 +1,124 @@
+# Copyright (C) 2024 Apple Inc. All Rights Reserved.
+# Factory functions to build and load ViT models.
+
+
+from __future__ import annotations
+
+import logging
+import types
+from dataclasses import dataclass
+from typing import Dict, List, Literal, Optional
+
+import timm
+import torch
+import torch.nn as nn
+
+from .vit import (
+ forward_features_eva_fixed,
+ make_vit_b16_backbone,
+ resize_patch_embed,
+ resize_vit,
+)
+
+LOGGER = logging.getLogger(__name__)
+
+
+ViTPreset = Literal[
+ "dinov2l16_384",
+]
+
+
+@dataclass
+class ViTConfig:
+ """Configuration for ViT."""
+
+ in_chans: int
+ embed_dim: int
+
+ img_size: int = 384
+ patch_size: int = 16
+
+ # In case we need to rescale the backbone when loading from timm.
+ timm_preset: Optional[str] = None
+ timm_img_size: int = 384
+ timm_patch_size: int = 16
+
+ # The following 2 parameters are only used by DPT. See dpt_factory.py.
+ encoder_feature_layer_ids: List[int] = None
+ """The layers in the Beit/ViT used to constructs encoder features for DPT."""
+ encoder_feature_dims: List[int] = None
+ """The dimension of features of encoder layers from Beit/ViT features for DPT."""
+
+
+VIT_CONFIG_DICT: Dict[ViTPreset, ViTConfig] = {
+ "dinov2l16_384": ViTConfig(
+ in_chans=3,
+ embed_dim=1024,
+ encoder_feature_layer_ids=[5, 11, 17, 23],
+ encoder_feature_dims=[256, 512, 1024, 1024],
+ img_size=384,
+ patch_size=16,
+ timm_preset="vit_large_patch14_dinov2",
+ timm_img_size=518,
+ timm_patch_size=14,
+ ),
+}
+
+
+def create_vit(
+ preset: ViTPreset,
+ use_pretrained: bool = False,
+ checkpoint_uri: str | None = None,
+ use_grad_checkpointing: bool = False,
+) -> nn.Module:
+ """Create and load a VIT backbone module.
+
+ Args:
+ ----
+ preset: The VIT preset to load the pre-defined config.
+ use_pretrained: Load pretrained weights if True, default is False.
+ checkpoint_uri: Checkpoint to load the wights from.
+ use_grad_checkpointing: Use grandient checkpointing.
+
+ Returns:
+ -------
+ A Torch ViT backbone module.
+
+ """
+ config = VIT_CONFIG_DICT[preset]
+
+ img_size = (config.img_size, config.img_size)
+ patch_size = (config.patch_size, config.patch_size)
+
+ if "eva02" in preset:
+ model = timm.create_model(config.timm_preset, pretrained=use_pretrained)
+ model.forward_features = types.MethodType(forward_features_eva_fixed, model)
+ else:
+ model = timm.create_model(
+ config.timm_preset, pretrained=use_pretrained, dynamic_img_size=True
+ )
+ model = make_vit_b16_backbone(
+ model,
+ encoder_feature_dims=config.encoder_feature_dims,
+ encoder_feature_layer_ids=config.encoder_feature_layer_ids,
+ vit_features=config.embed_dim,
+ use_grad_checkpointing=use_grad_checkpointing,
+ )
+ if config.patch_size != config.timm_patch_size:
+ model.model = resize_patch_embed(model.model, new_patch_size=patch_size)
+ if config.img_size != config.timm_img_size:
+ model.model = resize_vit(model.model, img_size=img_size)
+
+ if checkpoint_uri is not None:
+ state_dict = torch.load(checkpoint_uri, map_location="cpu")
+ missing_keys, unexpected_keys = model.load_state_dict(
+ state_dict=state_dict, strict=False
+ )
+
+ if len(unexpected_keys) != 0:
+ raise KeyError(f"Found unexpected keys when loading vit: {unexpected_keys}")
+ if len(missing_keys) != 0:
+ raise KeyError(f"Keys are missing when loading vit: {missing_keys}")
+
+ LOGGER.info(model)
+ return model.model
diff --git a/models/monoD/depth_pro/test_depth_pro.py b/models/monoD/depth_pro/test_depth_pro.py
new file mode 100644
index 0000000000000000000000000000000000000000..87c015226eea64ca8a94b57b8f475d5ecfb74b4a
--- /dev/null
+++ b/models/monoD/depth_pro/test_depth_pro.py
@@ -0,0 +1,27 @@
+from PIL import Image
+from models.monoD import depth_pro
+
+# Load model and preprocessing transform
+model, transform = depth_pro.create_model_and_transforms()
+model.eval()
+
+# Load and preprocess an image.
+image_path = "assets/dance/00000.jpg"
+image, _, f_px = depth_pro.load_rgb(image_path)
+image = transform(image)
+
+# Run inference.
+import time
+t0 = time.time()
+prediction = model.infer(image, f_px=f_px)
+depth = prediction["depth"] # Depth in [m].
+focallength_px = prediction["focallength_px"] # Focal length in pixels.
+import cv2
+import numpy as np
+depth = depth.clamp(0,30).squeeze().detach().cpu().numpy()
+depth = (depth - depth.min())/(depth.max()-depth.min()) * 255.0
+depth = depth.astype(np.uint8)
+depth = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
+cv2.imwrite("depth.png", depth)
+print(f"Time: {time.time() - t0:.2f}s")
+import pdb; pdb.set_trace()
\ No newline at end of file
diff --git a/models/monoD/depth_pro/utils.py b/models/monoD/depth_pro/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a401def2e1d6a2dd96b204e962569e9da5e0ef1
--- /dev/null
+++ b/models/monoD/depth_pro/utils.py
@@ -0,0 +1,112 @@
+# Copyright (C) 2024 Apple Inc. All Rights Reserved.
+
+import logging
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
+import numpy as np
+import pillow_heif
+from PIL import ExifTags, Image, TiffTags
+from pillow_heif import register_heif_opener
+
+register_heif_opener()
+LOGGER = logging.getLogger(__name__)
+
+
+def extract_exif(img_pil: Image) -> Dict[str, Any]:
+ """Return exif information as a dictionary.
+
+ Args:
+ ----
+ img_pil: A Pillow image.
+
+ Returns:
+ -------
+ A dictionary with extracted EXIF information.
+
+ """
+ # Get full exif description from get_ifd(0x8769):
+ # cf https://pillow.readthedocs.io/en/stable/releasenotes/8.2.0.html#image-getexif-exif-and-gps-ifd
+ img_exif = img_pil.getexif().get_ifd(0x8769)
+ exif_dict = {ExifTags.TAGS[k]: v for k, v in img_exif.items() if k in ExifTags.TAGS}
+
+ tiff_tags = img_pil.getexif()
+ tiff_dict = {
+ TiffTags.TAGS_V2[k].name: v
+ for k, v in tiff_tags.items()
+ if k in TiffTags.TAGS_V2
+ }
+ return {**exif_dict, **tiff_dict}
+
+
+def fpx_from_f35(width: float, height: float, f_mm: float = 50) -> float:
+ """Convert a focal length given in mm (35mm film equivalent) to pixels."""
+ return f_mm * np.sqrt(width**2.0 + height**2.0) / np.sqrt(36**2 + 24**2)
+
+
+def load_rgb(
+ path: Union[Path, str], auto_rotate: bool = True, remove_alpha: bool = True
+) -> Tuple[np.ndarray, List[bytes], float]:
+ """Load an RGB image.
+
+ Args:
+ ----
+ path: The url to the image to load.
+ auto_rotate: Rotate the image based on the EXIF data, default is True.
+ remove_alpha: Remove the alpha channel, default is True.
+
+ Returns:
+ -------
+ img: The image loaded as a numpy array.
+ icc_profile: The color profile of the image.
+ f_px: The optional focal length in pixels, extracting from the exif data.
+
+ """
+ LOGGER.debug(f"Loading image {path} ...")
+
+ path = Path(path)
+ if path.suffix.lower() in [".heic"]:
+ heif_file = pillow_heif.open_heif(path, convert_hdr_to_8bit=True)
+ img_pil = heif_file.to_pillow()
+ else:
+ img_pil = Image.open(path)
+
+ img_exif = extract_exif(img_pil)
+ icc_profile = img_pil.info.get("icc_profile", None)
+
+ # Rotate the image.
+ if auto_rotate:
+ exif_orientation = img_exif.get("Orientation", 1)
+ if exif_orientation == 3:
+ img_pil = img_pil.transpose(Image.ROTATE_180)
+ elif exif_orientation == 6:
+ img_pil = img_pil.transpose(Image.ROTATE_270)
+ elif exif_orientation == 8:
+ img_pil = img_pil.transpose(Image.ROTATE_90)
+ elif exif_orientation != 1:
+ LOGGER.warning(f"Ignoring image orientation {exif_orientation}.")
+
+ img = np.array(img_pil)
+ # Convert to RGB if single channel.
+ if img.ndim < 3 or img.shape[2] == 1:
+ img = np.dstack((img, img, img))
+
+ if remove_alpha:
+ img = img[:, :, :3]
+
+ LOGGER.debug(f"\tHxW: {img.shape[0]}x{img.shape[1]}")
+
+ # Extract the focal length from exif data.
+ f_35mm = img_exif.get(
+ "FocalLengthIn35mmFilm",
+ img_exif.get(
+ "FocalLenIn35mmFilm", img_exif.get("FocalLengthIn35mmFormat", None)
+ ),
+ )
+ if f_35mm is not None and f_35mm > 0:
+ LOGGER.debug(f"\tfocal length @ 35mm film: {f_35mm}mm")
+ f_px = fpx_from_f35(img.shape[1], img.shape[0], f_35mm)
+ else:
+ f_px = None
+
+ return img, icc_profile, f_px
diff --git a/models/monoD/zoeDepth/__init__.py b/models/monoD/zoeDepth/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/models/monoD/zoeDepth/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/models/monoD/zoeDepth/midas_c/__init__.py b/models/monoD/zoeDepth/midas_c/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..6e09b2cf4b1bc7d94ff59e8b3cb9fc2fc82779c4
--- /dev/null
+++ b/models/monoD/zoeDepth/midas_c/__init__.py
@@ -0,0 +1 @@
+from .hubconf import *
\ No newline at end of file
diff --git a/models/monoD/zoeDepth/midas_c/hubconf.py b/models/monoD/zoeDepth/midas_c/hubconf.py
new file mode 100755
index 0000000000000000000000000000000000000000..9b281847befd0ed2f6102232ba98af01085901ae
--- /dev/null
+++ b/models/monoD/zoeDepth/midas_c/hubconf.py
@@ -0,0 +1,435 @@
+dependencies = ["torch"]
+
+import torch
+
+from .midas.dpt_depth import DPTDepthModel
+from .midas.midas_net import MidasNet
+from .midas.midas_net_custom import MidasNet_small
+
+def DPT_BEiT_L_512(pretrained=True, **kwargs):
+ """ # This docstring shows up in hub.help()
+ MiDaS DPT_BEiT_L_512 model for monocular depth estimation
+ pretrained (bool): load pretrained weights into model
+ """
+
+ model = DPTDepthModel(
+ path=None,
+ backbone="beitl16_512",
+ non_negative=True,
+ )
+
+ if pretrained:
+ checkpoint = (
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt"
+ )
+ state_dict = torch.hub.load_state_dict_from_url(
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
+ )
+ model.load_state_dict(state_dict)
+
+ return model
+
+def DPT_BEiT_L_384(pretrained=True, **kwargs):
+ """ # This docstring shows up in hub.help()
+ MiDaS DPT_BEiT_L_384 model for monocular depth estimation
+ pretrained (bool): load pretrained weights into model
+ """
+
+ model = DPTDepthModel(
+ path=None,
+ backbone="beitl16_384",
+ non_negative=True,
+ )
+
+ if pretrained:
+ checkpoint = (
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt"
+ )
+ state_dict = torch.hub.load_state_dict_from_url(
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
+ )
+ model.load_state_dict(state_dict)
+
+ return model
+
+def DPT_BEiT_B_384(pretrained=True, **kwargs):
+ """ # This docstring shows up in hub.help()
+ MiDaS DPT_BEiT_B_384 model for monocular depth estimation
+ pretrained (bool): load pretrained weights into model
+ """
+
+ model = DPTDepthModel(
+ path=None,
+ backbone="beitb16_384",
+ non_negative=True,
+ )
+
+ if pretrained:
+ checkpoint = (
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt"
+ )
+ state_dict = torch.hub.load_state_dict_from_url(
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
+ )
+ model.load_state_dict(state_dict)
+
+ return model
+
+def DPT_SwinV2_L_384(pretrained=True, **kwargs):
+ """ # This docstring shows up in hub.help()
+ MiDaS DPT_SwinV2_L_384 model for monocular depth estimation
+ pretrained (bool): load pretrained weights into model
+ """
+
+ model = DPTDepthModel(
+ path=None,
+ backbone="swin2l24_384",
+ non_negative=True,
+ )
+
+ if pretrained:
+ checkpoint = (
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt"
+ )
+ state_dict = torch.hub.load_state_dict_from_url(
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
+ )
+ model.load_state_dict(state_dict)
+
+ return model
+
+def DPT_SwinV2_B_384(pretrained=True, **kwargs):
+ """ # This docstring shows up in hub.help()
+ MiDaS DPT_SwinV2_B_384 model for monocular depth estimation
+ pretrained (bool): load pretrained weights into model
+ """
+
+ model = DPTDepthModel(
+ path=None,
+ backbone="swin2b24_384",
+ non_negative=True,
+ )
+
+ if pretrained:
+ checkpoint = (
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt"
+ )
+ state_dict = torch.hub.load_state_dict_from_url(
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
+ )
+ model.load_state_dict(state_dict)
+
+ return model
+
+def DPT_SwinV2_T_256(pretrained=True, **kwargs):
+ """ # This docstring shows up in hub.help()
+ MiDaS DPT_SwinV2_T_256 model for monocular depth estimation
+ pretrained (bool): load pretrained weights into model
+ """
+
+ model = DPTDepthModel(
+ path=None,
+ backbone="swin2t16_256",
+ non_negative=True,
+ )
+
+ if pretrained:
+ checkpoint = (
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt"
+ )
+ state_dict = torch.hub.load_state_dict_from_url(
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
+ )
+ model.load_state_dict(state_dict)
+
+ return model
+
+def DPT_Swin_L_384(pretrained=True, **kwargs):
+ """ # This docstring shows up in hub.help()
+ MiDaS DPT_Swin_L_384 model for monocular depth estimation
+ pretrained (bool): load pretrained weights into model
+ """
+
+ model = DPTDepthModel(
+ path=None,
+ backbone="swinl12_384",
+ non_negative=True,
+ )
+
+ if pretrained:
+ checkpoint = (
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt"
+ )
+ state_dict = torch.hub.load_state_dict_from_url(
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
+ )
+ model.load_state_dict(state_dict)
+
+ return model
+
+def DPT_Next_ViT_L_384(pretrained=True, **kwargs):
+ """ # This docstring shows up in hub.help()
+ MiDaS DPT_Next_ViT_L_384 model for monocular depth estimation
+ pretrained (bool): load pretrained weights into model
+ """
+
+ model = DPTDepthModel(
+ path=None,
+ backbone="next_vit_large_6m",
+ non_negative=True,
+ )
+
+ if pretrained:
+ checkpoint = (
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt"
+ )
+ state_dict = torch.hub.load_state_dict_from_url(
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
+ )
+ model.load_state_dict(state_dict)
+
+ return model
+
+def DPT_LeViT_224(pretrained=True, **kwargs):
+ """ # This docstring shows up in hub.help()
+ MiDaS DPT_LeViT_224 model for monocular depth estimation
+ pretrained (bool): load pretrained weights into model
+ """
+
+ model = DPTDepthModel(
+ path=None,
+ backbone="levit_384",
+ non_negative=True,
+ head_features_1=64,
+ head_features_2=8,
+ )
+
+ if pretrained:
+ checkpoint = (
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt"
+ )
+ state_dict = torch.hub.load_state_dict_from_url(
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
+ )
+ model.load_state_dict(state_dict)
+
+ return model
+
+def DPT_Large(pretrained=True, **kwargs):
+ """ # This docstring shows up in hub.help()
+ MiDaS DPT-Large model for monocular depth estimation
+ pretrained (bool): load pretrained weights into model
+ """
+
+ model = DPTDepthModel(
+ path=None,
+ backbone="vitl16_384",
+ non_negative=True,
+ )
+
+ if pretrained:
+ checkpoint = (
+ "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt"
+ )
+ state_dict = torch.hub.load_state_dict_from_url(
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
+ )
+ model.load_state_dict(state_dict)
+
+ return model
+
+def DPT_Hybrid(pretrained=True, **kwargs):
+ """ # This docstring shows up in hub.help()
+ MiDaS DPT-Hybrid model for monocular depth estimation
+ pretrained (bool): load pretrained weights into model
+ """
+
+ model = DPTDepthModel(
+ path=None,
+ backbone="vitb_rn50_384",
+ non_negative=True,
+ )
+
+ if pretrained:
+ checkpoint = (
+ "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt"
+ )
+ state_dict = torch.hub.load_state_dict_from_url(
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
+ )
+ model.load_state_dict(state_dict)
+
+ return model
+
+def MiDaS(pretrained=True, **kwargs):
+ """ # This docstring shows up in hub.help()
+ MiDaS v2.1 model for monocular depth estimation
+ pretrained (bool): load pretrained weights into model
+ """
+
+ model = MidasNet()
+
+ if pretrained:
+ checkpoint = (
+ "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt"
+ )
+ state_dict = torch.hub.load_state_dict_from_url(
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
+ )
+ model.load_state_dict(state_dict)
+
+ return model
+
+def MiDaS_small(pretrained=True, **kwargs):
+ """ # This docstring shows up in hub.help()
+ MiDaS v2.1 small model for monocular depth estimation on resource-constrained devices
+ pretrained (bool): load pretrained weights into model
+ """
+
+ model = MidasNet_small(None, features=64, backbone="efficientnet_lite3", exportable=True, non_negative=True, blocks={'expand': True})
+
+ if pretrained:
+ checkpoint = (
+ "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt"
+ )
+ state_dict = torch.hub.load_state_dict_from_url(
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
+ )
+ model.load_state_dict(state_dict)
+
+ return model
+
+
+def transforms():
+ import cv2
+ from torchvision.transforms import Compose
+ from midas.transforms import Resize, NormalizeImage, PrepareForNet
+ from midas import transforms
+
+ transforms.default_transform = Compose(
+ [
+ lambda img: {"image": img / 255.0},
+ Resize(
+ 384,
+ 384,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method="upper_bound",
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ PrepareForNet(),
+ lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
+ ]
+ )
+
+ transforms.small_transform = Compose(
+ [
+ lambda img: {"image": img / 255.0},
+ Resize(
+ 256,
+ 256,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method="upper_bound",
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ PrepareForNet(),
+ lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
+ ]
+ )
+
+ transforms.dpt_transform = Compose(
+ [
+ lambda img: {"image": img / 255.0},
+ Resize(
+ 384,
+ 384,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method="minimal",
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ PrepareForNet(),
+ lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
+ ]
+ )
+
+ transforms.beit512_transform = Compose(
+ [
+ lambda img: {"image": img / 255.0},
+ Resize(
+ 512,
+ 512,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method="minimal",
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ PrepareForNet(),
+ lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
+ ]
+ )
+
+ transforms.swin384_transform = Compose(
+ [
+ lambda img: {"image": img / 255.0},
+ Resize(
+ 384,
+ 384,
+ resize_target=None,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=32,
+ resize_method="minimal",
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ PrepareForNet(),
+ lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
+ ]
+ )
+
+ transforms.swin256_transform = Compose(
+ [
+ lambda img: {"image": img / 255.0},
+ Resize(
+ 256,
+ 256,
+ resize_target=None,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=32,
+ resize_method="minimal",
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ PrepareForNet(),
+ lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
+ ]
+ )
+
+ transforms.levit_transform = Compose(
+ [
+ lambda img: {"image": img / 255.0},
+ Resize(
+ 224,
+ 224,
+ resize_target=None,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=32,
+ resize_method="minimal",
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ PrepareForNet(),
+ lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
+ ]
+ )
+
+ return transforms
diff --git a/models/monoD/zoeDepth/midas_c/midas/backbones/beit.py b/models/monoD/zoeDepth/midas_c/midas/backbones/beit.py
new file mode 100755
index 0000000000000000000000000000000000000000..8f25dd1ed756f4c468a2753299dc38b8eb101d82
--- /dev/null
+++ b/models/monoD/zoeDepth/midas_c/midas/backbones/beit.py
@@ -0,0 +1,196 @@
+import timm
+import torch
+import types
+
+import numpy as np
+import torch.nn.functional as F
+
+from .utils import forward_adapted_unflatten, make_backbone_default
+from timm.models.beit import gen_relative_position_index
+from torch.utils.checkpoint import checkpoint
+from typing import Optional
+
+
+def forward_beit(pretrained, x):
+ return forward_adapted_unflatten(pretrained, x, "forward_features")
+
+
+def patch_embed_forward(self, x):
+ """
+ Modification of timm.models.layers.patch_embed.py: PatchEmbed.forward to support arbitrary window sizes.
+ """
+ x = self.proj(x)
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ return x
+
+
+def _get_rel_pos_bias(self, window_size):
+ """
+ Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
+ """
+ old_height = 2 * self.window_size[0] - 1
+ old_width = 2 * self.window_size[1] - 1
+
+ new_height = 2 * window_size[0] - 1
+ new_width = 2 * window_size[1] - 1
+
+ old_relative_position_bias_table = self.relative_position_bias_table
+
+ old_num_relative_distance = self.num_relative_distance
+ new_num_relative_distance = new_height * new_width + 3
+
+ old_sub_table = old_relative_position_bias_table[:old_num_relative_distance - 3]
+
+ old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
+ new_sub_table = F.interpolate(old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear")
+ new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
+
+ new_relative_position_bias_table = torch.cat(
+ [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3:]])
+
+ key = str(window_size[1]) + "," + str(window_size[0])
+ if key not in self.relative_position_indices.keys():
+ self.relative_position_indices[key] = gen_relative_position_index(window_size)
+
+ relative_position_bias = new_relative_position_bias_table[
+ self.relative_position_indices[key].view(-1)].view(
+ window_size[0] * window_size[1] + 1,
+ window_size[0] * window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ return relative_position_bias.unsqueeze(0)
+
+
+def attention_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None):
+ """
+ Modification of timm.models.beit.py: Attention.forward to support arbitrary window sizes.
+ """
+ B, N, C = x.shape
+
+ qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ if self.relative_position_bias_table is not None:
+ window_size = tuple(np.array(resolution) // 16)
+ attn = attn + self._get_rel_pos_bias(window_size)
+ if shared_rel_pos_bias is not None:
+ attn = attn + shared_rel_pos_bias
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None):
+ """
+ Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes.
+ """
+ if self.gamma_1 is None:
+ x = x + self.drop_path(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ else:
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), resolution,
+ shared_rel_pos_bias=shared_rel_pos_bias))
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ return x
+
+
+def beit_forward_features(self, x):
+ """
+ Modification of timm.models.beit.py: Beit.forward_features to support arbitrary window sizes.
+ """
+ resolution = x.shape[2:]
+
+ x = self.patch_embed(x)
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+ for blk in self.blocks:
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
+ else:
+ x = blk(x, resolution, shared_rel_pos_bias=rel_pos_bias)
+ x = self.norm(x)
+ return x
+
+
+def _make_beit_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[0, 4, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+ start_index_readout=1,
+):
+ backbone = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index,
+ start_index_readout)
+
+ backbone.model.patch_embed.forward = types.MethodType(patch_embed_forward, backbone.model.patch_embed)
+ backbone.model.forward_features = types.MethodType(beit_forward_features, backbone.model)
+
+ for block in backbone.model.blocks:
+ attn = block.attn
+ attn._get_rel_pos_bias = types.MethodType(_get_rel_pos_bias, attn)
+ attn.forward = types.MethodType(attention_forward, attn)
+ attn.relative_position_indices = {}
+
+ block.forward = types.MethodType(block_forward, block)
+
+ return backbone
+
+
+def _make_pretrained_beitl16_512(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("beit_large_patch16_512", pretrained=pretrained)
+
+ hooks = [5, 11, 17, 23] if hooks is None else hooks
+
+ features = [256, 512, 1024, 1024]
+
+ return _make_beit_backbone(
+ model,
+ features=features,
+ size=[512, 512],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ )
+
+
+def _make_pretrained_beitl16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("beit_large_patch16_384", pretrained=pretrained)
+
+ hooks = [5, 11, 17, 23] if hooks is None else hooks
+ return _make_beit_backbone(
+ model,
+ features=[256, 512, 1024, 1024],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ )
+
+
+def _make_pretrained_beitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("beit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks is None else hooks
+ return _make_beit_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ hooks=hooks,
+ use_readout=use_readout,
+ )
diff --git a/models/monoD/zoeDepth/midas_c/midas/backbones/levit.py b/models/monoD/zoeDepth/midas_c/midas/backbones/levit.py
new file mode 100755
index 0000000000000000000000000000000000000000..6d023a98702a0451806d26f33f8bccf931814f10
--- /dev/null
+++ b/models/monoD/zoeDepth/midas_c/midas/backbones/levit.py
@@ -0,0 +1,106 @@
+import timm
+import torch
+import torch.nn as nn
+import numpy as np
+
+from .utils import activations, get_activation, Transpose
+
+
+def forward_levit(pretrained, x):
+ pretrained.model.forward_features(x)
+
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+
+ layer_1 = pretrained.act_postprocess1(layer_1)
+ layer_2 = pretrained.act_postprocess2(layer_2)
+ layer_3 = pretrained.act_postprocess3(layer_3)
+
+ return layer_1, layer_2, layer_3
+
+
+def _make_levit_backbone(
+ model,
+ hooks=[3, 11, 21],
+ patch_grid=[14, 14]
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+
+ pretrained.activations = activations
+
+ patch_grid_size = np.array(patch_grid, dtype=int)
+
+ pretrained.act_postprocess1 = nn.Sequential(
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist()))
+ )
+ pretrained.act_postprocess3 = nn.Sequential(
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist()))
+ )
+
+ return pretrained
+
+
+class ConvTransposeNorm(nn.Sequential):
+ """
+ Modification of
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm
+ such that ConvTranspose2d is used instead of Conv2d.
+ """
+
+ def __init__(
+ self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1,
+ groups=1, bn_weight_init=1):
+ super().__init__()
+ self.add_module('c',
+ nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False))
+ self.add_module('bn', nn.BatchNorm2d(out_chs))
+
+ nn.init.constant_(self.bn.weight, bn_weight_init)
+
+ @torch.no_grad()
+ def fuse(self):
+ c, bn = self._modules.values()
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
+ w = c.weight * w[:, None, None, None]
+ b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
+ m = nn.ConvTranspose2d(
+ w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
+ padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
+ m.weight.data.copy_(w)
+ m.bias.data.copy_(b)
+ return m
+
+
+def stem_b4_transpose(in_chs, out_chs, activation):
+ """
+ Modification of
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16
+ such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half.
+ """
+ return nn.Sequential(
+ ConvTransposeNorm(in_chs, out_chs, 3, 2, 1),
+ activation(),
+ ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1),
+ activation())
+
+
+def _make_pretrained_levit_384(pretrained, hooks=None):
+ model = timm.create_model("levit_384", pretrained=pretrained)
+
+ hooks = [3, 11, 21] if hooks == None else hooks
+ return _make_levit_backbone(
+ model,
+ hooks=hooks
+ )
diff --git a/models/monoD/zoeDepth/midas_c/midas/backbones/next_vit.py b/models/monoD/zoeDepth/midas_c/midas/backbones/next_vit.py
new file mode 100755
index 0000000000000000000000000000000000000000..8afdd8b743b5ab023a359dc3b721e601b1a40d11
--- /dev/null
+++ b/models/monoD/zoeDepth/midas_c/midas/backbones/next_vit.py
@@ -0,0 +1,39 @@
+import timm
+
+import torch.nn as nn
+
+from pathlib import Path
+from .utils import activations, forward_default, get_activation
+
+from ..external.next_vit.classification.nextvit import *
+
+
+def forward_next_vit(pretrained, x):
+ return forward_default(pretrained, x, "forward")
+
+
+def _make_next_vit_backbone(
+ model,
+ hooks=[2, 6, 36, 39],
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ return pretrained
+
+
+def _make_pretrained_next_vit_large_6m(hooks=None):
+ model = timm.create_model("nextvit_large")
+
+ hooks = [2, 6, 36, 39] if hooks == None else hooks
+ return _make_next_vit_backbone(
+ model,
+ hooks=hooks,
+ )
diff --git a/models/monoD/zoeDepth/midas_c/midas/backbones/swin.py b/models/monoD/zoeDepth/midas_c/midas/backbones/swin.py
new file mode 100755
index 0000000000000000000000000000000000000000..f8c71367e3e78b087f80b2ab3e2f495a9c372f1a
--- /dev/null
+++ b/models/monoD/zoeDepth/midas_c/midas/backbones/swin.py
@@ -0,0 +1,13 @@
+import timm
+
+from .swin_common import _make_swin_backbone
+
+
+def _make_pretrained_swinl12_384(pretrained, hooks=None):
+ model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained)
+
+ hooks = [1, 1, 17, 1] if hooks == None else hooks
+ return _make_swin_backbone(
+ model,
+ hooks=hooks
+ )
diff --git a/models/monoD/zoeDepth/midas_c/midas/backbones/swin2.py b/models/monoD/zoeDepth/midas_c/midas/backbones/swin2.py
new file mode 100755
index 0000000000000000000000000000000000000000..ce4c8f1d6fc1807a207dc6b9a261c6f7b14a87a3
--- /dev/null
+++ b/models/monoD/zoeDepth/midas_c/midas/backbones/swin2.py
@@ -0,0 +1,34 @@
+import timm
+
+from .swin_common import _make_swin_backbone
+
+
+def _make_pretrained_swin2l24_384(pretrained, hooks=None):
+ model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained)
+
+ hooks = [1, 1, 17, 1] if hooks == None else hooks
+ return _make_swin_backbone(
+ model,
+ hooks=hooks
+ )
+
+
+def _make_pretrained_swin2b24_384(pretrained, hooks=None):
+ model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained)
+
+ hooks = [1, 1, 17, 1] if hooks == None else hooks
+ return _make_swin_backbone(
+ model,
+ hooks=hooks
+ )
+
+
+def _make_pretrained_swin2t16_256(pretrained, hooks=None):
+ model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained)
+
+ hooks = [1, 1, 5, 1] if hooks == None else hooks
+ return _make_swin_backbone(
+ model,
+ hooks=hooks,
+ patch_grid=[64, 64]
+ )
diff --git a/models/monoD/zoeDepth/midas_c/midas/backbones/swin_common.py b/models/monoD/zoeDepth/midas_c/midas/backbones/swin_common.py
new file mode 100755
index 0000000000000000000000000000000000000000..94d63d408f18511179d90b3ac6f697385d1e556d
--- /dev/null
+++ b/models/monoD/zoeDepth/midas_c/midas/backbones/swin_common.py
@@ -0,0 +1,52 @@
+import torch
+
+import torch.nn as nn
+import numpy as np
+
+from .utils import activations, forward_default, get_activation, Transpose
+
+
+def forward_swin(pretrained, x):
+ return forward_default(pretrained, x)
+
+
+def _make_swin_backbone(
+ model,
+ hooks=[1, 1, 17, 1],
+ patch_grid=[96, 96]
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ if hasattr(model, "patch_grid"):
+ used_patch_grid = model.patch_grid
+ else:
+ used_patch_grid = patch_grid
+
+ patch_grid_size = np.array(used_patch_grid, dtype=int)
+
+ pretrained.act_postprocess1 = nn.Sequential(
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist()))
+ )
+ pretrained.act_postprocess3 = nn.Sequential(
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist()))
+ )
+ pretrained.act_postprocess4 = nn.Sequential(
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist()))
+ )
+
+ return pretrained
diff --git a/models/monoD/zoeDepth/midas_c/midas/backbones/utils.py b/models/monoD/zoeDepth/midas_c/midas/backbones/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..0558899dddcfccec5f01a764d4f21738eb612149
--- /dev/null
+++ b/models/monoD/zoeDepth/midas_c/midas/backbones/utils.py
@@ -0,0 +1,249 @@
+import torch
+
+import torch.nn as nn
+
+
+class Slice(nn.Module):
+ def __init__(self, start_index=1):
+ super(Slice, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ return x[:, self.start_index:]
+
+
+class AddReadout(nn.Module):
+ def __init__(self, start_index=1):
+ super(AddReadout, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ if self.start_index == 2:
+ readout = (x[:, 0] + x[:, 1]) / 2
+ else:
+ readout = x[:, 0]
+ return x[:, self.start_index:] + readout.unsqueeze(1)
+
+
+class ProjectReadout(nn.Module):
+ def __init__(self, in_features, start_index=1):
+ super(ProjectReadout, self).__init__()
+ self.start_index = start_index
+
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+
+ def forward(self, x):
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:])
+ features = torch.cat((x[:, self.start_index:], readout), -1)
+
+ return self.project(features)
+
+
+class Transpose(nn.Module):
+ def __init__(self, dim0, dim1):
+ super(Transpose, self).__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, x):
+ x = x.transpose(self.dim0, self.dim1)
+ return x
+
+
+activations = {}
+
+
+def get_activation(name):
+ def hook(model, input, output):
+ activations[name] = output
+
+ return hook
+
+
+def forward_default(pretrained, x, function_name="forward_features"):
+ exec(f"pretrained.model.{function_name}(x)")
+
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+ layer_4 = pretrained.activations["4"]
+
+ if hasattr(pretrained, "act_postprocess1"):
+ layer_1 = pretrained.act_postprocess1(layer_1)
+ if hasattr(pretrained, "act_postprocess2"):
+ layer_2 = pretrained.act_postprocess2(layer_2)
+ if hasattr(pretrained, "act_postprocess3"):
+ layer_3 = pretrained.act_postprocess3(layer_3)
+ if hasattr(pretrained, "act_postprocess4"):
+ layer_4 = pretrained.act_postprocess4(layer_4)
+
+ return layer_1, layer_2, layer_3, layer_4
+
+
+def forward_adapted_unflatten(pretrained, x, function_name="forward_features"):
+ b, c, h, w = x.shape
+
+ exec(f"glob = pretrained.model.{function_name}(x)")
+
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+ layer_4 = pretrained.activations["4"]
+
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+
+ unflatten = nn.Sequential(
+ nn.Unflatten(
+ 2,
+ torch.Size(
+ [
+ h // pretrained.model.patch_size[1],
+ w // pretrained.model.patch_size[0],
+ ]
+ ),
+ )
+ )
+
+ if layer_1.ndim == 3:
+ layer_1 = unflatten(layer_1)
+ if layer_2.ndim == 3:
+ layer_2 = unflatten(layer_2)
+ if layer_3.ndim == 3:
+ layer_3 = unflatten(layer_3)
+ if layer_4.ndim == 3:
+ layer_4 = unflatten(layer_4)
+
+ layer_1 = pretrained.act_postprocess1[3: len(pretrained.act_postprocess1)](layer_1)
+ layer_2 = pretrained.act_postprocess2[3: len(pretrained.act_postprocess2)](layer_2)
+ layer_3 = pretrained.act_postprocess3[3: len(pretrained.act_postprocess3)](layer_3)
+ layer_4 = pretrained.act_postprocess4[3: len(pretrained.act_postprocess4)](layer_4)
+
+ return layer_1, layer_2, layer_3, layer_4
+
+
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+ if use_readout == "ignore":
+ readout_oper = [Slice(start_index)] * len(features)
+ elif use_readout == "add":
+ readout_oper = [AddReadout(start_index)] * len(features)
+ elif use_readout == "project":
+ readout_oper = [
+ ProjectReadout(vit_features, start_index) for out_feat in features
+ ]
+ else:
+ assert (
+ False
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+
+ return readout_oper
+
+
+def make_backbone_default(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[2, 5, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+ start_index_readout=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index_readout)
+
+ # 32, 48, 136, 384
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ return pretrained
diff --git a/models/monoD/zoeDepth/midas_c/midas/backbones/vit.py b/models/monoD/zoeDepth/midas_c/midas/backbones/vit.py
new file mode 100755
index 0000000000000000000000000000000000000000..413f9693bd4548342280e329c9128c1a52cea920
--- /dev/null
+++ b/models/monoD/zoeDepth/midas_c/midas/backbones/vit.py
@@ -0,0 +1,221 @@
+import torch
+import torch.nn as nn
+import timm
+import types
+import math
+import torch.nn.functional as F
+
+from .utils import (activations, forward_adapted_unflatten, get_activation, get_readout_oper,
+ make_backbone_default, Transpose)
+
+
+def forward_vit(pretrained, x):
+ return forward_adapted_unflatten(pretrained, x, "forward_flex")
+
+
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+ posemb_tok, posemb_grid = (
+ posemb[:, : self.start_index],
+ posemb[0, self.start_index:],
+ )
+
+ gs_old = int(math.sqrt(len(posemb_grid)))
+
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+ return posemb
+
+
+def forward_flex(self, x):
+ b, c, h, w = x.shape
+
+ pos_embed = self._resize_pos_embed(
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
+ )
+
+ B = x.shape[0]
+
+ if hasattr(self.patch_embed, "backbone"):
+ x = self.patch_embed.backbone(x)
+ if isinstance(x, (list, tuple)):
+ x = x[-1] # last feature if backbone outputs list/tuple of features
+
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
+
+ if getattr(self, "dist_token", None) is not None:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+ else:
+ if self.no_embed_class:
+ x = x + pos_embed
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ if not self.no_embed_class:
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+
+ return x
+
+
+def _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[2, 5, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+ start_index_readout=1,
+):
+ pretrained = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index,
+ start_index_readout)
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[256, 512, 1024, 1024],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ )
+
+
+def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=[0, 1, 8, 11],
+ vit_features=768,
+ patch_size=[16, 16],
+ number_stages=2,
+ use_vit_only=False,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+
+ used_number_stages = 0 if use_vit_only else number_stages
+ for s in range(used_number_stages):
+ pretrained.model.patch_embed.backbone.stages[s].register_forward_hook(
+ get_activation(str(s + 1))
+ )
+ for s in range(used_number_stages, 4):
+ pretrained.model.blocks[hooks[s]].register_forward_hook(get_activation(str(s + 1)))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ for s in range(used_number_stages):
+ value = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity())
+ exec(f"pretrained.act_postprocess{s + 1}=value")
+ for s in range(used_number_stages, 4):
+ if s < number_stages:
+ final_layer = nn.ConvTranspose2d(
+ in_channels=features[s],
+ out_channels=features[s],
+ kernel_size=4 // (2 ** s),
+ stride=4 // (2 ** s),
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ )
+ elif s > number_stages:
+ final_layer = nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ )
+ else:
+ final_layer = None
+
+ layers = [
+ readout_oper[s],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[s],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ ]
+ if final_layer is not None:
+ layers.append(final_layer)
+
+ value = nn.Sequential(*layers)
+ exec(f"pretrained.act_postprocess{s + 1}=value")
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = patch_size
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitb_rn50_384(
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
+):
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
+ return _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
diff --git a/models/monoD/zoeDepth/midas_c/midas/base_model.py b/models/monoD/zoeDepth/midas_c/midas/base_model.py
new file mode 100755
index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd
--- /dev/null
+++ b/models/monoD/zoeDepth/midas_c/midas/base_model.py
@@ -0,0 +1,16 @@
+import torch
+
+
+class BaseModel(torch.nn.Module):
+ def load(self, path):
+ """Load model from file.
+
+ Args:
+ path (str): file path
+ """
+ parameters = torch.load(path, map_location=torch.device('cpu'))
+
+ if "optimizer" in parameters:
+ parameters = parameters["model"]
+
+ self.load_state_dict(parameters)
diff --git a/models/monoD/zoeDepth/midas_c/midas/blocks.py b/models/monoD/zoeDepth/midas_c/midas/blocks.py
new file mode 100755
index 0000000000000000000000000000000000000000..6d87a00680bb6ed9a6d7c3043ea30a1e90361794
--- /dev/null
+++ b/models/monoD/zoeDepth/midas_c/midas/blocks.py
@@ -0,0 +1,439 @@
+import torch
+import torch.nn as nn
+
+from .backbones.beit import (
+ _make_pretrained_beitl16_512,
+ _make_pretrained_beitl16_384,
+ _make_pretrained_beitb16_384,
+ forward_beit,
+)
+from .backbones.swin_common import (
+ forward_swin,
+)
+from .backbones.swin2 import (
+ _make_pretrained_swin2l24_384,
+ _make_pretrained_swin2b24_384,
+ _make_pretrained_swin2t16_256,
+)
+from .backbones.swin import (
+ _make_pretrained_swinl12_384,
+)
+from .backbones.levit import (
+ _make_pretrained_levit_384,
+ forward_levit,
+)
+from .backbones.vit import (
+ _make_pretrained_vitb_rn50_384,
+ _make_pretrained_vitl16_384,
+ _make_pretrained_vitb16_384,
+ forward_vit,
+)
+
+def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None,
+ use_vit_only=False, use_readout="ignore", in_features=[96, 256, 512, 1024]):
+ if backbone == "beitl16_512":
+ pretrained = _make_pretrained_beitl16_512(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # BEiT_512-L (backbone)
+ elif backbone == "beitl16_384":
+ pretrained = _make_pretrained_beitl16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # BEiT_384-L (backbone)
+ elif backbone == "beitb16_384":
+ pretrained = _make_pretrained_beitb16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # BEiT_384-B (backbone)
+ elif backbone == "swin2l24_384":
+ pretrained = _make_pretrained_swin2l24_384(
+ use_pretrained, hooks=hooks
+ )
+ scratch = _make_scratch(
+ [192, 384, 768, 1536], features, groups=groups, expand=expand
+ ) # Swin2-L/12to24 (backbone)
+ elif backbone == "swin2b24_384":
+ pretrained = _make_pretrained_swin2b24_384(
+ use_pretrained, hooks=hooks
+ )
+ scratch = _make_scratch(
+ [128, 256, 512, 1024], features, groups=groups, expand=expand
+ ) # Swin2-B/12to24 (backbone)
+ elif backbone == "swin2t16_256":
+ pretrained = _make_pretrained_swin2t16_256(
+ use_pretrained, hooks=hooks
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # Swin2-T/16 (backbone)
+ elif backbone == "swinl12_384":
+ pretrained = _make_pretrained_swinl12_384(
+ use_pretrained, hooks=hooks
+ )
+ scratch = _make_scratch(
+ [192, 384, 768, 1536], features, groups=groups, expand=expand
+ ) # Swin-L/12 (backbone)
+ elif backbone == "next_vit_large_6m":
+ from .backbones.next_vit import _make_pretrained_next_vit_large_6m
+ pretrained = _make_pretrained_next_vit_large_6m(hooks=hooks)
+ scratch = _make_scratch(
+ in_features, features, groups=groups, expand=expand
+ ) # Next-ViT-L on ImageNet-1K-6M (backbone)
+ elif backbone == "levit_384":
+ pretrained = _make_pretrained_levit_384(
+ use_pretrained, hooks=hooks
+ )
+ scratch = _make_scratch(
+ [384, 512, 768], features, groups=groups, expand=expand
+ ) # LeViT 384 (backbone)
+ elif backbone == "vitl16_384":
+ pretrained = _make_pretrained_vitl16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb_rn50_384":
+ pretrained = _make_pretrained_vitb_rn50_384(
+ use_pretrained,
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
+ scratch = _make_scratch(
+ [256, 512, 768, 768], features, groups=groups, expand=expand
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb16_384":
+ pretrained = _make_pretrained_vitb16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
+ elif backbone == "resnext101_wsl":
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
+ elif backbone == "efficientnet_lite3":
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
+ else:
+ print(f"Backbone '{backbone}' not implemented")
+ assert False
+
+ return pretrained, scratch
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape*2
+ out_shape3 = out_shape*4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape*8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+
+ return scratch
+
+
+def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
+ efficientnet = torch.hub.load(
+ "rwightman/gen-efficientnet-pytorch",
+ "tf_efficientnet_lite3",
+ pretrained=use_pretrained,
+ exportable=exportable
+ )
+ return _make_efficientnet_backbone(efficientnet)
+
+
+def _make_efficientnet_backbone(effnet):
+ pretrained = nn.Module()
+
+ pretrained.layer1 = nn.Sequential(
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
+ )
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
+
+ return pretrained
+
+
+def _make_resnet_backbone(resnet):
+ pretrained = nn.Module()
+ pretrained.layer1 = nn.Sequential(
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+ )
+
+ pretrained.layer2 = resnet.layer2
+ pretrained.layer3 = resnet.layer3
+ pretrained.layer4 = resnet.layer4
+
+ return pretrained
+
+
+def _make_pretrained_resnext101_wsl(use_pretrained):
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+ return _make_resnet_backbone(resnet)
+
+
+
+class Interpolate(nn.Module):
+ """Interpolation module.
+ """
+
+ def __init__(self, scale_factor, mode, align_corners=False):
+ """Init.
+
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: interpolated data
+ """
+
+ x = self.interp(
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
+ )
+
+ return x
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+ out = self.relu(x)
+ out = self.conv1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+
+ return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.resConfUnit1 = ResidualConvUnit(features)
+ self.resConfUnit2 = ResidualConvUnit(features)
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ output += self.resConfUnit1(xs[1])
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=True
+ )
+
+ return output
+
+
+
+
+class ResidualConvUnit_custom(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ if self.bn==True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn==True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn==True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+ # return out + x
+
+
+class FeatureFusionBlock_custom(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock_custom, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand==True:
+ out_features = features//2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ self.size=size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+ # output += res
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = nn.functional.interpolate(
+ output, **modifier, mode="bilinear", align_corners=self.align_corners
+ )
+
+ output = self.out_conv(output)
+
+ return output
+
diff --git a/models/monoD/zoeDepth/midas_c/midas/dpt_depth.py b/models/monoD/zoeDepth/midas_c/midas/dpt_depth.py
new file mode 100755
index 0000000000000000000000000000000000000000..456483dbb806b691dd31eb976411e8f2e0f15a55
--- /dev/null
+++ b/models/monoD/zoeDepth/midas_c/midas/dpt_depth.py
@@ -0,0 +1,166 @@
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import (
+ FeatureFusionBlock_custom,
+ Interpolate,
+ _make_encoder,
+ forward_beit,
+ forward_swin,
+ forward_levit,
+ forward_vit,
+)
+from .backbones.levit import stem_b4_transpose
+from timm.layers import get_act_layer
+
+
+def _make_fusion_block(features, use_bn, size = None):
+ return FeatureFusionBlock_custom(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ size=size,
+ )
+
+
+class DPT(BaseModel):
+ def __init__(
+ self,
+ head,
+ features=256,
+ backbone="vitb_rn50_384",
+ readout="project",
+ channels_last=False,
+ use_bn=False,
+ **kwargs
+ ):
+
+ super(DPT, self).__init__()
+
+ self.channels_last = channels_last
+
+ # For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the
+ # hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments.
+ hooks = {
+ "beitl16_512": [5, 11, 17, 23],
+ "beitl16_384": [5, 11, 17, 23],
+ "beitb16_384": [2, 5, 8, 11],
+ "swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1]
+ "swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
+ "swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1]
+ "swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
+ "next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39]
+ "levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21]
+ "vitb_rn50_384": [0, 1, 8, 11],
+ "vitb16_384": [2, 5, 8, 11],
+ "vitl16_384": [5, 11, 17, 23],
+ }[backbone]
+
+ if "next_vit" in backbone:
+ in_features = {
+ "next_vit_large_6m": [96, 256, 512, 1024],
+ }[backbone]
+ else:
+ in_features = None
+
+ # Instantiate backbone and reassemble blocks
+ self.pretrained, self.scratch = _make_encoder(
+ backbone,
+ features,
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
+ groups=1,
+ expand=False,
+ exportable=False,
+ hooks=hooks,
+ use_readout=readout,
+ in_features=in_features,
+ )
+
+ self.number_layers = len(hooks) if hooks is not None else 4
+ size_refinenet3 = None
+ self.scratch.stem_transpose = None
+
+ if "beit" in backbone:
+ self.forward_transformer = forward_beit
+ elif "swin" in backbone:
+ self.forward_transformer = forward_swin
+ elif "next_vit" in backbone:
+ from .backbones.next_vit import forward_next_vit
+ self.forward_transformer = forward_next_vit
+ elif "levit" in backbone:
+ self.forward_transformer = forward_levit
+ size_refinenet3 = 7
+ self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish"))
+ else:
+ self.forward_transformer = forward_vit
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3)
+ if self.number_layers >= 4:
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ self.scratch.output_conv = head
+
+
+ def forward(self, x):
+ if self.channels_last == True:
+ x.contiguous(memory_format=torch.channels_last)
+
+ layers = self.forward_transformer(self.pretrained, x)
+ if self.number_layers == 3:
+ layer_1, layer_2, layer_3 = layers
+ else:
+ layer_1, layer_2, layer_3, layer_4 = layers
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ if self.number_layers >= 4:
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ if self.number_layers == 3:
+ path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:])
+ else:
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ if self.scratch.stem_transpose is not None:
+ path_1 = self.scratch.stem_transpose(path_1)
+
+ out = self.scratch.output_conv(path_1)
+
+ return out
+
+
+class DPTDepthModel(DPT):
+ def __init__(self, path=None, non_negative=True, **kwargs):
+ features = kwargs["features"] if "features" in kwargs else 256
+ head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features
+ head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32
+ kwargs.pop("head_features_1", None)
+ kwargs.pop("head_features_2", None)
+
+ head = nn.Sequential(
+ nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ super().__init__(head, **kwargs)
+
+ if path is not None:
+ self.load(path)
+
+ def forward(self, x):
+ return super().forward(x).squeeze(dim=1)
diff --git a/models/monoD/zoeDepth/midas_c/midas/midas_net.py b/models/monoD/zoeDepth/midas_c/midas/midas_net.py
new file mode 100755
index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f
--- /dev/null
+++ b/models/monoD/zoeDepth/midas_c/midas/midas_net.py
@@ -0,0 +1,76 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
+
+
+class MidasNet(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=256, non_negative=True):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet, self).__init__()
+
+ use_pretrained = False if path is None else True
+
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
+
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
diff --git a/models/monoD/zoeDepth/midas_c/midas/midas_net_custom.py b/models/monoD/zoeDepth/midas_c/midas/midas_net_custom.py
new file mode 100755
index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c
--- /dev/null
+++ b/models/monoD/zoeDepth/midas_c/midas/midas_net_custom.py
@@ -0,0 +1,128 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
+
+
+class MidasNet_small(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
+ blocks={'expand': True}):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet_small, self).__init__()
+
+ use_pretrained = False if path else True
+
+ self.channels_last = channels_last
+ self.blocks = blocks
+ self.backbone = backbone
+
+ self.groups = 1
+
+ features1=features
+ features2=features
+ features3=features
+ features4=features
+ self.expand = False
+ if "expand" in self.blocks and self.blocks['expand'] == True:
+ self.expand = True
+ features1=features
+ features2=features*2
+ features3=features*4
+ features4=features*8
+
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
+
+ self.scratch.activation = nn.ReLU(False)
+
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
+
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
+ self.scratch.activation,
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+ if self.channels_last==True:
+ print("self.channels_last = ", self.channels_last)
+ x.contiguous(memory_format=torch.channels_last)
+
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
+
+
+
+def fuse_model(m):
+ prev_previous_type = nn.Identity()
+ prev_previous_name = ''
+ previous_type = nn.Identity()
+ previous_name = ''
+ for name, module in m.named_modules():
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
+ # print("FUSED ", prev_previous_name, previous_name, name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
+ # print("FUSED ", prev_previous_name, previous_name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
+ # print("FUSED ", previous_name, name)
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
+
+ prev_previous_type = previous_type
+ prev_previous_name = previous_name
+ previous_type = type(module)
+ previous_name = name
\ No newline at end of file
diff --git a/models/monoD/zoeDepth/midas_c/midas/model_loader.py b/models/monoD/zoeDepth/midas_c/midas/model_loader.py
new file mode 100755
index 0000000000000000000000000000000000000000..f1cd1f2d43054bfd3d650587c7b2ed35f1347c9e
--- /dev/null
+++ b/models/monoD/zoeDepth/midas_c/midas/model_loader.py
@@ -0,0 +1,242 @@
+import cv2
+import torch
+
+from midas.dpt_depth import DPTDepthModel
+from midas.midas_net import MidasNet
+from midas.midas_net_custom import MidasNet_small
+from midas.transforms import Resize, NormalizeImage, PrepareForNet
+
+from torchvision.transforms import Compose
+
+default_models = {
+ "dpt_beit_large_512": "weights/dpt_beit_large_512.pt",
+ "dpt_beit_large_384": "weights/dpt_beit_large_384.pt",
+ "dpt_beit_base_384": "weights/dpt_beit_base_384.pt",
+ "dpt_swin2_large_384": "weights/dpt_swin2_large_384.pt",
+ "dpt_swin2_base_384": "weights/dpt_swin2_base_384.pt",
+ "dpt_swin2_tiny_256": "weights/dpt_swin2_tiny_256.pt",
+ "dpt_swin_large_384": "weights/dpt_swin_large_384.pt",
+ "dpt_next_vit_large_384": "weights/dpt_next_vit_large_384.pt",
+ "dpt_levit_224": "weights/dpt_levit_224.pt",
+ "dpt_large_384": "weights/dpt_large_384.pt",
+ "dpt_hybrid_384": "weights/dpt_hybrid_384.pt",
+ "midas_v21_384": "weights/midas_v21_384.pt",
+ "midas_v21_small_256": "weights/midas_v21_small_256.pt",
+ "openvino_midas_v21_small_256": "weights/openvino_midas_v21_small_256.xml",
+}
+
+
+def load_model(device, model_path, model_type="dpt_large_384", optimize=True, height=None, square=False):
+ """Load the specified network.
+
+ Args:
+ device (device): the torch device used
+ model_path (str): path to saved model
+ model_type (str): the type of the model to be loaded
+ optimize (bool): optimize the model to half-integer on CUDA?
+ height (int): inference encoder image height
+ square (bool): resize to a square resolution?
+
+ Returns:
+ The loaded network, the transform which prepares images as input to the network and the dimensions of the
+ network input
+ """
+ if "openvino" in model_type:
+ from openvino.runtime import Core
+
+ keep_aspect_ratio = not square
+
+ if model_type == "dpt_beit_large_512":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="beitl16_512",
+ non_negative=True,
+ )
+ net_w, net_h = 512, 512
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_beit_large_384":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="beitl16_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_beit_base_384":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="beitb16_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_swin2_large_384":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="swin2l24_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ keep_aspect_ratio = False
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_swin2_base_384":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="swin2b24_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ keep_aspect_ratio = False
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_swin2_tiny_256":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="swin2t16_256",
+ non_negative=True,
+ )
+ net_w, net_h = 256, 256
+ keep_aspect_ratio = False
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_swin_large_384":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="swinl12_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ keep_aspect_ratio = False
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_next_vit_large_384":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="next_vit_large_6m",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ # We change the notation from dpt_levit_224 (MiDaS notation) to levit_384 (timm notation) here, where the 224 refers
+ # to the resolution 224x224 used by LeViT and 384 is the first entry of the embed_dim, see _cfg and model_cfgs of
+ # https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/levit.py
+ # (commit id: 927f031293a30afb940fff0bee34b85d9c059b0e)
+ elif model_type == "dpt_levit_224":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="levit_384",
+ non_negative=True,
+ head_features_1=64,
+ head_features_2=8,
+ )
+ net_w, net_h = 224, 224
+ keep_aspect_ratio = False
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_large_384":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitl16_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid_384":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitb_rn50_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21_384":
+ model = MidasNet(model_path, non_negative=True)
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ elif model_type == "midas_v21_small_256":
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
+ non_negative=True, blocks={'expand': True})
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ elif model_type == "openvino_midas_v21_small_256":
+ ie = Core()
+ uncompiled_model = ie.read_model(model=model_path)
+ model = ie.compile_model(uncompiled_model, "CPU")
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ else:
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
+ assert False
+
+ if not "openvino" in model_type:
+ print("Model loaded, number of parameters = {:.0f}M".format(sum(p.numel() for p in model.parameters()) / 1e6))
+ else:
+ print("Model loaded, optimized with OpenVINO")
+
+ if "openvino" in model_type:
+ keep_aspect_ratio = False
+
+ if height is not None:
+ net_w, net_h = height, height
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=keep_aspect_ratio,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ if not "openvino" in model_type:
+ model.eval()
+
+ if optimize and (device == torch.device("cuda")):
+ if not "openvino" in model_type:
+ model = model.to(memory_format=torch.channels_last)
+ model = model.half()
+ else:
+ print("Error: OpenVINO models are already optimized. No optimization to half-float possible.")
+ exit()
+
+ if not "openvino" in model_type:
+ model.to(device)
+
+ return model, transform, net_w, net_h
diff --git a/models/monoD/zoeDepth/midas_c/midas/transforms.py b/models/monoD/zoeDepth/midas_c/midas/transforms.py
new file mode 100755
index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9
--- /dev/null
+++ b/models/monoD/zoeDepth/midas_c/midas/transforms.py
@@ -0,0 +1,234 @@
+import numpy as np
+import cv2
+import math
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+ Args:
+ sample (dict): sample
+ size (tuple): image size
+
+ Returns:
+ tuple: new size
+ """
+ shape = list(sample["disparity"].shape)
+
+ if shape[0] >= size[0] and shape[1] >= size[1]:
+ return sample
+
+ scale = [0, 0]
+ scale[0] = size[0] / shape[0]
+ scale[1] = size[1] / shape[1]
+
+ scale = max(scale)
+
+ shape[0] = math.ceil(scale * shape[0])
+ shape[1] = math.ceil(scale * shape[1])
+
+ # resize
+ sample["image"] = cv2.resize(
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+ )
+
+ sample["disparity"] = cv2.resize(
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ tuple(shape[::-1]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return tuple(shape)
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented"
+ )
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, min_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, min_val=self.__width
+ )
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, max_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, max_val=self.__width
+ )
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(
+ sample["image"].shape[1], sample["image"].shape[0]
+ )
+
+ # resize sample
+ sample["image"] = cv2.resize(
+ sample["image"],
+ (width, height),
+ interpolation=self.__image_interpolation_method,
+ )
+
+ if self.__resize_target:
+ if "disparity" in sample:
+ sample["disparity"] = cv2.resize(
+ sample["disparity"],
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+ )
+
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ if "disparity" in sample:
+ disparity = sample["disparity"].astype(np.float32)
+ sample["disparity"] = np.ascontiguousarray(disparity)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ return sample
diff --git a/models/monoD/zoeDepth/models/__init__.py b/models/monoD/zoeDepth/models/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9
--- /dev/null
+++ b/models/monoD/zoeDepth/models/__init__.py
@@ -0,0 +1,24 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
diff --git a/models/monoD/zoeDepth/models/base_models/__init__.py b/models/monoD/zoeDepth/models/base_models/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9
--- /dev/null
+++ b/models/monoD/zoeDepth/models/base_models/__init__.py
@@ -0,0 +1,24 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
diff --git a/models/monoD/zoeDepth/models/base_models/midas.py b/models/monoD/zoeDepth/models/base_models/midas.py
new file mode 100755
index 0000000000000000000000000000000000000000..b32478d3e4aedfbfbb6d31e5e6b43176da12f0e1
--- /dev/null
+++ b/models/monoD/zoeDepth/models/base_models/midas.py
@@ -0,0 +1,382 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import torch
+import torch.nn as nn
+import numpy as np
+from torchvision.transforms import Normalize
+import importlib
+
+def denormalize(x):
+ """Reverses the imagenet normalization applied to the input.
+
+ Args:
+ x (torch.Tensor - shape(N,3,H,W)): input tensor
+
+ Returns:
+ torch.Tensor - shape(N,3,H,W): Denormalized input
+ """
+ mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
+ std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
+ return x * std + mean
+
+def get_activation(name, bank):
+ def hook(model, input, output):
+ bank[name] = output
+ return hook
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ ):
+ """Init.
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ print("Params passed to Resize transform:")
+ print("\twidth: ", width)
+ print("\theight: ", height)
+ print("\tresize_target: ", resize_target)
+ print("\tkeep_aspect_ratio: ", keep_aspect_ratio)
+ print("\tensure_multiple_of: ", ensure_multiple_of)
+ print("\tresize_method: ", resize_method)
+
+ self.__width = width
+ self.__height = height
+
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of)
+ * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of)
+ * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented"
+ )
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, min_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, min_val=self.__width
+ )
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, max_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, max_val=self.__width
+ )
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, x):
+ width, height = self.get_size(*x.shape[-2:][::-1])
+ return nn.functional.interpolate(x, (int(height), int(width)), mode='bilinear', align_corners=True)
+
+class PrepForMidas(object):
+ def __init__(self, resize_mode="minimal", keep_aspect_ratio=True, img_size=384, do_resize=True):
+ if isinstance(img_size, int):
+ img_size = (img_size, img_size)
+ net_h, net_w = img_size
+ self.normalization = Normalize(
+ mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+ self.resizer = Resize(net_w, net_h, keep_aspect_ratio=keep_aspect_ratio, ensure_multiple_of=32, resize_method=resize_mode) \
+ if do_resize else nn.Identity()
+
+ def __call__(self, x):
+ return self.normalization(self.resizer(x))
+
+
+class MidasCore(nn.Module):
+ def __init__(self, midas, trainable=False, fetch_features=True, layer_names=('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'), freeze_bn=False, keep_aspect_ratio=True,
+ img_size=384, **kwargs):
+ """Midas Base model used for multi-scale feature extraction.
+
+ Args:
+ midas (torch.nn.Module): Midas model.
+ trainable (bool, optional): Train midas model. Defaults to False.
+ fetch_features (bool, optional): Extract multi-scale features. Defaults to True.
+ layer_names (tuple, optional): Layers used for feature extraction. Order = (head output features, last layer features, ...decoder features). Defaults to ('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1').
+ freeze_bn (bool, optional): Freeze BatchNorm. Generally results in better finetuning performance. Defaults to False.
+ keep_aspect_ratio (bool, optional): Keep the aspect ratio of input images while resizing. Defaults to True.
+ img_size (int, tuple, optional): Input resolution. Defaults to 384.
+ """
+ super().__init__()
+ self.core = midas
+ self.output_channels = None
+ self.core_out = {}
+ self.trainable = trainable
+ self.fetch_features = fetch_features
+ # midas.scratch.output_conv = nn.Identity()
+ self.handles = []
+ # self.layer_names = ['out_conv','l4_rn', 'r4', 'r3', 'r2', 'r1']
+ self.layer_names = layer_names
+
+ self.set_trainable(trainable)
+ self.set_fetch_features(fetch_features)
+
+ self.prep = PrepForMidas(keep_aspect_ratio=keep_aspect_ratio,
+ img_size=img_size, do_resize=kwargs.get('do_resize', True))
+
+ if freeze_bn:
+ self.freeze_bn()
+
+ def set_trainable(self, trainable):
+ self.trainable = trainable
+ if trainable:
+ self.unfreeze()
+ else:
+ self.freeze()
+ return self
+
+ def set_fetch_features(self, fetch_features):
+ self.fetch_features = fetch_features
+ if fetch_features:
+ if len(self.handles) == 0:
+ self.attach_hooks(self.core)
+ else:
+ self.remove_hooks()
+ return self
+
+ def freeze(self):
+ for p in self.parameters():
+ p.requires_grad = False
+ self.trainable = False
+ return self
+
+ def unfreeze(self):
+ for p in self.parameters():
+ p.requires_grad = True
+ self.trainable = True
+ return self
+
+ def freeze_bn(self):
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ return self
+
+ def forward(self, x, denorm=False, return_rel_depth=False):
+ with torch.no_grad():
+ if denorm:
+ x = denormalize(x)
+ x = self.prep(x)
+ # print("Shape after prep: ", x.shape)
+
+ with torch.set_grad_enabled(self.trainable):
+
+ # print("Input size to Midascore", x.shape)
+ rel_depth = self.core(x)
+ # print("Output from midas shape", rel_depth.shape)
+ if not self.fetch_features:
+ return rel_depth
+ out = [self.core_out[k] for k in self.layer_names]
+
+ if return_rel_depth:
+ return rel_depth, out
+ return out
+
+ def get_rel_pos_params(self):
+ for name, p in self.core.pretrained.named_parameters():
+ if "relative_position" in name:
+ yield p
+
+ def get_enc_params_except_rel_pos(self):
+ for name, p in self.core.pretrained.named_parameters():
+ if "relative_position" not in name:
+ yield p
+
+ def freeze_encoder(self, freeze_rel_pos=False):
+ if freeze_rel_pos:
+ for p in self.core.pretrained.parameters():
+ p.requires_grad = False
+ else:
+ for p in self.get_enc_params_except_rel_pos():
+ p.requires_grad = False
+ return self
+
+ def attach_hooks(self, midas):
+ if len(self.handles) > 0:
+ self.remove_hooks()
+ if "out_conv" in self.layer_names:
+ self.handles.append(list(midas.scratch.output_conv.children())[
+ 3].register_forward_hook(get_activation("out_conv", self.core_out)))
+ if "r4" in self.layer_names:
+ self.handles.append(midas.scratch.refinenet4.register_forward_hook(
+ get_activation("r4", self.core_out)))
+ if "r3" in self.layer_names:
+ self.handles.append(midas.scratch.refinenet3.register_forward_hook(
+ get_activation("r3", self.core_out)))
+ if "r2" in self.layer_names:
+ self.handles.append(midas.scratch.refinenet2.register_forward_hook(
+ get_activation("r2", self.core_out)))
+ if "r1" in self.layer_names:
+ self.handles.append(midas.scratch.refinenet1.register_forward_hook(
+ get_activation("r1", self.core_out)))
+ if "l4_rn" in self.layer_names:
+ self.handles.append(midas.scratch.layer4_rn.register_forward_hook(
+ get_activation("l4_rn", self.core_out)))
+
+ return self
+
+ def remove_hooks(self):
+ for h in self.handles:
+ h.remove()
+ return self
+
+ def __del__(self):
+ self.remove_hooks()
+
+ def set_output_channels(self, model_type):
+ self.output_channels = MIDAS_SETTINGS[model_type]
+
+ @staticmethod
+ def build(midas_model_type="DPT_BEiT_L_384", train_midas=False, use_pretrained_midas=True, fetch_features=False, freeze_bn=True, force_keep_ar=False, force_reload=False, **kwargs):
+ if midas_model_type not in MIDAS_SETTINGS:
+ raise ValueError(
+ f"Invalid model type: {midas_model_type}. Must be one of {list(MIDAS_SETTINGS.keys())}")
+ if "img_size" in kwargs:
+ kwargs = MidasCore.parse_img_size(kwargs)
+ img_size = kwargs.pop("img_size", [384, 384])
+ print("img_size", img_size)
+ hubconf = importlib.import_module(f"models.monoD.zoeDepth.midas_c.hubconf")
+ midas = getattr(hubconf, midas_model_type)(pretrained=False)
+ ckpt_path = "models/monoD/zoeDepth/ckpts/dpt_beit_large_384.pt"
+ midas_ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
+ midas.load_state_dict(midas_ckpt)
+ # midas = torch.hub.load("intel-isl/MiDaS", midas_model_type,
+ # pretrained=use_pretrained_midas, force_reload=force_reload)
+ kwargs.update({'keep_aspect_ratio': force_keep_ar})
+ midas_core = MidasCore(midas, trainable=train_midas, fetch_features=fetch_features,
+ freeze_bn=freeze_bn, img_size=img_size, **kwargs)
+ midas_core.set_output_channels(midas_model_type)
+ return midas_core
+
+ @staticmethod
+ def build_from_config(config):
+ return MidasCore.build(**config)
+
+ @staticmethod
+ def parse_img_size(config):
+ assert 'img_size' in config
+ if isinstance(config['img_size'], str):
+ assert "," in config['img_size'], "img_size should be a string with comma separated img_size=H,W"
+ config['img_size'] = list(map(int, config['img_size'].split(",")))
+ assert len(
+ config['img_size']) == 2, "img_size should be a string with comma separated img_size=H,W"
+ elif isinstance(config['img_size'], int):
+ config['img_size'] = [config['img_size'], config['img_size']]
+ else:
+ assert isinstance(config['img_size'], list) and len(
+ config['img_size']) == 2, "img_size should be a list of H,W"
+ return config
+
+
+nchannels2models = {
+ tuple([256]*5): ["DPT_BEiT_L_384", "DPT_BEiT_L_512", "DPT_BEiT_B_384", "DPT_SwinV2_L_384", "DPT_SwinV2_B_384", "DPT_SwinV2_T_256", "DPT_Large", "DPT_Hybrid"],
+ (512, 256, 128, 64, 64): ["MiDaS_small"]
+}
+
+# Model name to number of output channels
+MIDAS_SETTINGS = {m: k for k, v in nchannels2models.items()
+ for m in v
+ }
diff --git a/models/monoD/zoeDepth/models/builder.py b/models/monoD/zoeDepth/models/builder.py
new file mode 100755
index 0000000000000000000000000000000000000000..a42fb42f18dd26e313ef2ed3a983381eaaf36dbd
--- /dev/null
+++ b/models/monoD/zoeDepth/models/builder.py
@@ -0,0 +1,52 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+from importlib import import_module
+from models.monoD.zoeDepth.models.depth_model import DepthModel
+
+def build_model(config) -> DepthModel:
+ """Builds a model from a config. The model is specified by the model name and version in the config. The model is then constructed using the build_from_config function of the model interface.
+ This function should be used to construct models for training and evaluation.
+
+ Args:
+ config (dict): Config dict. Config is constructed in utils/config.py. Each model has its own config file(s) saved in its root model folder.
+
+ Returns:
+ torch.nn.Module: Model corresponding to name and version as specified in config
+ """
+ module_name = f"models.monoD.zoeDepth.models.{config.model}"
+ try:
+ module = import_module(module_name)
+ except ModuleNotFoundError as e:
+ # print the original error message
+ print(e)
+ raise ValueError(
+ f"Model {config.model} not found. Refer above error for details.") from e
+ try:
+ get_version = getattr(module, "get_version")
+ except AttributeError as e:
+ raise ValueError(
+ f"Model {config.model} has no get_version function.") from e
+
+ return get_version(config.version_name).build_from_config(config)
diff --git a/models/monoD/zoeDepth/models/depth_model.py b/models/monoD/zoeDepth/models/depth_model.py
new file mode 100755
index 0000000000000000000000000000000000000000..fc421c108ea3928c9add62b4c190500d9bd4eda1
--- /dev/null
+++ b/models/monoD/zoeDepth/models/depth_model.py
@@ -0,0 +1,152 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision import transforms
+import PIL.Image
+from PIL import Image
+from typing import Union
+
+
+class DepthModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.device = 'cpu'
+
+ def to(self, device) -> nn.Module:
+ self.device = device
+ return super().to(device)
+
+ def forward(self, x, *args, **kwargs):
+ raise NotImplementedError
+
+ def _infer(self, x: torch.Tensor):
+ """
+ Inference interface for the model
+ Args:
+ x (torch.Tensor): input tensor of shape (b, c, h, w)
+ Returns:
+ torch.Tensor: output tensor of shape (b, 1, h, w)
+ """
+ return self(x)['metric_depth']
+
+ def _infer_with_pad_aug(self, x: torch.Tensor, pad_input: bool=True, fh: float=3, fw: float=3, upsampling_mode: str='bicubic', padding_mode="reflect", **kwargs) -> torch.Tensor:
+ """
+ Inference interface for the model with padding augmentation
+ Padding augmentation fixes the boundary artifacts in the output depth map.
+ Boundary artifacts are sometimes caused by the fact that the model is trained on NYU raw dataset which has a black or white border around the image.
+ This augmentation pads the input image and crops the prediction back to the original size / view.
+
+ Note: This augmentation is not required for the models trained with 'avoid_boundary'=True.
+ Args:
+ x (torch.Tensor): input tensor of shape (b, c, h, w)
+ pad_input (bool, optional): whether to pad the input or not. Defaults to True.
+ fh (float, optional): height padding factor. The padding is calculated as sqrt(h/2) * fh. Defaults to 3.
+ fw (float, optional): width padding factor. The padding is calculated as sqrt(w/2) * fw. Defaults to 3.
+ upsampling_mode (str, optional): upsampling mode. Defaults to 'bicubic'.
+ padding_mode (str, optional): padding mode. Defaults to "reflect".
+ Returns:
+ torch.Tensor: output tensor of shape (b, 1, h, w)
+ """
+ # assert x is nchw and c = 3
+ assert x.dim() == 4, "x must be 4 dimensional, got {}".format(x.dim())
+ assert x.shape[1] == 3, "x must have 3 channels, got {}".format(x.shape[1])
+
+ if pad_input:
+ assert fh > 0 or fw > 0, "atlease one of fh and fw must be greater than 0"
+ pad_h = int(np.sqrt(x.shape[2]/2) * fh)
+ pad_w = int(np.sqrt(x.shape[3]/2) * fw)
+ padding = [pad_w, pad_w]
+ if pad_h > 0:
+ padding += [pad_h, pad_h]
+
+ x = F.pad(x, padding, mode=padding_mode, **kwargs)
+ out = self._infer(x)
+ if out.shape[-2:] != x.shape[-2:]:
+ out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode=upsampling_mode, align_corners=False)
+ if pad_input:
+ # crop to the original size, handling the case where pad_h and pad_w is 0
+ if pad_h > 0:
+ out = out[:, :, pad_h:-pad_h,:]
+ if pad_w > 0:
+ out = out[:, :, :, pad_w:-pad_w]
+ return out
+
+ def infer_with_flip_aug(self, x, pad_input: bool=True, **kwargs) -> torch.Tensor:
+ """
+ Inference interface for the model with horizontal flip augmentation
+ Horizontal flip augmentation improves the accuracy of the model by averaging the output of the model with and without horizontal flip.
+ Args:
+ x (torch.Tensor): input tensor of shape (b, c, h, w)
+ pad_input (bool, optional): whether to use padding augmentation. Defaults to True.
+ Returns:
+ torch.Tensor: output tensor of shape (b, 1, h, w)
+ """
+ # infer with horizontal flip and average
+ out = self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs)
+ out_flip = self._infer_with_pad_aug(torch.flip(x, dims=[3]), pad_input=pad_input, **kwargs)
+ out = (out + torch.flip(out_flip, dims=[3])) / 2
+ return out
+
+ def infer(self, x, pad_input: bool=True, with_flip_aug: bool=True, **kwargs) -> torch.Tensor:
+ """
+ Inference interface for the model
+ Args:
+ x (torch.Tensor): input tensor of shape (b, c, h, w)
+ pad_input (bool, optional): whether to use padding augmentation. Defaults to True.
+ with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True.
+ Returns:
+ torch.Tensor: output tensor of shape (b, 1, h, w)
+ """
+ if with_flip_aug:
+ return self.infer_with_flip_aug(x, pad_input=pad_input, **kwargs)
+ else:
+ return self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs)
+
+ @torch.no_grad()
+ def infer_pil(self, pil_img, pad_input: bool=True, with_flip_aug: bool=True, output_type: str="numpy", **kwargs) -> Union[np.ndarray, PIL.Image.Image, torch.Tensor]:
+ """
+ Inference interface for the model for PIL image
+ Args:
+ pil_img (PIL.Image.Image): input PIL image
+ pad_input (bool, optional): whether to use padding augmentation. Defaults to True.
+ with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True.
+ output_type (str, optional): output type. Supported values are 'numpy', 'pil' and 'tensor'. Defaults to "numpy".
+ """
+ x = transforms.ToTensor()(pil_img).unsqueeze(0).to(self.device)
+ out_tensor = self.infer(x, pad_input=pad_input, with_flip_aug=with_flip_aug, **kwargs)
+ if output_type == "numpy":
+ return out_tensor.squeeze().cpu().numpy()
+ elif output_type == "pil":
+ # uint16 is required for depth pil image
+ out_16bit_numpy = (out_tensor.squeeze().cpu().numpy()*256).astype(np.uint16)
+ return Image.fromarray(out_16bit_numpy)
+ elif output_type == "tensor":
+ return out_tensor.squeeze().cpu()
+ else:
+ raise ValueError(f"output_type {output_type} not supported. Supported values are 'numpy', 'pil' and 'tensor'")
+
\ No newline at end of file
diff --git a/models/monoD/zoeDepth/models/layers/attractor.py b/models/monoD/zoeDepth/models/layers/attractor.py
new file mode 100755
index 0000000000000000000000000000000000000000..2a8efe645adea1d88a12e2ac5cc6bb2a251eef9d
--- /dev/null
+++ b/models/monoD/zoeDepth/models/layers/attractor.py
@@ -0,0 +1,208 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import torch
+import torch.nn as nn
+
+
+@torch.jit.script
+def exp_attractor(dx, alpha: float = 300, gamma: int = 2):
+ """Exponential attractor: dc = exp(-alpha*|dx|^gamma) * dx , where dx = a - c, a = attractor point, c = bin center, dc = shift in bin centermmary for exp_attractor
+
+ Args:
+ dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center.
+ alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300.
+ gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2.
+
+ Returns:
+ torch.Tensor : Delta shifts - dc; New bin centers = Old bin centers + dc
+ """
+ return torch.exp(-alpha*(torch.abs(dx)**gamma)) * (dx)
+
+
+@torch.jit.script
+def inv_attractor(dx, alpha: float = 300, gamma: int = 2):
+ """Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center
+ This is the default one according to the accompanying paper.
+
+ Args:
+ dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center.
+ alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300.
+ gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2.
+
+ Returns:
+ torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc
+ """
+ return dx.div(1+alpha*dx.pow(gamma))
+
+
+class AttractorLayer(nn.Module):
+ def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10,
+ alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False):
+ """
+ Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth)
+ """
+ super().__init__()
+
+ self.n_attractors = n_attractors
+ self.n_bins = n_bins
+ self.min_depth = min_depth
+ self.max_depth = max_depth
+ self.alpha = alpha
+ self.gamma = gamma
+ self.kind = kind
+ self.attractor_type = attractor_type
+ self.memory_efficient = memory_efficient
+
+ self._net = nn.Sequential(
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(mlp_dim, n_attractors*2, 1, 1, 0), # x2 for linear norm
+ nn.ReLU(inplace=True)
+ )
+
+ def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
+ """
+ Args:
+ x (torch.Tensor) : feature block; shape - n, c, h, w
+ b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w
+
+ Returns:
+ tuple(torch.Tensor,torch.Tensor) : new bin centers normed and scaled; shape - n, nbins, h, w
+ """
+ if prev_b_embedding is not None:
+ if interpolate:
+ prev_b_embedding = nn.functional.interpolate(
+ prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
+ x = x + prev_b_embedding
+
+ A = self._net(x)
+ eps = 1e-3
+ A = A + eps
+ n, c, h, w = A.shape
+ A = A.view(n, self.n_attractors, 2, h, w)
+ A_normed = A / A.sum(dim=2, keepdim=True) # n, a, 2, h, w
+ A_normed = A[:, :, 0, ...] # n, na, h, w
+
+ b_prev = nn.functional.interpolate(
+ b_prev, (h, w), mode='bilinear', align_corners=True)
+ b_centers = b_prev
+
+ if self.attractor_type == 'exp':
+ dist = exp_attractor
+ else:
+ dist = inv_attractor
+
+ if not self.memory_efficient:
+ func = {'mean': torch.mean, 'sum': torch.sum}[self.kind]
+ # .shape N, nbins, h, w
+ delta_c = func(dist(A_normed.unsqueeze(
+ 2) - b_centers.unsqueeze(1)), dim=1)
+ else:
+ delta_c = torch.zeros_like(b_centers, device=b_centers.device)
+ for i in range(self.n_attractors):
+ # .shape N, nbins, h, w
+ delta_c += dist(A_normed[:, i, ...].unsqueeze(1) - b_centers)
+
+ if self.kind == 'mean':
+ delta_c = delta_c / self.n_attractors
+
+ b_new_centers = b_centers + delta_c
+ B_centers = (self.max_depth - self.min_depth) * \
+ b_new_centers + self.min_depth
+ B_centers, _ = torch.sort(B_centers, dim=1)
+ B_centers = torch.clip(B_centers, self.min_depth, self.max_depth)
+ return b_new_centers, B_centers
+
+
+class AttractorLayerUnnormed(nn.Module):
+ def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10,
+ alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False):
+ """
+ Attractor layer for bin centers. Bin centers are unbounded
+ """
+ super().__init__()
+
+ self.n_attractors = n_attractors
+ self.n_bins = n_bins
+ self.min_depth = min_depth
+ self.max_depth = max_depth
+ self.alpha = alpha
+ self.gamma = gamma
+ self.kind = kind
+ self.attractor_type = attractor_type
+ self.memory_efficient = memory_efficient
+
+ self._net = nn.Sequential(
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0),
+ nn.Softplus()
+ )
+
+ def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
+ """
+ Args:
+ x (torch.Tensor) : feature block; shape - n, c, h, w
+ b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w
+
+ Returns:
+ tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version
+ """
+ if prev_b_embedding is not None:
+ if interpolate:
+ prev_b_embedding = nn.functional.interpolate(
+ prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
+ x = x + prev_b_embedding
+
+ A = self._net(x)
+ n, c, h, w = A.shape
+
+ b_prev = nn.functional.interpolate(
+ b_prev, (h, w), mode='bilinear', align_corners=True)
+ b_centers = b_prev
+
+ if self.attractor_type == 'exp':
+ dist = exp_attractor
+ else:
+ dist = inv_attractor
+
+ if not self.memory_efficient:
+ func = {'mean': torch.mean, 'sum': torch.sum}[self.kind]
+ # .shape N, nbins, h, w
+ delta_c = func(
+ dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1)
+ else:
+ delta_c = torch.zeros_like(b_centers, device=b_centers.device)
+ for i in range(self.n_attractors):
+ delta_c += dist(A[:, i, ...].unsqueeze(1) -
+ b_centers) # .shape N, nbins, h, w
+
+ if self.kind == 'mean':
+ delta_c = delta_c / self.n_attractors
+
+ b_new_centers = b_centers + delta_c
+ B_centers = b_new_centers
+
+ return b_new_centers, B_centers
diff --git a/models/monoD/zoeDepth/models/layers/dist_layers.py b/models/monoD/zoeDepth/models/layers/dist_layers.py
new file mode 100755
index 0000000000000000000000000000000000000000..3208405dfb78fdfc28d5765e5a6d5dbe31967a23
--- /dev/null
+++ b/models/monoD/zoeDepth/models/layers/dist_layers.py
@@ -0,0 +1,121 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import torch
+import torch.nn as nn
+
+
+def log_binom(n, k, eps=1e-7):
+ """ log(nCk) using stirling approximation """
+ n = n + eps
+ k = k + eps
+ return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps)
+
+
+class LogBinomial(nn.Module):
+ def __init__(self, n_classes=256, act=torch.softmax):
+ """Compute log binomial distribution for n_classes
+
+ Args:
+ n_classes (int, optional): number of output classes. Defaults to 256.
+ """
+ super().__init__()
+ self.K = n_classes
+ self.act = act
+ self.register_buffer('k_idx', torch.arange(
+ 0, n_classes).view(1, -1, 1, 1))
+ self.register_buffer('K_minus_1', torch.Tensor(
+ [self.K-1]).view(1, -1, 1, 1))
+
+ def forward(self, x, t=1., eps=1e-4):
+ """Compute log binomial distribution for x
+
+ Args:
+ x (torch.Tensor - NCHW): probabilities
+ t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1..
+ eps (float, optional): Small number for numerical stability. Defaults to 1e-4.
+
+ Returns:
+ torch.Tensor -NCHW: log binomial distribution logbinomial(p;t)
+ """
+ if x.ndim == 3:
+ x = x.unsqueeze(1) # make it nchw
+
+ one_minus_x = torch.clamp(1 - x, eps, 1)
+ x = torch.clamp(x, eps, 1)
+ y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \
+ torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x)
+ return self.act(y/t, dim=1)
+
+
+class ConditionalLogBinomial(nn.Module):
+ def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax):
+ """Conditional Log Binomial distribution
+
+ Args:
+ in_features (int): number of input channels in main feature
+ condition_dim (int): number of input channels in condition feature
+ n_classes (int, optional): Number of classes. Defaults to 256.
+ bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2.
+ p_eps (float, optional): small eps value. Defaults to 1e-4.
+ max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50.
+ min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7.
+ """
+ super().__init__()
+ self.p_eps = p_eps
+ self.max_temp = max_temp
+ self.min_temp = min_temp
+ self.log_binomial_transform = LogBinomial(n_classes, act=act)
+ bottleneck = (in_features + condition_dim) // bottleneck_factor
+ self.mlp = nn.Sequential(
+ nn.Conv2d(in_features + condition_dim, bottleneck,
+ kernel_size=1, stride=1, padding=0),
+ nn.GELU(),
+ # 2 for p linear norm, 2 for t linear norm
+ nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0),
+ nn.Softplus()
+ )
+
+ def forward(self, x, cond):
+ """Forward pass
+
+ Args:
+ x (torch.Tensor - NCHW): Main feature
+ cond (torch.Tensor - NCHW): condition feature
+
+ Returns:
+ torch.Tensor: Output log binomial distribution
+ """
+ pt = self.mlp(torch.concat((x, cond), dim=1))
+ p, t = pt[:, :2, ...], pt[:, 2:, ...]
+
+ p = p + self.p_eps
+ p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...])
+
+ t = t + self.p_eps
+ t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...])
+ t = t.unsqueeze(1)
+ t = (self.max_temp - self.min_temp) * t + self.min_temp
+
+ return self.log_binomial_transform(p, t)
diff --git a/models/monoD/zoeDepth/models/layers/localbins_layers.py b/models/monoD/zoeDepth/models/layers/localbins_layers.py
new file mode 100755
index 0000000000000000000000000000000000000000..f94481605c3e6958ce50e73b2eb31d9f0c07dc67
--- /dev/null
+++ b/models/monoD/zoeDepth/models/layers/localbins_layers.py
@@ -0,0 +1,169 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import torch
+import torch.nn as nn
+
+
+class SeedBinRegressor(nn.Module):
+ def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
+ """Bin center regressor network. Bin centers are bounded on (min_depth, max_depth) interval.
+
+ Args:
+ in_features (int): input channels
+ n_bins (int, optional): Number of bin centers. Defaults to 16.
+ mlp_dim (int, optional): Hidden dimension. Defaults to 256.
+ min_depth (float, optional): Min depth value. Defaults to 1e-3.
+ max_depth (float, optional): Max depth value. Defaults to 10.
+ """
+ super().__init__()
+ self.version = "1_1"
+ self.min_depth = min_depth
+ self.max_depth = max_depth
+
+ self._net = nn.Sequential(
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(mlp_dim, n_bins, 1, 1, 0),
+ nn.ReLU(inplace=True)
+ )
+
+ def forward(self, x):
+ """
+ Returns tensor of bin_width vectors (centers). One vector b for every pixel
+ """
+ B = self._net(x)
+ eps = 1e-3
+ B = B + eps
+ B_widths_normed = B / B.sum(dim=1, keepdim=True)
+ B_widths = (self.max_depth - self.min_depth) * \
+ B_widths_normed # .shape NCHW
+ # pad has the form (left, right, top, bottom, front, back)
+ B_widths = nn.functional.pad(
+ B_widths, (0, 0, 0, 0, 1, 0), mode='constant', value=self.min_depth)
+ B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW
+
+ B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:, 1:, ...])
+ return B_widths_normed, B_centers
+
+
+class SeedBinRegressorUnnormed(nn.Module):
+ def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
+ """Bin center regressor network. Bin centers are unbounded
+
+ Args:
+ in_features (int): input channels
+ n_bins (int, optional): Number of bin centers. Defaults to 16.
+ mlp_dim (int, optional): Hidden dimension. Defaults to 256.
+ min_depth (float, optional): Not used. (for compatibility with SeedBinRegressor)
+ max_depth (float, optional): Not used. (for compatibility with SeedBinRegressor)
+ """
+ super().__init__()
+ self.version = "1_1"
+ self._net = nn.Sequential(
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(mlp_dim, n_bins, 1, 1, 0),
+ nn.Softplus()
+ )
+
+ def forward(self, x):
+ """
+ Returns tensor of bin_width vectors (centers). One vector b for every pixel
+ """
+ B_centers = self._net(x)
+ return B_centers, B_centers
+
+
+class Projector(nn.Module):
+ def __init__(self, in_features, out_features, mlp_dim=128):
+ """Projector MLP
+
+ Args:
+ in_features (int): input channels
+ out_features (int): output channels
+ mlp_dim (int, optional): hidden dimension. Defaults to 128.
+ """
+ super().__init__()
+
+ self._net = nn.Sequential(
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(mlp_dim, out_features, 1, 1, 0),
+ )
+
+ def forward(self, x):
+ return self._net(x)
+
+
+
+class LinearSplitter(nn.Module):
+ def __init__(self, in_features, prev_nbins, split_factor=2, mlp_dim=128, min_depth=1e-3, max_depth=10):
+ super().__init__()
+
+ self.prev_nbins = prev_nbins
+ self.split_factor = split_factor
+ self.min_depth = min_depth
+ self.max_depth = max_depth
+
+ self._net = nn.Sequential(
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
+ nn.GELU(),
+ nn.Conv2d(mlp_dim, prev_nbins * split_factor, 1, 1, 0),
+ nn.ReLU()
+ )
+
+ def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
+ """
+ x : feature block; shape - n, c, h, w
+ b_prev : previous bin widths normed; shape - n, prev_nbins, h, w
+ """
+ if prev_b_embedding is not None:
+ if interpolate:
+ prev_b_embedding = nn.functional.interpolate(prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
+ x = x + prev_b_embedding
+ S = self._net(x)
+ eps = 1e-3
+ S = S + eps
+ n, c, h, w = S.shape
+ S = S.view(n, self.prev_nbins, self.split_factor, h, w)
+ S_normed = S / S.sum(dim=2, keepdim=True) # fractional splits
+
+ b_prev = nn.functional.interpolate(b_prev, (h,w), mode='bilinear', align_corners=True)
+
+
+ b_prev = b_prev / b_prev.sum(dim=1, keepdim=True) # renormalize for gurantees
+ # print(b_prev.shape, S_normed.shape)
+ # if is_for_query:(1).expand(-1, b_prev.size(0)//n, -1, -1, -1, -1).flatten(0,1) # TODO ? can replace all this with a single torch.repeat?
+ b = b_prev.unsqueeze(2) * S_normed
+ b = b.flatten(1,2) # .shape n, prev_nbins * split_factor, h, w
+
+ # calculate bin centers for loss calculation
+ B_widths = (self.max_depth - self.min_depth) * b # .shape N, nprev * splitfactor, H, W
+ # pad has the form (left, right, top, bottom, front, back)
+ B_widths = nn.functional.pad(B_widths, (0,0,0,0,1,0), mode='constant', value=self.min_depth)
+ B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW
+
+ B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:,1:,...])
+ return b, B_centers
\ No newline at end of file
diff --git a/models/monoD/zoeDepth/models/layers/patch_transformer.py b/models/monoD/zoeDepth/models/layers/patch_transformer.py
new file mode 100755
index 0000000000000000000000000000000000000000..99d9e51a06b981bae45ce7dd64eaef19a4121991
--- /dev/null
+++ b/models/monoD/zoeDepth/models/layers/patch_transformer.py
@@ -0,0 +1,91 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import torch
+import torch.nn as nn
+
+
+class PatchTransformerEncoder(nn.Module):
+ def __init__(self, in_channels, patch_size=10, embedding_dim=128, num_heads=4, use_class_token=False):
+ """ViT-like transformer block
+
+ Args:
+ in_channels (int): Input channels
+ patch_size (int, optional): patch size. Defaults to 10.
+ embedding_dim (int, optional): Embedding dimension in transformer model. Defaults to 128.
+ num_heads (int, optional): number of attention heads. Defaults to 4.
+ use_class_token (bool, optional): Whether to use extra token at the start for global accumulation (called as "class token"). Defaults to False.
+ """
+ super(PatchTransformerEncoder, self).__init__()
+ self.use_class_token = use_class_token
+ encoder_layers = nn.TransformerEncoderLayer(
+ embedding_dim, num_heads, dim_feedforward=1024)
+ self.transformer_encoder = nn.TransformerEncoder(
+ encoder_layers, num_layers=4) # takes shape S,N,E
+
+ self.embedding_convPxP = nn.Conv2d(in_channels, embedding_dim,
+ kernel_size=patch_size, stride=patch_size, padding=0)
+
+ def positional_encoding_1d(self, sequence_length, batch_size, embedding_dim, device='cpu'):
+ """Generate positional encodings
+
+ Args:
+ sequence_length (int): Sequence length
+ embedding_dim (int): Embedding dimension
+
+ Returns:
+ torch.Tensor SBE: Positional encodings
+ """
+ position = torch.arange(
+ 0, sequence_length, dtype=torch.float32, device=device).unsqueeze(1)
+ index = torch.arange(
+ 0, embedding_dim, 2, dtype=torch.float32, device=device).unsqueeze(0)
+ div_term = torch.exp(index * (-torch.log(torch.tensor(10000.0, device=device)) / embedding_dim))
+ pos_encoding = position * div_term
+ pos_encoding = torch.cat([torch.sin(pos_encoding), torch.cos(pos_encoding)], dim=1)
+ pos_encoding = pos_encoding.unsqueeze(1).repeat(1, batch_size, 1)
+ return pos_encoding
+
+
+ def forward(self, x):
+ """Forward pass
+
+ Args:
+ x (torch.Tensor - NCHW): Input feature tensor
+
+ Returns:
+ torch.Tensor - SNE: Transformer output embeddings. S - sequence length (=HW/patch_size^2), N - batch size, E - embedding dim
+ """
+ embeddings = self.embedding_convPxP(x).flatten(
+ 2) # .shape = n,c,s = n, embedding_dim, s
+ if self.use_class_token:
+ # extra special token at start ?
+ embeddings = nn.functional.pad(embeddings, (1, 0))
+
+ # change to S,N,E format required by transformer
+ embeddings = embeddings.permute(2, 0, 1)
+ S, N, E = embeddings.shape
+ embeddings = embeddings + self.positional_encoding_1d(S, N, E, device=embeddings.device)
+ x = self.transformer_encoder(embeddings) # .shape = S, N, E
+ return x
diff --git a/models/monoD/zoeDepth/models/model_io.py b/models/monoD/zoeDepth/models/model_io.py
new file mode 100755
index 0000000000000000000000000000000000000000..6bf52b67e94809a052c7af39cae4462f0f974141
--- /dev/null
+++ b/models/monoD/zoeDepth/models/model_io.py
@@ -0,0 +1,91 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import torch
+
+def load_state_dict(model, state_dict):
+ """Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict.
+
+ DataParallel prefixes state_dict keys with 'module.' when saving.
+ If the model is not a DataParallel model but the state_dict is, then prefixes are removed.
+ If the model is a DataParallel model but the state_dict is not, then prefixes are added.
+ """
+ state_dict = state_dict.get('model', state_dict)
+ # if model is a DataParallel model, then state_dict keys are prefixed with 'module.'
+
+ do_prefix = isinstance(
+ model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
+ state = {}
+ for k, v in state_dict.items():
+ if k.startswith('module.') and not do_prefix:
+ k = k[7:]
+
+ if not k.startswith('module.') and do_prefix:
+ k = 'module.' + k
+
+ state[k] = v
+ model.load_state_dict(state)
+ print("Loaded successfully")
+ return model
+
+
+def load_wts(model, checkpoint_path):
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
+ return load_state_dict(model, ckpt)
+
+
+def load_state_dict_from_url(model, url, **kwargs):
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs)
+ return load_state_dict(model, state_dict)
+
+
+def load_state_from_resource(model, resource: str):
+ """Loads weights to the model from a given resource. A resource can be of following types:
+ 1. URL. Prefixed with "url::"
+ e.g. url::http(s)://url.resource.com/ckpt.pt
+
+ 2. Local path. Prefixed with "local::"
+ e.g. local::/path/to/ckpt.pt
+
+
+ Args:
+ model (torch.nn.Module): Model
+ resource (str): resource string
+
+ Returns:
+ torch.nn.Module: Model with loaded weights
+ """
+ print(f"Using pretrained resource {resource}")
+
+ if resource.startswith('url::'):
+ url = resource.split('url::')[1]
+ return load_state_dict_from_url(model, url, progress=True)
+
+ elif resource.startswith('local::'):
+ path = resource.split('local::')[1]
+ return load_wts(model, path)
+
+ else:
+ raise ValueError("Invalid resource type, only url:: and local:: are supported")
+
\ No newline at end of file
diff --git a/models/monoD/zoeDepth/models/zoedepth/__init__.py b/models/monoD/zoeDepth/models/zoedepth/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..cc33f737d238766559f0e3a8def3c0b568f23b7f
--- /dev/null
+++ b/models/monoD/zoeDepth/models/zoedepth/__init__.py
@@ -0,0 +1,31 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+from .zoedepth_v1 import ZoeDepth
+
+all_versions = {
+ "v1": ZoeDepth,
+}
+
+get_version = lambda v : all_versions[v]
\ No newline at end of file
diff --git a/models/monoD/zoeDepth/models/zoedepth/config_zoedepth.json b/models/monoD/zoeDepth/models/zoedepth/config_zoedepth.json
new file mode 100755
index 0000000000000000000000000000000000000000..99beb2dcd886006ba87805bddbe408b6d5fdff78
--- /dev/null
+++ b/models/monoD/zoeDepth/models/zoedepth/config_zoedepth.json
@@ -0,0 +1,58 @@
+{
+ "model": {
+ "name": "ZoeDepth",
+ "version_name": "v1",
+ "n_bins": 64,
+ "bin_embedding_dim": 128,
+ "bin_centers_type": "softplus",
+ "n_attractors":[16, 8, 4, 1],
+ "attractor_alpha": 1000,
+ "attractor_gamma": 2,
+ "attractor_kind" : "mean",
+ "attractor_type" : "inv",
+ "midas_model_type" : "DPT_BEiT_L_384",
+ "min_temp": 0.0212,
+ "max_temp": 50.0,
+ "output_distribution": "logbinomial",
+ "memory_efficient": true,
+ "inverse_midas": false,
+ "img_size": [384, 512]
+ },
+
+ "train": {
+ "train_midas": true,
+ "use_pretrained_midas": true,
+ "trainer": "zoedepth",
+ "epochs": 5,
+ "bs": 16,
+ "optim_kwargs": {"lr": 0.000161, "wd": 0.01},
+ "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true},
+ "same_lr": false,
+ "w_si": 1,
+ "w_domain": 0.2,
+ "w_reg": 0,
+ "w_grad": 0,
+ "avoid_boundary": false,
+ "random_crop": false,
+ "input_width": 640,
+ "input_height": 480,
+ "midas_lr_factor": 1,
+ "encoder_lr_factor":10,
+ "pos_enc_lr_factor":10,
+ "freeze_midas_bn": true
+
+ },
+
+ "infer":{
+ "train_midas": false,
+ "use_pretrained_midas": false,
+ "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_N.pt",
+ "force_keep_ar": true
+ },
+
+ "eval":{
+ "train_midas": false,
+ "use_pretrained_midas": false,
+ "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_N.pt"
+ }
+}
\ No newline at end of file
diff --git a/models/monoD/zoeDepth/models/zoedepth/config_zoedepth_kitti.json b/models/monoD/zoeDepth/models/zoedepth/config_zoedepth_kitti.json
new file mode 100755
index 0000000000000000000000000000000000000000..b51802aa44b91c39e15aacaac4b5ab6bec884414
--- /dev/null
+++ b/models/monoD/zoeDepth/models/zoedepth/config_zoedepth_kitti.json
@@ -0,0 +1,22 @@
+{
+ "model": {
+ "bin_centers_type": "normed",
+ "img_size": [384, 768]
+ },
+
+ "train": {
+ },
+
+ "infer":{
+ "train_midas": false,
+ "use_pretrained_midas": false,
+ "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt",
+ "force_keep_ar": true
+ },
+
+ "eval":{
+ "train_midas": false,
+ "use_pretrained_midas": false,
+ "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt"
+ }
+}
\ No newline at end of file
diff --git a/models/monoD/zoeDepth/models/zoedepth/zoedepth_v1.py b/models/monoD/zoeDepth/models/zoedepth/zoedepth_v1.py
new file mode 100755
index 0000000000000000000000000000000000000000..7a10db5655b7002dbed9e75fbf095ed1ca37544b
--- /dev/null
+++ b/models/monoD/zoeDepth/models/zoedepth/zoedepth_v1.py
@@ -0,0 +1,256 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import itertools
+
+import torch
+import torch.nn as nn
+from models.monoD.zoeDepth.models.depth_model import DepthModel
+from models.monoD.zoeDepth.models.base_models.midas import MidasCore
+from models.monoD.zoeDepth.models.layers.attractor import (
+ AttractorLayer, AttractorLayerUnnormed
+)
+from models.monoD.zoeDepth.models.layers.dist_layers import (
+ ConditionalLogBinomial
+ )
+from models.monoD.zoeDepth.models.layers.localbins_layers import (
+ Projector, SeedBinRegressor,
+ SeedBinRegressorUnnormed
+ )
+from models.monoD.zoeDepth.models.model_io import load_state_from_resource
+
+
+class ZoeDepth(DepthModel):
+ def __init__(self, core, n_bins=64, bin_centers_type="softplus", bin_embedding_dim=128, min_depth=1e-3, max_depth=10,
+ n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', min_temp=5, max_temp=50, train_midas=True,
+ midas_lr_factor=10, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs):
+ """ZoeDepth model. This is the version of ZoeDepth that has a single metric head
+
+ Args:
+ core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features
+ n_bins (int, optional): Number of bin centers. Defaults to 64.
+ bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers.
+ For "softplus", softplus activation is used and thus are unbounded. Defaults to "softplus".
+ bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128.
+ min_depth (float, optional): Lower bound for normed bin centers. Defaults to 1e-3.
+ max_depth (float, optional): Upper bound for normed bin centers. Defaults to 10.
+ n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1].
+ attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300.
+ attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2.
+ attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'.
+ attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'.
+ min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5.
+ max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50.
+ train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True.
+ midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10.
+ encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10.
+ pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10.
+ """
+ super().__init__()
+
+ self.core = core
+ self.max_depth = max_depth
+ self.min_depth = min_depth
+ self.min_temp = min_temp
+ self.bin_centers_type = bin_centers_type
+
+ self.midas_lr_factor = midas_lr_factor
+ self.encoder_lr_factor = encoder_lr_factor
+ self.pos_enc_lr_factor = pos_enc_lr_factor
+ self.train_midas = train_midas
+ self.inverse_midas = inverse_midas
+
+ if self.encoder_lr_factor <= 0:
+ self.core.freeze_encoder(
+ freeze_rel_pos=self.pos_enc_lr_factor <= 0)
+
+ N_MIDAS_OUT = 32
+ btlnck_features = self.core.output_channels[0]
+ num_out_features = self.core.output_channels[1:]
+
+ self.conv2 = nn.Conv2d(btlnck_features, btlnck_features,
+ kernel_size=1, stride=1, padding=0) # btlnck conv
+
+ if bin_centers_type == "normed":
+ SeedBinRegressorLayer = SeedBinRegressor
+ Attractor = AttractorLayer
+ elif bin_centers_type == "softplus":
+ SeedBinRegressorLayer = SeedBinRegressorUnnormed
+ Attractor = AttractorLayerUnnormed
+ elif bin_centers_type == "hybrid1":
+ SeedBinRegressorLayer = SeedBinRegressor
+ Attractor = AttractorLayerUnnormed
+ elif bin_centers_type == "hybrid2":
+ SeedBinRegressorLayer = SeedBinRegressorUnnormed
+ Attractor = AttractorLayer
+ else:
+ raise ValueError(
+ "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'")
+
+ self.seed_bin_regressor = SeedBinRegressorLayer(
+ btlnck_features, n_bins=n_bins, min_depth=min_depth, max_depth=max_depth)
+ self.seed_projector = Projector(btlnck_features, bin_embedding_dim)
+ self.projectors = nn.ModuleList([
+ Projector(num_out, bin_embedding_dim)
+ for num_out in num_out_features
+ ])
+ self.attractors = nn.ModuleList([
+ Attractor(bin_embedding_dim, n_bins, n_attractors=n_attractors[i], min_depth=min_depth, max_depth=max_depth,
+ alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type)
+ for i in range(len(num_out_features))
+ ])
+
+ last_in = N_MIDAS_OUT + 1 # +1 for relative depth
+
+ # use log binomial instead of softmax
+ self.conditional_log_binomial = ConditionalLogBinomial(
+ last_in, bin_embedding_dim, n_classes=n_bins, min_temp=min_temp, max_temp=max_temp)
+
+ def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs):
+ """
+ Args:
+ x (torch.Tensor): Input image tensor of shape (B, C, H, W)
+ return_final_centers (bool, optional): Whether to return the final bin centers. Defaults to False.
+ denorm (bool, optional): Whether to denormalize the input image. This reverses ImageNet normalization as midas normalization is different. Defaults to False.
+ return_probs (bool, optional): Whether to return the output probability distribution. Defaults to False.
+
+ Returns:
+ dict: Dictionary containing the following keys:
+ - rel_depth (torch.Tensor): Relative depth map of shape (B, H, W)
+ - metric_depth (torch.Tensor): Metric depth map of shape (B, 1, H, W)
+ - bin_centers (torch.Tensor): Bin centers of shape (B, n_bins). Present only if return_final_centers is True
+ - probs (torch.Tensor): Output probability distribution of shape (B, n_bins, H, W). Present only if return_probs is True
+
+ """
+ b, c, h, w = x.shape
+ # print("input shape ", x.shape)
+ self.orig_input_width = w
+ self.orig_input_height = h
+ rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True)
+ # print("output shapes", rel_depth.shape, out.shape)
+
+ outconv_activation = out[0]
+ btlnck = out[1]
+ x_blocks = out[2:]
+
+ x_d0 = self.conv2(btlnck)
+ x = x_d0
+ _, seed_b_centers = self.seed_bin_regressor(x)
+
+ if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2':
+ b_prev = (seed_b_centers - self.min_depth) / \
+ (self.max_depth - self.min_depth)
+ else:
+ b_prev = seed_b_centers
+
+ prev_b_embedding = self.seed_projector(x)
+
+ # unroll this loop for better performance
+ for projector, attractor, x in zip(self.projectors, self.attractors, x_blocks):
+ b_embedding = projector(x)
+ b, b_centers = attractor(
+ b_embedding, b_prev, prev_b_embedding, interpolate=True)
+ b_prev = b.clone()
+ prev_b_embedding = b_embedding.clone()
+
+ last = outconv_activation
+
+ if self.inverse_midas:
+ # invert depth followed by normalization
+ rel_depth = 1.0 / (rel_depth + 1e-6)
+ rel_depth = (rel_depth - rel_depth.min()) / \
+ (rel_depth.max() - rel_depth.min())
+ # concat rel depth with last. First interpolate rel depth to last size
+ rel_cond = rel_depth.unsqueeze(1)
+ rel_cond = nn.functional.interpolate(
+ rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True)
+ last = torch.cat([last, rel_cond], dim=1)
+
+ b_embedding = nn.functional.interpolate(
+ b_embedding, last.shape[-2:], mode='bilinear', align_corners=True)
+ x = self.conditional_log_binomial(last, b_embedding)
+
+ # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor
+ # print(x.shape, b_centers.shape)
+ b_centers = nn.functional.interpolate(
+ b_centers, x.shape[-2:], mode='bilinear', align_corners=True)
+ out = torch.sum(x * b_centers, dim=1, keepdim=True)
+
+ # Structure output dict
+ output = dict(metric_depth=out)
+ if return_final_centers or return_probs:
+ output['bin_centers'] = b_centers
+
+ if return_probs:
+ output['probs'] = x
+
+ return output
+
+ def get_lr_params(self, lr):
+ """
+ Learning rate configuration for different layers of the model
+ Args:
+ lr (float) : Base learning rate
+ Returns:
+ list : list of parameters to optimize and their learning rates, in the format required by torch optimizers.
+ """
+ param_conf = []
+ if self.train_midas:
+ if self.encoder_lr_factor > 0:
+ param_conf.append({'params': self.core.get_enc_params_except_rel_pos(
+ ), 'lr': lr / self.encoder_lr_factor})
+
+ if self.pos_enc_lr_factor > 0:
+ param_conf.append(
+ {'params': self.core.get_rel_pos_params(), 'lr': lr / self.pos_enc_lr_factor})
+
+ midas_params = self.core.core.scratch.parameters()
+ midas_lr_factor = self.midas_lr_factor
+ param_conf.append(
+ {'params': midas_params, 'lr': lr / midas_lr_factor})
+
+ remaining_modules = []
+ for name, child in self.named_children():
+ if name != 'core':
+ remaining_modules.append(child)
+ remaining_params = itertools.chain(
+ *[child.parameters() for child in remaining_modules])
+
+ param_conf.append({'params': remaining_params, 'lr': lr})
+
+ return param_conf
+
+ @staticmethod
+ def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs):
+ core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas,
+ train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs)
+ model = ZoeDepth(core, **kwargs)
+ if pretrained_resource:
+ assert isinstance(pretrained_resource, str), "pretrained_resource must be a string"
+ model = load_state_from_resource(model, pretrained_resource)
+ return model
+
+ @staticmethod
+ def build_from_config(config):
+ return ZoeDepth.build(**config)
diff --git a/models/monoD/zoeDepth/models/zoedepth_nk/__init__.py b/models/monoD/zoeDepth/models/zoedepth_nk/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..513a278b939c10c010e3c0250ec73544d5663886
--- /dev/null
+++ b/models/monoD/zoeDepth/models/zoedepth_nk/__init__.py
@@ -0,0 +1,31 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+from .zoedepth_nk_v1 import ZoeDepthNK
+
+all_versions = {
+ "v1": ZoeDepthNK,
+}
+
+get_version = lambda v : all_versions[v]
\ No newline at end of file
diff --git a/models/monoD/zoeDepth/models/zoedepth_nk/config_zoedepth_nk.json b/models/monoD/zoeDepth/models/zoedepth_nk/config_zoedepth_nk.json
new file mode 100755
index 0000000000000000000000000000000000000000..42bab2a3ad159a09599a5aba270c491021a3cf1a
--- /dev/null
+++ b/models/monoD/zoeDepth/models/zoedepth_nk/config_zoedepth_nk.json
@@ -0,0 +1,67 @@
+{
+ "model": {
+ "name": "ZoeDepthNK",
+ "version_name": "v1",
+ "bin_conf" : [
+ {
+ "name": "nyu",
+ "n_bins": 64,
+ "min_depth": 1e-3,
+ "max_depth": 10.0
+ },
+ {
+ "name": "kitti",
+ "n_bins": 64,
+ "min_depth": 1e-3,
+ "max_depth": 80.0
+ }
+ ],
+ "bin_embedding_dim": 128,
+ "bin_centers_type": "softplus",
+ "n_attractors":[16, 8, 4, 1],
+ "attractor_alpha": 1000,
+ "attractor_gamma": 2,
+ "attractor_kind" : "mean",
+ "attractor_type" : "inv",
+ "min_temp": 0.0212,
+ "max_temp": 50.0,
+ "memory_efficient": true,
+ "midas_model_type" : "DPT_BEiT_L_384",
+ "img_size": [384, 512]
+ },
+
+ "train": {
+ "train_midas": true,
+ "use_pretrained_midas": true,
+ "trainer": "zoedepth_nk",
+ "epochs": 5,
+ "bs": 16,
+ "optim_kwargs": {"lr": 0.0002512, "wd": 0.01},
+ "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true},
+ "same_lr": false,
+ "w_si": 1,
+ "w_domain": 100,
+ "avoid_boundary": false,
+ "random_crop": false,
+ "input_width": 640,
+ "input_height": 480,
+ "w_grad": 0,
+ "w_reg": 0,
+ "midas_lr_factor": 10,
+ "encoder_lr_factor":10,
+ "pos_enc_lr_factor":10
+ },
+
+ "infer": {
+ "train_midas": false,
+ "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt",
+ "use_pretrained_midas": false,
+ "force_keep_ar": true
+ },
+
+ "eval": {
+ "train_midas": false,
+ "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt",
+ "use_pretrained_midas": false
+ }
+}
\ No newline at end of file
diff --git a/models/monoD/zoeDepth/models/zoedepth_nk/zoedepth_nk_v1.py b/models/monoD/zoeDepth/models/zoedepth_nk/zoedepth_nk_v1.py
new file mode 100755
index 0000000000000000000000000000000000000000..a5ebad4981c66bcc09bf6894f95134352fe3c31f
--- /dev/null
+++ b/models/monoD/zoeDepth/models/zoedepth_nk/zoedepth_nk_v1.py
@@ -0,0 +1,342 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import itertools
+
+import torch
+import torch.nn as nn
+
+from models.monoD.zoeDepth.models.depth_model import DepthModel
+from models.monoD.zoeDepth.models.base_models.midas import MidasCore
+from models.monoD.zoeDepth.models.layers.attractor import (
+ AttractorLayer, AttractorLayerUnnormed
+)
+from models.monoD.zoeDepth.models.layers.dist_layers import (
+ ConditionalLogBinomial
+)
+from models.monoD.zoeDepth.models.layers.localbins_layers import (
+ Projector, SeedBinRegressor, SeedBinRegressorUnnormed
+ )
+from models.monoD.zoeDepth.models.layers.patch_transformer import (
+ PatchTransformerEncoder
+)
+from models.monoD.zoeDepth.models.model_io import load_state_from_resource
+
+
+class ZoeDepthNK(DepthModel):
+ def __init__(self, core, bin_conf, bin_centers_type="softplus", bin_embedding_dim=128,
+ n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp',
+ min_temp=5, max_temp=50,
+ memory_efficient=False, train_midas=True,
+ is_midas_pretrained=True, midas_lr_factor=1, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs):
+ """ZoeDepthNK model. This is the version of ZoeDepth that has two metric heads and uses a learned router to route to experts.
+
+ Args:
+ core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features
+
+ bin_conf (List[dict]): A list of dictionaries that contain the bin configuration for each metric head. Each dictionary should contain the following keys:
+ "name" (str, typically same as the dataset name), "n_bins" (int), "min_depth" (float), "max_depth" (float)
+
+ The length of this list determines the number of metric heads.
+ bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers.
+ For "softplus", softplus activation is used and thus are unbounded. Defaults to "normed".
+ bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128.
+
+ n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1].
+ attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300.
+ attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2.
+ attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'.
+ attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'.
+
+ min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5.
+ max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50.
+
+ memory_efficient (bool, optional): Whether to use memory efficient version of attractor layers. Memory efficient version is slower but is recommended incase of multiple metric heads in order save GPU memory. Defaults to False.
+
+ train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True.
+ is_midas_pretrained (bool, optional): Is "core" pretrained? Defaults to True.
+ midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10.
+ encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10.
+ pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10.
+
+ """
+
+ super().__init__()
+
+ self.core = core
+ self.bin_conf = bin_conf
+ self.min_temp = min_temp
+ self.max_temp = max_temp
+ self.memory_efficient = memory_efficient
+ self.train_midas = train_midas
+ self.is_midas_pretrained = is_midas_pretrained
+ self.midas_lr_factor = midas_lr_factor
+ self.encoder_lr_factor = encoder_lr_factor
+ self.pos_enc_lr_factor = pos_enc_lr_factor
+ self.inverse_midas = inverse_midas
+
+ N_MIDAS_OUT = 32
+ btlnck_features = self.core.output_channels[0]
+ num_out_features = self.core.output_channels[1:]
+ # self.scales = [16, 8, 4, 2] # spatial scale factors
+
+ self.conv2 = nn.Conv2d(
+ btlnck_features, btlnck_features, kernel_size=1, stride=1, padding=0)
+
+ # Transformer classifier on the bottleneck
+ self.patch_transformer = PatchTransformerEncoder(
+ btlnck_features, 1, 128, use_class_token=True)
+ self.mlp_classifier = nn.Sequential(
+ nn.Linear(128, 128),
+ nn.ReLU(),
+ nn.Linear(128, 2)
+ )
+
+ if bin_centers_type == "normed":
+ SeedBinRegressorLayer = SeedBinRegressor
+ Attractor = AttractorLayer
+ elif bin_centers_type == "softplus":
+ SeedBinRegressorLayer = SeedBinRegressorUnnormed
+ Attractor = AttractorLayerUnnormed
+ elif bin_centers_type == "hybrid1":
+ SeedBinRegressorLayer = SeedBinRegressor
+ Attractor = AttractorLayerUnnormed
+ elif bin_centers_type == "hybrid2":
+ SeedBinRegressorLayer = SeedBinRegressorUnnormed
+ Attractor = AttractorLayer
+ else:
+ raise ValueError(
+ "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'")
+ self.bin_centers_type = bin_centers_type
+ # We have bins for each bin conf.
+ # Create a map (ModuleDict) of 'name' -> seed_bin_regressor
+ self.seed_bin_regressors = nn.ModuleDict(
+ {conf['name']: SeedBinRegressorLayer(btlnck_features, conf["n_bins"], mlp_dim=bin_embedding_dim//2, min_depth=conf["min_depth"], max_depth=conf["max_depth"])
+ for conf in bin_conf}
+ )
+
+ self.seed_projector = Projector(
+ btlnck_features, bin_embedding_dim, mlp_dim=bin_embedding_dim//2)
+ self.projectors = nn.ModuleList([
+ Projector(num_out, bin_embedding_dim, mlp_dim=bin_embedding_dim//2)
+ for num_out in num_out_features
+ ])
+
+ # Create a map (ModuleDict) of 'name' -> attractors (ModuleList)
+ self.attractors = nn.ModuleDict(
+ {conf['name']: nn.ModuleList([
+ Attractor(bin_embedding_dim, n_attractors[i],
+ mlp_dim=bin_embedding_dim, alpha=attractor_alpha,
+ gamma=attractor_gamma, kind=attractor_kind,
+ attractor_type=attractor_type, memory_efficient=memory_efficient,
+ min_depth=conf["min_depth"], max_depth=conf["max_depth"])
+ for i in range(len(n_attractors))
+ ])
+ for conf in bin_conf}
+ )
+
+ last_in = N_MIDAS_OUT
+ # conditional log binomial for each bin conf
+ self.conditional_log_binomial = nn.ModuleDict(
+ {conf['name']: ConditionalLogBinomial(last_in, bin_embedding_dim, conf['n_bins'], bottleneck_factor=4, min_temp=self.min_temp, max_temp=self.max_temp)
+ for conf in bin_conf}
+ )
+
+ def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs):
+ """
+ Args:
+ x (torch.Tensor): Input image tensor of shape (B, C, H, W). Assumes all images are from the same domain.
+ return_final_centers (bool, optional): Whether to return the final centers of the attractors. Defaults to False.
+ denorm (bool, optional): Whether to denormalize the input image. Defaults to False.
+ return_probs (bool, optional): Whether to return the probabilities of the bins. Defaults to False.
+
+ Returns:
+ dict: Dictionary of outputs with keys:
+ - "rel_depth": Relative depth map of shape (B, 1, H, W)
+ - "metric_depth": Metric depth map of shape (B, 1, H, W)
+ - "domain_logits": Domain logits of shape (B, 2)
+ - "bin_centers": Bin centers of shape (B, N, H, W). Present only if return_final_centers is True
+ - "probs": Bin probabilities of shape (B, N, H, W). Present only if return_probs is True
+ """
+ b, c, h, w = x.shape
+ self.orig_input_width = w
+ self.orig_input_height = h
+ rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True)
+
+ outconv_activation = out[0]
+ btlnck = out[1]
+ x_blocks = out[2:]
+
+ x_d0 = self.conv2(btlnck)
+ x = x_d0
+
+ # Predict which path to take
+ embedding = self.patch_transformer(x)[0] # N, E
+ domain_logits = self.mlp_classifier(embedding) # N, 2
+ domain_vote = torch.softmax(domain_logits.sum(
+ dim=0, keepdim=True), dim=-1) # 1, 2
+
+ # Get the path
+ bin_conf_name = ["nyu", "kitti"][torch.argmax(
+ domain_vote, dim=-1).squeeze().item()]
+
+ try:
+ conf = [c for c in self.bin_conf if c.name == bin_conf_name][0]
+ except IndexError:
+ raise ValueError(
+ f"bin_conf_name {bin_conf_name} not found in bin_confs")
+
+ min_depth = conf['min_depth']
+ max_depth = conf['max_depth']
+
+ seed_bin_regressor = self.seed_bin_regressors[bin_conf_name]
+ _, seed_b_centers = seed_bin_regressor(x)
+ if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2':
+ b_prev = (seed_b_centers - min_depth)/(max_depth - min_depth)
+ else:
+ b_prev = seed_b_centers
+ prev_b_embedding = self.seed_projector(x)
+
+ attractors = self.attractors[bin_conf_name]
+ for projector, attractor, x in zip(self.projectors, attractors, x_blocks):
+ b_embedding = projector(x)
+ b, b_centers = attractor(
+ b_embedding, b_prev, prev_b_embedding, interpolate=True)
+ b_prev = b
+ prev_b_embedding = b_embedding
+
+ last = outconv_activation
+
+ b_centers = nn.functional.interpolate(
+ b_centers, last.shape[-2:], mode='bilinear', align_corners=True)
+ b_embedding = nn.functional.interpolate(
+ b_embedding, last.shape[-2:], mode='bilinear', align_corners=True)
+
+ clb = self.conditional_log_binomial[bin_conf_name]
+ x = clb(last, b_embedding)
+
+ # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor
+ # print(x.shape, b_centers.shape)
+ # b_centers = nn.functional.interpolate(b_centers, x.shape[-2:], mode='bilinear', align_corners=True)
+ out = torch.sum(x * b_centers, dim=1, keepdim=True)
+
+ output = dict(domain_logits=domain_logits, metric_depth=out)
+ if return_final_centers or return_probs:
+ output['bin_centers'] = b_centers
+
+ if return_probs:
+ output['probs'] = x
+ return output
+
+ def get_lr_params(self, lr):
+ """
+ Learning rate configuration for different layers of the model
+
+ Args:
+ lr (float) : Base learning rate
+ Returns:
+ list : list of parameters to optimize and their learning rates, in the format required by torch optimizers.
+ """
+ param_conf = []
+ if self.train_midas:
+ def get_rel_pos_params():
+ for name, p in self.core.core.pretrained.named_parameters():
+ if "relative_position" in name:
+ yield p
+
+ def get_enc_params_except_rel_pos():
+ for name, p in self.core.core.pretrained.named_parameters():
+ if "relative_position" not in name:
+ yield p
+
+ encoder_params = get_enc_params_except_rel_pos()
+ rel_pos_params = get_rel_pos_params()
+ midas_params = self.core.core.scratch.parameters()
+ midas_lr_factor = self.midas_lr_factor if self.is_midas_pretrained else 1.0
+ param_conf.extend([
+ {'params': encoder_params, 'lr': lr / self.encoder_lr_factor},
+ {'params': rel_pos_params, 'lr': lr / self.pos_enc_lr_factor},
+ {'params': midas_params, 'lr': lr / midas_lr_factor}
+ ])
+
+ remaining_modules = []
+ for name, child in self.named_children():
+ if name != 'core':
+ remaining_modules.append(child)
+ remaining_params = itertools.chain(
+ *[child.parameters() for child in remaining_modules])
+ param_conf.append({'params': remaining_params, 'lr': lr})
+ return param_conf
+
+ def get_conf_parameters(self, conf_name):
+ """
+ Returns parameters of all the ModuleDicts children that are exclusively used for the given bin configuration
+ """
+ params = []
+ for name, child in self.named_children():
+ if isinstance(child, nn.ModuleDict):
+ for bin_conf_name, module in child.items():
+ if bin_conf_name == conf_name:
+ params += list(module.parameters())
+ return params
+
+ def freeze_conf(self, conf_name):
+ """
+ Freezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration
+ """
+ for p in self.get_conf_parameters(conf_name):
+ p.requires_grad = False
+
+ def unfreeze_conf(self, conf_name):
+ """
+ Unfreezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration
+ """
+ for p in self.get_conf_parameters(conf_name):
+ p.requires_grad = True
+
+ def freeze_all_confs(self):
+ """
+ Freezes all the parameters of all the ModuleDicts children
+ """
+ for name, child in self.named_children():
+ if isinstance(child, nn.ModuleDict):
+ for bin_conf_name, module in child.items():
+ for p in module.parameters():
+ p.requires_grad = False
+
+ @staticmethod
+ def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs):
+ core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas,
+ train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs)
+ model = ZoeDepthNK(core, **kwargs)
+ if pretrained_resource:
+ pretrained_resource="local::./models/monoD/zoeDepth/ckpts/ZoeD_M12_NK.pt"
+ assert isinstance(pretrained_resource, str), "pretrained_resource must be a string"
+ model = load_state_from_resource(model, pretrained_resource)
+ return model
+
+ @staticmethod
+ def build_from_config(config):
+
+ return ZoeDepthNK.build(**config)
diff --git a/models/monoD/zoeDepth/utils/__init__.py b/models/monoD/zoeDepth/utils/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9
--- /dev/null
+++ b/models/monoD/zoeDepth/utils/__init__.py
@@ -0,0 +1,24 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
diff --git a/models/monoD/zoeDepth/utils/arg_utils.py b/models/monoD/zoeDepth/utils/arg_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..8a3004ec3679c0a40fd8961253733fb4343ad545
--- /dev/null
+++ b/models/monoD/zoeDepth/utils/arg_utils.py
@@ -0,0 +1,33 @@
+
+
+def infer_type(x): # hacky way to infer type from string args
+ if not isinstance(x, str):
+ return x
+
+ try:
+ x = int(x)
+ return x
+ except ValueError:
+ pass
+
+ try:
+ x = float(x)
+ return x
+ except ValueError:
+ pass
+
+ return x
+
+
+def parse_unknown(unknown_args):
+ clean = []
+ for a in unknown_args:
+ if "=" in a:
+ k, v = a.split("=")
+ clean.extend([k, v])
+ else:
+ clean.append(a)
+
+ keys = clean[::2]
+ values = clean[1::2]
+ return {k.replace("--", ""): infer_type(v) for k, v in zip(keys, values)}
diff --git a/models/monoD/zoeDepth/utils/config.py b/models/monoD/zoeDepth/utils/config.py
new file mode 100755
index 0000000000000000000000000000000000000000..c1b8209af82ae7803ffd1f6f4023b2ffda29f195
--- /dev/null
+++ b/models/monoD/zoeDepth/utils/config.py
@@ -0,0 +1,437 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import json
+import os
+
+from models.monoD.zoeDepth.utils.easydict import EasyDict as edict
+
+from models.monoD.zoeDepth.utils.arg_utils import infer_type
+import pathlib
+import platform
+
+ROOT = pathlib.Path(__file__).parent.parent.resolve()
+
+HOME_DIR = os.path.expanduser("~")
+
+COMMON_CONFIG = {
+ "save_dir": os.path.expanduser("~/shortcuts/monodepth3_checkpoints"),
+ "project": "ZoeDepth",
+ "tags": '',
+ "notes": "",
+ "gpu": None,
+ "root": ".",
+ "uid": None,
+ "print_losses": False
+}
+
+DATASETS_CONFIG = {
+ "kitti": {
+ "dataset": "kitti",
+ "min_depth": 0.001,
+ "max_depth": 80,
+ "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"),
+ "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"),
+ "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt",
+ "input_height": 352,
+ "input_width": 1216, # 704
+ "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"),
+ "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"),
+ "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt",
+
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 80,
+
+ "do_random_rotate": True,
+ "degree": 1.0,
+ "do_kb_crop": True,
+ "garg_crop": True,
+ "eigen_crop": False,
+ "use_right": False
+ },
+ "kitti_test": {
+ "dataset": "kitti",
+ "min_depth": 0.001,
+ "max_depth": 80,
+ "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"),
+ "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"),
+ "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt",
+ "input_height": 352,
+ "input_width": 1216,
+ "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"),
+ "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"),
+ "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt",
+
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 80,
+
+ "do_random_rotate": False,
+ "degree": 1.0,
+ "do_kb_crop": True,
+ "garg_crop": True,
+ "eigen_crop": False,
+ "use_right": False
+ },
+ "nyu": {
+ "dataset": "nyu",
+ "avoid_boundary": False,
+ "min_depth": 1e-3, # originally 0.1
+ "max_depth": 10,
+ "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"),
+ "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"),
+ "filenames_file": "./train_test_inputs/nyudepthv2_train_files_with_gt.txt",
+ "input_height": 480,
+ "input_width": 640,
+ "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"),
+ "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"),
+ "filenames_file_eval": "./train_test_inputs/nyudepthv2_test_files_with_gt.txt",
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 10,
+ "min_depth_diff": -10,
+ "max_depth_diff": 10,
+
+ "do_random_rotate": True,
+ "degree": 1.0,
+ "do_kb_crop": False,
+ "garg_crop": False,
+ "eigen_crop": True
+ },
+ "ibims": {
+ "dataset": "ibims",
+ "ibims_root": os.path.join(HOME_DIR, "shortcuts/datasets/ibims/ibims1_core_raw/"),
+ "eigen_crop": True,
+ "garg_crop": False,
+ "do_kb_crop": False,
+ "min_depth_eval": 0,
+ "max_depth_eval": 10,
+ "min_depth": 1e-3,
+ "max_depth": 10
+ },
+ "sunrgbd": {
+ "dataset": "sunrgbd",
+ "sunrgbd_root": os.path.join(HOME_DIR, "shortcuts/datasets/SUNRGBD/test/"),
+ "eigen_crop": True,
+ "garg_crop": False,
+ "do_kb_crop": False,
+ "min_depth_eval": 0,
+ "max_depth_eval": 8,
+ "min_depth": 1e-3,
+ "max_depth": 10
+ },
+ "diml_indoor": {
+ "dataset": "diml_indoor",
+ "diml_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_indoor_test/"),
+ "eigen_crop": True,
+ "garg_crop": False,
+ "do_kb_crop": False,
+ "min_depth_eval": 0,
+ "max_depth_eval": 10,
+ "min_depth": 1e-3,
+ "max_depth": 10
+ },
+ "diml_outdoor": {
+ "dataset": "diml_outdoor",
+ "diml_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_outdoor_test/"),
+ "eigen_crop": False,
+ "garg_crop": True,
+ "do_kb_crop": False,
+ "min_depth_eval": 2,
+ "max_depth_eval": 80,
+ "min_depth": 1e-3,
+ "max_depth": 80
+ },
+ "diode_indoor": {
+ "dataset": "diode_indoor",
+ "diode_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_indoor/"),
+ "eigen_crop": True,
+ "garg_crop": False,
+ "do_kb_crop": False,
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 10,
+ "min_depth": 1e-3,
+ "max_depth": 10
+ },
+ "diode_outdoor": {
+ "dataset": "diode_outdoor",
+ "diode_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_outdoor/"),
+ "eigen_crop": False,
+ "garg_crop": True,
+ "do_kb_crop": False,
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 80,
+ "min_depth": 1e-3,
+ "max_depth": 80
+ },
+ "hypersim_test": {
+ "dataset": "hypersim_test",
+ "hypersim_test_root": os.path.join(HOME_DIR, "shortcuts/datasets/hypersim_test/"),
+ "eigen_crop": True,
+ "garg_crop": False,
+ "do_kb_crop": False,
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 80,
+ "min_depth": 1e-3,
+ "max_depth": 10
+ },
+ "vkitti": {
+ "dataset": "vkitti",
+ "vkitti_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti_test/"),
+ "eigen_crop": False,
+ "garg_crop": True,
+ "do_kb_crop": True,
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 80,
+ "min_depth": 1e-3,
+ "max_depth": 80
+ },
+ "vkitti2": {
+ "dataset": "vkitti2",
+ "vkitti2_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti2/"),
+ "eigen_crop": False,
+ "garg_crop": True,
+ "do_kb_crop": True,
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 80,
+ "min_depth": 1e-3,
+ "max_depth": 80,
+ },
+ "ddad": {
+ "dataset": "ddad",
+ "ddad_root": os.path.join(HOME_DIR, "shortcuts/datasets/ddad/ddad_val/"),
+ "eigen_crop": False,
+ "garg_crop": True,
+ "do_kb_crop": True,
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 80,
+ "min_depth": 1e-3,
+ "max_depth": 80,
+ },
+}
+
+ALL_INDOOR = ["nyu", "ibims", "sunrgbd", "diode_indoor", "hypersim_test"]
+ALL_OUTDOOR = ["kitti", "diml_outdoor", "diode_outdoor", "vkitti2", "ddad"]
+ALL_EVAL_DATASETS = ALL_INDOOR + ALL_OUTDOOR
+
+COMMON_TRAINING_CONFIG = {
+ "dataset": "nyu",
+ "distributed": True,
+ "workers": 16,
+ "clip_grad": 0.1,
+ "use_shared_dict": False,
+ "shared_dict": None,
+ "use_amp": False,
+
+ "aug": True,
+ "random_crop": False,
+ "random_translate": False,
+ "translate_prob": 0.2,
+ "max_translation": 100,
+
+ "validate_every": 0.25,
+ "log_images_every": 0.1,
+ "prefetch": False,
+}
+
+
+def flatten(config, except_keys=('bin_conf')):
+ def recurse(inp):
+ if isinstance(inp, dict):
+ for key, value in inp.items():
+ if key in except_keys:
+ yield (key, value)
+ if isinstance(value, dict):
+ yield from recurse(value)
+ else:
+ yield (key, value)
+
+ return dict(list(recurse(config)))
+
+
+def split_combined_args(kwargs):
+ """Splits the arguments that are combined with '__' into multiple arguments.
+ Combined arguments should have equal number of keys and values.
+ Keys are separated by '__' and Values are separated with ';'.
+ For example, '__n_bins__lr=256;0.001'
+
+ Args:
+ kwargs (dict): key-value pairs of arguments where key-value is optionally combined according to the above format.
+
+ Returns:
+ dict: Parsed dict with the combined arguments split into individual key-value pairs.
+ """
+ new_kwargs = dict(kwargs)
+ for key, value in kwargs.items():
+ if key.startswith("__"):
+ keys = key.split("__")[1:]
+ values = value.split(";")
+ assert len(keys) == len(
+ values), f"Combined arguments should have equal number of keys and values. Keys are separated by '__' and Values are separated with ';'. For example, '__n_bins__lr=256;0.001. Given (keys,values) is ({keys}, {values})"
+ for k, v in zip(keys, values):
+ new_kwargs[k] = v
+ return new_kwargs
+
+
+def parse_list(config, key, dtype=int):
+ """Parse a list of values for the key if the value is a string. The values are separated by a comma.
+ Modifies the config in place.
+ """
+ if key in config:
+ if isinstance(config[key], str):
+ config[key] = list(map(dtype, config[key].split(',')))
+ assert isinstance(config[key], list) and all([isinstance(e, dtype) for e in config[key]]
+ ), f"{key} should be a list of values dtype {dtype}. Given {config[key]} of type {type(config[key])} with values of type {[type(e) for e in config[key]]}."
+
+
+def get_model_config(model_name, model_version=None):
+ """Find and parse the .json config file for the model.
+
+ Args:
+ model_name (str): name of the model. The config file should be named config_{model_name}[_{model_version}].json under the models/{model_name} directory.
+ model_version (str, optional): Specific config version. If specified config_{model_name}_{model_version}.json is searched for and used. Otherwise config_{model_name}.json is used. Defaults to None.
+
+ Returns:
+ easydict: the config dictionary for the model.
+ """
+ config_fname = f"config_{model_name}_{model_version}.json" if model_version is not None else f"config_{model_name}.json"
+ config_file = os.path.join(ROOT, "models", model_name, config_fname)
+ if not os.path.exists(config_file):
+ return None
+
+ with open(config_file, "r") as f:
+ config = edict(json.load(f))
+
+ # handle dictionary inheritance
+ # only training config is supported for inheritance
+ if "inherit" in config.train and config.train.inherit is not None:
+ inherit_config = get_model_config(config.train["inherit"]).train
+ for key, value in inherit_config.items():
+ if key not in config.train:
+ config.train[key] = value
+ return edict(config)
+
+
+def update_model_config(config, mode, model_name, model_version=None, strict=False):
+ model_config = get_model_config(model_name, model_version)
+ if model_config is not None:
+ config = {**config, **
+ flatten({**model_config.model, **model_config[mode]})}
+ elif strict:
+ raise ValueError(f"Config file for model {model_name} not found.")
+ return config
+
+
+def check_choices(name, value, choices):
+ # return # No checks in dev branch
+ if value not in choices:
+ raise ValueError(f"{name} {value} not in supported choices {choices}")
+
+
+KEYS_TYPE_BOOL = ["use_amp", "distributed", "use_shared_dict", "same_lr", "aug", "three_phase",
+ "prefetch", "cycle_momentum"] # Casting is not necessary as their int casted values in config are 0 or 1
+
+
+def get_config(model_name, mode='train', dataset=None, **overwrite_kwargs):
+ """Main entry point to get the config for the model.
+
+ Args:
+ model_name (str): name of the desired model.
+ mode (str, optional): "train" or "infer". Defaults to 'train'.
+ dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults to None.
+
+ Keyword Args: key-value pairs of arguments to overwrite the default config.
+
+ The order of precedence for overwriting the config is (Higher precedence first):
+ # 1. overwrite_kwargs
+ # 2. "config_version": Config file version if specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{config_version}.json
+ # 3. "version_name": Default Model version specific config specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{version_name}.json
+ # 4. common_config: Default config for all models specified in COMMON_CONFIG
+
+ Returns:
+ easydict: The config dictionary for the model.
+ """
+
+
+ check_choices("Model", model_name, ["zoedepth", "zoedepth_nk"])
+ check_choices("Mode", mode, ["train", "infer", "eval"])
+ if mode == "train":
+ check_choices("Dataset", dataset, ["nyu", "kitti", "mix", None])
+
+ config = flatten({**COMMON_CONFIG, **COMMON_TRAINING_CONFIG})
+ config = update_model_config(config, mode, model_name)
+
+ # update with model version specific config
+ version_name = overwrite_kwargs.get("version_name", config["version_name"])
+ config = update_model_config(config, mode, model_name, version_name)
+
+ # update with config version if specified
+ config_version = overwrite_kwargs.get("config_version", None)
+ if config_version is not None:
+ print("Overwriting config with config_version", config_version)
+ config = update_model_config(config, mode, model_name, config_version)
+
+ # update with overwrite_kwargs
+ # Combined args are useful for hyperparameter search
+ overwrite_kwargs = split_combined_args(overwrite_kwargs)
+ config = {**config, **overwrite_kwargs}
+
+ # Casting to bool # TODO: Not necessary. Remove and test
+ for key in KEYS_TYPE_BOOL:
+ if key in config:
+ config[key] = bool(config[key])
+
+ # Model specific post processing of config
+ parse_list(config, "n_attractors")
+
+ # adjust n_bins for each bin configuration if bin_conf is given and n_bins is passed in overwrite_kwargs
+ if 'bin_conf' in config and 'n_bins' in overwrite_kwargs:
+ bin_conf = config['bin_conf'] # list of dicts
+ n_bins = overwrite_kwargs['n_bins']
+ new_bin_conf = []
+ for conf in bin_conf:
+ conf['n_bins'] = n_bins
+ new_bin_conf.append(conf)
+ config['bin_conf'] = new_bin_conf
+
+ if mode == "train":
+ orig_dataset = dataset
+ if dataset == "mix":
+ dataset = 'nyu' # Use nyu as default for mix. Dataset config is changed accordingly while loading the dataloader
+ if dataset is not None:
+ config['project'] = f"MonoDepth3-{orig_dataset}" # Set project for wandb
+
+ if dataset is not None:
+ config['dataset'] = dataset
+ config = {**DATASETS_CONFIG[dataset], **config}
+
+
+ config['model'] = model_name
+ typed_config = {k: infer_type(v) for k, v in config.items()}
+ # add hostname to config
+ config['hostname'] = platform.node()
+ return edict(typed_config)
+
+
+def change_dataset(config, new_dataset):
+ config.update(DATASETS_CONFIG[new_dataset])
+ return config
diff --git a/models/monoD/zoeDepth/utils/easydict/__init__.py b/models/monoD/zoeDepth/utils/easydict/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..15928179b0182c6045d98bc0a7be1c6ca45f675e
--- /dev/null
+++ b/models/monoD/zoeDepth/utils/easydict/__init__.py
@@ -0,0 +1,158 @@
+"""
+EasyDict
+Copy/pasted from https://github.com/makinacorpus/easydict
+Original author: Mathieu Leplatre
+"""
+
+class EasyDict(dict):
+ """
+ Get attributes
+
+ >>> d = EasyDict({'foo':3})
+ >>> d['foo']
+ 3
+ >>> d.foo
+ 3
+ >>> d.bar
+ Traceback (most recent call last):
+ ...
+ AttributeError: 'EasyDict' object has no attribute 'bar'
+
+ Works recursively
+
+ >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
+ >>> isinstance(d.bar, dict)
+ True
+ >>> d.bar.x
+ 1
+
+ Bullet-proof
+
+ >>> EasyDict({})
+ {}
+ >>> EasyDict(d={})
+ {}
+ >>> EasyDict(None)
+ {}
+ >>> d = {'a': 1}
+ >>> EasyDict(**d)
+ {'a': 1}
+ >>> EasyDict((('a', 1), ('b', 2)))
+ {'a': 1, 'b': 2}
+
+ Set attributes
+
+ >>> d = EasyDict()
+ >>> d.foo = 3
+ >>> d.foo
+ 3
+ >>> d.bar = {'prop': 'value'}
+ >>> d.bar.prop
+ 'value'
+ >>> d
+ {'foo': 3, 'bar': {'prop': 'value'}}
+ >>> d.bar.prop = 'newer'
+ >>> d.bar.prop
+ 'newer'
+
+
+ Values extraction
+
+ >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
+ >>> isinstance(d.bar, list)
+ True
+ >>> from operator import attrgetter
+ >>> list(map(attrgetter('x'), d.bar))
+ [1, 3]
+ >>> list(map(attrgetter('y'), d.bar))
+ [2, 4]
+ >>> d = EasyDict()
+ >>> list(d.keys())
+ []
+ >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
+ >>> d.foo
+ 3
+ >>> d.bar.x
+ 1
+
+ Still like a dict though
+
+ >>> o = EasyDict({'clean':True})
+ >>> list(o.items())
+ [('clean', True)]
+
+ And like a class
+
+ >>> class Flower(EasyDict):
+ ... power = 1
+ ...
+ >>> f = Flower()
+ >>> f.power
+ 1
+ >>> f = Flower({'height': 12})
+ >>> f.height
+ 12
+ >>> f['power']
+ 1
+ >>> sorted(f.keys())
+ ['height', 'power']
+
+ update and pop items
+ >>> d = EasyDict(a=1, b='2')
+ >>> e = EasyDict(c=3.0, a=9.0)
+ >>> d.update(e)
+ >>> d.c
+ 3.0
+ >>> d['c']
+ 3.0
+ >>> d.get('c')
+ 3.0
+ >>> d.update(a=4, b=4)
+ >>> d.b
+ 4
+ >>> d.pop('a')
+ 4
+ >>> d.a
+ Traceback (most recent call last):
+ ...
+ AttributeError: 'EasyDict' object has no attribute 'a'
+ """
+ def __init__(self, d=None, **kwargs):
+ if d is None:
+ d = {}
+ else:
+ d = dict(d)
+ if kwargs:
+ d.update(**kwargs)
+ for k, v in d.items():
+ setattr(self, k, v)
+ # Class attributes
+ for k in self.__class__.__dict__.keys():
+ if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'):
+ setattr(self, k, getattr(self, k))
+
+ def __setattr__(self, name, value):
+ if isinstance(value, (list, tuple)):
+ value = [self.__class__(x)
+ if isinstance(x, dict) else x for x in value]
+ elif isinstance(value, dict) and not isinstance(value, self.__class__):
+ value = self.__class__(value)
+ super(EasyDict, self).__setattr__(name, value)
+ super(EasyDict, self).__setitem__(name, value)
+
+ __setitem__ = __setattr__
+
+ def update(self, e=None, **f):
+ d = e or dict()
+ d.update(f)
+ for k in d:
+ setattr(self, k, d[k])
+
+ def pop(self, k, d=None):
+ delattr(self, k)
+ return super(EasyDict, self).pop(k, d)
+
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
\ No newline at end of file
diff --git a/models/monoD/zoeDepth/utils/geometry.py b/models/monoD/zoeDepth/utils/geometry.py
new file mode 100755
index 0000000000000000000000000000000000000000..e3da8c75b5a8e39b4b58a4dcd827b84d79b9115c
--- /dev/null
+++ b/models/monoD/zoeDepth/utils/geometry.py
@@ -0,0 +1,98 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import numpy as np
+
+def get_intrinsics(H,W):
+ """
+ Intrinsics for a pinhole camera model.
+ Assume fov of 55 degrees and central principal point.
+ """
+ f = 0.5 * W / np.tan(0.5 * 55 * np.pi / 180.0)
+ cx = 0.5 * W
+ cy = 0.5 * H
+ return np.array([[f, 0, cx],
+ [0, f, cy],
+ [0, 0, 1]])
+
+def depth_to_points(depth, R=None, t=None):
+
+ K = get_intrinsics(depth.shape[1], depth.shape[2])
+ Kinv = np.linalg.inv(K)
+ if R is None:
+ R = np.eye(3)
+ if t is None:
+ t = np.zeros(3)
+
+ # M converts from your coordinate to PyTorch3D's coordinate system
+ M = np.eye(3)
+ M[0, 0] = -1.0
+ M[1, 1] = -1.0
+
+ height, width = depth.shape[1:3]
+
+ x = np.arange(width)
+ y = np.arange(height)
+ coord = np.stack(np.meshgrid(x, y), -1)
+ coord = np.concatenate((coord, np.ones_like(coord)[:, :, [0]]), -1) # z=1
+ coord = coord.astype(np.float32)
+ # coord = torch.as_tensor(coord, dtype=torch.float32, device=device)
+ coord = coord[None] # bs, h, w, 3
+
+ D = depth[:, :, :, None, None]
+ # print(D.shape, Kinv[None, None, None, ...].shape, coord[:, :, :, :, None].shape )
+ pts3D_1 = D * Kinv[None, None, None, ...] @ coord[:, :, :, :, None]
+ # pts3D_1 live in your coordinate system. Convert them to Py3D's
+ pts3D_1 = M[None, None, None, ...] @ pts3D_1
+ # from reference to targe tviewpoint
+ pts3D_2 = R[None, None, None, ...] @ pts3D_1 + t[None, None, None, :, None]
+ # pts3D_2 = pts3D_1
+ # depth_2 = pts3D_2[:, :, :, 2, :] # b,1,h,w
+ return pts3D_2[:, :, :, :3, 0][0]
+
+
+def create_triangles(h, w, mask=None):
+ """
+ Reference: https://github.com/google-research/google-research/blob/e96197de06613f1b027d20328e06d69829fa5a89/infinite_nature/render_utils.py#L68
+ Creates mesh triangle indices from a given pixel grid size.
+ This function is not and need not be differentiable as triangle indices are
+ fixed.
+ Args:
+ h: (int) denoting the height of the image.
+ w: (int) denoting the width of the image.
+ Returns:
+ triangles: 2D numpy array of indices (int) with shape (2(W-1)(H-1) x 3)
+ """
+ x, y = np.meshgrid(range(w - 1), range(h - 1))
+ tl = y * w + x
+ tr = y * w + x + 1
+ bl = (y + 1) * w + x
+ br = (y + 1) * w + x + 1
+ triangles = np.array([tl, bl, tr, br, tr, bl])
+ triangles = np.transpose(triangles, (1, 2, 0)).reshape(
+ ((w - 1) * (h - 1) * 2, 3))
+ if mask is not None:
+ mask = mask.reshape(-1)
+ triangles = triangles[mask[triangles].all(1)]
+ return triangles
diff --git a/models/monoD/zoeDepth/utils/misc.py b/models/monoD/zoeDepth/utils/misc.py
new file mode 100755
index 0000000000000000000000000000000000000000..4bbe403d3669829eecdf658458c76aa5e87e2b33
--- /dev/null
+++ b/models/monoD/zoeDepth/utils/misc.py
@@ -0,0 +1,368 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+"""Miscellaneous utility functions."""
+
+from scipy import ndimage
+
+import base64
+import math
+import re
+from io import BytesIO
+
+import matplotlib
+import matplotlib.cm
+import numpy as np
+import requests
+import torch
+import torch.distributed as dist
+import torch.nn
+import torch.nn as nn
+import torch.utils.data.distributed
+from PIL import Image
+from torchvision.transforms import ToTensor
+
+
+class RunningAverage:
+ def __init__(self):
+ self.avg = 0
+ self.count = 0
+
+ def append(self, value):
+ self.avg = (value + self.count * self.avg) / (self.count + 1)
+ self.count += 1
+
+ def get_value(self):
+ return self.avg
+
+
+def denormalize(x):
+ """Reverses the imagenet normalization applied to the input.
+
+ Args:
+ x (torch.Tensor - shape(N,3,H,W)): input tensor
+
+ Returns:
+ torch.Tensor - shape(N,3,H,W): Denormalized input
+ """
+ mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
+ std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
+ return x * std + mean
+
+
+class RunningAverageDict:
+ """A dictionary of running averages."""
+ def __init__(self):
+ self._dict = None
+
+ def update(self, new_dict):
+ if new_dict is None:
+ return
+
+ if self._dict is None:
+ self._dict = dict()
+ for key, value in new_dict.items():
+ self._dict[key] = RunningAverage()
+
+ for key, value in new_dict.items():
+ self._dict[key].append(value)
+
+ def get_value(self):
+ if self._dict is None:
+ return None
+ return {key: value.get_value() for key, value in self._dict.items()}
+
+
+def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
+ """Converts a depth map to a color image.
+
+ Args:
+ value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed
+ vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None.
+ vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None.
+ cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'.
+ invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99.
+ invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None.
+ background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255).
+ gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False.
+ value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None.
+
+ Returns:
+ numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4)
+ """
+ if isinstance(value, torch.Tensor):
+ value = value.detach().cpu().numpy()
+
+ value = value.squeeze()
+ if invalid_mask is None:
+ invalid_mask = value == invalid_val
+ mask = np.logical_not(invalid_mask)
+
+ # normalize
+ vmin = np.percentile(value[mask],2) if vmin is None else vmin
+ vmax = np.percentile(value[mask],85) if vmax is None else vmax
+ if vmin != vmax:
+ value = (value - vmin) / (vmax - vmin) # vmin..vmax
+ else:
+ # Avoid 0-division
+ value = value * 0.
+
+ # squeeze last dim if it exists
+ # grey out the invalid values
+
+ value[invalid_mask] = np.nan
+ cmapper = matplotlib.cm.get_cmap(cmap)
+ if value_transform:
+ value = value_transform(value)
+ # value = value / value.max()
+ value = cmapper(value, bytes=True) # (nxmx4)
+
+ # img = value[:, :, :]
+ img = value[...]
+ img[invalid_mask] = background_color
+
+ # return img.transpose((2, 0, 1))
+ if gamma_corrected:
+ # gamma correction
+ img = img / 255
+ img = np.power(img, 2.2)
+ img = img * 255
+ img = img.astype(np.uint8)
+ return img
+
+
+def count_parameters(model, include_all=False):
+ return sum(p.numel() for p in model.parameters() if p.requires_grad or include_all)
+
+
+def compute_errors(gt, pred):
+ """Compute metrics for 'pred' compared to 'gt'
+
+ Args:
+ gt (numpy.ndarray): Ground truth values
+ pred (numpy.ndarray): Predicted values
+
+ gt.shape should be equal to pred.shape
+
+ Returns:
+ dict: Dictionary containing the following metrics:
+ 'a1': Delta1 accuracy: Fraction of pixels that are within a scale factor of 1.25
+ 'a2': Delta2 accuracy: Fraction of pixels that are within a scale factor of 1.25^2
+ 'a3': Delta3 accuracy: Fraction of pixels that are within a scale factor of 1.25^3
+ 'abs_rel': Absolute relative error
+ 'rmse': Root mean squared error
+ 'log_10': Absolute log10 error
+ 'sq_rel': Squared relative error
+ 'rmse_log': Root mean squared error on the log scale
+ 'silog': Scale invariant log error
+ """
+ thresh = np.maximum((gt / pred), (pred / gt))
+ a1 = (thresh < 1.25).mean()
+ a2 = (thresh < 1.25 ** 2).mean()
+ a3 = (thresh < 1.25 ** 3).mean()
+
+ abs_rel = np.mean(np.abs(gt - pred) / gt)
+ sq_rel = np.mean(((gt - pred) ** 2) / gt)
+
+ rmse = (gt - pred) ** 2
+ rmse = np.sqrt(rmse.mean())
+
+ rmse_log = (np.log(gt) - np.log(pred)) ** 2
+ rmse_log = np.sqrt(rmse_log.mean())
+
+ err = np.log(pred) - np.log(gt)
+ silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100
+
+ log_10 = (np.abs(np.log10(gt) - np.log10(pred))).mean()
+ return dict(a1=a1, a2=a2, a3=a3, abs_rel=abs_rel, rmse=rmse, log_10=log_10, rmse_log=rmse_log,
+ silog=silog, sq_rel=sq_rel)
+
+
+def compute_metrics(gt, pred, interpolate=True, garg_crop=False, eigen_crop=True, dataset='nyu', min_depth_eval=0.1, max_depth_eval=10, **kwargs):
+ """Compute metrics of predicted depth maps. Applies cropping and masking as necessary or specified via arguments. Refer to compute_errors for more details on metrics.
+ """
+ if 'config' in kwargs:
+ config = kwargs['config']
+ garg_crop = config.garg_crop
+ eigen_crop = config.eigen_crop
+ min_depth_eval = config.min_depth_eval
+ max_depth_eval = config.max_depth_eval
+
+ if gt.shape[-2:] != pred.shape[-2:] and interpolate:
+ pred = nn.functional.interpolate(
+ pred, gt.shape[-2:], mode='bilinear', align_corners=True)
+
+ pred = pred.squeeze().cpu().numpy()
+ pred[pred < min_depth_eval] = min_depth_eval
+ pred[pred > max_depth_eval] = max_depth_eval
+ pred[np.isinf(pred)] = max_depth_eval
+ pred[np.isnan(pred)] = min_depth_eval
+
+ gt_depth = gt.squeeze().cpu().numpy()
+ valid_mask = np.logical_and(
+ gt_depth > min_depth_eval, gt_depth < max_depth_eval)
+
+ if garg_crop or eigen_crop:
+ gt_height, gt_width = gt_depth.shape
+ eval_mask = np.zeros(valid_mask.shape)
+
+ if garg_crop:
+ eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height),
+ int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
+
+ elif eigen_crop:
+ # print("-"*10, " EIGEN CROP ", "-"*10)
+ if dataset == 'kitti':
+ eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height),
+ int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
+ else:
+ # assert gt_depth.shape == (480, 640), "Error: Eigen crop is currently only valid for (480, 640) images"
+ eval_mask[45:471, 41:601] = 1
+ else:
+ eval_mask = np.ones(valid_mask.shape)
+ valid_mask = np.logical_and(valid_mask, eval_mask)
+ return compute_errors(gt_depth[valid_mask], pred[valid_mask])
+
+
+#################################### Model uilts ################################################
+
+
+def parallelize(config, model, find_unused_parameters=True):
+
+ if config.gpu is not None:
+ torch.cuda.set_device(config.gpu)
+ model = model.cuda(config.gpu)
+
+ config.multigpu = False
+ if config.distributed:
+ # Use DDP
+ config.multigpu = True
+ config.rank = config.rank * config.ngpus_per_node + config.gpu
+ dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url,
+ world_size=config.world_size, rank=config.rank)
+ config.batch_size = int(config.batch_size / config.ngpus_per_node)
+ # config.batch_size = 8
+ config.workers = int(
+ (config.num_workers + config.ngpus_per_node - 1) / config.ngpus_per_node)
+ print("Device", config.gpu, "Rank", config.rank, "batch size",
+ config.batch_size, "Workers", config.workers)
+ torch.cuda.set_device(config.gpu)
+ model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ model = model.cuda(config.gpu)
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu], output_device=config.gpu,
+ find_unused_parameters=find_unused_parameters)
+
+ elif config.gpu is None:
+ # Use DP
+ config.multigpu = True
+ model = model.cuda()
+ model = torch.nn.DataParallel(model)
+
+ return model
+
+
+#################################################################################################
+
+
+#####################################################################################################
+
+
+class colors:
+ '''Colors class:
+ Reset all colors with colors.reset
+ Two subclasses fg for foreground and bg for background.
+ Use as colors.subclass.colorname.
+ i.e. colors.fg.red or colors.bg.green
+ Also, the generic bold, disable, underline, reverse, strikethrough,
+ and invisible work with the main class
+ i.e. colors.bold
+ '''
+ reset = '\033[0m'
+ bold = '\033[01m'
+ disable = '\033[02m'
+ underline = '\033[04m'
+ reverse = '\033[07m'
+ strikethrough = '\033[09m'
+ invisible = '\033[08m'
+
+ class fg:
+ black = '\033[30m'
+ red = '\033[31m'
+ green = '\033[32m'
+ orange = '\033[33m'
+ blue = '\033[34m'
+ purple = '\033[35m'
+ cyan = '\033[36m'
+ lightgrey = '\033[37m'
+ darkgrey = '\033[90m'
+ lightred = '\033[91m'
+ lightgreen = '\033[92m'
+ yellow = '\033[93m'
+ lightblue = '\033[94m'
+ pink = '\033[95m'
+ lightcyan = '\033[96m'
+
+ class bg:
+ black = '\033[40m'
+ red = '\033[41m'
+ green = '\033[42m'
+ orange = '\033[43m'
+ blue = '\033[44m'
+ purple = '\033[45m'
+ cyan = '\033[46m'
+ lightgrey = '\033[47m'
+
+
+def printc(text, color):
+ print(f"{color}{text}{colors.reset}")
+
+############################################
+
+def get_image_from_url(url):
+ response = requests.get(url)
+ img = Image.open(BytesIO(response.content)).convert("RGB")
+ return img
+
+def url_to_torch(url, size=(384, 384)):
+ img = get_image_from_url(url)
+ img = img.resize(size, Image.ANTIALIAS)
+ img = torch.from_numpy(np.asarray(img)).float()
+ img = img.permute(2, 0, 1)
+ img.div_(255)
+ return img
+
+def pil_to_batched_tensor(img):
+ return ToTensor()(img).unsqueeze(0)
+
+def save_raw_16bit(depth, fpath="raw.png"):
+ if isinstance(depth, torch.Tensor):
+ depth = depth.squeeze().cpu().numpy()
+
+ assert isinstance(depth, np.ndarray), "Depth must be a torch tensor or numpy array"
+ assert depth.ndim == 2, "Depth must be 2D"
+ depth = depth * 256 # scale for 16-bit png
+ depth = depth.astype(np.uint16)
+ depth = Image.fromarray(depth)
+ depth.save(fpath)
+ print("Saved raw depth to", fpath)
\ No newline at end of file
diff --git a/models/vggt/setup.py b/models/vggt/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..2774caecaa1df7dff23e77d3b6dd8175c933e2f2
--- /dev/null
+++ b/models/vggt/setup.py
@@ -0,0 +1,8 @@
+from setuptools import setup, find_packages
+
+setup(
+ name='vggt',
+ version='0.1',
+ packages=find_packages(),
+ description='vggt local package',
+)
\ No newline at end of file
diff --git a/models/vggt/vggt/__init__.py b/models/vggt/vggt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0519ecba6ea913e21689ec692e81e9e4973fbf73
--- /dev/null
+++ b/models/vggt/vggt/__init__.py
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/models/vggt/vggt/heads/camera_head.py b/models/vggt/vggt/heads/camera_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee95d35859ef0e4532d24b91ed1d6d208fa21a71
--- /dev/null
+++ b/models/vggt/vggt/heads/camera_head.py
@@ -0,0 +1,162 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from models.vggt.vggt.layers import Mlp
+from models.vggt.vggt.layers.block import Block
+from models.vggt.vggt.heads.head_act import activate_pose
+
+
+class CameraHead(nn.Module):
+ """
+ CameraHead predicts camera parameters from token representations using iterative refinement.
+
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
+ """
+
+ def __init__(
+ self,
+ dim_in: int = 2048,
+ trunk_depth: int = 4,
+ pose_encoding_type: str = "absT_quaR_FoV",
+ num_heads: int = 16,
+ mlp_ratio: int = 4,
+ init_values: float = 0.01,
+ trans_act: str = "linear",
+ quat_act: str = "linear",
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
+ ):
+ super().__init__()
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ self.target_dim = 9
+ else:
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
+
+ self.trans_act = trans_act
+ self.quat_act = quat_act
+ self.fl_act = fl_act
+ self.trunk_depth = trunk_depth
+
+ # Build the trunk using a sequence of transformer blocks.
+ self.trunk = nn.Sequential(
+ *[
+ Block(
+ dim=dim_in,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ init_values=init_values,
+ )
+ for _ in range(trunk_depth)
+ ]
+ )
+
+ # Normalizations for camera token and trunk output.
+ self.token_norm = nn.LayerNorm(dim_in)
+ self.trunk_norm = nn.LayerNorm(dim_in)
+
+ # Learnable empty camera pose token.
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
+
+ # Module for producing modulation parameters: shift, scale, and a gate.
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
+
+ # Adaptive layer normalization without affine parameters.
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
+ self.pose_branch = Mlp(
+ in_features=dim_in,
+ hidden_features=dim_in // 2,
+ out_features=self.target_dim,
+ drop=0,
+ )
+
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
+ """
+ Forward pass to predict camera parameters.
+
+ Args:
+ aggregated_tokens_list (list): List of token tensors from the network;
+ the last tensor is used for prediction.
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
+
+ Returns:
+ list: A list of predicted camera encodings (post-activation) from each iteration.
+ """
+ # Use tokens from the last block for camera prediction.
+ tokens = aggregated_tokens_list[-1]
+
+ # Extract the camera tokens
+ pose_tokens = tokens[:, :, 0]
+ pose_tokens = self.token_norm(pose_tokens)
+
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
+ return pred_pose_enc_list
+
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
+ """
+ Iteratively refine camera pose predictions.
+
+ Args:
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
+ num_iterations (int): Number of refinement iterations.
+
+ Returns:
+ list: List of activated camera encodings from each iteration.
+ """
+ B, S, C = pose_tokens.shape # S is expected to be 1.
+ pred_pose_enc = None
+ pred_pose_enc_list = []
+
+ for _ in range(num_iterations):
+ # Use a learned empty pose for the first iteration.
+ if pred_pose_enc is None:
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
+ else:
+ # Detach the previous prediction to avoid backprop through time.
+ pred_pose_enc = pred_pose_enc.detach()
+ module_input = self.embed_pose(pred_pose_enc)
+
+ # Generate modulation parameters and split them into shift, scale, and gate components.
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
+
+ # Adaptive layer normalization and modulation.
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
+
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
+ # Compute the delta update for the pose encoding.
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
+
+ if pred_pose_enc is None:
+ pred_pose_enc = pred_pose_enc_delta
+ else:
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
+
+ # Apply final activation functions for translation, quaternion, and field-of-view.
+ activated_pose = activate_pose(
+ pred_pose_enc,
+ trans_act=self.trans_act,
+ quat_act=self.quat_act,
+ fl_act=self.fl_act,
+ )
+ pred_pose_enc_list.append(activated_pose)
+
+ return pred_pose_enc_list
+
+
+def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
+ """
+ Modulate the input tensor using scaling and shifting parameters.
+ """
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
+ return x * (1 + scale) + shift
diff --git a/models/vggt/vggt/heads/dpt_head.py b/models/vggt/vggt/heads/dpt_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa5cf6e1cee2eb7cb2ad9538d0d168f97a590382
--- /dev/null
+++ b/models/vggt/vggt/heads/dpt_head.py
@@ -0,0 +1,497 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
+
+
+import os
+from typing import List, Dict, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .head_act import activate_head
+from .utils import create_uv_grid, position_grid_to_embed
+
+
+class DPTHead(nn.Module):
+ """
+ DPT Head for dense prediction tasks.
+
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
+ backbone and produces dense predictions by fusing multi-scale features.
+
+ Args:
+ dim_in (int): Input dimension (channels).
+ patch_size (int, optional): Patch size. Default is 14.
+ output_dim (int, optional): Number of output channels. Default is 4.
+ activation (str, optional): Activation type. Default is "inv_log".
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
+ out_channels (List[int], optional): Output channels for each intermediate layer.
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
+ """
+
+ def __init__(
+ self,
+ dim_in: int,
+ patch_size: int = 14,
+ output_dim: int = 4,
+ activation: str = "inv_log",
+ conf_activation: str = "expp1",
+ features: int = 256,
+ out_channels: List[int] = [256, 512, 1024, 1024],
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
+ pos_embed: bool = True,
+ feature_only: bool = False,
+ down_ratio: int = 1,
+ ) -> None:
+ super(DPTHead, self).__init__()
+ self.patch_size = patch_size
+ self.activation = activation
+ self.conf_activation = conf_activation
+ self.pos_embed = pos_embed
+ self.feature_only = feature_only
+ self.down_ratio = down_ratio
+ self.intermediate_layer_idx = intermediate_layer_idx
+
+ self.norm = nn.LayerNorm(dim_in)
+
+ # Projection layers for each output channel from tokens.
+ self.projects = nn.ModuleList(
+ [
+ nn.Conv2d(
+ in_channels=dim_in,
+ out_channels=oc,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ for oc in out_channels
+ ]
+ )
+
+ # Resize layers for upsampling feature maps.
+ self.resize_layers = nn.ModuleList(
+ [
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
+ ),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
+ ),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
+ ),
+ ]
+ )
+
+ self.scratch = _make_scratch(
+ out_channels,
+ features,
+ expand=False,
+ )
+
+ # Attach additional modules to scratch.
+ self.scratch.stem_transpose = None
+ self.scratch.refinenet1 = _make_fusion_block(features)
+ self.scratch.refinenet2 = _make_fusion_block(features)
+ self.scratch.refinenet3 = _make_fusion_block(features)
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
+
+ head_features_1 = features
+ head_features_2 = 32
+
+ if feature_only:
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
+ else:
+ self.scratch.output_conv1 = nn.Conv2d(
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
+ )
+ conv2_in_channels = head_features_1 // 2
+
+ self.scratch.output_conv2 = nn.Sequential(
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
+ )
+
+ def forward(
+ self,
+ aggregated_tokens_list: List[torch.Tensor],
+ images: torch.Tensor,
+ patch_start_idx: int,
+ frames_chunk_size: int = 8,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Forward pass through the DPT head, supports processing by chunking frames.
+ Args:
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
+ If None or larger than S, all frames are processed at once. Default: 8.
+
+ Returns:
+ Tensor or Tuple[Tensor, Tensor]:
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
+ """
+ B, S, _, H, W = images.shape
+
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
+ if frames_chunk_size is None or frames_chunk_size >= S:
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
+
+ # Otherwise, process frames in chunks to manage memory usage
+ assert frames_chunk_size > 0
+
+ # Process frames in batches
+ all_preds = []
+ all_conf = []
+
+ for frames_start_idx in range(0, S, frames_chunk_size):
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
+
+ # Process batch of frames
+ if self.feature_only:
+ chunk_output = self._forward_impl(
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
+ )
+ all_preds.append(chunk_output)
+ else:
+ chunk_preds, chunk_conf = self._forward_impl(
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
+ )
+ all_preds.append(chunk_preds)
+ all_conf.append(chunk_conf)
+
+ # Concatenate results along the sequence dimension
+ if self.feature_only:
+ return torch.cat(all_preds, dim=1)
+ else:
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
+
+ def _forward_impl(
+ self,
+ aggregated_tokens_list: List[torch.Tensor],
+ images: torch.Tensor,
+ patch_start_idx: int,
+ frames_start_idx: int = None,
+ frames_end_idx: int = None,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Implementation of the forward pass through the DPT head.
+
+ This method processes a specific chunk of frames from the sequence.
+
+ Args:
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
+ images (Tensor): Input images with shape [B, S, 3, H, W].
+ patch_start_idx (int): Starting index for patch tokens.
+ frames_start_idx (int, optional): Starting index for frames to process.
+ frames_end_idx (int, optional): Ending index for frames to process.
+
+ Returns:
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
+ """
+ if frames_start_idx is not None and frames_end_idx is not None:
+ images = images[:, frames_start_idx:frames_end_idx].contiguous()
+
+ B, S, _, H, W = images.shape
+
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
+
+ out = []
+ dpt_idx = 0
+
+ for layer_idx in self.intermediate_layer_idx:
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
+
+ # Select frames if processing a chunk
+ if frames_start_idx is not None and frames_end_idx is not None:
+ x = x[:, frames_start_idx:frames_end_idx]
+
+ x = x.view(B * S, -1, x.shape[-1])
+
+ x = self.norm(x)
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[dpt_idx](x)
+ if self.pos_embed:
+ x = self._apply_pos_embed(x, W, H)
+ x = self.resize_layers[dpt_idx](x)
+
+ out.append(x)
+ dpt_idx += 1
+
+ # Fuse features from multiple layers.
+ out = self.scratch_forward(out)
+ # Interpolate fused output to match target image resolution.
+ out = custom_interpolate(
+ out,
+ (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
+ mode="bilinear",
+ align_corners=True,
+ )
+
+ if self.pos_embed:
+ out = self._apply_pos_embed(out, W, H)
+
+ if self.feature_only:
+ return out.view(B, S, *out.shape[1:])
+
+ out = self.scratch.output_conv2(out)
+ preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
+
+ preds = preds.view(B, S, *preds.shape[1:])
+ conf = conf.view(B, S, *conf.shape[1:])
+ return preds, conf
+
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
+ """
+ Apply positional embedding to tensor x.
+ """
+ patch_w = x.shape[-1]
+ patch_h = x.shape[-2]
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
+ pos_embed = pos_embed * ratio
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
+ return x + pos_embed
+
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
+ """
+ Forward pass through the fusion blocks.
+
+ Args:
+ features (List[Tensor]): List of feature maps from different layers.
+
+ Returns:
+ Tensor: Fused feature map.
+ """
+ layer_1, layer_2, layer_3, layer_4 = features
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ del layer_4_rn, layer_4
+
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
+ del layer_3_rn, layer_3
+
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
+ del layer_2_rn, layer_2
+
+ out = self.scratch.refinenet1(out, layer_1_rn)
+ del layer_1_rn, layer_1
+
+ out = self.scratch.output_conv1(out)
+ return out
+
+
+################################################################################
+# Modules
+################################################################################
+
+
+def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(inplace=True),
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=size,
+ has_residual=has_residual,
+ groups=groups,
+ )
+
+
+def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
+ scratch = nn.Module()
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ return scratch
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module."""
+
+ def __init__(self, features, activation, bn, groups=1):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+ self.groups = groups
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ self.norm1 = None
+ self.norm2 = None
+
+ self.activation = activation
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.norm1 is not None:
+ out = self.norm1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.norm2 is not None:
+ out = self.norm2(out)
+
+ return self.skip_add.add(out, x)
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block."""
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=None,
+ has_residual=True,
+ groups=1,
+ ):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+ self.groups = groups
+ self.expand = expand
+ out_features = features
+ if self.expand == True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
+ )
+
+ if has_residual:
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
+
+ self.has_residual = has_residual
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+ self.size = size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if self.has_residual:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
+ output = self.out_conv(output)
+
+ return output
+
+
+def custom_interpolate(
+ x: torch.Tensor,
+ size: Tuple[int, int] = None,
+ scale_factor: float = None,
+ mode: str = "bilinear",
+ align_corners: bool = True,
+) -> torch.Tensor:
+ """
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
+ """
+ if size is None:
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
+
+ INT_MAX = 1610612736
+
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
+
+ if input_elements > INT_MAX:
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
+ interpolated_chunks = [
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
+ ]
+ x = torch.cat(interpolated_chunks, dim=0)
+ return x.contiguous()
+ else:
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
diff --git a/models/vggt/vggt/heads/head_act.py b/models/vggt/vggt/heads/head_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dedfcf1180a653dddc99623e60df625e5897489
--- /dev/null
+++ b/models/vggt/vggt/heads/head_act.py
@@ -0,0 +1,125 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn.functional as F
+
+
+def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
+ """
+ Activate pose parameters with specified activation functions.
+
+ Args:
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
+ trans_act: Activation type for translation component
+ quat_act: Activation type for quaternion component
+ fl_act: Activation type for focal length component
+
+ Returns:
+ Activated pose parameters tensor
+ """
+ T = pred_pose_enc[..., :3]
+ quat = pred_pose_enc[..., 3:7]
+ fl = pred_pose_enc[..., 7:] # or fov
+
+ T = base_pose_act(T, trans_act)
+ quat = base_pose_act(quat, quat_act)
+ fl = base_pose_act(fl, fl_act) # or fov
+
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
+
+ return pred_pose_enc
+
+
+def base_pose_act(pose_enc, act_type="linear"):
+ """
+ Apply basic activation function to pose parameters.
+
+ Args:
+ pose_enc: Tensor containing encoded pose parameters
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
+
+ Returns:
+ Activated pose parameters
+ """
+ if act_type == "linear":
+ return pose_enc
+ elif act_type == "inv_log":
+ return inverse_log_transform(pose_enc)
+ elif act_type == "exp":
+ return torch.exp(pose_enc)
+ elif act_type == "relu":
+ return F.relu(pose_enc)
+ else:
+ raise ValueError(f"Unknown act_type: {act_type}")
+
+
+def activate_head(out, activation="norm_exp", conf_activation="expp1"):
+ """
+ Process network output to extract 3D points and confidence values.
+
+ Args:
+ out: Network output tensor (B, C, H, W)
+ activation: Activation type for 3D points
+ conf_activation: Activation type for confidence values
+
+ Returns:
+ Tuple of (3D points tensor, confidence tensor)
+ """
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
+
+ # Split into xyz (first C-1 channels) and confidence (last channel)
+ xyz = fmap[:, :, :, :-1]
+ conf = fmap[:, :, :, -1]
+
+ if activation == "norm_exp":
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
+ xyz_normed = xyz / d
+ pts3d = xyz_normed * torch.expm1(d)
+ elif activation == "norm":
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
+ elif activation == "exp":
+ pts3d = torch.exp(xyz)
+ elif activation == "relu":
+ pts3d = F.relu(xyz)
+ elif activation == "inv_log":
+ pts3d = inverse_log_transform(xyz)
+ elif activation == "xy_inv_log":
+ xy, z = xyz.split([2, 1], dim=-1)
+ z = inverse_log_transform(z)
+ pts3d = torch.cat([xy * z, z], dim=-1)
+ elif activation == "sigmoid":
+ pts3d = torch.sigmoid(xyz)
+ elif activation == "linear":
+ pts3d = xyz
+ else:
+ raise ValueError(f"Unknown activation: {activation}")
+
+ if conf_activation == "expp1":
+ conf_out = 1 + conf.exp()
+ elif conf_activation == "expp0":
+ conf_out = conf.exp()
+ elif conf_activation == "sigmoid":
+ conf_out = torch.sigmoid(conf)
+ else:
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
+
+ return pts3d, conf_out
+
+
+def inverse_log_transform(y):
+ """
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
+
+ Args:
+ y: Input tensor
+
+ Returns:
+ Transformed tensor
+ """
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
diff --git a/models/vggt/vggt/heads/scale_head.py b/models/vggt/vggt/heads/scale_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fdf09b21583dcfd532151fcf02febe01d43c6e9
--- /dev/null
+++ b/models/vggt/vggt/heads/scale_head.py
@@ -0,0 +1,162 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from models.vggt.vggt.layers import Mlp
+from models.vggt.vggt.layers.block import Block
+from models.vggt.vggt.heads.head_act import activate_pose
+
+
+class ScaleHead(nn.Module):
+ """
+ ScaleHead predicts camera parameters from token representations using iterative refinement.
+
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
+ """
+
+ def __init__(
+ self,
+ dim_in: int = 2048,
+ trunk_depth: int = 4,
+ pose_encoding_type: str = "absT_quaR_FoV",
+ num_heads: int = 16,
+ mlp_ratio: int = 4,
+ init_values: float = 0.01,
+ trans_act: str = "linear",
+ quat_act: str = "linear",
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
+ ):
+ super().__init__()
+
+ self.target_dim = 2
+
+ self.trans_act = trans_act
+ self.quat_act = quat_act
+ self.fl_act = fl_act
+ self.trunk_depth = trunk_depth
+
+ # Build the trunk using a sequence of transformer blocks.
+ self.trunk = nn.Sequential(
+ *[
+ Block(
+ dim=dim_in,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ init_values=init_values,
+ )
+ for _ in range(trunk_depth)
+ ]
+ )
+
+ # Normalizations for camera token and trunk output.
+ self.token_norm = nn.LayerNorm(dim_in)
+ self.trunk_norm = nn.LayerNorm(dim_in)
+
+ # Learnable empty camera pose token.
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
+
+ # Module for producing modulation parameters: shift, scale, and a gate.
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
+
+ # Adaptive layer normalization without affine parameters.
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
+ self.pose_branch = Mlp(
+ in_features=dim_in,
+ hidden_features=dim_in // 2,
+ out_features=self.target_dim,
+ drop=0,
+ )
+
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
+ """
+ Forward pass to predict camera parameters.
+
+ Args:
+ aggregated_tokens_list (list): List of token tensors from the network;
+ the last tensor is used for prediction.
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
+
+ Returns:
+ list: A list of predicted camera encodings (post-activation) from each iteration.
+ """
+ # Use tokens from the last block for camera prediction.
+ tokens = aggregated_tokens_list[-1]
+
+ # Extract the camera tokens
+ pose_tokens = tokens[:, :, 5]
+ pose_tokens = self.token_norm(pose_tokens)
+
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
+ return pred_pose_enc_list
+
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
+ """
+ Iteratively refine camera pose predictions.
+
+ Args:
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
+ num_iterations (int): Number of refinement iterations.
+
+ Returns:
+ list: List of activated camera encodings from each iteration.
+ """
+ B, S, C = pose_tokens.shape # S is expected to be 1.
+ pred_pose_enc = None
+ pred_pose_enc_list = []
+
+ for _ in range(num_iterations):
+ # Use a learned empty pose for the first iteration.
+ if pred_pose_enc is None:
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
+ else:
+ # Detach the previous prediction to avoid backprop through time.
+ pred_pose_enc = pred_pose_enc.detach()
+ module_input = self.embed_pose(pred_pose_enc)
+
+ # Generate modulation parameters and split them into shift, scale, and gate components.
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
+
+ # Adaptive layer normalization and modulation.
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
+
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
+ # Compute the delta update for the pose encoding.
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
+
+ if pred_pose_enc is None:
+ pred_pose_enc = pred_pose_enc_delta
+ else:
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
+
+ # Apply final activation functions for translation, quaternion, and field-of-view.
+ activated_pose = activate_pose(
+ pred_pose_enc,
+ trans_act=self.trans_act,
+ quat_act=self.quat_act,
+ fl_act=self.fl_act,
+ )
+ activated_pose_proc = activated_pose.clone()
+ activated_pose_proc[...,:1] = activated_pose_proc[...,:1].clamp(min=1e-5, max=1e3)
+ activated_pose_proc[...,1:] = activated_pose_proc[...,1:]*1e-2
+ pred_pose_enc_list.append(activated_pose_proc)
+
+ return pred_pose_enc_list
+
+
+def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
+ """
+ Modulate the input tensor using scaling and shifting parameters.
+ """
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
+ return x * (1 + scale) + shift
diff --git a/models/vggt/vggt/heads/track_head.py b/models/vggt/vggt/heads/track_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ec7199bd185060989c236997f93b93f4fc77825
--- /dev/null
+++ b/models/vggt/vggt/heads/track_head.py
@@ -0,0 +1,108 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch.nn as nn
+from .dpt_head import DPTHead
+from .track_modules.base_track_predictor import BaseTrackerPredictor
+
+
+class TrackHead(nn.Module):
+ """
+ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
+ The tracking is performed iteratively, refining predictions over multiple iterations.
+ """
+
+ def __init__(
+ self,
+ dim_in,
+ patch_size=14,
+ features=128,
+ iters=4,
+ predict_conf=True,
+ stride=2,
+ corr_levels=7,
+ corr_radius=4,
+ hidden_size=384,
+ ):
+ """
+ Initialize the TrackHead module.
+
+ Args:
+ dim_in (int): Input dimension of tokens from the backbone.
+ patch_size (int): Size of image patches used in the vision transformer.
+ features (int): Number of feature channels in the feature extractor output.
+ iters (int): Number of refinement iterations for tracking predictions.
+ predict_conf (bool): Whether to predict confidence scores for tracked points.
+ stride (int): Stride value for the tracker predictor.
+ corr_levels (int): Number of correlation pyramid levels
+ corr_radius (int): Radius for correlation computation, controlling the search area.
+ hidden_size (int): Size of hidden layers in the tracker network.
+ """
+ super().__init__()
+
+ self.patch_size = patch_size
+
+ # Feature extractor based on DPT architecture
+ # Processes tokens into feature maps for tracking
+ self.feature_extractor = DPTHead(
+ dim_in=dim_in,
+ patch_size=patch_size,
+ features=features,
+ feature_only=True, # Only output features, no activation
+ down_ratio=2, # Reduces spatial dimensions by factor of 2
+ pos_embed=False,
+ )
+
+ # Tracker module that predicts point trajectories
+ # Takes feature maps and predicts coordinates and visibility
+ self.tracker = BaseTrackerPredictor(
+ latent_dim=features, # Match the output_dim of feature extractor
+ predict_conf=predict_conf,
+ stride=stride,
+ corr_levels=corr_levels,
+ corr_radius=corr_radius,
+ hidden_size=hidden_size,
+ )
+
+ self.iters = iters
+
+ def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
+ """
+ Forward pass of the TrackHead.
+
+ Args:
+ aggregated_tokens_list (list): List of aggregated tokens from the backbone.
+ images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
+ B = batch size, S = sequence length.
+ patch_start_idx (int): Starting index for patch tokens.
+ query_points (torch.Tensor, optional): Initial query points to track.
+ If None, points are initialized by the tracker.
+ iters (int, optional): Number of refinement iterations. If None, uses self.iters.
+
+ Returns:
+ tuple:
+ - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
+ - vis_scores (torch.Tensor): Visibility scores for tracked points.
+ - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
+ """
+ B, S, _, H, W = images.shape
+
+ # Extract features from tokens
+ # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
+ feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
+
+ # Use default iterations if not specified
+ if iters is None:
+ iters = self.iters
+
+ # Perform tracking using the extracted features
+ coord_preds, vis_scores, conf_scores = self.tracker(
+ query_points=query_points,
+ fmaps=feature_maps,
+ iters=iters,
+ )
+
+ return coord_preds, vis_scores, conf_scores
diff --git a/models/vggt/vggt/heads/track_modules/__init__.py b/models/vggt/vggt/heads/track_modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa
--- /dev/null
+++ b/models/vggt/vggt/heads/track_modules/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/models/vggt/vggt/heads/track_modules/base_track_predictor.py b/models/vggt/vggt/heads/track_modules/base_track_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ce8ec4b66fff236e015d1bcaf85c8237a52be7a
--- /dev/null
+++ b/models/vggt/vggt/heads/track_modules/base_track_predictor.py
@@ -0,0 +1,209 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+
+from .blocks import EfficientUpdateFormer, CorrBlock
+from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
+from .modules import Mlp
+
+
+class BaseTrackerPredictor(nn.Module):
+ def __init__(
+ self,
+ stride=1,
+ corr_levels=5,
+ corr_radius=4,
+ latent_dim=128,
+ hidden_size=384,
+ use_spaceatt=True,
+ depth=6,
+ max_scale=518,
+ predict_conf=True,
+ ):
+ super(BaseTrackerPredictor, self).__init__()
+ """
+ The base template to create a track predictor
+
+ Modified from https://github.com/facebookresearch/co-tracker/
+ and https://github.com/facebookresearch/vggsfm
+ """
+
+ self.stride = stride
+ self.latent_dim = latent_dim
+ self.corr_levels = corr_levels
+ self.corr_radius = corr_radius
+ self.hidden_size = hidden_size
+ self.max_scale = max_scale
+ self.predict_conf = predict_conf
+
+ self.flows_emb_dim = latent_dim // 2
+
+ self.corr_mlp = Mlp(
+ in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
+ hidden_features=self.hidden_size,
+ out_features=self.latent_dim,
+ )
+
+ self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
+
+ self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
+
+ space_depth = depth if use_spaceatt else 0
+ time_depth = depth
+
+ self.updateformer = EfficientUpdateFormer(
+ space_depth=space_depth,
+ time_depth=time_depth,
+ input_dim=self.transformer_dim,
+ hidden_size=self.hidden_size,
+ output_dim=self.latent_dim + 2,
+ mlp_ratio=4.0,
+ add_space_attn=use_spaceatt,
+ )
+
+ self.fmap_norm = nn.LayerNorm(self.latent_dim)
+ self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
+
+ # A linear layer to update track feats at each iteration
+ self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
+
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
+
+ if predict_conf:
+ self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
+
+ def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
+ """
+ query_points: B x N x 2, the number of batches, tracks, and xy
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
+ note HH and WW is the size of feature maps instead of original images
+ """
+ B, N, D = query_points.shape
+ B, S, C, HH, WW = fmaps.shape
+
+ assert D == 2, "Input points must be 2D coordinates"
+
+ # apply a layernorm to fmaps here
+ fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
+ fmaps = fmaps.permute(0, 1, 4, 2, 3)
+
+ # Scale the input query_points because we may downsample the images
+ # by down_ratio or self.stride
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
+ # its query_points should be query_points/4
+ if down_ratio > 1:
+ query_points = query_points / float(down_ratio)
+
+ query_points = query_points / float(self.stride)
+
+ # Init with coords as the query points
+ # It means the search will start from the position of query points at the reference frames
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
+
+ # Sample/extract the features of the query points in the query frame
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
+
+ # init track feats by query feats
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
+ # back up the init coords
+ coords_backup = coords.clone()
+
+ fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
+
+ coord_preds = []
+
+ # Iterative Refinement
+ for _ in range(iters):
+ # Detach the gradients from the last iteration
+ # (in my experience, not very important for performance)
+ coords = coords.detach()
+
+ fcorrs = fcorr_fn.corr_sample(track_feats, coords)
+
+ corr_dim = fcorrs.shape[3]
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
+ fcorrs_ = self.corr_mlp(fcorrs_)
+
+ # Movement of current coords relative to query points
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
+
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
+
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
+ flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
+
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
+
+ # Concatenate them as the input for the transformers
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
+
+ # 2D positional embed
+ # TODO: this can be much simplified
+ pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
+ sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
+
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
+
+ x = transformer_input + sampled_pos_emb
+
+ # Add the query ref token to the track feats
+ query_ref_token = torch.cat(
+ [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
+ )
+ x = x + query_ref_token.to(x.device).to(x.dtype)
+
+ # B, N, S, C
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
+
+ # Compute the delta coordinates and delta track features
+ delta, _ = self.updateformer(x)
+
+ # BN, S, C
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
+ delta_coords_ = delta[:, :, :2]
+ delta_feats_ = delta[:, :, 2:]
+
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
+
+ # Update the track features
+ track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
+
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
+
+ # B x S x N x 2
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
+
+ # Force coord0 as query
+ # because we assume the query points should not be changed
+ coords[:, 0] = coords_backup[:, 0]
+
+ # The predicted tracks are in the original image scale
+ if down_ratio > 1:
+ coord_preds.append(coords * self.stride * down_ratio)
+ else:
+ coord_preds.append(coords * self.stride)
+
+ # B, S, N
+ vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
+ if apply_sigmoid:
+ vis_e = torch.sigmoid(vis_e)
+
+ if self.predict_conf:
+ conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
+ if apply_sigmoid:
+ conf_e = torch.sigmoid(conf_e)
+ else:
+ conf_e = None
+
+ if return_feat:
+ return coord_preds, vis_e, track_feats, query_track_feat, conf_e
+ else:
+ return coord_preds, vis_e, conf_e
diff --git a/models/vggt/vggt/heads/track_modules/blocks.py b/models/vggt/vggt/heads/track_modules/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e7763f4fd8f515662421db192594380dbb574e5
--- /dev/null
+++ b/models/vggt/vggt/heads/track_modules/blocks.py
@@ -0,0 +1,246 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+# Modified from https://github.com/facebookresearch/co-tracker/
+
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .utils import bilinear_sampler
+from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
+
+
+class EfficientUpdateFormer(nn.Module):
+ """
+ Transformer model that updates track estimates.
+ """
+
+ def __init__(
+ self,
+ space_depth=6,
+ time_depth=6,
+ input_dim=320,
+ hidden_size=384,
+ num_heads=8,
+ output_dim=130,
+ mlp_ratio=4.0,
+ add_space_attn=True,
+ num_virtual_tracks=64,
+ ):
+ super().__init__()
+
+ self.out_channels = 2
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.add_space_attn = add_space_attn
+
+ # Add input LayerNorm before linear projection
+ self.input_norm = nn.LayerNorm(input_dim)
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
+
+ # Add output LayerNorm before final projection
+ self.output_norm = nn.LayerNorm(hidden_size)
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
+ self.num_virtual_tracks = num_virtual_tracks
+
+ if self.add_space_attn:
+ self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
+ else:
+ self.virual_tracks = None
+
+ self.time_blocks = nn.ModuleList(
+ [
+ AttnBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ attn_class=nn.MultiheadAttention,
+ )
+ for _ in range(time_depth)
+ ]
+ )
+
+ if add_space_attn:
+ self.space_virtual_blocks = nn.ModuleList(
+ [
+ AttnBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ attn_class=nn.MultiheadAttention,
+ )
+ for _ in range(space_depth)
+ ]
+ )
+ self.space_point2virtual_blocks = nn.ModuleList(
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
+ )
+ self.space_virtual2point_blocks = nn.ModuleList(
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
+ )
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
+
+ self.apply(_basic_init)
+
+ def forward(self, input_tensor, mask=None):
+ # Apply input LayerNorm
+ input_tensor = self.input_norm(input_tensor)
+ tokens = self.input_transform(input_tensor)
+
+ init_tokens = tokens
+
+ B, _, T, _ = tokens.shape
+
+ if self.add_space_attn:
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
+
+ _, N, _, _ = tokens.shape
+
+ j = 0
+ for i in range(len(self.time_blocks)):
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
+
+ time_tokens = self.time_blocks[i](time_tokens)
+
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
+ if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
+ space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
+
+ virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
+ point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
+
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
+ tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
+ j += 1
+
+ if self.add_space_attn:
+ tokens = tokens[:, : N - self.num_virtual_tracks]
+
+ tokens = tokens + init_tokens
+
+ # Apply output LayerNorm before final projection
+ tokens = self.output_norm(tokens)
+ flow = self.flow_head(tokens)
+
+ return flow, None
+
+
+class CorrBlock:
+ def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
+ """
+ Build a pyramid of feature maps from the input.
+
+ fmaps: Tensor (B, S, C, H, W)
+ num_levels: number of pyramid levels (each downsampled by factor 2)
+ radius: search radius for sampling correlation
+ multiple_track_feats: if True, split the target features per pyramid level
+ padding_mode: passed to grid_sample / bilinear_sampler
+ """
+ B, S, C, H, W = fmaps.shape
+ self.S, self.C, self.H, self.W = S, C, H, W
+ self.num_levels = num_levels
+ self.radius = radius
+ self.padding_mode = padding_mode
+ self.multiple_track_feats = multiple_track_feats
+
+ # Build pyramid: each level is half the spatial resolution of the previous
+ self.fmaps_pyramid = [fmaps] # level 0 is full resolution
+ current_fmaps = fmaps
+ for i in range(num_levels - 1):
+ B, S, C, H, W = current_fmaps.shape
+ # Merge batch & sequence dimensions
+ current_fmaps = current_fmaps.reshape(B * S, C, H, W)
+ # Avg pool down by factor 2
+ current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
+ _, _, H_new, W_new = current_fmaps.shape
+ current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
+ self.fmaps_pyramid.append(current_fmaps)
+
+ # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
+ # This grid is added to the (scaled) coordinate centroids.
+ r = self.radius
+ dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
+ dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
+ # delta: for every (dy,dx) displacement (i.e. Δx, Δy)
+ self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
+
+ def corr_sample(self, targets, coords):
+ """
+ Instead of storing the entire correlation pyramid, we compute each level's correlation
+ volume, sample it immediately, then discard it. This saves GPU memory.
+
+ Args:
+ targets: Tensor (B, S, N, C) — features for the current targets.
+ coords: Tensor (B, S, N, 2) — coordinates at full resolution.
+
+ Returns:
+ Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
+ """
+ B, S, N, C = targets.shape
+
+ # If you have multiple track features, split them per level.
+ if self.multiple_track_feats:
+ targets_split = torch.split(targets, C // self.num_levels, dim=-1)
+
+ out_pyramid = []
+ for i, fmaps in enumerate(self.fmaps_pyramid):
+ # Get current spatial resolution H, W for this pyramid level.
+ B, S, C, H, W = fmaps.shape
+ # Reshape feature maps for correlation computation:
+ # fmap2s: (B, S, C, H*W)
+ fmap2s = fmaps.view(B, S, C, H * W)
+ # Choose appropriate target features.
+ fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
+
+ # Compute correlation directly
+ corrs = compute_corr_level(fmap1, fmap2s, C)
+ corrs = corrs.view(B, S, N, H, W)
+
+ # Prepare sampling grid:
+ # Scale down the coordinates for the current level.
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
+ # Make sure our precomputed delta grid is on the same device/dtype.
+ delta_lvl = self.delta.to(coords.device).to(coords.dtype)
+ # Now the grid for grid_sample is:
+ # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
+ coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
+
+ # Sample from the correlation volume using bilinear interpolation.
+ # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
+ corrs_sampled = bilinear_sampler(
+ corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
+ )
+ # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
+ corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
+ out_pyramid.append(corrs_sampled)
+
+ # Concatenate all levels along the last dimension.
+ out = torch.cat(out_pyramid, dim=-1).contiguous()
+ return out
+
+
+def compute_corr_level(fmap1, fmap2s, C):
+ # fmap1: (B, S, N, C)
+ # fmap2s: (B, S, C, H*W)
+ corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
+ corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
+ return corrs / math.sqrt(C)
diff --git a/models/vggt/vggt/heads/track_modules/modules.py b/models/vggt/vggt/heads/track_modules/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b090ddc4a9db01c8dd3564f9053e1ca9cdde93a
--- /dev/null
+++ b/models/vggt/vggt/heads/track_modules/modules.py
@@ -0,0 +1,218 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from typing import Callable
+import collections
+from torch import Tensor
+from itertools import repeat
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ return val if exists(val) else d
+
+
+to_2tuple = _ntuple(2)
+
+
+class ResidualBlock(nn.Module):
+ """
+ ResidualBlock: construct a block of two conv layers with residual connections
+ """
+
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_planes,
+ planes,
+ kernel_size=kernel_size,
+ padding=1,
+ stride=stride,
+ padding_mode="zeros",
+ )
+ self.conv2 = nn.Conv2d(
+ planes,
+ planes,
+ kernel_size=kernel_size,
+ padding=1,
+ padding_mode="zeros",
+ )
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == "group":
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == "batch":
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == "instance":
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == "none":
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not stride == 1:
+ self.norm3 = nn.Sequential()
+ else:
+ raise NotImplementedError
+
+ if stride == 1:
+ self.downsample = None
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
+ self.norm3,
+ )
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x + y)
+
+
+class Mlp(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.0,
+ use_conv=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class AttnBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
+ mlp_ratio=4.0,
+ **block_kwargs
+ ):
+ """
+ Self attention block
+ """
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(hidden_size)
+ self.norm2 = nn.LayerNorm(hidden_size)
+
+ self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
+
+ def forward(self, x, mask=None):
+ # Prepare the mask for PyTorch's attention (it expects a different format)
+ # attn_mask = mask if mask is not None else None
+ # Normalize before attention
+ x = self.norm1(x)
+
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
+
+ attn_output, _ = self.attn(x, x, x)
+
+ # Add & Norm
+ x = x + attn_output
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class CrossAttnBlock(nn.Module):
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
+ """
+ Cross attention block
+ """
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(hidden_size)
+ self.norm_context = nn.LayerNorm(hidden_size)
+ self.norm2 = nn.LayerNorm(hidden_size)
+
+ self.cross_attn = nn.MultiheadAttention(
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
+ )
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
+
+ def forward(self, x, context, mask=None):
+ # Normalize inputs
+ x = self.norm1(x)
+ context = self.norm_context(context)
+
+ # Apply cross attention
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
+
+ # Add & Norm
+ x = x + attn_output
+ x = x + self.mlp(self.norm2(x))
+ return x
diff --git a/models/vggt/vggt/heads/track_modules/utils.py b/models/vggt/vggt/heads/track_modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..51d01d39cdc10388a04dab5db7cf409b31bde766
--- /dev/null
+++ b/models/vggt/vggt/heads/track_modules/utils.py
@@ -0,0 +1,226 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from https://github.com/facebookresearch/vggsfm
+# and https://github.com/facebookresearch/co-tracker/tree/main
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Optional, Tuple, Union
+
+
+def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
+ """
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid_size: The grid size.
+ Returns:
+ - pos_embed: The generated 2D positional embedding.
+ """
+ if isinstance(grid_size, tuple):
+ grid_size_h, grid_size_w = grid_size
+ else:
+ grid_size_h = grid_size_w = grid_size
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
+ grid = torch.stack(grid, dim=0)
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if return_grid:
+ return (
+ pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
+ grid,
+ )
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid: The grid to generate the embedding from.
+
+ Returns:
+ - emb: The generated 2D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
+ """
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - pos: The position to generate the embedding from.
+
+ Returns:
+ - emb: The generated 1D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb[None].float()
+
+
+def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
+
+ Args:
+ - xy: The coordinates to generate the embedding from.
+ - C: The size of the embedding.
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
+
+ Returns:
+ - pe: The generated 2D positional embedding.
+ """
+ B, N, D = xy.shape
+ assert D == 2
+
+ x = xy[:, :, 0:1]
+ y = xy[:, :, 1:2]
+ div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
+
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
+
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
+
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
+ if cat_coords:
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
+ return pe
+
+
+def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
+ r"""Sample a tensor using bilinear interpolation
+
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
+ convention.
+
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
+ :math:`B` is the batch size, :math:`C` is the number of channels,
+ :math:`H` is the height of the image, and :math:`W` is the width of the
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
+
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
+ that in this case the order of the components is slightly different
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
+
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
+ left-most image pixel :math:`W-1` to the center of the right-most
+ pixel.
+
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
+ the left-most pixel :math:`W` to the right edge of the right-most
+ pixel.
+
+ Similar conventions apply to the :math:`y` for the range
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
+ :math:`[0,T-1]` and :math:`[0,T]`.
+
+ Args:
+ input (Tensor): batch of input images.
+ coords (Tensor): batch of coordinates.
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
+
+ Returns:
+ Tensor: sampled points.
+ """
+ coords = coords.detach().clone()
+ ############################################################
+ # IMPORTANT:
+ coords = coords.to(input.device).to(input.dtype)
+ ############################################################
+
+ sizes = input.shape[2:]
+
+ assert len(sizes) in [2, 3]
+
+ if len(sizes) == 3:
+ # t x y -> x y t to match dimensions T H W in grid_sample
+ coords = coords[..., [1, 2, 0]]
+
+ if align_corners:
+ scale = torch.tensor(
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
+ )
+ else:
+ scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
+
+ coords.mul_(scale) # coords = coords * scale
+ coords.sub_(1) # coords = coords - 1
+
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
+
+
+def sample_features4d(input, coords):
+ r"""Sample spatial features
+
+ `sample_features4d(input, coords)` samples the spatial features
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
+
+ The field is sampled at coordinates :attr:`coords` using bilinear
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
+
+ The output tensor has one feature per point, and has shape :math:`(B,
+ R, C)`.
+
+ Args:
+ input (Tensor): spatial features.
+ coords (Tensor): points.
+
+ Returns:
+ Tensor: sampled features.
+ """
+
+ B, _, _, _ = input.shape
+
+ # B R 2 -> B R 1 2
+ coords = coords.unsqueeze(2)
+
+ # B C R 1
+ feats = bilinear_sampler(input, coords)
+
+ return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
diff --git a/models/vggt/vggt/heads/utils.py b/models/vggt/vggt/heads/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..533fc8ae67a75cd0a94d5ca96dc5a0513446c64f
--- /dev/null
+++ b/models/vggt/vggt/heads/utils.py
@@ -0,0 +1,109 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+
+def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
+ """
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
+
+ Args:
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
+ embed_dim: Output channel dimension for embeddings
+
+ Returns:
+ Tensor of shape (H, W, embed_dim) with positional embeddings
+ """
+ H, W, grid_dim = pos_grid.shape
+ assert grid_dim == 2
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
+
+ # Process x and y coordinates separately
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
+
+ # Combine and reshape
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
+
+ return emb.view(H, W, embed_dim) # [H, W, D]
+
+
+def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
+ """
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - pos: The position to generate the embedding from.
+
+ Returns:
+ - emb: The generated 1D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+ device = pos.device
+ omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / omega_0**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb.float()
+
+
+# Inspired by https://github.com/microsoft/moge
+
+
+def create_uv_grid(
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
+) -> torch.Tensor:
+ """
+ Create a normalized UV grid of shape (width, height, 2).
+
+ The grid spans horizontally and vertically according to an aspect ratio,
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
+
+ Args:
+ width (int): Number of points horizontally.
+ height (int): Number of points vertically.
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
+ device (torch.device, optional): Device on which the tensor is created.
+
+ Returns:
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
+ """
+ # Derive aspect ratio if not explicitly provided
+ if aspect_ratio is None:
+ aspect_ratio = float(width) / float(height)
+
+ # Compute normalized spans for X and Y
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
+ span_x = aspect_ratio / diag_factor
+ span_y = 1.0 / diag_factor
+
+ # Establish the linspace boundaries
+ left_x = -span_x * (width - 1) / width
+ right_x = span_x * (width - 1) / width
+ top_y = -span_y * (height - 1) / height
+ bottom_y = span_y * (height - 1) / height
+
+ # Generate 1D coordinates
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
+
+ # Create 2D meshgrid (width x height) and stack into UV
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
+ uv_grid = torch.stack((uu, vv), dim=-1)
+
+ return uv_grid
diff --git a/models/vggt/vggt/layers/__init__.py b/models/vggt/vggt/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1
--- /dev/null
+++ b/models/vggt/vggt/layers/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/models/vggt/vggt/layers/attention.py b/models/vggt/vggt/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab3089ce0c7493342ef0cf373dfe74a1df2b9563
--- /dev/null
+++ b/models/vggt/vggt/layers/attention.py
@@ -0,0 +1,98 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+import os
+import warnings
+
+from torch import Tensor
+from torch import nn
+import torch.nn.functional as F
+
+XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = nn.LayerNorm,
+ qk_norm: bool = False,
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
+ rope=None,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim**-0.5
+ self.fused_attn = fused_attn
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.rope = rope
+
+ def forward(self, x: Tensor, pos=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0)
+ q, k = self.q_norm(q), self.k_norm(k)
+
+ if self.rope is not None:
+ q = self.rope(q, pos)
+ k = self.rope(k, pos)
+
+ if self.fused_attn:
+ x = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ dropout_p=self.attn_drop.p if self.training else 0.0,
+ )
+ else:
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
+ assert pos is None
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/models/vggt/vggt/layers/block.py b/models/vggt/vggt/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f89e4da7121effca97151d1d8429586e422346e
--- /dev/null
+++ b/models/vggt/vggt/layers/block.py
@@ -0,0 +1,259 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+import os
+from typing import Callable, List, Any, Tuple, Dict
+import warnings
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ qk_norm: bool = False,
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
+ rope=None,
+ ) -> None:
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ qk_norm=qk_norm,
+ fused_attn=fused_attn,
+ rope=rope,
+ )
+
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor, pos=None) -> Tensor:
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ pos=pos,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x, pos=pos)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+ pos=None,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ if pos is not None:
+ # if necessary, apply rope to the subset
+ pos = pos[brange]
+ residual = residual_func(x_subset, pos=pos)
+ else:
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/models/vggt/vggt/layers/drop_path.py b/models/vggt/vggt/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5
--- /dev/null
+++ b/models/vggt/vggt/layers/drop_path.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/models/vggt/vggt/layers/layer_scale.py b/models/vggt/vggt/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386
--- /dev/null
+++ b/models/vggt/vggt/layers/layer_scale.py
@@ -0,0 +1,27 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/models/vggt/vggt/layers/mlp.py b/models/vggt/vggt/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e
--- /dev/null
+++ b/models/vggt/vggt/layers/mlp.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/models/vggt/vggt/layers/patch_embed.py b/models/vggt/vggt/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339
--- /dev/null
+++ b/models/vggt/vggt/layers/patch_embed.py
@@ -0,0 +1,88 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/models/vggt/vggt/layers/rope.py b/models/vggt/vggt/layers/rope.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d5d33304e55dbd05687bd86752a47a80e5f82df
--- /dev/null
+++ b/models/vggt/vggt/layers/rope.py
@@ -0,0 +1,188 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+
+# Implementation of 2D Rotary Position Embeddings (RoPE).
+
+# This module provides a clean implementation of 2D Rotary Position Embeddings,
+# which extends the original RoPE concept to handle 2D spatial positions.
+
+# Inspired by:
+# https://github.com/meta-llama/codellama/blob/main/llama/model.py
+# https://github.com/naver-ai/rope-vit
+
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Dict, Tuple
+
+
+class PositionGetter:
+ """Generates and caches 2D spatial positions for patches in a grid.
+
+ This class efficiently manages the generation of spatial coordinates for patches
+ in a 2D grid, caching results to avoid redundant computations.
+
+ Attributes:
+ position_cache: Dictionary storing precomputed position tensors for different
+ grid dimensions.
+ """
+
+ def __init__(self):
+ """Initializes the position generator with an empty cache."""
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
+
+ def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
+ """Generates spatial positions for a batch of patches.
+
+ Args:
+ batch_size: Number of samples in the batch.
+ height: Height of the grid in patches.
+ width: Width of the grid in patches.
+ device: Target device for the position tensor.
+
+ Returns:
+ Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
+ for each position in the grid, repeated for each batch item.
+ """
+ if (height, width) not in self.position_cache:
+ y_coords = torch.arange(height, device=device)
+ x_coords = torch.arange(width, device=device)
+ positions = torch.cartesian_prod(y_coords, x_coords)
+ self.position_cache[height, width] = positions
+
+ cached_positions = self.position_cache[height, width]
+ return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
+
+
+class RotaryPositionEmbedding2D(nn.Module):
+ """2D Rotary Position Embedding implementation.
+
+ This module applies rotary position embeddings to input tokens based on their
+ 2D spatial positions. It handles the position-dependent rotation of features
+ separately for vertical and horizontal dimensions.
+
+ Args:
+ frequency: Base frequency for the position embeddings. Default: 100.0
+ scaling_factor: Scaling factor for frequency computation. Default: 1.0
+
+ Attributes:
+ base_frequency: Base frequency for computing position embeddings.
+ scaling_factor: Factor to scale the computed frequencies.
+ frequency_cache: Cache for storing precomputed frequency components.
+ """
+
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
+ """Initializes the 2D RoPE module."""
+ super().__init__()
+ self.base_frequency = frequency
+ self.scaling_factor = scaling_factor
+ self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
+
+ def _compute_frequency_components(
+ self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Computes frequency components for rotary embeddings.
+
+ Args:
+ dim: Feature dimension (must be even).
+ seq_len: Maximum sequence length.
+ device: Target device for computations.
+ dtype: Data type for the computed tensors.
+
+ Returns:
+ Tuple of (cosine, sine) tensors for frequency components.
+ """
+ cache_key = (dim, seq_len, device, dtype)
+ if cache_key not in self.frequency_cache:
+ # Compute frequency bands
+ exponents = torch.arange(0, dim, 2, device=device).float() / dim
+ inv_freq = 1.0 / (self.base_frequency**exponents)
+
+ # Generate position-dependent frequencies
+ positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
+
+ # Compute and cache frequency components
+ angles = angles.to(dtype)
+ angles = torch.cat((angles, angles), dim=-1)
+ cos_components = angles.cos().to(dtype)
+ sin_components = angles.sin().to(dtype)
+ self.frequency_cache[cache_key] = (cos_components, sin_components)
+
+ return self.frequency_cache[cache_key]
+
+ @staticmethod
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
+ """Performs feature rotation by splitting and recombining feature dimensions.
+
+ Args:
+ x: Input tensor to rotate.
+
+ Returns:
+ Rotated feature tensor.
+ """
+ feature_dim = x.shape[-1]
+ x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def _apply_1d_rope(
+ self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
+ ) -> torch.Tensor:
+ """Applies 1D rotary position embeddings along one dimension.
+
+ Args:
+ tokens: Input token features.
+ positions: Position indices.
+ cos_comp: Cosine components for rotation.
+ sin_comp: Sine components for rotation.
+
+ Returns:
+ Tokens with applied rotary position embeddings.
+ """
+ # Embed positions with frequency components
+ cos = F.embedding(positions, cos_comp)[:, None, :, :]
+ sin = F.embedding(positions, sin_comp)[:, None, :, :]
+
+ # Apply rotation
+ return (tokens * cos) + (self._rotate_features(tokens) * sin)
+
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
+ """Applies 2D rotary position embeddings to input tokens.
+
+ Args:
+ tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
+ The feature dimension (dim) must be divisible by 4.
+ positions: Position tensor of shape (batch_size, n_tokens, 2) containing
+ the y and x coordinates for each token.
+
+ Returns:
+ Tensor of same shape as input with applied 2D rotary position embeddings.
+
+ Raises:
+ AssertionError: If input dimensions are invalid or positions are malformed.
+ """
+ # Validate inputs
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
+ assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
+
+ # Compute feature dimension for each spatial direction
+ feature_dim = tokens.size(-1) // 2
+
+ # Get frequency components
+ max_position = int(positions.max()) + 1
+ cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
+
+ # Split features for vertical and horizontal processing
+ vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
+
+ # Apply RoPE separately for each dimension
+ vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
+ horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
+
+ # Combine processed features
+ return torch.cat((vertical_features, horizontal_features), dim=-1)
diff --git a/models/vggt/vggt/layers/swiglu_ffn.py b/models/vggt/vggt/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..54fe8e90b7bedf6fbdbf09c6215844e3cc63f857
--- /dev/null
+++ b/models/vggt/vggt/layers/swiglu_ffn.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Callable, Optional
+import warnings
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+# try:
+# if XFORMERS_ENABLED:
+# from xformers.ops import SwiGLU
+
+# XFORMERS_AVAILABLE = True
+# warnings.warn("xFormers is available (SwiGLU)")
+# else:
+# warnings.warn("xFormers is disabled (SwiGLU)")
+# raise ImportError
+# except ImportError:
+SwiGLU = SwiGLUFFN
+XFORMERS_AVAILABLE = False
+
+# warnings.warn("xFormers is not available (SwiGLU)")
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/models/vggt/vggt/layers/vision_transformer.py b/models/vggt/vggt/layers/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..120cbe6c26650d212e50aefc497669abdc937467
--- /dev/null
+++ b/models/vggt/vggt/layers/vision_transformer.py
@@ -0,0 +1,407 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+from torch.nn.init import trunc_normal_
+from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ qk_norm=False,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ # tricky but makes it work
+ self.use_checkpoint = False
+ #
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
+ assert N == M * M
+ kwargs = {}
+ if self.interpolate_offset:
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
+ sx = float(w0 + self.interpolate_offset) / M
+ sy = float(h0 + self.interpolate_offset) / M
+ kwargs["scale_factor"] = (sx, sy)
+ else:
+ # Simply specify an output size instead of a scale factor
+ kwargs["size"] = (w0, h0)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ **kwargs,
+ )
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
+ else:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
+ else:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=True, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
diff --git a/models/vggt/vggt/models/aggregator.py b/models/vggt/vggt/models/aggregator.py
new file mode 100644
index 0000000000000000000000000000000000000000..3218af629d5c3e6026a1d6c33e442f291f5b94d3
--- /dev/null
+++ b/models/vggt/vggt/models/aggregator.py
@@ -0,0 +1,338 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Optional, Tuple, Union, List, Dict, Any
+
+from models.vggt.vggt.layers import PatchEmbed
+from models.vggt.vggt.layers.block import Block
+from models.vggt.vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
+from models.vggt.vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
+from torch.utils.checkpoint import checkpoint
+
+logger = logging.getLogger(__name__)
+
+_RESNET_MEAN = [0.485, 0.456, 0.406]
+_RESNET_STD = [0.229, 0.224, 0.225]
+
+
+class Aggregator(nn.Module):
+ """
+ The Aggregator applies alternating-attention over input frames,
+ as described in VGGT: Visual Geometry Grounded Transformer.
+
+
+ Args:
+ img_size (int): Image size in pixels.
+ patch_size (int): Size of each patch for PatchEmbed.
+ embed_dim (int): Dimension of the token embeddings.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
+ num_register_tokens (int): Number of register tokens.
+ block_fn (nn.Module): The block type used for attention (Block by default).
+ qkv_bias (bool): Whether to include bias in QKV projections.
+ proj_bias (bool): Whether to include bias in the output projection.
+ ffn_bias (bool): Whether to include bias in MLP layers.
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
+ aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
+ qk_norm (bool): Whether to apply QK normalization.
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
+ init_values (float): Init scale for layer scale.
+ """
+
+ def __init__(
+ self,
+ img_size=518,
+ patch_size=14,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4.0,
+ num_register_tokens=4,
+ block_fn=Block,
+ qkv_bias=True,
+ proj_bias=True,
+ ffn_bias=True,
+ patch_embed="dinov2_vitl14_reg",
+ aa_order=["frame", "global"],
+ aa_block_size=1,
+ qk_norm=True,
+ rope_freq=100,
+ init_values=0.01,
+ ):
+ super().__init__()
+
+ self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
+
+ # Initialize rotary position embedding if frequency > 0
+ self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
+ self.position_getter = PositionGetter() if self.rope is not None else None
+
+ self.frame_blocks = nn.ModuleList(
+ [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ rope=self.rope,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.global_blocks = nn.ModuleList(
+ [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ rope=self.rope,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.depth = depth
+ self.aa_order = aa_order
+ self.patch_size = patch_size
+ self.aa_block_size = aa_block_size
+
+ # Validate that depth is divisible by aa_block_size
+ if self.depth % self.aa_block_size != 0:
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
+
+ self.aa_block_num = self.depth // self.aa_block_size
+
+ # Note: We have two camera tokens, one for the first frame and one for the rest
+ # The same applies for register tokens
+ self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
+ self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
+
+ # The patch tokens start after the camera and register tokens
+ self.patch_start_idx = 1 + num_register_tokens
+
+ # Initialize parameters with small values
+ nn.init.normal_(self.camera_token, std=1e-6)
+ nn.init.normal_(self.register_token, std=1e-6)
+
+ # Register normalization constants as buffers
+ for name, value in (
+ ("_resnet_mean", _RESNET_MEAN),
+ ("_resnet_std", _RESNET_STD),
+ ):
+ self.register_buffer(
+ name,
+ torch.FloatTensor(value).view(1, 1, 3, 1, 1),
+ persistent=False,
+ )
+
+ def __build_patch_embed__(
+ self,
+ patch_embed,
+ img_size,
+ patch_size,
+ num_register_tokens,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ block_chunks=0,
+ init_values=1.0,
+ embed_dim=1024,
+ ):
+ """
+ Build the patch embed layer. If 'conv', we use a
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
+ """
+
+ if "conv" in patch_embed:
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
+ else:
+ vit_models = {
+ "dinov2_vitl14_reg": vit_large,
+ "dinov2_vitb14_reg": vit_base,
+ "dinov2_vits14_reg": vit_small,
+ "dinov2_vitg2_reg": vit_giant2,
+ }
+
+ self.patch_embed = vit_models[patch_embed](
+ img_size=img_size,
+ patch_size=patch_size,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ block_chunks=block_chunks,
+ init_values=init_values,
+ )
+
+ # Disable gradient updates for mask token
+ if hasattr(self.patch_embed, "mask_token"):
+ self.patch_embed.mask_token.requires_grad_(False)
+
+ def forward(
+ self,
+ images: torch.Tensor,
+ ) -> Tuple[List[torch.Tensor], int]:
+ """
+ Args:
+ images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
+
+ Returns:
+ (list[torch.Tensor], int):
+ The list of outputs from the attention blocks,
+ and the patch_start_idx indicating where patch tokens begin.
+ """
+ B, S, C_in, H, W = images.shape
+
+ if C_in != 3:
+ raise ValueError(f"Expected 3 input channels, got {C_in}")
+
+ # Normalize images and reshape for patch embed
+ images = (images - self._resnet_mean) / self._resnet_std
+
+ # Reshape to [B*S, C, H, W] for patch embedding
+ images = images.view(B * S, C_in, H, W)
+ patch_tokens = self.patch_embed(images)
+
+ if isinstance(patch_tokens, dict):
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
+
+ _, P, C = patch_tokens.shape
+
+ # Expand camera and register tokens to match batch size and sequence length
+ camera_token = slice_expand_and_flatten(self.camera_token, B, S)
+ register_token = slice_expand_and_flatten(self.register_token, B, S)
+
+ # Concatenate special tokens with patch tokens
+ tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
+
+ pos = None
+ if self.rope is not None:
+ pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
+
+ if self.patch_start_idx > 0:
+ # do not use position embedding for special tokens (camera and register tokens)
+ # so set pos to 0 for the special tokens
+ pos = pos + 1
+ pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
+ pos = torch.cat([pos_special, pos], dim=1)
+
+ # update P because we added special tokens
+ _, P, C = tokens.shape
+
+ frame_idx = 0
+ global_idx = 0
+ output_list = []
+
+ for _ in range(self.aa_block_num):
+ for attn_type in self.aa_order:
+ if attn_type == "frame":
+ tokens, frame_idx, frame_intermediates = self._process_frame_attention(
+ tokens, B, S, P, C, frame_idx, pos=pos
+ )
+ elif attn_type == "global":
+ tokens, global_idx, global_intermediates = self._process_global_attention(
+ tokens, B, S, P, C, global_idx, pos=pos
+ )
+ else:
+ raise ValueError(f"Unknown attention type: {attn_type}")
+
+ for i in range(len(frame_intermediates)):
+ # concat frame and global intermediates, [B x S x P x 2C]
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
+ output_list.append(concat_inter)
+
+ del concat_inter
+ del frame_intermediates
+ del global_intermediates
+ return output_list, self.patch_start_idx
+
+ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
+ """
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
+ """
+ # If needed, reshape tokens or positions:
+ if tokens.shape != (B * S, P, C):
+ tokens = tokens.view(B, S, P, C).view(B * S, P, C)
+
+ if pos is not None and pos.shape != (B * S, P, 2):
+ pos = pos.view(B, S, P, 2).view(B * S, P, 2)
+
+ intermediates = []
+
+ # by default, self.aa_block_size=1, which processes one block at a time
+ for _ in range(self.aa_block_size):
+ if self.training:
+ tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=False)
+ else:
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
+ frame_idx += 1
+ intermediates.append(tokens.view(B, S, P, C))
+
+ return tokens, frame_idx, intermediates
+
+ def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
+ """
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
+ """
+ if tokens.shape != (B, S * P, C):
+ tokens = tokens.view(B, S, P, C).view(B, S * P, C)
+
+ if pos is not None and pos.shape != (B, S * P, 2):
+ pos = pos.view(B, S, P, 2).view(B, S * P, 2)
+
+ intermediates = []
+
+ # by default, self.aa_block_size=1, which processes one block at a time
+ for _ in range(self.aa_block_size):
+ if self.training:
+ tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=False)
+ else:
+ tokens = self.global_blocks[global_idx](tokens, pos=pos)
+ global_idx += 1
+ intermediates.append(tokens.view(B, S, P, C))
+
+ return tokens, global_idx, intermediates
+
+
+def slice_expand_and_flatten(token_tensor, B, S):
+ """
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
+ 1) Uses the first position (index=0) for the first frame only
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
+ 3) Expands both to match batch size B
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
+ followed by (S-1) second-position tokens
+ 5) Flattens to (B*S, X, C) for processing
+
+ Returns:
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
+ """
+
+ # Slice out the "query" tokens => shape (1, 1, ...)
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
+ # Slice out the "other" tokens => shape (1, S-1, ...)
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
+ # Concatenate => shape (B, S, ...)
+ combined = torch.cat([query, others], dim=1)
+
+ # Finally flatten => shape (B*S, ...)
+ combined = combined.view(B * S, *combined.shape[2:])
+ return combined
diff --git a/models/vggt/vggt/models/aggregator_front.py b/models/vggt/vggt/models/aggregator_front.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d7a24398bee9d6ddea7d2a78b3809707dcf3db6
--- /dev/null
+++ b/models/vggt/vggt/models/aggregator_front.py
@@ -0,0 +1,342 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Optional, Tuple, Union, List, Dict, Any
+
+from models.vggt.vggt.layers import PatchEmbed
+from models.vggt.vggt.layers.block import Block
+from models.vggt.vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
+from models.vggt.vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
+from torch.utils.checkpoint import checkpoint
+
+logger = logging.getLogger(__name__)
+
+_RESNET_MEAN = [0.485, 0.456, 0.406]
+_RESNET_STD = [0.229, 0.224, 0.225]
+
+
+class Aggregator(nn.Module):
+ """
+ The Aggregator applies alternating-attention over input frames,
+ as described in VGGT: Visual Geometry Grounded Transformer.
+
+
+ Args:
+ img_size (int): Image size in pixels.
+ patch_size (int): Size of each patch for PatchEmbed.
+ embed_dim (int): Dimension of the token embeddings.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
+ num_register_tokens (int): Number of register tokens.
+ block_fn (nn.Module): The block type used for attention (Block by default).
+ qkv_bias (bool): Whether to include bias in QKV projections.
+ proj_bias (bool): Whether to include bias in the output projection.
+ ffn_bias (bool): Whether to include bias in MLP layers.
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
+ aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
+ qk_norm (bool): Whether to apply QK normalization.
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
+ init_values (float): Init scale for layer scale.
+ """
+
+ def __init__(
+ self,
+ img_size=518,
+ patch_size=14,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4.0,
+ num_register_tokens=4,
+ block_fn=Block,
+ qkv_bias=True,
+ proj_bias=True,
+ ffn_bias=True,
+ patch_embed="dinov2_vitl14_reg",
+ aa_order=["frame", "global"],
+ aa_block_size=1,
+ qk_norm=True,
+ rope_freq=100,
+ init_values=0.01,
+ ):
+ super().__init__()
+
+ # self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
+
+ self.use_reentrant = False
+ # Initialize rotary position embedding if frequency > 0
+ self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
+ self.position_getter = PositionGetter() if self.rope is not None else None
+
+ self.frame_blocks = nn.ModuleList(
+ [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ rope=self.rope,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.global_blocks = nn.ModuleList(
+ [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ rope=self.rope,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.depth = depth
+ self.aa_order = aa_order
+ self.patch_size = patch_size
+ self.aa_block_size = aa_block_size
+
+ # Validate that depth is divisible by aa_block_size
+ if self.depth % self.aa_block_size != 0:
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
+
+ self.aa_block_num = self.depth // self.aa_block_size
+
+ # Note: We have two camera tokens, one for the first frame and one for the rest
+ # The same applies for register tokens
+ self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
+ self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
+ self.scale_shift_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
+
+ # The patch tokens start after the camera and register tokens
+ self.patch_start_idx = 1 + num_register_tokens + 1
+
+ # Initialize parameters with small values
+ nn.init.normal_(self.camera_token, std=1e-6)
+ nn.init.normal_(self.register_token, std=1e-6)
+ nn.init.normal_(self.scale_shift_token, std=1e-6)
+
+ # Register normalization constants as buffers
+ for name, value in (
+ ("_resnet_mean", _RESNET_MEAN),
+ ("_resnet_std", _RESNET_STD),
+ ):
+ self.register_buffer(
+ name,
+ torch.FloatTensor(value).view(1, 1, 3, 1, 1),
+ persistent=False,
+ )
+
+ def __build_patch_embed__(
+ self,
+ patch_embed,
+ img_size,
+ patch_size,
+ num_register_tokens,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ block_chunks=0,
+ init_values=1.0,
+ embed_dim=1024,
+ ):
+ """
+ Build the patch embed layer. If 'conv', we use a
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
+ """
+
+ if "conv" in patch_embed:
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
+ else:
+ vit_models = {
+ "dinov2_vitl14_reg": vit_large,
+ "dinov2_vitb14_reg": vit_base,
+ "dinov2_vits14_reg": vit_small,
+ "dinov2_vitg2_reg": vit_giant2,
+ }
+
+ self.patch_embed = vit_models[patch_embed](
+ img_size=img_size,
+ patch_size=patch_size,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ block_chunks=block_chunks,
+ init_values=init_values,
+ )
+
+ # Disable gradient updates for mask token
+ if hasattr(self.patch_embed, "mask_token"):
+ self.patch_embed.mask_token.requires_grad_(False)
+
+ def forward(
+ self,
+ images: torch.Tensor,
+ patch_tokens: torch.Tensor,
+ ) -> Tuple[List[torch.Tensor], int]:
+ """
+ Args:
+ images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
+
+ Returns:
+ (list[torch.Tensor], int):
+ The list of outputs from the attention blocks,
+ and the patch_start_idx indicating where patch tokens begin.
+ """
+ B, S, C_in, H, W = images.shape
+
+ # if C_in != 3:
+ # raise ValueError(f"Expected 3 input channels, got {C_in}")
+
+ # # Normalize images and reshape for patch embed
+ # images = (images - self._resnet_mean) / self._resnet_std
+
+ # # Reshape to [B*S, C, H, W] for patch embedding
+ # images = images.view(B * S, C_in, H, W)
+ # patch_tokens = self.patch_embed(images)
+
+ if isinstance(patch_tokens, dict):
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
+
+ _, P, C = patch_tokens.shape
+ # Expand camera and register tokens to match batch size and sequence length
+ camera_token = slice_expand_and_flatten(self.camera_token, B, S)
+ register_token = slice_expand_and_flatten(self.register_token, B, S)
+ scale_shift_token = slice_expand_and_flatten(self.scale_shift_token, B, S)
+
+ # Concatenate special tokens with patch tokens
+ tokens = torch.cat([camera_token, register_token, scale_shift_token, patch_tokens], dim=1)
+
+ pos = None
+ if self.rope is not None:
+ pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
+
+ if self.patch_start_idx > 0:
+ # do not use position embedding for special tokens (camera and register tokens)
+ # so set pos to 0 for the special tokens
+ pos = pos + 1
+ pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
+ pos = torch.cat([pos_special, pos], dim=1)
+
+ # update P because we added special tokens
+ _, P, C = tokens.shape
+
+ frame_idx = 0
+ global_idx = 0
+ output_list = []
+
+ for _ in range(self.aa_block_num):
+ for attn_type in self.aa_order:
+ if attn_type == "frame":
+ tokens, frame_idx, frame_intermediates = self._process_frame_attention(
+ tokens, B, S, P, C, frame_idx, pos=pos
+ )
+ elif attn_type == "global":
+ tokens, global_idx, global_intermediates = self._process_global_attention(
+ tokens, B, S, P, C, global_idx, pos=pos
+ )
+ else:
+ raise ValueError(f"Unknown attention type: {attn_type}")
+
+ for i in range(len(frame_intermediates)):
+ # concat frame and global intermediates, [B x S x P x 2C]
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
+ output_list.append(concat_inter)
+
+ del concat_inter
+ del frame_intermediates
+ del global_intermediates
+ return output_list, self.patch_start_idx
+
+ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
+ """
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
+ """
+ # If needed, reshape tokens or positions:
+ if tokens.shape != (B * S, P, C):
+ tokens = tokens.view(B, S, P, C).view(B * S, P, C)
+
+ if pos is not None and pos.shape != (B * S, P, 2):
+ pos = pos.view(B, S, P, 2).view(B * S, P, 2)
+
+ intermediates = []
+
+ # by default, self.aa_block_size=1, which processes one block at a time
+ for _ in range(self.aa_block_size):
+ if self.training:
+ tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
+ else:
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
+ frame_idx += 1
+ intermediates.append(tokens.view(B, S, P, C))
+
+ return tokens, frame_idx, intermediates
+
+ def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
+ """
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
+ """
+ if tokens.shape != (B, S * P, C):
+ tokens = tokens.view(B, S, P, C).view(B, S * P, C)
+
+ if pos is not None and pos.shape != (B, S * P, 2):
+ pos = pos.view(B, S, P, 2).view(B, S * P, 2)
+
+ intermediates = []
+
+ # by default, self.aa_block_size=1, which processes one block at a time
+ for _ in range(self.aa_block_size):
+ if self.training:
+ tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
+ else:
+ tokens = self.global_blocks[global_idx](tokens, pos=pos)
+ global_idx += 1
+ intermediates.append(tokens.view(B, S, P, C))
+
+ return tokens, global_idx, intermediates
+
+
+def slice_expand_and_flatten(token_tensor, B, S):
+ """
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
+ 1) Uses the first position (index=0) for the first frame only
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
+ 3) Expands both to match batch size B
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
+ followed by (S-1) second-position tokens
+ 5) Flattens to (B*S, X, C) for processing
+
+ Returns:
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
+ """
+
+ # Slice out the "query" tokens => shape (1, 1, ...)
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
+ # Slice out the "other" tokens => shape (1, S-1, ...)
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
+ # Concatenate => shape (B, S, ...)
+ combined = torch.cat([query, others], dim=1)
+
+ # Finally flatten => shape (B*S, ...)
+ combined = combined.view(B * S, *combined.shape[2:])
+ return combined
diff --git a/models/vggt/vggt/models/tracker_front.py b/models/vggt/vggt/models/tracker_front.py
new file mode 100644
index 0000000000000000000000000000000000000000..88a7582b5ee58ae31a1bc334fb76544c51ea56d9
--- /dev/null
+++ b/models/vggt/vggt/models/tracker_front.py
@@ -0,0 +1,132 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+from huggingface_hub import PyTorchModelHubMixin # used for model hub
+
+from models.vggt.vggt.models.aggregator_front import Aggregator
+from models.vggt.vggt.heads.camera_head import CameraHead
+from models.vggt.vggt.heads.scale_head import ScaleHead
+from einops import rearrange
+from models.vggt.vggt.utils.loss import compute_loss
+from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
+import torch.nn.functional as F
+
+class FrontTracker(nn.Module, PyTorchModelHubMixin):
+ def __init__(self, img_size=518,
+ patch_size=14, embed_dim=1024, base_model=None, use_checkpoint=True, use_scale_head=False):
+ super().__init__()
+
+ self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
+ self.camera_head = CameraHead(dim_in=2 * embed_dim)
+ if use_scale_head:
+ self.scale_head = ScaleHead(dim_in=2 * embed_dim)
+ else:
+ self.scale_head = None
+ self.base_model = base_model
+ self.use_checkpoint = use_checkpoint
+ self.intermediate_layers = [4, 11, 17, 23]
+ self.residual_proj = nn.ModuleList([nn.Linear(2048, 1024) for _ in range(len(self.intermediate_layers))])
+ # init the residual proj
+ for i in range(len(self.intermediate_layers)):
+ nn.init.xavier_uniform_(self.residual_proj[i].weight)
+ nn.init.zeros_(self.residual_proj[i].bias)
+ # self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
+ # self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1")
+ # self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
+
+ def forward(self,
+ images: torch.Tensor,
+ annots = {},
+ **kwargs):
+ """
+ Forward pass of the FrontTracker model.
+
+ Args:
+ images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
+ query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
+ Shape: [N, 2] or [B, N, 2], where N is the number of query points.
+ Default: None
+
+ Returns:
+ dict: A dictionary containing the following predictions:
+ - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
+ - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
+ - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
+ - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
+ - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
+ - images (torch.Tensor): Original input images, preserved for visualization
+
+ If query_points is provided, also includes:
+ - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
+ - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
+ - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
+ """
+
+ # If without batch dimension, add it
+ if len(images.shape) == 4:
+ images = images.unsqueeze(0)
+ B, T, C, H, W = images.shape
+ images = (images - self.base_model.image_mean) / self.base_model.image_std
+ H_14 = H // 14 * 14
+ W_14 = W // 14 * 14
+ image_14 = F.interpolate(images.view(B*T, C, H, W), (H_14, W_14), mode="bilinear", align_corners=False, antialias=True).view(B, T, C, H_14, W_14)
+
+ with torch.no_grad():
+ features = self.base_model.backbone.get_intermediate_layers(rearrange(image_14, 'b t c h w -> (b t) c h w'),
+ self.base_model.intermediate_layers, return_class_token=True)
+ # aggregate the features with checkpoint
+ aggregated_tokens_list, patch_start_idx = self.aggregator(image_14, patch_tokens=features[-1][0])
+
+ # enhance the features
+ enhanced_features = []
+ for layer_i, layer in enumerate(self.intermediate_layers):
+ # patch_feat_i = features[layer_i][0] + self.residual_proj[layer_i](aggregated_tokens_list[layer][:,:,patch_start_idx:,:].view(B*T, features[layer_i][0].shape[1], -1))
+ patch_feat_i = self.residual_proj[layer_i](aggregated_tokens_list[layer][:,:,patch_start_idx:,:].view(B*T, features[layer_i][0].shape[1], -1))
+ enhance_i = (patch_feat_i, features[layer_i][1])
+ enhanced_features.append(enhance_i)
+
+ predictions = {}
+
+ with torch.cuda.amp.autocast(enabled=False):
+ if self.camera_head is not None:
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
+ predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
+ if self.scale_head is not None:
+ scale_list = self.scale_head(aggregated_tokens_list)
+ predictions["scale"] = scale_list[-1] # scale of the last iteration
+ # Predict points (and mask) with checkpoint
+ output = self.base_model.head(enhanced_features, image_14)
+ points, mask = output
+
+ # Post-process points and mask
+ points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
+ points = self.base_model._remap_points(points) # slightly improves the performance in case of very large output values
+ # prepare the predictions
+ predictions["images"] = (images * self.base_model.image_std + self.base_model.image_mean)*255.0
+ points = F.interpolate(points.permute(0, 3, 1, 2), (H, W), mode="bilinear", align_corners=False, antialias=True).permute(0, 2, 3, 1)
+ predictions["points_map"] = points
+ mask = F.interpolate(mask.unsqueeze(1), (H, W), mode="bilinear", align_corners=False, antialias=True).squeeze(1)
+ predictions["unc_metric"] = mask
+ predictions["pose_enc_list"] = pose_enc_list
+
+ if self.training:
+ loss = compute_loss(predictions, annots)
+ predictions["loss"] = loss
+
+ # rescale the points
+ if self.scale_head is not None:
+ points_scale = points * predictions["scale"].view(B*T, 1, 1, 2)[..., :1]
+ points_scale[..., 2:] += predictions["scale"].view(B*T, 1, 1, 2)[..., 1:]
+ predictions["points_map"] = points_scale
+
+ predictions["poses_pred"] = torch.eye(4)[None].repeat(predictions["images"].shape[1], 1, 1)[None]
+ predictions["poses_pred"][:,:,:3,:4], predictions["intrs"] = pose_encoding_to_extri_intri(predictions["pose_enc_list"][-1],
+ predictions["images"].shape[-2:])
+ return predictions
diff --git a/models/vggt/vggt/models/vggt.py b/models/vggt/vggt/models/vggt.py
new file mode 100644
index 0000000000000000000000000000000000000000..75587dc2cd16ca54466c0200dbfebff06578dbe3
--- /dev/null
+++ b/models/vggt/vggt/models/vggt.py
@@ -0,0 +1,96 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from huggingface_hub import PyTorchModelHubMixin # used for model hub
+
+from vggt.models.aggregator import Aggregator
+from vggt.heads.camera_head import CameraHead
+from vggt.heads.dpt_head import DPTHead
+from vggt.heads.track_head import TrackHead
+
+
+class VGGT(nn.Module, PyTorchModelHubMixin):
+ def __init__(self, img_size=518, patch_size=14, embed_dim=1024):
+ super().__init__()
+
+ self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
+ self.camera_head = CameraHead(dim_in=2 * embed_dim)
+ self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
+ self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1")
+ self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
+
+ def forward(
+ self,
+ images: torch.Tensor,
+ query_points: torch.Tensor = None,
+ ):
+ """
+ Forward pass of the VGGT model.
+
+ Args:
+ images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
+ query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
+ Shape: [N, 2] or [B, N, 2], where N is the number of query points.
+ Default: None
+
+ Returns:
+ dict: A dictionary containing the following predictions:
+ - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
+ - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
+ - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
+ - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
+ - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
+ - images (torch.Tensor): Original input images, preserved for visualization
+
+ If query_points is provided, also includes:
+ - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
+ - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
+ - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
+ """
+
+ # If without batch dimension, add it
+ if len(images.shape) == 4:
+ images = images.unsqueeze(0)
+ if query_points is not None and len(query_points.shape) == 2:
+ query_points = query_points.unsqueeze(0)
+
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images)
+
+ predictions = {}
+
+ with torch.cuda.amp.autocast(enabled=False):
+ if self.camera_head is not None:
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
+ predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
+
+ if self.depth_head is not None:
+ depth, depth_conf = self.depth_head(
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
+ )
+ predictions["depth"] = depth
+ predictions["depth_conf"] = depth_conf
+
+ if self.point_head is not None:
+ pts3d, pts3d_conf = self.point_head(
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
+ )
+ predictions["world_points"] = pts3d
+ predictions["world_points_conf"] = pts3d_conf
+
+ if self.track_head is not None and query_points is not None:
+ track_list, vis, conf = self.track_head(
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points
+ )
+ predictions["track"] = track_list[-1] # track of the last iteration
+ predictions["vis"] = vis
+ predictions["conf"] = conf
+
+ predictions["images"] = images
+
+ return predictions
diff --git a/models/vggt/vggt/models/vggt_moe.py b/models/vggt/vggt/models/vggt_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..f122c4379a56dafdba05a2acb787f3ac8ff1149b
--- /dev/null
+++ b/models/vggt/vggt/models/vggt_moe.py
@@ -0,0 +1,107 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from huggingface_hub import PyTorchModelHubMixin # used for model hub
+
+from models.vggt.vggt.models.aggregator import Aggregator
+from models.vggt.vggt.heads.camera_head import CameraHead
+from models.vggt.vggt.heads.dpt_head import DPTHead
+from models.vggt.vggt.heads.track_head import TrackHead
+from models.vggt.vggt.utils.loss import compute_loss
+from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
+from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
+from models.vggt.vggt.utils.load_fn import preprocess_image
+from einops import rearrange
+import torch.nn.functional as F
+
+class VGGT_MoE(nn.Module, PyTorchModelHubMixin):
+ def __init__(self, img_size=518, patch_size=14, embed_dim=1024):
+ super().__init__()
+
+ self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
+ self.camera_head = CameraHead(dim_in=2 * embed_dim)
+ self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="sigmoid")
+
+ def forward(
+ self,
+ images: torch.Tensor,
+ annots = {},
+ **kwargs):
+ """
+ Forward pass of the VGGT_MoE model.
+
+ Args:
+ images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
+ query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
+ Shape: [N, 2] or [B, N, 2], where N is the number of query points.
+ Default: None
+
+ Returns:
+ dict: A dictionary containing the following predictions:
+ - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
+ - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
+ - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
+ - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
+ - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
+ - images (torch.Tensor): Original input images, preserved for visualization
+
+ If query_points is provided, also includes:
+ - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
+ - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
+ - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
+ """
+
+ # If without batch dimension, add it
+ B, T, C, H, W = images.shape
+ images_proc = preprocess_image(images.view(B*T, C, H, W).clone())
+ images_proc = rearrange(images_proc, '(b t) c h w -> b t c h w', b=B, t=T)
+ _, _, _, H_proc, W_proc = images_proc.shape
+
+ if len(images.shape) == 4:
+ images = images.unsqueeze(0)
+
+ with torch.no_grad():
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images_proc)
+
+ predictions = {}
+
+ with torch.cuda.amp.autocast(enabled=False):
+ if self.camera_head is not None:
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
+ predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
+ predictions["pose_enc_list"] = pose_enc_list
+
+ if self.depth_head is not None:
+ depth, depth_conf = self.depth_head(
+ aggregated_tokens_list, images=images_proc, patch_start_idx=patch_start_idx
+ )
+ predictions["depth"] = depth
+ predictions["unc_metric"] = depth_conf.view(B*T, H_proc, W_proc)
+
+ predictions["images"] = (images)*255.0
+ # output the camera pose
+ predictions["poses_pred"] = torch.eye(4)[None].repeat(T, 1, 1)[None]
+ predictions["poses_pred"][:,:,:3,:4], predictions["intrs"] = pose_encoding_to_extri_intri(predictions["pose_enc_list"][-1],
+ images_proc.shape[-2:])
+ predictions["poses_pred"] = torch.inverse(predictions["poses_pred"])
+ points_map = depth_to_points_colmap(depth.view(B*T, H_proc, W_proc), predictions["intrs"].view(B*T, 3, 3))
+ predictions["points_map"] = points_map
+ #NOTE: resize back
+ predictions["points_map"] = F.interpolate(points_map.permute(0,3,1,2),
+ size=(H, W), mode='bilinear', align_corners=True).permute(0,2,3,1)
+ predictions["unc_metric"] = F.interpolate(predictions["unc_metric"][:,None],
+ size=(H, W), mode='bilinear', align_corners=True)[:,0]
+ predictions["intrs"][..., :1, :] *= W/W_proc
+ predictions["intrs"][..., 1:2, :] *= H/H_proc
+
+ if self.training:
+ loss = compute_loss(predictions, annots)
+ predictions["loss"] = loss
+
+ return predictions
diff --git a/models/vggt/vggt/utils/__init__.py b/models/vggt/vggt/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0519ecba6ea913e21689ec692e81e9e4973fbf73
--- /dev/null
+++ b/models/vggt/vggt/utils/__init__.py
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/models/vggt/vggt/utils/geometry.py b/models/vggt/vggt/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ebd25dbc6cac6b0095956524c4f0628410dd5cb
--- /dev/null
+++ b/models/vggt/vggt/utils/geometry.py
@@ -0,0 +1,166 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import torch
+import numpy as np
+
+
+def unproject_depth_map_to_point_map(
+ depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
+) -> np.ndarray:
+ """
+ Unproject a batch of depth maps to 3D world coordinates.
+
+ Args:
+ depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
+ extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
+ intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
+
+ Returns:
+ np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
+ """
+ if isinstance(depth_map, torch.Tensor):
+ depth_map = depth_map.cpu().numpy()
+ if isinstance(extrinsics_cam, torch.Tensor):
+ extrinsics_cam = extrinsics_cam.cpu().numpy()
+ if isinstance(intrinsics_cam, torch.Tensor):
+ intrinsics_cam = intrinsics_cam.cpu().numpy()
+
+ world_points_list = []
+ for frame_idx in range(depth_map.shape[0]):
+ cur_world_points, _, _ = depth_to_world_coords_points(
+ depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
+ )
+ world_points_list.append(cur_world_points)
+ world_points_array = np.stack(world_points_list, axis=0)
+
+ return world_points_array
+
+
+def depth_to_world_coords_points(
+ depth_map: np.ndarray,
+ extrinsic: np.ndarray,
+ intrinsic: np.ndarray,
+ eps=1e-8,
+) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Convert a depth map to world coordinates.
+
+ Args:
+ depth_map (np.ndarray): Depth map of shape (H, W).
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
+ extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
+ """
+ if depth_map is None:
+ return None, None, None
+
+ # Valid depth mask
+ point_mask = depth_map > eps
+
+ # Convert depth map to camera coordinates
+ cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
+
+ # Multiply with the inverse of extrinsic matrix to transform to world coordinates
+ # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
+ cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
+
+ R_cam_to_world = cam_to_world_extrinsic[:3, :3]
+ t_cam_to_world = cam_to_world_extrinsic[:3, 3]
+
+ # Apply the rotation and translation to the camera coordinates
+ world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
+ # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
+
+ return world_coords_points, cam_coords_points, point_mask
+
+
+def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
+ """
+ Convert a depth map to camera coordinates.
+
+ Args:
+ depth_map (np.ndarray): Depth map of shape (H, W).
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
+ """
+ H, W = depth_map.shape
+ assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
+ assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
+
+ # Intrinsic parameters
+ fu, fv = intrinsic[0, 0], intrinsic[1, 1]
+ cu, cv = intrinsic[0, 2], intrinsic[1, 2]
+
+ # Generate grid of pixel coordinates
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
+
+ # Unproject to camera coordinates
+ x_cam = (u - cu) * depth_map / fu
+ y_cam = (v - cv) * depth_map / fv
+ z_cam = depth_map
+
+ # Stack to form camera coordinates
+ cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
+
+ return cam_coords
+
+
+def closed_form_inverse_se3(se3, R=None, T=None):
+ """
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
+
+ If `R` and `T` are provided, they must correspond to the rotation and translation
+ components of `se3`. Otherwise, they will be extracted from `se3`.
+
+ Args:
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
+ R (optional): Nx3x3 array or tensor of rotation matrices.
+ T (optional): Nx3x1 array or tensor of translation vectors.
+
+ Returns:
+ Inverted SE3 matrices with the same type and device as `se3`.
+
+ Shapes:
+ se3: (N, 4, 4)
+ R: (N, 3, 3)
+ T: (N, 3, 1)
+ """
+ # Check if se3 is a numpy array or a torch tensor
+ is_numpy = isinstance(se3, np.ndarray)
+
+ # Validate shapes
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
+
+ # Extract R and T if not provided
+ if R is None:
+ R = se3[:, :3, :3] # (N,3,3)
+ if T is None:
+ T = se3[:, :3, 3:] # (N,3,1)
+
+ # Transpose R
+ if is_numpy:
+ # Compute the transpose of the rotation for NumPy
+ R_transposed = np.transpose(R, (0, 2, 1))
+ # -R^T t for NumPy
+ top_right = -np.matmul(R_transposed, T)
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
+ else:
+ R_transposed = R.transpose(1, 2) # (N,3,3)
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
+
+ inverted_matrix[:, :3, :3] = R_transposed
+ inverted_matrix[:, :3, 3:] = top_right
+
+ return inverted_matrix
diff --git a/models/vggt/vggt/utils/load_fn.py b/models/vggt/vggt/utils/load_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..35e2ddfa86e6863993afe6833eb5df5c4a419c73
--- /dev/null
+++ b/models/vggt/vggt/utils/load_fn.py
@@ -0,0 +1,200 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from PIL import Image
+from torchvision import transforms as TF
+
+
+def load_and_preprocess_images(image_path_list, mode="crop"):
+ """
+ A quick start function to load and preprocess images for model input.
+ This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
+
+ Args:
+ image_path_list (list): List of paths to image files
+ mode (str, optional): Preprocessing mode, either "crop" or "pad".
+ - "crop" (default): Sets width to 518px and center crops height if needed.
+ - "pad": Preserves all pixels by making the largest dimension 518px
+ and padding the smaller dimension to reach a square shape.
+
+ Returns:
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
+
+ Raises:
+ ValueError: If the input list is empty or if mode is invalid
+
+ Notes:
+ - Images with different dimensions will be padded with white (value=1.0)
+ - A warning is printed when images have different shapes
+ - When mode="crop": The function ensures width=518px while maintaining aspect ratio
+ and height is center-cropped if larger than 518px
+ - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
+ and the smaller dimension is padded to reach a square shape (518x518)
+ - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
+ """
+ # Check for empty list
+ if len(image_path_list) == 0:
+ raise ValueError("At least 1 image is required")
+
+ # Validate mode
+ if mode not in ["crop", "pad"]:
+ raise ValueError("Mode must be either 'crop' or 'pad'")
+
+ images = []
+ shapes = set()
+ to_tensor = TF.ToTensor()
+ target_size = 518
+
+ # First process all images and collect their shapes
+ for image_path in image_path_list:
+
+ # Open image
+ img = Image.open(image_path)
+
+ # If there's an alpha channel, blend onto white background:
+ if img.mode == "RGBA":
+ # Create white background
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
+ # Alpha composite onto the white background
+ img = Image.alpha_composite(background, img)
+
+ # Now convert to "RGB" (this step assigns white for transparent areas)
+ img = img.convert("RGB")
+
+ width, height = img.size
+
+ if mode == "pad":
+ # Make the largest dimension 518px while maintaining aspect ratio
+ if width >= height:
+ new_width = target_size
+ new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14
+ else:
+ new_height = target_size
+ new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14
+ else: # mode == "crop"
+ # Original behavior: set width to 518px
+ new_width = target_size
+ # Calculate height maintaining aspect ratio, divisible by 14
+ new_height = round(height * (new_width / width) / 14) * 14
+
+ # Resize with new dimensions (width, height)
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
+ img = to_tensor(img) # Convert to tensor (0, 1)
+
+ # Center crop height if it's larger than 518 (only in crop mode)
+ if mode == "crop" and new_height > target_size:
+ start_y = (new_height - target_size) // 2
+ img = img[:, start_y : start_y + target_size, :]
+
+ # For pad mode, pad to make a square of target_size x target_size
+ if mode == "pad":
+ h_padding = target_size - img.shape[1]
+ w_padding = target_size - img.shape[2]
+
+ if h_padding > 0 or w_padding > 0:
+ pad_top = h_padding // 2
+ pad_bottom = h_padding - pad_top
+ pad_left = w_padding // 2
+ pad_right = w_padding - pad_left
+
+ # Pad with white (value=1.0)
+ img = torch.nn.functional.pad(
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
+ )
+
+ shapes.add((img.shape[1], img.shape[2]))
+ images.append(img)
+
+ # Check if we have different shapes
+ # In theory our model can also work well with different shapes
+ if len(shapes) > 1:
+ print(f"Warning: Found images with different shapes: {shapes}")
+ # Find maximum dimensions
+ max_height = max(shape[0] for shape in shapes)
+ max_width = max(shape[1] for shape in shapes)
+
+ # Pad images if necessary
+ padded_images = []
+ for img in images:
+ h_padding = max_height - img.shape[1]
+ w_padding = max_width - img.shape[2]
+
+ if h_padding > 0 or w_padding > 0:
+ pad_top = h_padding // 2
+ pad_bottom = h_padding - pad_top
+ pad_left = w_padding // 2
+ pad_right = w_padding - pad_left
+
+ img = torch.nn.functional.pad(
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
+ )
+ padded_images.append(img)
+ images = padded_images
+
+ images = torch.stack(images) # concatenate images
+
+ # Ensure correct shape when single image
+ if len(image_path_list) == 1:
+ # Verify shape is (1, C, H, W)
+ if images.dim() == 3:
+ images = images.unsqueeze(0)
+
+ return images
+
+def preprocess_image(img_tensor, mode="crop", target_size=518):
+ """
+ Preprocess image tensor(s) to target size with crop or pad mode.
+ Args:
+ img_tensor (torch.Tensor): Image tensor of shape (C, H, W) or (T, C, H, W), values in [0, 1]
+ mode (str): 'crop' or 'pad'
+ target_size (int): Target size for width/height
+ Returns:
+ torch.Tensor: Preprocessed image tensor(s), same batch dim as input
+ """
+ if mode not in ["crop", "pad"]:
+ raise ValueError("Mode must be either 'crop' or 'pad'")
+ if img_tensor.dim() == 3:
+ tensors = [img_tensor]
+ squeeze = True
+ elif img_tensor.dim() == 4:
+ tensors = list(img_tensor)
+ squeeze = False
+ else:
+ raise ValueError("Input tensor must be (C, H, W) or (T, C, H, W)")
+ processed = []
+ for img in tensors:
+ C, H, W = img.shape
+ if mode == "pad":
+ if W >= H:
+ new_W = target_size
+ new_H = round(H * (new_W / W) / 14) * 14
+ else:
+ new_H = target_size
+ new_W = round(W * (new_H / H) / 14) * 14
+ out = torch.nn.functional.interpolate(img.unsqueeze(0), size=(new_H, new_W), mode="bicubic", align_corners=False).squeeze(0)
+ h_padding = target_size - new_H
+ w_padding = target_size - new_W
+ pad_top = h_padding // 2
+ pad_bottom = h_padding - pad_top
+ pad_left = w_padding // 2
+ pad_right = w_padding - pad_left
+ if h_padding > 0 or w_padding > 0:
+ out = torch.nn.functional.pad(
+ out, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
+ )
+ else: # crop
+ new_W = target_size
+ new_H = round(H * (new_W / W) / 14) * 14
+ out = torch.nn.functional.interpolate(img.unsqueeze(0), size=(new_H, new_W), mode="bicubic", align_corners=False).squeeze(0)
+ if new_H > target_size:
+ start_y = (new_H - target_size) // 2
+ out = out[:, start_y : start_y + target_size, :]
+ processed.append(out)
+ result = torch.stack(processed)
+ if squeeze:
+ return result[0]
+ return result
diff --git a/models/vggt/vggt/utils/loss.py b/models/vggt/vggt/utils/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..72030360c2a9e0bb3f17fa5f1937ccc9015bf489
--- /dev/null
+++ b/models/vggt/vggt/utils/loss.py
@@ -0,0 +1,123 @@
+# This file contains the loss functions for FrontTracker
+
+import torch
+import torch.nn as nn
+import utils3d
+from models.moge.train.losses import (
+ affine_invariant_global_loss,
+ affine_invariant_local_loss,
+ edge_loss,
+ normal_loss,
+ mask_l2_loss,
+ mask_bce_loss,
+ monitoring,
+)
+import torch.nn.functional as F
+from models.SpaTrackV2.models.utils import pose_enc2mat, matrix_to_quaternion, get_track_points, normalize_rgb
+from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
+from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri, extri_intri_to_pose_encoding
+
+def compute_loss(predictions, annots):
+ """
+ Compute the loss for the FrontTracker model.
+ """
+
+ B, T, C, H, W = predictions["images"].shape
+ H_resize, W_resize = H, W
+
+ if "poses_gt" in annots.keys():
+ intrs, c2w_traj_gt = pose_enc2mat(annots["poses_gt"],
+ H_resize, W_resize, min(H, W))
+ else:
+ c2w_traj_gt = None
+
+ if "intrs_gt" in annots.keys():
+ intrs = annots["intrs_gt"].view(B, T, 3, 3)
+ fx_factor = W_resize / W
+ fy_factor = H_resize / H
+ intrs[:,:,0,:] *= fx_factor
+ intrs[:,:,1,:] *= fy_factor
+
+ if "depth_gt" in annots.keys():
+
+ metric_depth_gt = annots['depth_gt'].view(B*T, 1, H, W)
+ metric_depth_gt = F.interpolate(metric_depth_gt,
+ size=(H_resize, W_resize), mode='nearest')
+
+ _depths = metric_depth_gt[metric_depth_gt > 0].reshape(-1)
+ q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values
+ q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values
+ iqr = q75 - q25
+ upper_bound = (q75 + 0.8*iqr).clamp(min=1e-6, max=10*q25)
+ _depth_roi = torch.tensor(
+ [1e-1, upper_bound.item()],
+ dtype=metric_depth_gt.dtype,
+ device=metric_depth_gt.device
+ )
+ mask_roi = (metric_depth_gt > _depth_roi[0]) & (metric_depth_gt < _depth_roi[1])
+ # fin mask
+ gt_mask_fin = ((metric_depth_gt > 0)*(mask_roi)).float()
+ # filter the sky
+ inf_thres = 50*q25.clamp(min=200, max=1e3)
+ gt_mask_inf = (metric_depth_gt > inf_thres).float()
+ # gt mask
+ gt_mask = (metric_depth_gt > 0)*(metric_depth_gt < 10*q25)
+
+ points_map_gt = depth_to_points_colmap(metric_depth_gt.squeeze(1), intrs.view(B*T, 3, 3))
+
+ if annots["syn_real"] == 1:
+ ln_msk_l2, _ = mask_l2_loss(predictions["unc_metric"], gt_mask_fin[:,0], gt_mask_inf[:,0])
+ ln_msk_l2 = 50*ln_msk_l2.mean()
+ else:
+ ln_msk_l2 = 0 * points_map_gt.mean()
+
+ # loss1: global invariant loss
+ ln_depth_glob, _, gt_metric_scale, gt_metric_shift = affine_invariant_global_loss(predictions["points_map"], points_map_gt, gt_mask[:,0], align_resolution=32)
+ ln_depth_glob = 100*ln_depth_glob.mean()
+ # loss2: edge loss
+ ln_edge, _ = edge_loss(predictions["points_map"], points_map_gt, gt_mask[:,0])
+ ln_edge = ln_edge.mean()
+ # loss3: normal loss
+ ln_normal, _ = normal_loss(predictions["points_map"], points_map_gt, gt_mask[:,0])
+ ln_normal = ln_normal.mean()
+ #NOTE: loss4: consistent loss
+ norm_rescale = gt_metric_scale.mean()
+ points_map_gt_cons = points_map_gt.clone() / norm_rescale
+ if "scale" in predictions.keys():
+ scale_ = predictions["scale"].view(B*T, 2, 1, 1)[:,:1]
+ shift_ = predictions["scale"].view(B*T, 2, 1, 1)[:,1:]
+ else:
+ scale_ = torch.ones_like(predictions["points_map"])
+ shift_ = torch.zeros_like(predictions["points_map"])[..., 2:]
+
+ points_pred_cons = predictions["points_map"] * scale_
+ points_pred_cons[..., 2:] += shift_
+ pred_mask = predictions["unc_metric"].clone().clamp(min=5e-2)
+ ln_cons = torch.abs(points_pred_cons - points_map_gt_cons).norm(dim=-1) * pred_mask - 0.05 * torch.log(pred_mask)
+ ln_cons = 0.5*ln_cons[(1-gt_mask_inf.squeeze()).bool()].clamp(max=100).mean()
+ # loss5: scale shift loss
+ if "scale" in predictions.keys():
+ ln_scale_shift = torch.abs(scale_.squeeze() - gt_metric_scale / norm_rescale) + torch.abs(shift_.squeeze() - gt_metric_shift[:,2] / norm_rescale)
+ ln_scale_shift = 10*ln_scale_shift.mean()
+ else:
+ ln_scale_shift = 0 * ln_cons.mean()
+ # loss6: pose loss
+ c2w_traj_gt[...,:3, 3] /= norm_rescale
+ ln_pose = 0
+ for i_t, pose_enc_i in enumerate(predictions["pose_enc_list"]):
+ pose_enc_gt = extri_intri_to_pose_encoding(torch.inverse(c2w_traj_gt)[...,:3,:4], intrs, predictions["images"].shape[-2:])
+ T_loss = torch.abs(pose_enc_i[..., :3] - pose_enc_gt[..., :3]).mean()
+ R_loss = torch.abs(pose_enc_i[..., 3:7] - pose_enc_gt[..., 3:7]).mean()
+ K_loss = torch.abs(pose_enc_i[..., 7:] - pose_enc_gt[..., 7:]).mean()
+ pose_loss_i = 25*(T_loss + R_loss) + K_loss
+ ln_pose += 0.8**(len(predictions["pose_enc_list"]) - i_t - 1)*(pose_loss_i)
+ ln_pose = 5*ln_pose
+ if annots["syn_real"] == 1:
+ loss = ln_depth_glob + ln_edge + ln_normal + ln_cons + ln_scale_shift + ln_pose + ln_msk_l2
+ else:
+ loss = ln_cons + ln_pose
+ ln_scale_shift = 0*ln_scale_shift
+ return {"loss": loss, "ln_depth_glob": ln_depth_glob, "ln_edge": ln_edge, "ln_normal": ln_normal,
+ "ln_cons": ln_cons, "ln_scale_shift": ln_scale_shift,
+ "ln_pose": ln_pose, "ln_msk_l2": ln_msk_l2, "norm_scale": norm_rescale}
+
diff --git a/models/vggt/vggt/utils/pose_enc.py b/models/vggt/vggt/utils/pose_enc.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f98b0878cb13451b8cdb80074349cbf2644c5fa
--- /dev/null
+++ b/models/vggt/vggt/utils/pose_enc.py
@@ -0,0 +1,130 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from .rotation import quat_to_mat, mat_to_quat
+
+
+def extri_intri_to_pose_encoding(
+ extrinsics,
+ intrinsics,
+ image_size_hw=None, # e.g., (256, 512)
+ pose_encoding_type="absT_quaR_FoV",
+):
+ """Convert camera extrinsics and intrinsics to a compact pose encoding.
+
+ This function transforms camera parameters into a unified pose encoding format,
+ which can be used for various downstream tasks like pose prediction or representation.
+
+ Args:
+ extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
+ where B is batch size and S is sequence length.
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
+ The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
+ intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
+ Defined in pixels, with format:
+ [[fx, 0, cx],
+ [0, fy, cy],
+ [0, 0, 1]]
+ where fx, fy are focal lengths and (cx, cy) is the principal point
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
+ Required for computing field of view values. For example: (256, 512).
+ pose_encoding_type (str): Type of pose encoding to use. Currently only
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
+
+ Returns:
+ torch.Tensor: Encoded camera pose parameters with shape BxSx9.
+ For "absT_quaR_FoV" type, the 9 dimensions are:
+ - [:3] = absolute translation vector T (3D)
+ - [3:7] = rotation as quaternion quat (4D)
+ - [7:] = field of view (2D)
+ """
+
+ # extrinsics: BxSx3x4
+ # intrinsics: BxSx3x3
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
+ T = extrinsics[:, :, :3, 3] # BxSx3
+
+ quat = mat_to_quat(R)
+ # Note the order of h and w here
+ H, W = image_size_hw
+ fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
+ fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
+ pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
+ else:
+ raise NotImplementedError
+
+ return pose_encoding
+
+
+def pose_encoding_to_extri_intri(
+ pose_encoding,
+ image_size_hw=None, # e.g., (256, 512)
+ pose_encoding_type="absT_quaR_FoV",
+ build_intrinsics=True,
+):
+ """Convert a pose encoding back to camera extrinsics and intrinsics.
+
+ This function performs the inverse operation of extri_intri_to_pose_encoding,
+ reconstructing the full camera parameters from the compact encoding.
+
+ Args:
+ pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
+ where B is batch size and S is sequence length.
+ For "absT_quaR_FoV" type, the 9 dimensions are:
+ - [:3] = absolute translation vector T (3D)
+ - [3:7] = rotation as quaternion quat (4D)
+ - [7:] = field of view (2D)
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
+ Required for reconstructing intrinsics from field of view values.
+ For example: (256, 512).
+ pose_encoding_type (str): Type of pose encoding used. Currently only
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
+ build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
+ If False, only extrinsics are returned and intrinsics will be None.
+
+ Returns:
+ tuple: (extrinsics, intrinsics)
+ - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
+ transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
+ a 3x1 translation vector.
+ - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
+ or None if build_intrinsics is False. Defined in pixels, with format:
+ [[fx, 0, cx],
+ [0, fy, cy],
+ [0, 0, 1]]
+ where fx, fy are focal lengths and (cx, cy) is the principal point,
+ assumed to be at the center of the image (W/2, H/2).
+ """
+
+ intrinsics = None
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ T = pose_encoding[..., :3]
+ quat = pose_encoding[..., 3:7]
+ fov_h = pose_encoding[..., 7]
+ fov_w = pose_encoding[..., 8]
+
+ R = quat_to_mat(quat)
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
+
+ if build_intrinsics:
+ H, W = image_size_hw
+ fy = (H / 2.0) / torch.tan(fov_h / 2.0)
+ fx = (W / 2.0) / torch.tan(fov_w / 2.0)
+ intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
+ intrinsics[..., 0, 0] = fx
+ intrinsics[..., 1, 1] = fy
+ intrinsics[..., 0, 2] = W / 2
+ intrinsics[..., 1, 2] = H / 2
+ intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
+ else:
+ raise NotImplementedError
+
+ return extrinsics, intrinsics
diff --git a/models/vggt/vggt/utils/rotation.py b/models/vggt/vggt/utils/rotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..657583e6915437c824c192d51939990b589a14fa
--- /dev/null
+++ b/models/vggt/vggt/utils/rotation.py
@@ -0,0 +1,138 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
+
+import torch
+import numpy as np
+import torch.nn.functional as F
+
+
+def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Quaternion Order: XYZW or say ijkr, scalar-last
+
+ Convert rotations given as quaternions to rotation matrices.
+ Args:
+ quaternions: quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ i, j, k, r = torch.unbind(quaternions, -1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part last, as tensor of shape (..., 4).
+ Quaternion Order: XYZW or say ijkr, scalar-last
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
+
+ # Convert from rijk to ijkr
+ out = out[..., [1, 2, 3, 0]]
+
+ out = standardize_quaternion(out)
+
+ return out
+
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ if torch.is_grad_enabled():
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ else:
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
+ return ret
+
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
diff --git a/models/vggt/vggt/utils/visual_track.py b/models/vggt/vggt/utils/visual_track.py
new file mode 100644
index 0000000000000000000000000000000000000000..796c114ccba00b5f7850e04b9444a6cd5c44b154
--- /dev/null
+++ b/models/vggt/vggt/utils/visual_track.py
@@ -0,0 +1,239 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import cv2
+import torch
+import numpy as np
+import os
+
+
+def color_from_xy(x, y, W, H, cmap_name="hsv"):
+ """
+ Map (x, y) -> color in (R, G, B).
+ 1) Normalize x,y to [0,1].
+ 2) Combine them into a single scalar c in [0,1].
+ 3) Use matplotlib's colormap to convert c -> (R,G,B).
+
+ You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
+ """
+ import matplotlib.cm
+ import matplotlib.colors
+
+ x_norm = x / max(W - 1, 1)
+ y_norm = y / max(H - 1, 1)
+ # Simple combination:
+ c = (x_norm + y_norm) / 2.0
+
+ cmap = matplotlib.cm.get_cmap(cmap_name)
+ # cmap(c) -> (r,g,b,a) in [0,1]
+ rgba = cmap(c)
+ r, g, b = rgba[0], rgba[1], rgba[2]
+ return (r, g, b) # in [0,1], RGB order
+
+
+def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"):
+ """
+ Given all tracks in one sample (b), compute a (N,3) array of RGB color values
+ in [0,255]. The color is determined by the (x,y) position in the first
+ visible frame for each track.
+
+ Args:
+ tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
+ vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
+ image_width, image_height: used for normalizing (x, y).
+ cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
+
+ Returns:
+ track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
+ """
+ S, N, _ = tracks_b.shape
+ track_colors = np.zeros((N, 3), dtype=np.uint8)
+
+ if vis_mask_b is None:
+ # treat all as visible
+ vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
+
+ for i in range(N):
+ # Find first visible frame for track i
+ visible_frames = torch.where(vis_mask_b[:, i])[0]
+ if len(visible_frames) == 0:
+ # track is never visible; just assign black or something
+ track_colors[i] = (0, 0, 0)
+ continue
+
+ first_s = int(visible_frames[0].item())
+ # use that frame's (x,y)
+ x, y = tracks_b[first_s, i].tolist()
+
+ # map (x,y) -> (R,G,B) in [0,1]
+ r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name)
+ # scale to [0,255]
+ r, g, b = int(r * 255), int(g * 255), int(b * 255)
+ track_colors[i] = (r, g, b)
+
+ return track_colors
+
+
+def visualize_tracks_on_images(
+ images,
+ tracks,
+ track_vis_mask=None,
+ out_dir="track_visuals_concat_by_xy",
+ image_format="CHW", # "CHW" or "HWC"
+ normalize_mode="[0,1]",
+ cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
+ frames_per_row=4, # New parameter for grid layout
+ save_grid=True, # Flag to control whether to save the grid image
+):
+ """
+ Visualizes frames in a grid layout with specified frames per row.
+ Each track's color is determined by its (x,y) position
+ in the first visible frame (or frame 0 if always visible).
+ Finally convert the BGR result to RGB before saving.
+ Also saves each individual frame as a separate PNG file.
+
+ Args:
+ images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
+ tracks: torch.Tensor (S, N, 2), last dim = (x, y).
+ track_vis_mask: torch.Tensor (S, N) or None.
+ out_dir: folder to save visualizations.
+ image_format: "CHW" or "HWC".
+ normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
+ cmap_name: a matplotlib colormap name for color_from_xy.
+ frames_per_row: number of frames to display in each row of the grid.
+ save_grid: whether to save all frames in one grid image.
+
+ Returns:
+ None (saves images in out_dir).
+ """
+
+ if len(tracks.shape) == 4:
+ tracks = tracks.squeeze(0)
+ images = images.squeeze(0)
+ if track_vis_mask is not None:
+ track_vis_mask = track_vis_mask.squeeze(0)
+
+ import matplotlib
+
+ matplotlib.use("Agg") # for non-interactive (optional)
+
+ os.makedirs(out_dir, exist_ok=True)
+
+ S = images.shape[0]
+ _, N, _ = tracks.shape # (S, N, 2)
+
+ # Move to CPU
+ images = images.cpu().clone()
+ tracks = tracks.cpu().clone()
+ if track_vis_mask is not None:
+ track_vis_mask = track_vis_mask.cpu().clone()
+
+ # Infer H, W from images shape
+ if image_format == "CHW":
+ # e.g. images[s].shape = (3, H, W)
+ H, W = images.shape[2], images.shape[3]
+ else:
+ # e.g. images[s].shape = (H, W, 3)
+ H, W = images.shape[1], images.shape[2]
+
+ # Pre-compute the color for each track i based on first visible position
+ track_colors_rgb = get_track_colors_by_position(
+ tracks, # shape (S, N, 2)
+ vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
+ image_width=W,
+ image_height=H,
+ cmap_name=cmap_name,
+ )
+
+ # We'll accumulate each frame's drawn image in a list
+ frame_images = []
+
+ for s in range(S):
+ # shape => either (3, H, W) or (H, W, 3)
+ img = images[s]
+
+ # Convert to (H, W, 3)
+ if image_format == "CHW":
+ img = img.permute(1, 2, 0) # (H, W, 3)
+ # else "HWC", do nothing
+
+ img = img.numpy().astype(np.float32)
+
+ # Scale to [0,255] if needed
+ if normalize_mode == "[0,1]":
+ img = np.clip(img, 0, 1) * 255.0
+ elif normalize_mode == "[-1,1]":
+ img = (img + 1.0) * 0.5 * 255.0
+ img = np.clip(img, 0, 255.0)
+ # else no normalization
+
+ # Convert to uint8
+ img = img.astype(np.uint8)
+
+ # For drawing in OpenCV, convert to BGR
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+
+ # Draw each visible track
+ cur_tracks = tracks[s] # shape (N, 2)
+ if track_vis_mask is not None:
+ valid_indices = torch.where(track_vis_mask[s])[0]
+ else:
+ valid_indices = range(N)
+
+ cur_tracks_np = cur_tracks.numpy()
+ for i in valid_indices:
+ x, y = cur_tracks_np[i]
+ pt = (int(round(x)), int(round(y)))
+
+ # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
+ R, G, B = track_colors_rgb[i]
+ color_bgr = (int(B), int(G), int(R))
+ cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
+
+ # Convert back to RGB for consistent final saving:
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
+
+ # Save individual frame
+ frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
+ # Convert to BGR for OpenCV imwrite
+ frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(frame_path, frame_bgr)
+
+ frame_images.append(img_rgb)
+
+ # Only create and save the grid image if save_grid is True
+ if save_grid:
+ # Calculate grid dimensions
+ num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
+
+ # Create a grid of images
+ grid_img = None
+ for row in range(num_rows):
+ start_idx = row * frames_per_row
+ end_idx = min(start_idx + frames_per_row, S)
+
+ # Concatenate this row horizontally
+ row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
+
+ # If this row has fewer than frames_per_row images, pad with black
+ if end_idx - start_idx < frames_per_row:
+ padding_width = (frames_per_row - (end_idx - start_idx)) * W
+ padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
+ row_img = np.concatenate([row_img, padding], axis=1)
+
+ # Add this row to the grid
+ if grid_img is None:
+ grid_img = row_img
+ else:
+ grid_img = np.concatenate([grid_img, row_img], axis=0)
+
+ out_path = os.path.join(out_dir, "tracks_grid.png")
+ # Convert back to BGR for OpenCV imwrite
+ grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(out_path, grid_img_bgr)
+ print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
+
+ print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..260f39d22cbbb7d447359e81d60730349bb10fbc
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,34 @@
+gradio==5.31.0
+pydantic==2.10.6
+opencv-python-headless
+einops
+einx
+plotly
+prettytable
+decord
+easydict
+flow_vis
+moviepy==1.0.0
+safetensors
+scikit-learn
+hydra-core
+omegaconf
+albumentations
+matplotlib
+mediapy
+scikit-image
+pycolmap==3.11.1
+git+https://github.com/facebookresearch/segment-anything.git
+git+https://github.com/EasternJournalist/utils3d.git#egg=utils3d
+huggingface_hub
+pyceres
+kornia
+xformers
+timm
+PyJWT
+gdown
+rich
+decord
+ray[default]
+jaxtyping
+transformers
\ No newline at end of file
diff --git a/tapip3d_viz.py b/tapip3d_viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..c47f9b33add62ed61cc5f6c68c091a8add429caf
--- /dev/null
+++ b/tapip3d_viz.py
@@ -0,0 +1,207 @@
+# Copyright (c) TAPIP3D team(https://tapip3d.github.io/)
+
+import os
+import numpy as np
+import cv2
+import json
+import struct
+import zlib
+import argparse
+from einops import rearrange
+from pathlib import Path
+import shutil
+from tempfile import TemporaryDirectory
+import http.server
+import socketserver
+import socket
+import sys
+from http.server import SimpleHTTPRequestHandler
+from socketserver import ThreadingTCPServer
+import base64
+
+viz_html_path = Path(__file__).parent / "viz.html"
+DEFAULT_PORT = 8000
+
+def compress_and_write(filename, header, blob):
+ header_bytes = json.dumps(header).encode("utf-8")
+ header_len = struct.pack(" T H W C") * 255).astype(np.uint8)
+ rgb_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_AREA)
+ for frame in rgb_video])
+
+ depth_video = data["depths"].astype(np.float32)
+ depth_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_NEAREST)
+ for frame in depth_video])
+
+ scale_x = fixed_size[0] / W
+ scale_y = fixed_size[1] / H
+ intrinsics = intrinsics.copy()
+ intrinsics[:, 0, :] *= scale_x
+ intrinsics[:, 1, :] *= scale_y
+
+ min_depth = float(depth_video.min()) * 0.8
+ max_depth = float(depth_video.max()) * 1.5
+
+ depth_normalized = (depth_video - min_depth) / (max_depth - min_depth)
+ depth_int = (depth_normalized * ((1 << 16) - 1)).astype(np.uint16)
+
+ depths_rgb = np.zeros((T, fixed_size[1], fixed_size[0], 3), dtype=np.uint8)
+ depths_rgb[:, :, :, 0] = (depth_int & 0xFF).astype(np.uint8)
+ depths_rgb[:, :, :, 1] = ((depth_int >> 8) & 0xFF).astype(np.uint8)
+
+ first_frame_inv = np.linalg.inv(extrinsics[0])
+ normalized_extrinsics = np.array([first_frame_inv @ ext for ext in extrinsics])
+
+ normalized_trajs = np.zeros_like(trajs)
+ for t in range(T):
+ homogeneous_trajs = np.concatenate([trajs[t], np.ones((trajs.shape[1], 1))], axis=1)
+ transformed_trajs = (first_frame_inv @ homogeneous_trajs.T).T
+ normalized_trajs[t] = transformed_trajs[:, :3]
+
+ # Get conf data from npz file
+ conf_data = data["conf"].item() if "conf" in data else {}
+
+ arrays = {
+ "rgb_video": rgb_video,
+ "depths_rgb": depths_rgb,
+ "intrinsics": intrinsics,
+ "extrinsics": normalized_extrinsics,
+ "inv_extrinsics": np.linalg.inv(normalized_extrinsics),
+ "trajectories": normalized_trajs.astype(np.float32),
+ "cameraZ": 0.0,
+ "visibs": data["visibs"] if "visibs" in data else None,
+ "confs": data["confs"] if "confs" in data else None
+ }
+
+ header = {}
+ blob_parts = []
+ offset = 0
+ for key, arr in arrays.items():
+ if arr is not None:
+ arr = np.ascontiguousarray(arr)
+ arr_bytes = arr.tobytes()
+ header[key] = {
+ "dtype": str(arr.dtype),
+ "shape": arr.shape,
+ "offset": offset,
+ "length": len(arr_bytes)
+ }
+ blob_parts.append(arr_bytes)
+ offset += len(arr_bytes)
+
+ raw_blob = b"".join(blob_parts)
+ compressed_blob = zlib.compress(raw_blob, level=9)
+
+ header["meta"] = {
+ "depthRange": [min_depth, max_depth],
+ "totalFrames": int(T),
+ "resolution": fixed_size,
+ "baseFrameRate": fps,
+ "numTrajectoryPoints": normalized_trajs.shape[1],
+ "fov": float(fov_y),
+ "fov_x": float(fov_x),
+ "original_aspect_ratio": float(original_aspect_ratio),
+ "fixed_aspect_ratio": float(fixed_size[0]/fixed_size[1]),
+ "depthFilter": conf_data.get("depthFilter", {})
+ }
+
+ compress_and_write(output_file, header, compressed_blob)
+
+ if static_html_file is not None:
+ # encode the .bin file to a base64 string
+ with open(output_file, "rb") as f:
+ encoded_blob = base64.b64encode(f.read()).decode("ascii")
+
+ with open(viz_html_path, "r", encoding="utf-8") as f:
+ html_template = f.read()
+
+ injected_html = html_template.replace(
+ "",
+ f"\n"
+ )
+
+ with open(static_html_file, "w", encoding="utf-8") as f:
+ f.write(injected_html)
+
+
+ return None
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('input_file', help='Path to the input .result.npz file')
+ parser.add_argument('--width', '-W', type=int, default=256, help='Target width')
+ parser.add_argument('--height', '-H', type=int, default=192, help='Target height')
+ parser.add_argument('--fps', type=int, default=4, help='Base frame rate for playback')
+ parser.add_argument('--port', '-p', type=int, default=DEFAULT_PORT, help=f'Port to serve the visualization (default: {DEFAULT_PORT})')
+ parser.add_argument('--static-html', '-s', type=str, default=None, help='Path to the static HTML file')
+
+ args = parser.parse_args()
+
+ with TemporaryDirectory() as temp_dir:
+ temp_path = Path(temp_dir)
+ process_point_cloud_data(
+ args.input_file,
+ temp_path / "data.bin",
+ args.static_html,
+ width=args.width,
+ height=args.height,
+ fps=args.fps
+ )
+ if args.static_html is not None:
+ return
+ shutil.copy(viz_html_path, temp_path / "index.html")
+
+ os.chdir(temp_path)
+
+ host = "0.0.0.0"
+ port = int(args.port)
+
+ Handler = SimpleHTTPRequestHandler
+ httpd = None
+
+ try:
+ httpd = ThreadingTCPServer((host, port), Handler)
+ except OSError as e:
+ if e.errno == socket.errno.EADDRINUSE:
+ print(f"Port {port} is already in use, trying a random port...")
+ try:
+ httpd = ThreadingTCPServer((host, 0), Handler)
+ port = httpd.server_address[1] # Get the assigned port
+ except OSError as e2:
+ print(f"Failed to bind to a random port: {e2}", file=sys.stderr)
+ sys.exit(1)
+ else:
+ print(f"Failed to start server: {e}", file=sys.stderr)
+ sys.exit(1)
+
+ if httpd:
+ print(f"Serving at http://{host}:{port}")
+ try:
+ httpd.serve_forever()
+ except KeyboardInterrupt:
+ print("\nServer stopped.")
+ finally:
+ httpd.server_close()
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file