{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "0625f0a1", "metadata": {}, "outputs": [], "source": [ "import random\n", "from pathlib import Path\n", "\n", "import numpy as np\n", "from ase.db import connect\n", "\n", "random.seed(0)\n", "\n", "DATA_DIR = Path(\".\")\n", "\n", "db = connect(DATA_DIR / \"c2db.db\")\n", "random_indices = random.sample(range(1, len(db) + 1), 1000)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "005708b9", "metadata": {}, "outputs": [], "source": [ "import itertools\n", "\n", "import pandas as pd\n", "import phonopy\n", "from tqdm.auto import tqdm\n", "\n", "from mlip_arena.models import MLIPEnum\n", "\n", "for row, model in tqdm(\n", " itertools.product(db.select(filter=lambda r: r[\"id\"] in random_indices), MLIPEnum)\n", "):\n", " uid = row[\"uid\"]\n", "\n", " if Path(f\"{model.name}.parquet\").exists():\n", " df = pd.read_parquet(f\"{model.name}.parquet\")\n", " if uid in df[\"uid\"].unique():\n", " continue\n", " else:\n", " df = pd.DataFrame(columns=[\"model\", \"uid\", \"eigenvalues\", \"frequencies\"])\n", "\n", " try:\n", " path = Path(model.name) / uid\n", " phonon = phonopy.load(path / \"phonopy.yaml\")\n", " frequencies = phonon.get_frequencies(q=(0, 0, 0))\n", "\n", " data = np.load(path / \"elastic.npz\")\n", "\n", " eigenvalues = data[\"eigenvalues\"]\n", "\n", " new_row = pd.DataFrame(\n", " [\n", " {\n", " \"model\": model.name,\n", " \"uid\": uid,\n", " \"eigenvalues\": eigenvalues,\n", " \"frequencies\": frequencies,\n", " }\n", " ]\n", " )\n", "\n", " df = pd.concat([df, new_row], ignore_index=True)\n", " df.drop_duplicates(subset=[\"model\", \"uid\"], keep=\"last\", inplace=True)\n", "\n", " df.to_parquet(f\"{model.name}.parquet\", index=False)\n", " except Exception:\n", " pass\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "b8d87638", "metadata": {}, "outputs": [], "source": [ "uids = []\n", "stabilities = []\n", "for row in db.select(filter=lambda r: r[\"id\"] in random_indices):\n", " stable = row.key_value_pairs[\"dyn_stab\"]\n", " if stable.lower() == \"unknown\":\n", " stable = None\n", " else:\n", " stable = True if stable.lower() == \"yes\" else False\n", " uids.append(row.key_value_pairs[\"uid\"])\n", " stabilities.append(stable)\n", "\n", "\n", "stabilities = np.array(stabilities)\n", "\n", "(stabilities == True).sum(), (stabilities == False).sum(), (stabilities == None).sum()" ] }, { "cell_type": "markdown", "id": "a3c516a7", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 104, "id": "0052d0ff", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "\n", "from pathlib import Path\n", "\n", "import numpy as np\n", "import pandas as pd\n", "from matplotlib import pyplot as plt\n", "from sklearn.metrics import (\n", " ConfusionMatrixDisplay,\n", " classification_report,\n", " confusion_matrix,\n", ")\n", "\n", "from mlip_arena.models import MLIPEnum\n", "\n", "thres = -1e-7\n", "\n", "select_models = [\n", " \"ALIGNN\",\n", " \"CHGNet\",\n", " \"M3GNet\",\n", " \"MACE-MP(M)\",\n", " \"MACE-MPA\",\n", " \"MatterSim\",\n", " \"ORBv2\",\n", " \"SevenNet\",\n", "]\n", "\n", "with plt.style.context(\"default\"):\n", " # plt.rcParams.update({\n", " # # \"title.fontsize\": 10,\n", " # \"axes.titlesize\": 10,\n", " # \"axes.labelsize\": 8,\n", " # })\n", "\n", " SMALL_SIZE = 8\n", " MEDIUM_SIZE = 10\n", " BIGGER_SIZE = 12\n", " plt.rcParams.update(\n", " {\n", " \"font.size\": SMALL_SIZE,\n", " \"axes.titlesize\": MEDIUM_SIZE,\n", " \"axes.labelsize\": MEDIUM_SIZE,\n", " \"xtick.labelsize\": MEDIUM_SIZE,\n", " \"ytick.labelsize\": MEDIUM_SIZE,\n", " \"legend.fontsize\": SMALL_SIZE,\n", " \"figure.titlesize\": BIGGER_SIZE,\n", " }\n", " )\n", "\n", " fig, axs = plt.subplots(\n", " nrows=int(np.ceil(len(MLIPEnum) / 4)),\n", " ncols=4,\n", " figsize=(6, 3 * int(np.ceil(len(select_models) / 4))),\n", " sharey=True,\n", " sharex=True,\n", " layout=\"constrained\",\n", " )\n", " axs = axs.flatten()\n", " plot_idx = 0\n", "\n", " for model in MLIPEnum:\n", " fpath = DATA_DIR / f\"{model.name}.parquet\"\n", " if not fpath.exists():\n", " continue\n", "\n", " if model.name not in select_models:\n", " continue\n", "\n", " df = pd.read_parquet(fpath)\n", " df[\"eigval_min\"] = df[\"eigenvalues\"].apply(\n", " lambda x: x.min() if np.isreal(x).all() else thres\n", " )\n", " df[\"freq_min\"] = df[\"frequencies\"].apply(\n", " lambda x: x.min() if np.isreal(x).all() else thres\n", " )\n", " df[\"dyn_stab\"] = ~np.logical_or(\n", " df[\"eigval_min\"] < thres, df[\"freq_min\"] < thres\n", " )\n", "\n", " arg = np.argsort(uids)\n", " uids_sorted = np.array(uids)[arg]\n", " stabilities_sorted = stabilities[arg]\n", "\n", " sorted_df = (\n", " df[df[\"uid\"].isin(uids_sorted)].set_index(\"uid\").reindex(uids_sorted)\n", " )\n", " mask = ~(stabilities_sorted == None)\n", "\n", " y_true = stabilities_sorted[mask].astype(\"int\")\n", " y_pred = sorted_df[\"dyn_stab\"][mask].fillna(-1).astype(\"int\")\n", " cm = confusion_matrix(y_true, y_pred, labels=[1, 0, -1])\n", "\n", " ax = axs[plot_idx]\n", " ConfusionMatrixDisplay(\n", " cm, display_labels=[\"stable\", \"unstable\", \"missing\"]\n", " ).plot(ax=ax, cmap=\"Blues\", colorbar=False)\n", "\n", " ax.set_title(model.name)\n", " ax.set_xlabel(\"Predicted\")\n", " ax.set_ylabel(\"True\")\n", " ax.set_xticks([0, 1, 2])\n", " ax.set_xticklabels([\"stable\", \"unstable\", \"missing\"])\n", " ax.set_yticks([0, 1, 2])\n", " ax.set_yticklabels([\"stable\", \"unstable\", \"missing\"])\n", "\n", " plot_idx += 1\n", "\n", " # Hide unused subplots\n", " for i in range(plot_idx, len(axs)):\n", " fig.delaxes(axs[i])\n", "\n", " # plt.tight_layout()\n", " plt.savefig(\"c2db-confusion_matrices.pdf\", bbox_inches=\"tight\")\n", " plt.show()\n" ] }, { "cell_type": "code", "execution_count": 52, "id": "573b3c38", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from sklearn.metrics import confusion_matrix\n", "\n", "from mlip_arena.models import MLIPEnum\n", "\n", "thres = -1e-7\n", "\n", "summary_df = pd.DataFrame(columns=[\"Model\", \"Stable F1\", \"Unstable F1\", \"Weighted F1\"])\n", "\n", "for model in MLIPEnum:\n", " fpath = DATA_DIR / f\"{model.name}.parquet\"\n", "\n", " if not fpath.exists() or model.name not in select_models:\n", " # print(f\"File {fpath} does not exist\")\n", " continue\n", " df = pd.read_parquet(fpath)\n", "\n", " df[\"eigval_min\"] = df[\"eigenvalues\"].apply(\n", " lambda x: x.min() if np.isreal(x).all() else thres\n", " )\n", " df[\"freq_min\"] = df[\"frequencies\"].apply(\n", " lambda x: x.min() if np.isreal(x).all() else thres\n", " )\n", " df[\"dyn_stab\"] = ~np.logical_or(df[\"eigval_min\"] < thres, df[\"freq_min\"] < thres)\n", "\n", " arg = np.argsort(uids)\n", " uids = np.array(uids)[arg]\n", " stabilities = stabilities[arg]\n", "\n", " sorted_df = df[df[\"uid\"].isin(uids)].sort_values(by=\"uid\")\n", "\n", " # sorted_df = sorted_df.reindex(uids).reset_index()\n", " sorted_df = sorted_df.set_index(\"uid\").reindex(uids) # .loc[uids].reset_index()\n", "\n", " sorted_df = sorted_df.loc[uids]\n", " # mask = ~np.logical_or(sorted_df['dyn_stab'].isna().values, stabilities == None)\n", " mask = ~(stabilities == None)\n", "\n", " y_true = stabilities[mask].astype(\"int\")\n", " y_pred = sorted_df[\"dyn_stab\"][mask].fillna(-1).astype(\"int\")\n", " cm = confusion_matrix(y_true, y_pred, labels=[1, 0, -1])\n", " # print(model)\n", " # print(cm)\n", " # print(classification_report(y_true, y_pred, labels=[1, 0], target_names=['stable', 'unstable'], digits=3, output_dict=False))\n", "\n", " report = classification_report(\n", " y_true,\n", " y_pred,\n", " labels=[1, 0],\n", " target_names=[\"stable\", \"unstable\"],\n", " digits=3,\n", " output_dict=True,\n", " )\n", "\n", " summary_df = pd.concat(\n", " [\n", " summary_df,\n", " pd.DataFrame(\n", " [\n", " {\n", " \"Model\": model.name,\n", " \"Stable F1\": report[\"stable\"][\"f1-score\"],\n", " \"Unstable F1\": report[\"unstable\"][\"f1-score\"],\n", " \"Macro F1\": report[\"macro avg\"][\"f1-score\"],\n", " # 'Micro F1': report['micro avg']['f1-score'],\n", " \"Weighted F1\": report[\"weighted avg\"][\"f1-score\"],\n", " }\n", " ]\n", " ),\n", " ],\n", " ignore_index=True,\n", " )\n", "\n", " # break" ] }, { "cell_type": "code", "execution_count": 85, "id": "df660870", "metadata": {}, "outputs": [], "source": [ "summary_df = summary_df.sort_values(by=[\"Macro F1\", \"Weighted F1\"], ascending=False)\n", "summary_df.to_latex(\"c2db_summary_table.tex\", index=False, float_format=\"%.3f\")" ] }, { "cell_type": "code", "execution_count": 103, "id": "18f4a59b", "metadata": {}, "outputs": [], "source": [ "from matplotlib import cm\n", "\n", "# Metrics and bar settings\n", "metrics = [\"Stable F1\", \"Unstable F1\", \"Macro F1\", \"Weighted F1\"]\n", "bar_width = 0.2\n", "x = np.arange(len(summary_df))\n", "\n", "# Get Set2 colormap (as RGBA)\n", "cmap = plt.get_cmap(\"tab20\")\n", "colors = {metric: cmap(i) for i, metric in enumerate(metrics)}\n", "\n", "with plt.style.context(\"default\"):\n", " plt.rcParams.update(\n", " {\n", " \"font.size\": SMALL_SIZE,\n", " \"axes.titlesize\": MEDIUM_SIZE,\n", " \"axes.labelsize\": MEDIUM_SIZE,\n", " \"xtick.labelsize\": MEDIUM_SIZE,\n", " \"ytick.labelsize\": MEDIUM_SIZE,\n", " \"legend.fontsize\": SMALL_SIZE,\n", " \"figure.titlesize\": BIGGER_SIZE,\n", " }\n", " )\n", "\n", " fig, ax = plt.subplots(figsize=(4, 3), layout=\"constrained\")\n", "\n", " # Bar positions\n", " positions = {\n", " \"Stable F1\": x - 1.5 * bar_width,\n", " \"Unstable F1\": x - 0.5 * bar_width,\n", " \"Macro F1\": x + 0.5 * bar_width,\n", " \"Weighted F1\": x + 1.5 * bar_width,\n", " }\n", "\n", " # Plot each metric with assigned color\n", " for metric, pos in positions.items():\n", " ax.bar(\n", " pos, summary_df[metric], width=bar_width, label=metric, color=colors[metric]\n", " )\n", "\n", " ax.set_xlabel(\"Model\")\n", " ax.set_ylabel(\"F1 Score\")\n", " # ax.set_title('F1 Scores by Model and Class')\n", " ax.set_xticks(x)\n", " ax.set_xticklabels(summary_df[\"Model\"], rotation=45, ha=\"right\")\n", " ax.legend(ncols=2, bbox_to_anchor=(0.5, 1), loc=\"upper center\", fontsize=SMALL_SIZE)\n", " # ax.legend(ncols=2, fontsize=SMALL_SIZE)\n", " ax.spines[[\"top\", \"right\"]].set_visible(False)\n", " plt.tight_layout()\n", " plt.ylim(0, 0.9)\n", " plt.grid(axis=\"y\", linestyle=\"--\", alpha=0.6)\n", "\n", " plt.savefig(\"c2db_f1_bar.pdf\", bbox_inches=\"tight\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "1c50f705", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "mlip-arena", "language": "python", "name": "mlip-arena" } }, "nbformat": 4, "nbformat_minor": 5 }