import wandb
import pandas as pd
import matplotlib.pyplot as plt
import sys
import argparse

"""
Running plotting scripts over key metrics and key runs
export MODEL=40dataset_waction_add_gpu_8_nodes_1
python common/plot/plot_from_wandb.py --field teacher_force_psnr --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_psnr_delta --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_ssim --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_pred_lpips --run_id $MODEL
python common/plot/plot_from_wandb.py --field teacher_force_loss --run_id $MODEL

"""
# Initialize the wandb API client
api = wandb.Api()

# Replace with your specific project and entity
entity = "latent-mage"
project = "video_val"

# List of datasets to process
datasets = [
    "bridge_data_v2",
    "fractal20220817_data",
    "language_table",
    "ucsd_pick_and_place_dataset_converted_externally_to_rlds",
    "kaist_nonprehensile_converted_externally_to_rlds",
    "ucsd_kitchen_dataset_converted_externally_to_rlds",
    "utokyo_xarm_bimanual_converted_externally_to_rlds",
    "stanford_hydra_dataset_converted_externally_to_rlds",
    "austin_sirius_dataset_converted_externally_to_rlds",
    "berkeley_fanuc_manipulation",
    "berkeley_mvp_converted_externally_to_rlds",
    "berkeley_rpt_converted_externally_to_rlds",
    "cmu_play_fusion",
    "iamlab_cmu_pickup_insert_converted_externally_to_rlds",
    "qut_dexterous_manpulation",
    "robo_net",
    "furniture_bench_dataset_converted_externally_to_rlds",
    "dlr_sara_grid_clamp_converted_externally_to_rlds",
    "cmu_stretch",
    "spoc",
    "columbia_cairlab_pusht_real",
    "droid",
    "toto",
    "io_ai_tech",
    "conq_hose_manipulation",
    "dobbe",
    "berkeley_gnm_cory_hall",
    "plex_robosuite",
    "usc_cloth_sim_converted_externally_to_rlds",
    "berkeley_cable_routing",
    "imperial_wrist_dataset",
    "bc_z",
    "kuka",
    "roboturk",
    "metaworld",
    "robomimic",
    "epic_kitchen",
    "ego4d",
    "nyu_door_opening_surprising_effectiveness"
]

# List to store dataframes of PSNR metrics for each dataset

# Get runs based on a path
# Set up argument parser
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--field', type=str, default='teacher_force_psnr', help='The field to process')
parser.add_argument('--run_id', type=str, default='40dataset_waction_add_gpu_8_nodes_1', help='The run ID to process')

# Parse arguments
args = parser.parse_args()

field = args.field
run_id = args.run_id

runs_path = f"{entity}/{project}/runs"
run = api.run(f"{entity}/{project}/runs/{run_id}")

# Get the history dataframe of a run
history = run.history(pandas=True)
model_step = 0
summary_metrics = run.summary
num_datasets = 0
fields = ['num_examples', 'teacher_force_psnr', 'teacher_force_psnr_delta', 'teacher_force_ssim', 'teacher_force_pred_lpips', 'teacher_force_loss']

for field in fields:
    metrics_data = []
    if not history.empty:
        # Filter the history to only include PSNR metrics for the specified datasets
        for dataset in datasets:
            field_col = f"{dataset}/{field}"
            step_col = f"{dataset}/model_step"
            if field_col in history.columns:
                # Calculate PSNR divided by the number of examples (uncomment if needed)
                # history[field_col] = history[field_col] / history.shape[0]
                valid_field = history[field_col].dropna()
                if not valid_field.empty:
                    last_valid_value = valid_field.iloc[-1]  # Get the last non-NaN value
                    num_datasets += 1
                    metrics = pd.DataFrame({field_col: [last_valid_value]})
                    metrics['dataset'] = dataset
                    metrics_data.append(metrics)

            if step_col in summary_metrics:
                model_step = summary_metrics[step_col]

    # Combine all the metric dataframes into one
    if metrics_data:
        all_metrics_df = pd.concat(metrics_data, ignore_index=True)

        # Print columns for debugging

        # Compute aggregated statistics (mean, median, std, etc.) for PSNR
        aggregated_stats = all_metrics_df.groupby('dataset').mean()

        # Plot the mean PSNR for each dataset
        plt.figure(figsize=(12, 8))
        aggregated_stats[f'{field}'] = aggregated_stats.mean(axis=1)
        aggregated_stats[f'{field}'].plot(kind='bar')
        # print number of steps in the wandb run
        print(f"run: {run_id} field: {field} steps: {model_step} num of dataset: {len(metrics_data)}")
        print(f"{field}: {aggregated_stats[field].mean():.2f}+-{aggregated_stats[field].std():.2f}", )

        # plt.title(f"Mean {field} for Each Dataset")
        # plt.xlabel("Dataset")
        # plt.ylabel(f"Mean {field} ")
        # plt.xticks(rotation=90)
        # plt.tight_layout()

        # # Save the plot
        # import os
        # pwd = os.path.dirname(os.path.abspath(__file__))
        # plt.savefig(f"{pwd}/output/{run.id}_{field}_plot.png")

# Display aggregated statistics
# print(aggregated_stats)

# Save the aggregated statistics as a CSV if needed
# aggregated_stats.to_csv(f"{run_id}_{field}_stat.csv", index=True)