|
import torch |
|
from pytorch3d.renderer import PerspectiveCameras |
|
|
|
import sys |
|
sys.path.append('./') |
|
from sparseags.cam_utils import normalize_cameras_with_up_axis |
|
|
|
sys.path[0] = sys.path[0] + '/dust3r' |
|
from dust3r.inference import inference |
|
from dust3r.utils.image import load_images |
|
from dust3r.image_pairs import make_pairs |
|
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode |
|
|
|
|
|
def infer_dust3r(dust3r_model, file_names, device='cuda'): |
|
batch_size = 1 |
|
schedule = 'cosine' |
|
lr = 0.01 |
|
niter = 300 |
|
|
|
images = load_images(file_names, size=224) |
|
pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True) |
|
output = inference(pairs, dust3r_model, device, batch_size=batch_size) |
|
|
|
scene = global_aligner(output, optimize_pp=True, device=device, mode=GlobalAlignerMode.PointCloudOptimizer) |
|
loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr) |
|
|
|
|
|
imgs = scene.imgs |
|
cams2world = scene.get_im_poses() |
|
w2c = torch.linalg.inv(cams2world) |
|
pps = scene.get_principal_points() * 256 / 224 |
|
focals = scene.get_focals() * 256 / 224 |
|
|
|
w2c[:, :2] *= -1 |
|
Rs = w2c[:, :3, :3].transpose(1, 2) |
|
Ts = w2c[:, :3, 3] |
|
|
|
cameras = PerspectiveCameras( |
|
focal_length=focals, |
|
principal_point=pps, |
|
in_ndc=False, |
|
R=Rs, |
|
T=Ts, |
|
) |
|
normalized_cameras, _, _, _, _, needs_checking = normalize_cameras_with_up_axis(cameras, None, in_ndc=False) |
|
|
|
if normalized_cameras is None: |
|
print("It seems something wrong...") |
|
return 0 |
|
|
|
data = {} |
|
base_names = [file_name.split('/')[-1].split('.')[0] for file_name in file_names] |
|
file_names = [file_name.replace('source', 'processed').replace('.png', '_rgba.png') for file_name in file_names] |
|
|
|
for idx, base_name in enumerate(base_names): |
|
data[base_name] = {} |
|
data[base_name]["R"] = normalized_cameras.R[idx].cpu().tolist() |
|
data[base_name]["T"] = normalized_cameras.T[idx].cpu().tolist() |
|
data[base_name]["needs_checking"] = needs_checking |
|
data[base_name]["principal_point"] = normalized_cameras.principal_point[idx].cpu().tolist() |
|
data[base_name]["focal_length"] = normalized_cameras.focal_length[idx].cpu().tolist() |
|
data[base_name]["flag"] = 1 |
|
data[base_name]["filepath"] = file_names[idx] |
|
|
|
return data |