Spaces:
Runtime error
Runtime error
File size: 4,282 Bytes
1bb1365 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import argparse
import json
import os
import numpy as np
from PIL import Image
try:
from sklearn.cluster import KMeans # type: ignore[import]
except ImportError:
print("Please install sklearn to use this script.")
exit(1)
# Define the folder containing the image and JSON files
subfolder = "/path/to/your/dataset"
output_file = os.path.join(subfolder, "transforms.json")
# List to hold the frames
frames = []
# Iterate over the files in the folder
for file in sorted(os.listdir(subfolder)):
if file.endswith(".json"):
# Read the JSON file containing camera extrinsics and intrinsics
json_path = os.path.join(subfolder, file)
with open(json_path, "r") as f:
data = json.load(f)
# Read the corresponding image file
image_file = file.replace(".json", ".png")
image_path = os.path.join(subfolder, image_file)
if not os.path.exists(image_path):
print(f"Image file not found for {file}, skipping...")
continue
with Image.open(image_path) as img:
w, h = img.size
# Extract and normalize intrinsic matrix K
K = data["K"]
fx = K[0][0] * w
fy = K[1][1] * h
cx = K[0][2] * w
cy = K[1][2] * h
# Extract the transformation matrix
transform_matrix = np.array(data["c2w"])
# Adjust for OpenGL convention
transform_matrix[..., [1, 2]] *= -1
# Add the frame data
frames.append(
{
"fl_x": fx,
"fl_y": fy,
"cx": cx,
"cy": cy,
"w": w,
"h": h,
"file_path": f"./{os.path.relpath(image_path, subfolder)}",
"transform_matrix": transform_matrix.tolist(),
}
)
# Create the output dictionary
transforms_data = {"orientation_override": "none", "frames": frames}
# Write to the transforms.json file
with open(output_file, "w") as f:
json.dump(transforms_data, f, indent=4)
print(f"transforms.json generated at {output_file}")
# Train-test split function using K-means clustering with stride
def create_train_test_split(frames, n, output_path, stride):
# Prepare the data for K-means
positions = []
for frame in frames:
transform_matrix = np.array(frame["transform_matrix"])
position = transform_matrix[:3, 3] # 3D camera position
direction = transform_matrix[:3, 2] / np.linalg.norm(
transform_matrix[:3, 2]
) # Normalized 3D direction
positions.append(np.concatenate([position, direction]))
positions = np.array(positions)
# Apply K-means clustering
kmeans = KMeans(n_clusters=n, random_state=42)
kmeans.fit(positions)
centers = kmeans.cluster_centers_
# Find the index closest to each cluster center
train_ids = []
for center in centers:
distances = np.linalg.norm(positions - center, axis=1)
train_ids.append(int(np.argmin(distances))) # Convert to Python int
# Remaining indices as test_ids, applying stride
all_indices = set(range(len(frames)))
remaining_indices = sorted(all_indices - set(train_ids))
test_ids = [
int(idx) for idx in remaining_indices[::stride]
] # Convert to Python int
# Create the split data
split_data = {"train_ids": sorted(train_ids), "test_ids": test_ids}
with open(output_path, "w") as f:
json.dump(split_data, f, indent=4)
print(f"Train-test split file generated at {output_path}")
# Parse arguments
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate train-test split JSON file using K-means clustering."
)
parser.add_argument(
"--n",
type=int,
required=True,
help="Number of frames to include in the training set.",
)
parser.add_argument(
"--stride",
type=int,
default=1,
help="Stride for selecting test frames (not used with K-means).",
)
args = parser.parse_args()
# Create train-test split
train_test_split_path = os.path.join(subfolder, f"train_test_split_{args.n}.json")
create_train_test_split(frames, args.n, train_test_split_path, args.stride)
|