Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,269 Bytes
246c106 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import wandb
import pandas as pd
import matplotlib.pyplot as plt
import sys
import argparse
"""
Running plotting scripts over key metrics and key runs
export MODEL=40dataset_waction_add_gpu_8_nodes_1
python common/plot/plot_from_wandb.py --field teacher_force_psnr --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_psnr_delta --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_ssim --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_pred_lpips --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_loss --run_id $MODEL
"""
# Initialize the wandb API client
api = wandb.Api()
# Replace with your specific project and entity
entity = "latent-mage"
project = "video_val"
# List of datasets to process
datasets = [
"bridge_data_v2",
"fractal20220817_data",
"language_table",
"ucsd_pick_and_place_dataset_converted_externally_to_rlds",
"kaist_nonprehensile_converted_externally_to_rlds",
"ucsd_kitchen_dataset_converted_externally_to_rlds",
"utokyo_xarm_bimanual_converted_externally_to_rlds",
"stanford_hydra_dataset_converted_externally_to_rlds",
"austin_sirius_dataset_converted_externally_to_rlds",
"berkeley_fanuc_manipulation",
"berkeley_mvp_converted_externally_to_rlds",
"berkeley_rpt_converted_externally_to_rlds",
"cmu_play_fusion",
"iamlab_cmu_pickup_insert_converted_externally_to_rlds",
"qut_dexterous_manpulation",
"robo_net",
"furniture_bench_dataset_converted_externally_to_rlds",
"dlr_sara_grid_clamp_converted_externally_to_rlds",
"cmu_stretch",
"spoc",
"columbia_cairlab_pusht_real",
"droid",
"toto",
"io_ai_tech",
"conq_hose_manipulation",
"dobbe",
"berkeley_gnm_cory_hall",
"plex_robosuite",
"usc_cloth_sim_converted_externally_to_rlds",
"berkeley_cable_routing",
"imperial_wrist_dataset",
"bc_z",
"kuka",
"roboturk",
"metaworld",
"robomimic",
"epic_kitchen",
"ego4d",
"nyu_door_opening_surprising_effectiveness"
]
# List to store dataframes of PSNR metrics for each dataset
# Get runs based on a path
# Set up argument parser
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--field', type=str, default='teacher_force_psnr', help='The field to process')
parser.add_argument('--run_id', type=str, default='40dataset_waction_add_gpu_8_nodes_1', help='The run ID to process')
# Parse arguments
args = parser.parse_args()
field = args.field
run_id = args.run_id
runs_path = f"{entity}/{project}/runs"
run = api.run(f"{entity}/{project}/runs/{run_id}")
# Get the history dataframe of a run
history = run.history(pandas=True)
model_step = 0
summary_metrics = run.summary
num_datasets = 0
fields = ['num_examples', 'teacher_force_psnr', 'teacher_force_psnr_delta', 'teacher_force_ssim', 'teacher_force_pred_lpips', 'teacher_force_loss']
for field in fields:
metrics_data = []
if not history.empty:
# Filter the history to only include PSNR metrics for the specified datasets
for dataset in datasets:
field_col = f"{dataset}/{field}"
step_col = f"{dataset}/model_step"
if field_col in history.columns:
# Calculate PSNR divided by the number of examples (uncomment if needed)
# history[field_col] = history[field_col] / history.shape[0]
valid_field = history[field_col].dropna()
if not valid_field.empty:
last_valid_value = valid_field.iloc[-1] # Get the last non-NaN value
num_datasets += 1
metrics = pd.DataFrame({field_col: [last_valid_value]})
metrics['dataset'] = dataset
metrics_data.append(metrics)
if step_col in summary_metrics:
model_step = summary_metrics[step_col]
# Combine all the metric dataframes into one
if metrics_data:
all_metrics_df = pd.concat(metrics_data, ignore_index=True)
# Print columns for debugging
# Compute aggregated statistics (mean, median, std, etc.) for PSNR
aggregated_stats = all_metrics_df.groupby('dataset').mean()
# Plot the mean PSNR for each dataset
plt.figure(figsize=(12, 8))
aggregated_stats[f'{field}'] = aggregated_stats.mean(axis=1)
aggregated_stats[f'{field}'].plot(kind='bar')
# print number of steps in the wandb run
print(f"run: {run_id} field: {field} steps: {model_step} num of dataset: {len(metrics_data)}")
print(f"{field}: {aggregated_stats[field].mean():.2f}+-{aggregated_stats[field].std():.2f}", )
# plt.title(f"Mean {field} for Each Dataset")
# plt.xlabel("Dataset")
# plt.ylabel(f"Mean {field} ")
# plt.xticks(rotation=90)
# plt.tight_layout()
# # Save the plot
# import os
# pwd = os.path.dirname(os.path.abspath(__file__))
# plt.savefig(f"{pwd}/output/{run.id}_{field}_plot.png")
# Display aggregated statistics
# print(aggregated_stats)
# Save the aggregated statistics as a CSV if needed
# aggregated_stats.to_csv(f"{run_id}_{field}_stat.csv", index=True) |