File size: 5,269 Bytes
246c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import wandb
import pandas as pd
import matplotlib.pyplot as plt
import sys
import argparse

"""
Running plotting scripts over key metrics and key runs
export MODEL=40dataset_waction_add_gpu_8_nodes_1
python common/plot/plot_from_wandb.py --field teacher_force_psnr --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_psnr_delta --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_ssim --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_pred_lpips --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_loss --run_id $MODEL

"""
# Initialize the wandb API client
api = wandb.Api()

# Replace with your specific project and entity
entity = "latent-mage"
project = "video_val"

# List of datasets to process
datasets = [
    "bridge_data_v2",
    "fractal20220817_data",
    "language_table",
    "ucsd_pick_and_place_dataset_converted_externally_to_rlds",
    "kaist_nonprehensile_converted_externally_to_rlds",
    "ucsd_kitchen_dataset_converted_externally_to_rlds",
    "utokyo_xarm_bimanual_converted_externally_to_rlds",
    "stanford_hydra_dataset_converted_externally_to_rlds",
    "austin_sirius_dataset_converted_externally_to_rlds",
    "berkeley_fanuc_manipulation",
    "berkeley_mvp_converted_externally_to_rlds",
    "berkeley_rpt_converted_externally_to_rlds",
    "cmu_play_fusion",
    "iamlab_cmu_pickup_insert_converted_externally_to_rlds",
    "qut_dexterous_manpulation",
    "robo_net",
    "furniture_bench_dataset_converted_externally_to_rlds",
    "dlr_sara_grid_clamp_converted_externally_to_rlds",
    "cmu_stretch",
    "spoc",
    "columbia_cairlab_pusht_real",
    "droid",
    "toto",
    "io_ai_tech",
    "conq_hose_manipulation",
    "dobbe",
    "berkeley_gnm_cory_hall",
    "plex_robosuite",
    "usc_cloth_sim_converted_externally_to_rlds",
    "berkeley_cable_routing",
    "imperial_wrist_dataset",
    "bc_z",
    "kuka",
    "roboturk",
    "metaworld",
    "robomimic",
    "epic_kitchen",
    "ego4d",
    "nyu_door_opening_surprising_effectiveness"
]

# List to store dataframes of PSNR metrics for each dataset

# Get runs based on a path
# Set up argument parser
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--field', type=str, default='teacher_force_psnr', help='The field to process')
parser.add_argument('--run_id', type=str, default='40dataset_waction_add_gpu_8_nodes_1', help='The run ID to process')

# Parse arguments
args = parser.parse_args()

field = args.field
run_id = args.run_id

runs_path = f"{entity}/{project}/runs"
run = api.run(f"{entity}/{project}/runs/{run_id}")

# Get the history dataframe of a run
history = run.history(pandas=True)
model_step = 0
summary_metrics = run.summary
num_datasets = 0
fields = ['num_examples', 'teacher_force_psnr', 'teacher_force_psnr_delta', 'teacher_force_ssim', 'teacher_force_pred_lpips', 'teacher_force_loss']

for field in fields:
    metrics_data = []
    if not history.empty:
        # Filter the history to only include PSNR metrics for the specified datasets
        for dataset in datasets:
            field_col = f"{dataset}/{field}"
            step_col = f"{dataset}/model_step"
            if field_col in history.columns:
                # Calculate PSNR divided by the number of examples (uncomment if needed)
                # history[field_col] = history[field_col] / history.shape[0]
                valid_field = history[field_col].dropna()
                if not valid_field.empty:
                    last_valid_value = valid_field.iloc[-1]  # Get the last non-NaN value
                    num_datasets += 1
                    metrics = pd.DataFrame({field_col: [last_valid_value]})
                    metrics['dataset'] = dataset
                    metrics_data.append(metrics)

            if step_col in summary_metrics:
                model_step = summary_metrics[step_col]

    # Combine all the metric dataframes into one
    if metrics_data:
        all_metrics_df = pd.concat(metrics_data, ignore_index=True)

        # Print columns for debugging

        # Compute aggregated statistics (mean, median, std, etc.) for PSNR
        aggregated_stats = all_metrics_df.groupby('dataset').mean()

        # Plot the mean PSNR for each dataset
        plt.figure(figsize=(12, 8))
        aggregated_stats[f'{field}'] = aggregated_stats.mean(axis=1)
        aggregated_stats[f'{field}'].plot(kind='bar')
        # print number of steps in the wandb run
        print(f"run: {run_id} field: {field} steps: {model_step} num of dataset: {len(metrics_data)}")
        print(f"{field}: {aggregated_stats[field].mean():.2f}+-{aggregated_stats[field].std():.2f}", )

        # plt.title(f"Mean {field} for Each Dataset")
        # plt.xlabel("Dataset")
        # plt.ylabel(f"Mean {field} ")
        # plt.xticks(rotation=90)
        # plt.tight_layout()

        # # Save the plot
        # import os
        # pwd = os.path.dirname(os.path.abspath(__file__))
        # plt.savefig(f"{pwd}/output/{run.id}_{field}_plot.png")

# Display aggregated statistics
# print(aggregated_stats)

# Save the aggregated statistics as a CSV if needed
# aggregated_stats.to_csv(f"{run_id}_{field}_stat.csv", index=True)