hma / common /plot /plot_from_wandb.py
LeroyWaa's picture
draft
246c106
raw
history blame
7.59 kB
import wandb
import pandas as pd
import matplotlib.pyplot as plt
import sys
import argparse
import os
"""
Running plotting scripts over key metrics and key runs
export MODEL=final2_40dataset_waction_concat_gpu_8_nodes_1
python common/plot/plot_from_wandb.py --run_id $MODEL
python common/plot/plot_from_wandb.py --run_id final2_40dataset_noaction_gpu_8_nodes_1_step15k_v5
python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_modulate_gpu_8_nodes_1_step15k_v5
python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_attn_gpu_8_nodes_1_step15k_v5
python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_add_gpu_8_nodes_1_step15k_v5
python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_d64_gpu_8_nodes_1_step15k_v5
python common/plot/plot_from_wandb.py --run_id final2_40dataset_forward_dynamics_gpu_8_nodes_1_step15k_v5
python common/plot/plot_from_wandb.py --run_id final2_40dataset_full_dynamics_gpu_8_nodes_1_step15k_v5
python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_traj100000_gpu_8_nodes_1_68536steps_step15k_v5
python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_traj10000_gpu_8_nodes_1_68536steps_step15k_v5
python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_traj100_gpu_8_nodes_1_68536steps_step15k_v5
python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_traj1000_gpu_8_nodes_1_68536steps_step15k_v5
python common/plot/plot_from_wandb.py --run_id final2_5dataset_waction_gpu_8_nodes_1_step24k_v5
python common/plot/plot_from_wandb.py --run_id final2_30dataset_waction_gpu_8_nodes_1_step24k_v5
python common/plot/plot_from_wandb.py --run_id final2_5dataset_waction_gpu_8_nodes_1_step24k_v5
python common/plot/plot_from_wandb.py --run_id final2_10dataset_waction_gpu_8_nodes_1_step24k_v5
"""
# Initialize the wandb API client
api = wandb.Api()
pwd = os.path.dirname(os.path.abspath(__file__))
# Replace with your specific project and entity
entity = "latent-mage"
project = "video_val"
# List of datasets to process
datasets = [
"bridge_data_v2",
"fractal20220817_data",
"language_table",
"ucsd_pick_and_place_dataset_converted_externally_to_rlds",
"kaist_nonprehensile_converted_externally_to_rlds",
"ucsd_kitchen_dataset_converted_externally_to_rlds",
"utokyo_xarm_bimanual_converted_externally_to_rlds",
"stanford_hydra_dataset_converted_externally_to_rlds",
"austin_sirius_dataset_converted_externally_to_rlds",
"berkeley_fanuc_manipulation",
"berkeley_mvp_converted_externally_to_rlds",
"berkeley_rpt_converted_externally_to_rlds",
"cmu_play_fusion",
"iamlab_cmu_pickup_insert_converted_externally_to_rlds",
"qut_dexterous_manpulation",
"robo_net",
"furniture_bench_dataset_converted_externally_to_rlds",
"dlr_sara_grid_clamp_converted_externally_to_rlds",
"cmu_stretch",
"spoc",
"columbia_cairlab_pusht_real",
"droid",
"toto",
"io_ai_tech",
"conq_hose_manipulation",
"dobbe",
"berkeley_gnm_cory_hall",
"plex_robosuite",
"usc_cloth_sim_converted_externally_to_rlds",
"berkeley_cable_routing",
"imperial_wrist_dataset",
"bc_z",
"kuka",
"roboturk",
"metaworld",
"robomimic",
"epic_kitchen",
"ego4d",
"nyu_door_opening_surprising_effectiveness"
]
def normalize_dataset(metric, runs):
"""
Figure out best and worst values for a metric across all runs
and use it for normalization
"""
pass
# List to store dataframes of PSNR metrics for each dataset
metrics_data = []
# Get runs based on a path
# Set up argument parser
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--run_id', type=str, default='40dataset_waction_add_gpu_8_nodes_1', help='The run ID to process')
# Parse arguments
args = parser.parse_args()
fields = ['num_examples', 'teacher_force_psnr', 'teacher_force_psnr_delta', 'teacher_force_ssim', 'teacher_force_pred_lpips', 'teacher_force_loss']
num_fields = len(fields)
run_id = args.run_id
runs_path = f"{entity}/{project}/runs"
run = api.run(f"{entity}/{project}/runs/{run_id}")
# Get the history dataframe of a run
history = run.history(pandas=True)
model_step = 0
summary_metrics = run.summary
num_datasets = 0
# output the field into csv
# csv_output = f"{pwd}/aggregated_output.csv"
csv_output = f"aggregated_output.csv"
# initialize the csv file
if not os.path.exists(csv_output):
with open(csv_output, 'w') as f:
field_str = f"name,"
for dataset in datasets:
for field in fields:
field_str += f"{dataset}/{field},"
f.write(field_str.rstrip(",") + "\n")
results = [run_id] + [None] * len(datasets) * num_fields
for field_idx, field in enumerate(fields):
if not history.empty:
# Filter the history to only include PSNR metrics for the specified datasets
for dataset_idx, dataset in enumerate(datasets):
field_col = f"{dataset}/{field}"
col_idx = dataset_idx * num_fields + field_idx + 1
if field == "num_examples":
if f"{dataset}/num_examples" in summary_metrics:
results[col_idx] = summary_metrics[f"{dataset}/num_examples"]
continue
if field_col in history.columns:
# Calculate PSNR divided by the number of examples (uncomment if needed)
# history[field_col] = history[field_col] / history.shape[0]
valid_field = history[field_col].dropna()
if not valid_field.empty:
last_valid_value = valid_field.iloc[-1] # Get the last non-NaN value
num_datasets += 1
metrics = pd.DataFrame({field_col: [last_valid_value]})
metrics['dataset'] = dataset
results[col_idx] = last_valid_value
metrics_data.append(metrics)
else:
pass
# print("missing dataset:", dataset)
if f"{dataset}/model_step" in summary_metrics:
model_step = summary_metrics[f"{dataset}/model_step"]
# Combine all the metric dataframes into one
if metrics_data:
all_metrics_df = pd.concat(metrics_data, ignore_index=True)
# # Compute aggregated statistics (mean, median, std, etc.) for PSNR
# aggregated_stats = all_metrics_df.groupby('dataset').mean()
#
# # Plot the mean PSNR for each dataset
# plt.figure(figsize=(12, 8))
# aggregated_stats[f'{field}'] = aggregated_stats.mean(axis=1)
# aggregated_stats[f'{field}'].plot(kind='bar')
# # print number of steps in the wandb run
# print(f"run: {run_id} field: {field} steps: {model_step} num of dataset: {len(metrics_data)}")
# print(f"{field}: {aggregated_stats[field].mean():.2f}+-{aggregated_stats[field].std():.2f}", )
#
# plt.title(f"Mean {field} for Each Dataset")
# plt.xlabel("Dataset")
# plt.ylabel(f"Mean {field} ")
# plt.xticks(rotation=90)
# plt.tight_layout()
#
# # Save the plot
# plt.savefig(f"{pwd}/output/{run.id}_{field}_plot.png")
# write the results into csv
with open(csv_output, 'a+') as f:
f.write(",".join([str(x) for x in results]) + "\n")
# Display aggregated statistics
# print(aggregated_stats)
# Save the aggregated statistics as a CSV if needed
# aggregated_stats.to_csv(f"{run_id}_{field}_stat.csv", index=True)