hma / common /plot /plot_from_wandb_singledataset.py
LeroyWaa's picture
draft
246c106
raw
history blame
5.27 kB
import wandb
import pandas as pd
import matplotlib.pyplot as plt
import sys
import argparse
"""
Running plotting scripts over key metrics and key runs
export MODEL=40dataset_waction_add_gpu_8_nodes_1
python common/plot/plot_from_wandb.py --field teacher_force_psnr --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_psnr_delta --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_ssim --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_pred_lpips --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_loss --run_id $MODEL
"""
# Initialize the wandb API client
api = wandb.Api()
# Replace with your specific project and entity
entity = "latent-mage"
project = "video_val"
# List of datasets to process
datasets = [
"bridge_data_v2",
"fractal20220817_data",
"language_table",
"ucsd_pick_and_place_dataset_converted_externally_to_rlds",
"kaist_nonprehensile_converted_externally_to_rlds",
"ucsd_kitchen_dataset_converted_externally_to_rlds",
"utokyo_xarm_bimanual_converted_externally_to_rlds",
"stanford_hydra_dataset_converted_externally_to_rlds",
"austin_sirius_dataset_converted_externally_to_rlds",
"berkeley_fanuc_manipulation",
"berkeley_mvp_converted_externally_to_rlds",
"berkeley_rpt_converted_externally_to_rlds",
"cmu_play_fusion",
"iamlab_cmu_pickup_insert_converted_externally_to_rlds",
"qut_dexterous_manpulation",
"robo_net",
"furniture_bench_dataset_converted_externally_to_rlds",
"dlr_sara_grid_clamp_converted_externally_to_rlds",
"cmu_stretch",
"spoc",
"columbia_cairlab_pusht_real",
"droid",
"toto",
"io_ai_tech",
"conq_hose_manipulation",
"dobbe",
"berkeley_gnm_cory_hall",
"plex_robosuite",
"usc_cloth_sim_converted_externally_to_rlds",
"berkeley_cable_routing",
"imperial_wrist_dataset",
"bc_z",
"kuka",
"roboturk",
"metaworld",
"robomimic",
"epic_kitchen",
"ego4d",
"nyu_door_opening_surprising_effectiveness"
]
# List to store dataframes of PSNR metrics for each dataset
# Get runs based on a path
# Set up argument parser
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--field', type=str, default='teacher_force_psnr', help='The field to process')
parser.add_argument('--run_id', type=str, default='40dataset_waction_add_gpu_8_nodes_1', help='The run ID to process')
# Parse arguments
args = parser.parse_args()
field = args.field
run_id = args.run_id
runs_path = f"{entity}/{project}/runs"
run = api.run(f"{entity}/{project}/runs/{run_id}")
# Get the history dataframe of a run
history = run.history(pandas=True)
model_step = 0
summary_metrics = run.summary
num_datasets = 0
fields = ['num_examples', 'teacher_force_psnr', 'teacher_force_psnr_delta', 'teacher_force_ssim', 'teacher_force_pred_lpips', 'teacher_force_loss']
for field in fields:
metrics_data = []
if not history.empty:
# Filter the history to only include PSNR metrics for the specified datasets
for dataset in datasets:
field_col = f"{dataset}/{field}"
step_col = f"{dataset}/model_step"
if field_col in history.columns:
# Calculate PSNR divided by the number of examples (uncomment if needed)
# history[field_col] = history[field_col] / history.shape[0]
valid_field = history[field_col].dropna()
if not valid_field.empty:
last_valid_value = valid_field.iloc[-1] # Get the last non-NaN value
num_datasets += 1
metrics = pd.DataFrame({field_col: [last_valid_value]})
metrics['dataset'] = dataset
metrics_data.append(metrics)
if step_col in summary_metrics:
model_step = summary_metrics[step_col]
# Combine all the metric dataframes into one
if metrics_data:
all_metrics_df = pd.concat(metrics_data, ignore_index=True)
# Print columns for debugging
# Compute aggregated statistics (mean, median, std, etc.) for PSNR
aggregated_stats = all_metrics_df.groupby('dataset').mean()
# Plot the mean PSNR for each dataset
plt.figure(figsize=(12, 8))
aggregated_stats[f'{field}'] = aggregated_stats.mean(axis=1)
aggregated_stats[f'{field}'].plot(kind='bar')
# print number of steps in the wandb run
print(f"run: {run_id} field: {field} steps: {model_step} num of dataset: {len(metrics_data)}")
print(f"{field}: {aggregated_stats[field].mean():.2f}+-{aggregated_stats[field].std():.2f}", )
# plt.title(f"Mean {field} for Each Dataset")
# plt.xlabel("Dataset")
# plt.ylabel(f"Mean {field} ")
# plt.xticks(rotation=90)
# plt.tight_layout()
# # Save the plot
# import os
# pwd = os.path.dirname(os.path.abspath(__file__))
# plt.savefig(f"{pwd}/output/{run.id}_{field}_plot.png")
# Display aggregated statistics
# print(aggregated_stats)
# Save the aggregated statistics as a CSV if needed
# aggregated_stats.to_csv(f"{run_id}_{field}_stat.csv", index=True)