xiaoyuxi
Cleaned history, reset to current state
c8d9d42
raw
history blame
8.35 kB
from models.SpaTrackV2.models.predictor import Predictor
import yaml
import easydict
import os
import numpy as np
import cv2
import torch
import torchvision.transforms as T
from PIL import Image
import io
import moviepy.editor as mp
from models.SpaTrackV2.utils.visualizer import Visualizer
import tqdm
from models.SpaTrackV2.models.utils import get_points_on_a_grid
import glob
from rich import print
import argparse
import decord
from huggingface_hub import hf_hub_download
config = {
"ckpt_dir": "Yuxihenry/SpatialTrackerCkpts", # HuggingFace repo ID
"cfg_dir": "config/magic_infer_moge.yaml",
}
def get_tracker_predictor(output_dir: str, vo_points: int = 756, tracker_model=None):
"""
Initialize and return the tracker predictor and visualizer
Args:
output_dir: Directory to save visualization results
vo_points: Number of points for visual odometry
Returns:
Tuple of (tracker_predictor, visualizer)
"""
viz = True
os.makedirs(output_dir, exist_ok=True)
with open(config["cfg_dir"], "r") as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)
cfg = easydict.EasyDict(cfg)
cfg.out_dir = output_dir
cfg.model.track_num = vo_points
# Check if it's a local path or HuggingFace repo
if tracker_model is not None:
model = tracker_model
model.spatrack.track_num = vo_points
else:
if os.path.exists(config["ckpt_dir"]):
# Local file
model = Predictor.from_pretrained(config["ckpt_dir"], model_cfg=cfg["model"])
else:
# HuggingFace repo - download the model
print(f"Downloading model from HuggingFace: {config['ckpt_dir']}")
checkpoint_path = hf_hub_download(
repo_id=config["ckpt_dir"],
repo_type="model",
filename="SpaTrack3_offline.pth"
)
model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"])
model.eval()
model.to("cuda")
viser = Visualizer(save_dir=cfg.out_dir, grayscale=True,
fps=10, pad_value=0, tracks_leave_trace=5)
return model, viser
def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3):
"""
Run tracking on a video sequence
Args:
model: Tracker predictor instance
viser: Visualizer instance
temp_dir: Directory containing temporary files
video_name: Name of the video file (without extension)
grid_size: Size of the tracking grid
vo_points: Number of points for visual odometry
fps: Frames per second for visualization
"""
# Setup paths
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
mask_path = os.path.join(temp_dir, f"{video_name}.png")
out_dir = os.path.join(temp_dir, "results")
os.makedirs(out_dir, exist_ok=True)
# Load video using decord
video_reader = decord.VideoReader(video_path)
video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2) # Convert to tensor and permute to (N, C, H, W)
# resize make sure the shortest side is 336
h, w = video_tensor.shape[2:]
scale = max(336 / h, 336 / w)
if scale < 1:
new_h, new_w = int(h * scale), int(w * scale)
video_tensor = T.Resize((new_h, new_w))(video_tensor)
video_tensor = video_tensor[::fps].float()
depth_tensor = None
intrs = None
extrs = None
data_npz_load = {}
# Load and process mask
if os.path.exists(mask_path):
mask = cv2.imread(mask_path)
mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
mask = mask.sum(axis=-1)>0
else:
mask = np.ones_like(video_tensor[0,0].numpy())>0
# Get frame dimensions and create grid points
frame_H, frame_W = video_tensor.shape[2:]
grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
# Sample mask values at grid points and filter out points where mask=0
if os.path.exists(mask_path):
grid_pts_int = grid_pts[0].long()
mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
grid_pts = grid_pts[:, mask_values]
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
# run vggt
if os.environ.get("VGGT_DIR", None) is not None:
vggt_model = VGGT()
vggt_model.load_state_dict(torch.load(VGGT_DIR))
vggt_model.eval()
vggt_model = vggt_model.to("cuda")
# process the image tensor
video_tensor = preprocess_image(video_tensor)[None]
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
# Predict attributes including cameras, depth maps, and point maps.
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_tensor.cuda()/255)
pose_enc = vggt_model.camera_head(aggregated_tokens_list)[-1]
# Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, video_tensor.shape[-2:])
# Predict Depth Maps
depth_map, depth_conf = vggt_model.depth_head(aggregated_tokens_list, video_tensor.cuda()/255, ps_idx)
# clear the cache
del vggt_model, aggregated_tokens_list, ps_idx, pose_enc
torch.cuda.empty_cache()
depth_tensor = depth_map.squeeze().cpu().numpy()
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
extrs[:, :3, :4] = extrinsic.squeeze().cpu().numpy()
intrs = intrinsic.squeeze().cpu().numpy()
video_tensor = video_tensor.squeeze()
#NOTE: 20% of the depth is not reliable
# threshold = depth_conf.squeeze().view(-1).quantile(0.5)
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
# Run model inference
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
(
c2w_traj, intrs, point_map, conf_depth,
track3d_pred, track2d_pred, vis_pred, conf_pred, video
) = model.forward(video_tensor, depth=depth_tensor,
intrs=intrs, extrs=extrs,
queries=query_xyt,
fps=1, full_point=False, iters_track=4,
query_no_BA=True, fixed_cam=False, stage=1,
support_frame=len(video_tensor)-1, replace_ratio=0.2)
# Resize results to avoid too large I/O Burden
max_size = 336
h, w = video.shape[2:]
scale = min(max_size / h, max_size / w)
if scale < 1:
new_h, new_w = int(h * scale), int(w * scale)
video = T.Resize((new_h, new_w))(video)
video_tensor = T.Resize((new_h, new_w))(video_tensor)
point_map = T.Resize((new_h, new_w))(point_map)
track2d_pred[...,:2] = track2d_pred[...,:2] * scale
intrs[:,:2,:] = intrs[:,:2,:] * scale
if depth_tensor is not None:
depth_tensor = T.Resize((new_h, new_w))(depth_tensor)
conf_depth = T.Resize((new_h, new_w))(conf_depth)
# Visualize tracks
viser.visualize(video=video[None],
tracks=track2d_pred[None][...,:2],
visibility=vis_pred[None],filename="test")
# Save in tapip3d format
data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
data_npz_load["intrinsics"] = intrs.cpu().numpy()
data_npz_load["depths"] = point_map[:,2,...].cpu().numpy()
data_npz_load["video"] = (video_tensor).cpu().numpy()/255
data_npz_load["visibs"] = vis_pred.cpu().numpy()
data_npz_load["confs"] = conf_pred.cpu().numpy()
data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]")