Spaces:
Running
on
Zero
Running
on
Zero
import wandb | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import sys | |
import argparse | |
""" | |
Running plotting scripts over key metrics and key runs | |
export MODEL=40dataset_waction_add_gpu_8_nodes_1 | |
python common/plot/plot_from_wandb.py --field teacher_force_psnr --run_id $MODEL | |
python common/plot/plot_from_wandb.py --field teacher_force_psnr_delta --run_id $MODEL | |
python common/plot/plot_from_wandb.py --field teacher_force_ssim --run_id $MODEL | |
python common/plot/plot_from_wandb.py --field teacher_force_pred_lpips --run_id $MODEL | |
python common/plot/plot_from_wandb.py --field teacher_force_loss --run_id $MODEL | |
""" | |
# Initialize the wandb API client | |
api = wandb.Api() | |
# Replace with your specific project and entity | |
entity = "latent-mage" | |
project = "video_val" | |
# List of datasets to process | |
datasets = [ | |
"bridge_data_v2", | |
"fractal20220817_data", | |
"language_table", | |
"ucsd_pick_and_place_dataset_converted_externally_to_rlds", | |
"kaist_nonprehensile_converted_externally_to_rlds", | |
"ucsd_kitchen_dataset_converted_externally_to_rlds", | |
"utokyo_xarm_bimanual_converted_externally_to_rlds", | |
"stanford_hydra_dataset_converted_externally_to_rlds", | |
"austin_sirius_dataset_converted_externally_to_rlds", | |
"berkeley_fanuc_manipulation", | |
"berkeley_mvp_converted_externally_to_rlds", | |
"berkeley_rpt_converted_externally_to_rlds", | |
"cmu_play_fusion", | |
"iamlab_cmu_pickup_insert_converted_externally_to_rlds", | |
"qut_dexterous_manpulation", | |
"robo_net", | |
"furniture_bench_dataset_converted_externally_to_rlds", | |
"dlr_sara_grid_clamp_converted_externally_to_rlds", | |
"cmu_stretch", | |
"spoc", | |
"columbia_cairlab_pusht_real", | |
"droid", | |
"toto", | |
"io_ai_tech", | |
"conq_hose_manipulation", | |
"dobbe", | |
"berkeley_gnm_cory_hall", | |
"plex_robosuite", | |
"usc_cloth_sim_converted_externally_to_rlds", | |
"berkeley_cable_routing", | |
"imperial_wrist_dataset", | |
"bc_z", | |
"kuka", | |
"roboturk", | |
"metaworld", | |
"robomimic", | |
"epic_kitchen", | |
"ego4d", | |
"nyu_door_opening_surprising_effectiveness" | |
] | |
# List to store dataframes of PSNR metrics for each dataset | |
# Get runs based on a path | |
# Set up argument parser | |
parser = argparse.ArgumentParser(description='Process some integers.') | |
parser.add_argument('--field', type=str, default='teacher_force_psnr', help='The field to process') | |
parser.add_argument('--run_id', type=str, default='40dataset_waction_add_gpu_8_nodes_1', help='The run ID to process') | |
# Parse arguments | |
args = parser.parse_args() | |
field = args.field | |
run_id = args.run_id | |
runs_path = f"{entity}/{project}/runs" | |
run = api.run(f"{entity}/{project}/runs/{run_id}") | |
# Get the history dataframe of a run | |
history = run.history(pandas=True) | |
model_step = 0 | |
summary_metrics = run.summary | |
num_datasets = 0 | |
fields = ['num_examples', 'teacher_force_psnr', 'teacher_force_psnr_delta', 'teacher_force_ssim', 'teacher_force_pred_lpips', 'teacher_force_loss'] | |
for field in fields: | |
metrics_data = [] | |
if not history.empty: | |
# Filter the history to only include PSNR metrics for the specified datasets | |
for dataset in datasets: | |
field_col = f"{dataset}/{field}" | |
step_col = f"{dataset}/model_step" | |
if field_col in history.columns: | |
# Calculate PSNR divided by the number of examples (uncomment if needed) | |
# history[field_col] = history[field_col] / history.shape[0] | |
valid_field = history[field_col].dropna() | |
if not valid_field.empty: | |
last_valid_value = valid_field.iloc[-1] # Get the last non-NaN value | |
num_datasets += 1 | |
metrics = pd.DataFrame({field_col: [last_valid_value]}) | |
metrics['dataset'] = dataset | |
metrics_data.append(metrics) | |
if step_col in summary_metrics: | |
model_step = summary_metrics[step_col] | |
# Combine all the metric dataframes into one | |
if metrics_data: | |
all_metrics_df = pd.concat(metrics_data, ignore_index=True) | |
# Print columns for debugging | |
# Compute aggregated statistics (mean, median, std, etc.) for PSNR | |
aggregated_stats = all_metrics_df.groupby('dataset').mean() | |
# Plot the mean PSNR for each dataset | |
plt.figure(figsize=(12, 8)) | |
aggregated_stats[f'{field}'] = aggregated_stats.mean(axis=1) | |
aggregated_stats[f'{field}'].plot(kind='bar') | |
# print number of steps in the wandb run | |
print(f"run: {run_id} field: {field} steps: {model_step} num of dataset: {len(metrics_data)}") | |
print(f"{field}: {aggregated_stats[field].mean():.2f}+-{aggregated_stats[field].std():.2f}", ) | |
# plt.title(f"Mean {field} for Each Dataset") | |
# plt.xlabel("Dataset") | |
# plt.ylabel(f"Mean {field} ") | |
# plt.xticks(rotation=90) | |
# plt.tight_layout() | |
# # Save the plot | |
# import os | |
# pwd = os.path.dirname(os.path.abspath(__file__)) | |
# plt.savefig(f"{pwd}/output/{run.id}_{field}_plot.png") | |
# Display aggregated statistics | |
# print(aggregated_stats) | |
# Save the aggregated statistics as a CSV if needed | |
# aggregated_stats.to_csv(f"{run_id}_{field}_stat.csv", index=True) |