hma / common /plot /plot_from_wandb_singledataset.py
LeroyWaa's picture
draft
246c106
import wandb
import pandas as pd
import matplotlib.pyplot as plt
import sys
import argparse
"""
Running plotting scripts over key metrics and key runs
export MODEL=40dataset_waction_add_gpu_8_nodes_1
python common/plot/plot_from_wandb.py --field teacher_force_psnr --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_psnr_delta --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_ssim --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_pred_lpips --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_loss --run_id $MODEL
"""
# Initialize the wandb API client
api = wandb.Api()
# Replace with your specific project and entity
entity = "latent-mage"
project = "video_val"
# List of datasets to process
datasets = [
"bridge_data_v2",
"fractal20220817_data",
"language_table",
"ucsd_pick_and_place_dataset_converted_externally_to_rlds",
"kaist_nonprehensile_converted_externally_to_rlds",
"ucsd_kitchen_dataset_converted_externally_to_rlds",
"utokyo_xarm_bimanual_converted_externally_to_rlds",
"stanford_hydra_dataset_converted_externally_to_rlds",
"austin_sirius_dataset_converted_externally_to_rlds",
"berkeley_fanuc_manipulation",
"berkeley_mvp_converted_externally_to_rlds",
"berkeley_rpt_converted_externally_to_rlds",
"cmu_play_fusion",
"iamlab_cmu_pickup_insert_converted_externally_to_rlds",
"qut_dexterous_manpulation",
"robo_net",
"furniture_bench_dataset_converted_externally_to_rlds",
"dlr_sara_grid_clamp_converted_externally_to_rlds",
"cmu_stretch",
"spoc",
"columbia_cairlab_pusht_real",
"droid",
"toto",
"io_ai_tech",
"conq_hose_manipulation",
"dobbe",
"berkeley_gnm_cory_hall",
"plex_robosuite",
"usc_cloth_sim_converted_externally_to_rlds",
"berkeley_cable_routing",
"imperial_wrist_dataset",
"bc_z",
"kuka",
"roboturk",
"metaworld",
"robomimic",
"epic_kitchen",
"ego4d",
"nyu_door_opening_surprising_effectiveness"
]
# List to store dataframes of PSNR metrics for each dataset
# Get runs based on a path
# Set up argument parser
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--field', type=str, default='teacher_force_psnr', help='The field to process')
parser.add_argument('--run_id', type=str, default='40dataset_waction_add_gpu_8_nodes_1', help='The run ID to process')
# Parse arguments
args = parser.parse_args()
field = args.field
run_id = args.run_id
runs_path = f"{entity}/{project}/runs"
run = api.run(f"{entity}/{project}/runs/{run_id}")
# Get the history dataframe of a run
history = run.history(pandas=True)
model_step = 0
summary_metrics = run.summary
num_datasets = 0
fields = ['num_examples', 'teacher_force_psnr', 'teacher_force_psnr_delta', 'teacher_force_ssim', 'teacher_force_pred_lpips', 'teacher_force_loss']
for field in fields:
metrics_data = []
if not history.empty:
# Filter the history to only include PSNR metrics for the specified datasets
for dataset in datasets:
field_col = f"{dataset}/{field}"
step_col = f"{dataset}/model_step"
if field_col in history.columns:
# Calculate PSNR divided by the number of examples (uncomment if needed)
# history[field_col] = history[field_col] / history.shape[0]
valid_field = history[field_col].dropna()
if not valid_field.empty:
last_valid_value = valid_field.iloc[-1] # Get the last non-NaN value
num_datasets += 1
metrics = pd.DataFrame({field_col: [last_valid_value]})
metrics['dataset'] = dataset
metrics_data.append(metrics)
if step_col in summary_metrics:
model_step = summary_metrics[step_col]
# Combine all the metric dataframes into one
if metrics_data:
all_metrics_df = pd.concat(metrics_data, ignore_index=True)
# Print columns for debugging
# Compute aggregated statistics (mean, median, std, etc.) for PSNR
aggregated_stats = all_metrics_df.groupby('dataset').mean()
# Plot the mean PSNR for each dataset
plt.figure(figsize=(12, 8))
aggregated_stats[f'{field}'] = aggregated_stats.mean(axis=1)
aggregated_stats[f'{field}'].plot(kind='bar')
# print number of steps in the wandb run
print(f"run: {run_id} field: {field} steps: {model_step} num of dataset: {len(metrics_data)}")
print(f"{field}: {aggregated_stats[field].mean():.2f}+-{aggregated_stats[field].std():.2f}", )
# plt.title(f"Mean {field} for Each Dataset")
# plt.xlabel("Dataset")
# plt.ylabel(f"Mean {field} ")
# plt.xticks(rotation=90)
# plt.tight_layout()
# # Save the plot
# import os
# pwd = os.path.dirname(os.path.abspath(__file__))
# plt.savefig(f"{pwd}/output/{run.id}_{field}_plot.png")
# Display aggregated statistics
# print(aggregated_stats)
# Save the aggregated statistics as a CSV if needed
# aggregated_stats.to_csv(f"{run_id}_{field}_stat.csv", index=True)