|
""" |
|
Evaluate P-FID between two batches of point clouds. |
|
|
|
The point cloud batches should be saved to two npz files, where there |
|
is an arr_0 key of shape [N x K x 3], where K is the dimensionality of |
|
each point cloud and N is the number of clouds. |
|
""" |
|
|
|
import argparse |
|
|
|
from point_e.evals.feature_extractor import PointNetClassifier, get_torch_devices |
|
from point_e.evals.fid_is import compute_statistics |
|
from point_e.evals.npz_stream import NpzStreamer |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--cache_dir", type=str, default=None) |
|
parser.add_argument("batch_1", type=str) |
|
parser.add_argument("batch_2", type=str) |
|
args = parser.parse_args() |
|
|
|
print("creating classifier...") |
|
clf = PointNetClassifier(devices=get_torch_devices(), cache_dir=args.cache_dir) |
|
|
|
print("computing first batch activations") |
|
|
|
features_1, _ = clf.features_and_preds(NpzStreamer(args.batch_1)) |
|
stats_1 = compute_statistics(features_1) |
|
del features_1 |
|
|
|
features_2, _ = clf.features_and_preds(NpzStreamer(args.batch_2)) |
|
stats_2 = compute_statistics(features_2) |
|
del features_2 |
|
|
|
print(f"P-FID: {stats_1.frechet_distance(stats_2)}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|