File size: 7,589 Bytes
246c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import wandb
import pandas as pd
import matplotlib.pyplot as plt
import sys
import argparse
import os

"""
Running plotting scripts over key metrics and key runs
export MODEL=final2_40dataset_waction_concat_gpu_8_nodes_1
python common/plot/plot_from_wandb.py  --run_id $MODEL
python common/plot/plot_from_wandb.py  --run_id final2_40dataset_noaction_gpu_8_nodes_1_step15k_v5
python common/plot/plot_from_wandb.py  --run_id final2_40dataset_waction_modulate_gpu_8_nodes_1_step15k_v5
python common/plot/plot_from_wandb.py  --run_id final2_40dataset_waction_attn_gpu_8_nodes_1_step15k_v5
python common/plot/plot_from_wandb.py  --run_id final2_40dataset_waction_add_gpu_8_nodes_1_step15k_v5
python common/plot/plot_from_wandb.py  --run_id final2_40dataset_waction_d64_gpu_8_nodes_1_step15k_v5
python common/plot/plot_from_wandb.py  --run_id final2_40dataset_forward_dynamics_gpu_8_nodes_1_step15k_v5
python common/plot/plot_from_wandb.py  --run_id final2_40dataset_full_dynamics_gpu_8_nodes_1_step15k_v5
python common/plot/plot_from_wandb.py  --run_id final2_40dataset_waction_traj100000_gpu_8_nodes_1_68536steps_step15k_v5
python common/plot/plot_from_wandb.py  --run_id final2_40dataset_waction_traj10000_gpu_8_nodes_1_68536steps_step15k_v5
python common/plot/plot_from_wandb.py  --run_id final2_40dataset_waction_traj100_gpu_8_nodes_1_68536steps_step15k_v5
python common/plot/plot_from_wandb.py  --run_id final2_40dataset_waction_traj1000_gpu_8_nodes_1_68536steps_step15k_v5
python common/plot/plot_from_wandb.py  --run_id final2_5dataset_waction_gpu_8_nodes_1_step24k_v5
python common/plot/plot_from_wandb.py  --run_id final2_30dataset_waction_gpu_8_nodes_1_step24k_v5
python common/plot/plot_from_wandb.py  --run_id final2_5dataset_waction_gpu_8_nodes_1_step24k_v5
python common/plot/plot_from_wandb.py  --run_id final2_10dataset_waction_gpu_8_nodes_1_step24k_v5

"""
# Initialize the wandb API client
api = wandb.Api()
pwd = os.path.dirname(os.path.abspath(__file__))

# Replace with your specific project and entity
entity = "latent-mage"
project = "video_val"

# List of datasets to process
datasets = [
    "bridge_data_v2",
    "fractal20220817_data",
    "language_table",
    "ucsd_pick_and_place_dataset_converted_externally_to_rlds",
    "kaist_nonprehensile_converted_externally_to_rlds",
    "ucsd_kitchen_dataset_converted_externally_to_rlds",
    "utokyo_xarm_bimanual_converted_externally_to_rlds",
    "stanford_hydra_dataset_converted_externally_to_rlds",
    "austin_sirius_dataset_converted_externally_to_rlds",
    "berkeley_fanuc_manipulation",
    "berkeley_mvp_converted_externally_to_rlds",
    "berkeley_rpt_converted_externally_to_rlds",
    "cmu_play_fusion",
    "iamlab_cmu_pickup_insert_converted_externally_to_rlds",
    "qut_dexterous_manpulation",
    "robo_net",
    "furniture_bench_dataset_converted_externally_to_rlds",
    "dlr_sara_grid_clamp_converted_externally_to_rlds",
    "cmu_stretch",
    "spoc",
    "columbia_cairlab_pusht_real",
    "droid",
    "toto",
    "io_ai_tech",
    "conq_hose_manipulation",
    "dobbe",
    "berkeley_gnm_cory_hall",
    "plex_robosuite",
    "usc_cloth_sim_converted_externally_to_rlds",
    "berkeley_cable_routing",
    "imperial_wrist_dataset",
    "bc_z",
    "kuka",
    "roboturk",
    "metaworld",
    "robomimic",
    "epic_kitchen",
    "ego4d",
    "nyu_door_opening_surprising_effectiveness"
]

def normalize_dataset(metric, runs):
    """
    Figure out best and worst values for a metric across all runs
    and use it for normalization
    """
    pass

# List to store dataframes of PSNR metrics for each dataset
metrics_data = []
# Get runs based on a path
# Set up argument parser
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--run_id', type=str, default='40dataset_waction_add_gpu_8_nodes_1', help='The run ID to process')

# Parse arguments
args = parser.parse_args()

fields = ['num_examples', 'teacher_force_psnr', 'teacher_force_psnr_delta', 'teacher_force_ssim', 'teacher_force_pred_lpips', 'teacher_force_loss']
num_fields = len(fields)
run_id = args.run_id

runs_path = f"{entity}/{project}/runs"
run = api.run(f"{entity}/{project}/runs/{run_id}")

# Get the history dataframe of a run
history = run.history(pandas=True)
model_step = 0
summary_metrics = run.summary
num_datasets = 0

# output the field into csv
# csv_output = f"{pwd}/aggregated_output.csv"
csv_output = f"aggregated_output.csv"

# initialize the csv file
if not os.path.exists(csv_output):
    with open(csv_output, 'w') as f:
        field_str = f"name,"
        for dataset in datasets:
            for field in fields:
                field_str += f"{dataset}/{field},"
        f.write(field_str.rstrip(",") + "\n")

results = [run_id] + [None] * len(datasets) * num_fields
for field_idx, field in enumerate(fields):
    if not history.empty:
        # Filter the history to only include PSNR metrics for the specified datasets
        for dataset_idx, dataset in enumerate(datasets):
            field_col = f"{dataset}/{field}"
            col_idx = dataset_idx * num_fields + field_idx + 1
            if field == "num_examples":
                if f"{dataset}/num_examples" in summary_metrics:
                    results[col_idx] = summary_metrics[f"{dataset}/num_examples"]

                continue
            if field_col in history.columns:
                # Calculate PSNR divided by the number of examples (uncomment if needed)
                # history[field_col] = history[field_col] / history.shape[0]
                valid_field = history[field_col].dropna()
                if not valid_field.empty:
                    last_valid_value = valid_field.iloc[-1]  # Get the last non-NaN value
                    num_datasets += 1
                    metrics = pd.DataFrame({field_col: [last_valid_value]})
                    metrics['dataset'] = dataset
                    results[col_idx] = last_valid_value
                    metrics_data.append(metrics)
            else:
                pass
                # print("missing dataset:", dataset)

            if f"{dataset}/model_step" in summary_metrics:
                model_step = summary_metrics[f"{dataset}/model_step"]

    # Combine all the metric dataframes into one
    if metrics_data:
        all_metrics_df = pd.concat(metrics_data, ignore_index=True)

        # # Compute aggregated statistics (mean, median, std, etc.) for PSNR
        # aggregated_stats = all_metrics_df.groupby('dataset').mean()
        #
        # # Plot the mean PSNR for each dataset
        # plt.figure(figsize=(12, 8))
        # aggregated_stats[f'{field}'] = aggregated_stats.mean(axis=1)
        # aggregated_stats[f'{field}'].plot(kind='bar')
        # # print number of steps in the wandb run
        # print(f"run: {run_id} field: {field} steps: {model_step} num of dataset: {len(metrics_data)}")
        # print(f"{field}: {aggregated_stats[field].mean():.2f}+-{aggregated_stats[field].std():.2f}", )
        #
        # plt.title(f"Mean {field} for Each Dataset")
        # plt.xlabel("Dataset")
        # plt.ylabel(f"Mean {field} ")
        # plt.xticks(rotation=90)
        # plt.tight_layout()
        #
        # # Save the plot
        # plt.savefig(f"{pwd}/output/{run.id}_{field}_plot.png")

# write the results into csv
with open(csv_output, 'a+') as f:
    f.write(",".join([str(x) for x in results]) + "\n")

# Display aggregated statistics
# print(aggregated_stats)

# Save the aggregated statistics as a CSV if needed
# aggregated_stats.to_csv(f"{run_id}_{field}_stat.csv", index=True)