import wandb import pandas as pd import matplotlib.pyplot as plt import sys import argparse import os """ Running plotting scripts over key metrics and key runs export MODEL=final2_40dataset_waction_concat_gpu_8_nodes_1 python common/plot/plot_from_wandb.py --run_id $MODEL python common/plot/plot_from_wandb.py --run_id final2_40dataset_noaction_gpu_8_nodes_1_step15k_v5 python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_modulate_gpu_8_nodes_1_step15k_v5 python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_attn_gpu_8_nodes_1_step15k_v5 python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_add_gpu_8_nodes_1_step15k_v5 python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_d64_gpu_8_nodes_1_step15k_v5 python common/plot/plot_from_wandb.py --run_id final2_40dataset_forward_dynamics_gpu_8_nodes_1_step15k_v5 python common/plot/plot_from_wandb.py --run_id final2_40dataset_full_dynamics_gpu_8_nodes_1_step15k_v5 python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_traj100000_gpu_8_nodes_1_68536steps_step15k_v5 python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_traj10000_gpu_8_nodes_1_68536steps_step15k_v5 python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_traj100_gpu_8_nodes_1_68536steps_step15k_v5 python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_traj1000_gpu_8_nodes_1_68536steps_step15k_v5 python common/plot/plot_from_wandb.py --run_id final2_5dataset_waction_gpu_8_nodes_1_step24k_v5 python common/plot/plot_from_wandb.py --run_id final2_30dataset_waction_gpu_8_nodes_1_step24k_v5 python common/plot/plot_from_wandb.py --run_id final2_5dataset_waction_gpu_8_nodes_1_step24k_v5 python common/plot/plot_from_wandb.py --run_id final2_10dataset_waction_gpu_8_nodes_1_step24k_v5 """ # Initialize the wandb API client api = wandb.Api() pwd = os.path.dirname(os.path.abspath(__file__)) # Replace with your specific project and entity entity = "latent-mage" project = "video_val" # List of datasets to process datasets = [ "bridge_data_v2", "fractal20220817_data", "language_table", "ucsd_pick_and_place_dataset_converted_externally_to_rlds", "kaist_nonprehensile_converted_externally_to_rlds", "ucsd_kitchen_dataset_converted_externally_to_rlds", "utokyo_xarm_bimanual_converted_externally_to_rlds", "stanford_hydra_dataset_converted_externally_to_rlds", "austin_sirius_dataset_converted_externally_to_rlds", "berkeley_fanuc_manipulation", "berkeley_mvp_converted_externally_to_rlds", "berkeley_rpt_converted_externally_to_rlds", "cmu_play_fusion", "iamlab_cmu_pickup_insert_converted_externally_to_rlds", "qut_dexterous_manpulation", "robo_net", "furniture_bench_dataset_converted_externally_to_rlds", "dlr_sara_grid_clamp_converted_externally_to_rlds", "cmu_stretch", "spoc", "columbia_cairlab_pusht_real", "droid", "toto", "io_ai_tech", "conq_hose_manipulation", "dobbe", "berkeley_gnm_cory_hall", "plex_robosuite", "usc_cloth_sim_converted_externally_to_rlds", "berkeley_cable_routing", "imperial_wrist_dataset", "bc_z", "kuka", "roboturk", "metaworld", "robomimic", "epic_kitchen", "ego4d", "nyu_door_opening_surprising_effectiveness" ] def normalize_dataset(metric, runs): """ Figure out best and worst values for a metric across all runs and use it for normalization """ pass # List to store dataframes of PSNR metrics for each dataset metrics_data = [] # Get runs based on a path # Set up argument parser parser = argparse.ArgumentParser(description='Process some integers.') parser.add_argument('--run_id', type=str, default='40dataset_waction_add_gpu_8_nodes_1', help='The run ID to process') # Parse arguments args = parser.parse_args() fields = ['num_examples', 'teacher_force_psnr', 'teacher_force_psnr_delta', 'teacher_force_ssim', 'teacher_force_pred_lpips', 'teacher_force_loss'] num_fields = len(fields) run_id = args.run_id runs_path = f"{entity}/{project}/runs" run = api.run(f"{entity}/{project}/runs/{run_id}") # Get the history dataframe of a run history = run.history(pandas=True) model_step = 0 summary_metrics = run.summary num_datasets = 0 # output the field into csv # csv_output = f"{pwd}/aggregated_output.csv" csv_output = f"aggregated_output.csv" # initialize the csv file if not os.path.exists(csv_output): with open(csv_output, 'w') as f: field_str = f"name," for dataset in datasets: for field in fields: field_str += f"{dataset}/{field}," f.write(field_str.rstrip(",") + "\n") results = [run_id] + [None] * len(datasets) * num_fields for field_idx, field in enumerate(fields): if not history.empty: # Filter the history to only include PSNR metrics for the specified datasets for dataset_idx, dataset in enumerate(datasets): field_col = f"{dataset}/{field}" col_idx = dataset_idx * num_fields + field_idx + 1 if field == "num_examples": if f"{dataset}/num_examples" in summary_metrics: results[col_idx] = summary_metrics[f"{dataset}/num_examples"] continue if field_col in history.columns: # Calculate PSNR divided by the number of examples (uncomment if needed) # history[field_col] = history[field_col] / history.shape[0] valid_field = history[field_col].dropna() if not valid_field.empty: last_valid_value = valid_field.iloc[-1] # Get the last non-NaN value num_datasets += 1 metrics = pd.DataFrame({field_col: [last_valid_value]}) metrics['dataset'] = dataset results[col_idx] = last_valid_value metrics_data.append(metrics) else: pass # print("missing dataset:", dataset) if f"{dataset}/model_step" in summary_metrics: model_step = summary_metrics[f"{dataset}/model_step"] # Combine all the metric dataframes into one if metrics_data: all_metrics_df = pd.concat(metrics_data, ignore_index=True) # # Compute aggregated statistics (mean, median, std, etc.) for PSNR # aggregated_stats = all_metrics_df.groupby('dataset').mean() # # # Plot the mean PSNR for each dataset # plt.figure(figsize=(12, 8)) # aggregated_stats[f'{field}'] = aggregated_stats.mean(axis=1) # aggregated_stats[f'{field}'].plot(kind='bar') # # print number of steps in the wandb run # print(f"run: {run_id} field: {field} steps: {model_step} num of dataset: {len(metrics_data)}") # print(f"{field}: {aggregated_stats[field].mean():.2f}+-{aggregated_stats[field].std():.2f}", ) # # plt.title(f"Mean {field} for Each Dataset") # plt.xlabel("Dataset") # plt.ylabel(f"Mean {field} ") # plt.xticks(rotation=90) # plt.tight_layout() # # # Save the plot # plt.savefig(f"{pwd}/output/{run.id}_{field}_plot.png") # write the results into csv with open(csv_output, 'a+') as f: f.write(",".join([str(x) for x in results]) + "\n") # Display aggregated statistics # print(aggregated_stats) # Save the aggregated statistics as a CSV if needed # aggregated_stats.to_csv(f"{run_id}_{field}_stat.csv", index=True)