import wandb import pandas as pd import matplotlib.pyplot as plt import sys import argparse """ Running plotting scripts over key metrics and key runs export MODEL=40dataset_waction_add_gpu_8_nodes_1 python common/plot/plot_from_wandb.py --field teacher_force_psnr --run_id $MODEL python common/plot/plot_from_wandb.py --field teacher_force_psnr_delta --run_id $MODEL python common/plot/plot_from_wandb.py --field teacher_force_ssim --run_id $MODEL python common/plot/plot_from_wandb.py --field teacher_force_pred_lpips --run_id $MODEL python common/plot/plot_from_wandb.py --field teacher_force_loss --run_id $MODEL """ # Initialize the wandb API client api = wandb.Api() # Replace with your specific project and entity entity = "latent-mage" project = "video_val" # List of datasets to process datasets = [ "bridge_data_v2", "fractal20220817_data", "language_table", "ucsd_pick_and_place_dataset_converted_externally_to_rlds", "kaist_nonprehensile_converted_externally_to_rlds", "ucsd_kitchen_dataset_converted_externally_to_rlds", "utokyo_xarm_bimanual_converted_externally_to_rlds", "stanford_hydra_dataset_converted_externally_to_rlds", "austin_sirius_dataset_converted_externally_to_rlds", "berkeley_fanuc_manipulation", "berkeley_mvp_converted_externally_to_rlds", "berkeley_rpt_converted_externally_to_rlds", "cmu_play_fusion", "iamlab_cmu_pickup_insert_converted_externally_to_rlds", "qut_dexterous_manpulation", "robo_net", "furniture_bench_dataset_converted_externally_to_rlds", "dlr_sara_grid_clamp_converted_externally_to_rlds", "cmu_stretch", "spoc", "columbia_cairlab_pusht_real", "droid", "toto", "io_ai_tech", "conq_hose_manipulation", "dobbe", "berkeley_gnm_cory_hall", "plex_robosuite", "usc_cloth_sim_converted_externally_to_rlds", "berkeley_cable_routing", "imperial_wrist_dataset", "bc_z", "kuka", "roboturk", "metaworld", "robomimic", "epic_kitchen", "ego4d", "nyu_door_opening_surprising_effectiveness" ] # List to store dataframes of PSNR metrics for each dataset # Get runs based on a path # Set up argument parser parser = argparse.ArgumentParser(description='Process some integers.') parser.add_argument('--field', type=str, default='teacher_force_psnr', help='The field to process') parser.add_argument('--run_id', type=str, default='40dataset_waction_add_gpu_8_nodes_1', help='The run ID to process') # Parse arguments args = parser.parse_args() field = args.field run_id = args.run_id runs_path = f"{entity}/{project}/runs" run = api.run(f"{entity}/{project}/runs/{run_id}") # Get the history dataframe of a run history = run.history(pandas=True) model_step = 0 summary_metrics = run.summary num_datasets = 0 fields = ['num_examples', 'teacher_force_psnr', 'teacher_force_psnr_delta', 'teacher_force_ssim', 'teacher_force_pred_lpips', 'teacher_force_loss'] for field in fields: metrics_data = [] if not history.empty: # Filter the history to only include PSNR metrics for the specified datasets for dataset in datasets: field_col = f"{dataset}/{field}" step_col = f"{dataset}/model_step" if field_col in history.columns: # Calculate PSNR divided by the number of examples (uncomment if needed) # history[field_col] = history[field_col] / history.shape[0] valid_field = history[field_col].dropna() if not valid_field.empty: last_valid_value = valid_field.iloc[-1] # Get the last non-NaN value num_datasets += 1 metrics = pd.DataFrame({field_col: [last_valid_value]}) metrics['dataset'] = dataset metrics_data.append(metrics) if step_col in summary_metrics: model_step = summary_metrics[step_col] # Combine all the metric dataframes into one if metrics_data: all_metrics_df = pd.concat(metrics_data, ignore_index=True) # Print columns for debugging # Compute aggregated statistics (mean, median, std, etc.) for PSNR aggregated_stats = all_metrics_df.groupby('dataset').mean() # Plot the mean PSNR for each dataset plt.figure(figsize=(12, 8)) aggregated_stats[f'{field}'] = aggregated_stats.mean(axis=1) aggregated_stats[f'{field}'].plot(kind='bar') # print number of steps in the wandb run print(f"run: {run_id} field: {field} steps: {model_step} num of dataset: {len(metrics_data)}") print(f"{field}: {aggregated_stats[field].mean():.2f}+-{aggregated_stats[field].std():.2f}", ) # plt.title(f"Mean {field} for Each Dataset") # plt.xlabel("Dataset") # plt.ylabel(f"Mean {field} ") # plt.xticks(rotation=90) # plt.tight_layout() # # Save the plot # import os # pwd = os.path.dirname(os.path.abspath(__file__)) # plt.savefig(f"{pwd}/output/{run.id}_{field}_plot.png") # Display aggregated statistics # print(aggregated_stats) # Save the aggregated statistics as a CSV if needed # aggregated_stats.to_csv(f"{run_id}_{field}_stat.csv", index=True)