{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import os\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "output_type = \"png\" # or \"pdf\"\n", "timevis = \"noB_tnn\"\n", "dvi = \"parametricUmap_step2_A\"" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "DATASET = \"mnist\"\n", "CONTENT_PATH = \"/home/xianglin/projects/DVI_data/resnet18_{}\".format(DATASET)\n", "content_path = CONTENT_PATH" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "train_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_corrs.npy\".format(timevis)))\n", "train_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_ps.npy\".format(timevis)))\n", "train_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_5_tnn.npy\".format(timevis)))\n", "test_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_corrs.npy\".format(timevis)))\n", "test_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_ps.npy\".format(timevis)))\n", "test_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_5_tnn.npy\".format(timevis)))\n", "\n", "\n", "dvi_train_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_corrs.npy\".format(dvi)))\n", "dvi_train_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_ps.npy\".format(dvi)))\n", "dvi_train_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_5_tnn.npy\".format(dvi)))\n", "dvi_test_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_corrs.npy\".format(dvi)))\n", "dvi_test_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_ps.npy\".format(dvi)))\n", "dvi_test_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_5_tnn.npy\".format(dvi)))\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "selected_idxs = np.argsort(train_corrs[19])[-100:]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(,\n", " )" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAARtklEQVR4nO3db4xlB1nH8e+PbitoK2xlW9d1NwWsCCFScEAsaICCLn1TMGBFhA1WtwQxIITYwAs1vkGjSPwT7AINq0EoQrFFsVBKoZJCYSGlbF2ggEDXbrpTQKmagFseX9zTOA6zu3e3c+5zZ+b7SW7uveeeO+fpZO63Z8/ccydVhSRp9h7QPYAkbVQGWJKaGGBJamKAJamJAZakJpu6B5jGzp0769prr+0eQ5JOVlZauCb2gO++++7uESRp1a2JAEvSemSAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCajBTjJA5N8IslnktyW5PeH5WcmuS7J7cP15rFmkKR5NuYe8LeBp1fVY4HzgJ1JngRcBlxfVecC1w/3JWnDGS3ANfGfw91Th0sBFwF7h+V7gWePNYMkzbNRjwEnOSXJLcBh4Lqquhk4u6oOAQzXZx3lubuT7Euyb3FxccwxJc2xbdt3kGQuLtu271jV/7ZRP5C9qu4FzkvyEOA9SR5zAs/dA+wBWFhYqHEmlDTv7jx4BxdfflP3GABceen5q/r1ZvIuiKr6d+DDwE7griRbAYbrw7OYQZLmzZjvgtgy7PmS5EHAM4DPAdcAu4bVdgFXjzWDJM2zMQ9BbAX2JjmFSejfWVX/kORjwDuTXAJ8DXjeiDNI0twaLcBVdSvwuBWWfx24YKztStJa4ZlwktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1KT0QKcZHuSG5IcSHJbkpcPy38vyb8luWW4XDjWDJI0zzaN+LWPAK+qqk8nOQP4VJLrhsf+tKr+eMRtS9LcGy3AVXUIODTcvifJAWDbWNuTpLVmJseAk5wDPA64eVj0siS3JrkiyeajPGd3kn1J9i0uLs5iTEmaqdEDnOR04N3AK6rqW8AbgUcA5zHZQ/6TlZ5XVXuqaqGqFrZs2TL2mJI0c6MGOMmpTOL7tqq6CqCq7qqqe6vqu8CbgCeOOYMkzasx3wUR4C3Agap6/ZLlW5es9hxg/1gzSNI8G/NdEE8GXgh8Nsktw7LXAM9Pch5QwFeAS0ecQZLm1pjvgvgokBUeet9Y25SktcQz4SSpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJqMFOMn2JDckOZDktiQvH5afmeS6JLcP15vHmkGS5tmYe8BHgFdV1aOAJwG/meTRwGXA9VV1LnD9cF+SNpzRAlxVh6rq08Pte4ADwDbgImDvsNpe4NljzSBJ82wmx4CTnAM8DrgZOLuqDsEk0sBZR3nO7iT7kuxbXFycxZiSNFOjBzjJ6cC7gVdU1bemfV5V7amqhapa2LJly3gDSlKTUQOc5FQm8X1bVV01LL4rydbh8a3A4TFnkKR5Nea7IAK8BThQVa9f8tA1wK7h9i7g6rFmkKR5tmnEr/1k4IXAZ5PcMix7DfA64J1JLgG+BjxvxBkkaW6NFuCq+iiQozx8wVjblaS1wjPhJKmJAZakJgZYkpoYYElqYoC14W3bvoMkc3HZtn1H97dDMzTm29CkNeHOg3dw8eU3dY8BwJWXnt89gmbIPWBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmUwU4yZOnWSZJmt60e8B/PuUySdKUjvlXkZP8DHA+sCXJK5c89IPAKWMOJknr3fH+LP1pwOnDemcsWf4t4LljDSVJG8ExA1xVHwE+kuStVfXVGc0kSRvC8faA7/N9SfYA5yx9TlU9fYyhJGkjmDbAfwf8FfBm4N7xxpGkjWPaAB+pqjeOOokkbTDTvg3tvUlemmRrkjPvu4w6mSStc9PuAe8arl+9ZFkBD1/dcSRp45gqwFX1sLEHkaSNZqoAJ3nRSsur6q9XdxxJ2jimPQTxhCW3HwhcAHwaMMCSdJKmPQTxW0vvJ3kw8DejTCRJG8TJfhzlfwPnruYgkrTRTHsM+L1M3vUAkw/heRTwzrGGkqSNYNpjwH+85PYR4KtVdXCEeSRpw5jqEMTwoTyfY/KJaJuB7xzvOUmuSHI4yf4ly34vyb8luWW4XHiyg0vSWjftX8T4JeATwPOAXwJuTnK8j6N8K7BzheV/WlXnDZf3nciwkrSeTHsI4rXAE6rqMECSLcAHgXcd7QlVdWOSc+73hJK0Tk37LogH3BffwddP4LnLvSzJrcMhis1HWynJ7iT7kuxbXFw8yU1Ja8wDNpFkLi7btu/o/m6se9PuAV+b5P3A24f7FwMnc/jgjcAfMHlHxR8AfwL82korVtUeYA/AwsJCrbSOtO589wgXX35T9xQAXHnp+d0jrHvH+5twPwacXVWvTvKLwFOAAB8D3naiG6uqu5Z87TcB/3CiX0OS1ovjHUZ4A3APQFVdVVWvrKrfZrL3+4YT3ViSrUvuPgfYf7R1JWm9O94hiHOq6tblC6tq3/F+wZbk7cBTgYcmOQj8LvDUJOcxOQTxFeDSEx9ZktaH4wX4gcd47EHHemJVPX+FxW857kSStEEc7xDEJ5P8xvKFSS4BPjXOSJK0MRxvD/gVwHuSvID/C+4CcBqTY7iSpJN0zAAP71o4P8nTgMcMi/+xqj40+mSStM5N+3nANwA3jDyLJG0o056IIWmjGc7K03gMsKSVzclZeev5jLyT/TwHSdL9ZIAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaeCKG2mzbvoM7D97RPYbUxgCrzZ0H7/BMK21oHoKQpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKajBbgJFckOZxk/5JlZya5Lsntw/XmsbYvSfNuzD3gtwI7ly27DLi+qs4Frh/uS9KGNFqAq+pG4BvLFl8E7B1u7wWePdb2JWnezfoY8NlVdQhguD7raCsm2Z1kX5J9i4uLMxtQkmZlbn8JV1V7qmqhqha2bNnSPY4krbpZB/iuJFsBhuvDM96+JM2NWQf4GmDXcHsXcPWMty9Jc2PMt6G9HfgY8MgkB5NcArwOeGaS24FnDvclaUPaNNYXrqrnH+WhC8bapiStJXP7SzhJWu8MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDXZ1LHRJF8B7gHuBY5U1ULHHJLUqSXAg6dV1d2N25ekVh6CkKQmXQEu4ANJPpVk90orJNmdZF+SfYuLiye1kW3bd5Ck/bJt+477872StE51HYJ4clXdmeQs4Lokn6uqG5euUFV7gD0ACwsLdTIbufPgHVx8+U33f9r76cpLz+8eQdIcatkDrqo7h+vDwHuAJ3bMIUmdZh7gJD+Q5Iz7bgM/D+yf9RyS1K3jEMTZwHuS3Lf9v62qaxvmkKRWMw9wVX0ZeOystytJ88a3oUlSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktRkU/cAG8IDNpGkewoATjn1+7j3f77dPYYkDPBsfPcIF19+U/cUAFx56flzNYu0kXkIQpKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJi0BTrIzyeeTfDHJZR0zSFK3mQc4ySnAXwLPAh4NPD/Jo2c9hyR169gDfiLwxar6clV9B3gHcFHDHJLUKlU12w0mzwV2VtWvD/dfCPx0Vb1s2Xq7gd3D3UcCn7+fm34ocPf9/Bod1uLca3FmcO5ZWoszw8nPfXdV7Vy+sOOPcq7054G/5/8CVbUH2LNqG032VdXCan29WVmLc6/FmcG5Z2ktzgyrP3fHIYiDwPYl938UuLNhDklq1RHgTwLnJnlYktOAXwauaZhDklrN/BBEVR1J8jLg/cApwBVVddsMNr1qhzNmbC3OvRZnBueepbU4M6zy3DP/JZwkacIz4SSpiQGWpCbrNsBJzkxyXZLbh+vNK6yzPckNSQ4kuS3Jy5tmPeap2Zn4s+HxW5M8vmPO5aaY+wXDvLcmuSnJYzvmXG7aU+GTPCHJvcN711tNM3OSpya5ZfhZ/sisZ1zJFD8jD07y3iSfGeZ+ccecy2a6IsnhJPuP8vjqvR6ral1egD8CLhtuXwb84QrrbAUeP9w+A/gC8OgZz3kK8CXg4cBpwGeWzwBcCPwTk/dQPwm4eQ6+v9PMfT6webj9rLUy95L1PgS8D3juvM8MPAT4F2DHcP+stfC9Bl5z32sT2AJ8Azitee6fAx4P7D/K46v2ely3e8BMTm/eO9zeCzx7+QpVdaiqPj3cvgc4AGyb1YCDaU7Nvgj465r4OPCQJFtnPOdyx527qm6qqm8Odz/O5D3f3aY9Ff63gHcDh2c53FFMM/OvAFdV1dcAqmqtzF3AGUkCnM4kwEdmO+aygapuHOY4mlV7Pa7nAJ9dVYdgElrgrGOtnOQc4HHAzeOP9v9sA+5Ycv8g3/s/gWnWmbUTnekSJnsN3Y47d5JtwHOAv5rhXMcyzff6x4HNST6c5FNJXjSz6Y5umrn/AngUk5OxPgu8vKq+O5vxTtqqvR47TkVeNUk+CPzwCg+99gS/zulM9nZeUVXfWo3ZTmTzKyxb/t7AqU7fnrGpZ0ryNCYBfsqoE01nmrnfAPxOVd072TFrN83Mm4CfAi4AHgR8LMnHq+oLYw93DNPM/QvALcDTgUcA1yX554bX4YlYtdfjmg5wVT3jaI8luSvJ1qo6NPzzYMV/kiU5lUl831ZVV4006rFMc2r2PJ6+PdVMSX4SeDPwrKr6+oxmO5Zp5l4A3jHE96HAhUmOVNXfz2TC7zXtz8jdVfVfwH8luRF4LJPfa3SZZu4XA6+rycHVLyb5V+AngE/MZsSTsmqvx/V8COIaYNdwexdw9fIVhuNObwEOVNXrZzjbUtOcmn0N8KLht69PAv7jvsMrjY47d5IdwFXAC5v3xJY67txV9bCqOqeqzgHeBby0Mb4w3c/I1cDPJtmU5PuBn2byO41O08z9NSZ77SQ5m8knH355plOeuNV7PXb+tnHk32T+EHA9cPtwfeaw/EeA9w23n8Lknw63Mvln0C3AhQ2zXshkT+VLwGuHZS8BXjLcDpMPsf8Sk+NkC93f3ynnfjPwzSXf233dM08z97J130rzuyCmnRl4NZN3Quxncjht7r/Xw+vxA8PP9X7gV+dg5rcDh4D/YbK3e8lYr0dPRZakJuv5EIQkzTUDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1+V8M8r9g8JjwfQAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAayElEQVR4nO3df7Bc5X3f8fcnMsa0NgaKIEKCgaRyG2BiHBSihrTjYE8taFrhTuzKTQzj4ighEIObpgHnj6TT0YzbOia1qUkU7EGkibGaOEVxwQSDHdcTflh2MUKAbSUQLNAgyYEapzMKkr/9Y4+GRVpdrdA9+9y99/2a2dmz3z3n7lfA/XD0nOc8m6pCkjR539e6AUlaqAxgSWrEAJakRgxgSWrEAJakRl7VuoG+rFq1qj772c+2bkOSADKqOG/PgHfv3t26BUma0bwNYEma6wxgSWrEAJakRgxgSWrEAJakRgxgSWrEAJakRgxgSWrEAJakRnoP4CSLkvyfJJ/pXp+U5O4k3+yeTxza9/ok25J8PcnbhurnJ9nSvfeRJCNv65OkaTKJM+BrgMeGXl8H3FNVy4F7utckORtYA5wDrAI+lmRRd8xNwFpgefdYNYG+JalXvQZwkmXAPwNuHiqvBjZ02xuAS4fqt1XVnqp6AtgGXJBkCXB8Vd1Xg+9PunXoGEmaWn2fAf8W8O+B7w3VTq2qHQDd8yldfSnwraH9tne1pd32gfWDJFmbZHOSzbt27ZqVP4Ak9aW3AE7yU8DOqvrKuIeMqNUM9YOLVeurakVVrVi8ePGYHytJbfS5HvCFwL9IcgnwGuD4JP8deDbJkqra0Q0v7Oz23w6cPnT8MuCZrr5sRF3SAvPeq9/P07uff1lt6ckncPONN7Rp6Cj1FsBVdT1wPUCSNwP/rqp+Nsl/AS4HPtg9394dsgn4gyQfBk5jcLHtwaral+SFJCuBB4DLgI/21bekuevp3c/z+ovWvrx27/pG3Ry9Ft+I8UFgY5IrgKeAdwBU1dYkG4FHgb3AVVW1rzvmSuAW4Djgzu4hSVNtIgFcVV8AvtBtfxt4yyH2WwesG1HfDJzbX4eSNHneCSdJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktTIq1o3IGn+ee/V7+fp3c8fVF968gncfOMNk29ojjKAJc26p3c/z+svWntw/d71DbqZuxyCkKRGDGBJasQAlqRGegvgJK9J8mCSryXZmuQ/dPXfSPJ0koe6xyVDx1yfZFuSryd521D9/CRbuvc+kiR99S1Jk9LnRbg9wEVV9d0kxwBfSnJn994NVfWh4Z2TnA2sAc4BTgM+l+QNVbUPuAlYC9wP3AGsAu5EkqZYb2fANfDd7uUx3aNmOGQ1cFtV7amqJ4BtwAVJlgDHV9V9VVXArcClffUtSZPS6xhwkkVJHgJ2AndX1QPdW1cneTjJJ5Kc2NWWAt8aOnx7V1vabR9YH/V5a5NsTrJ5165ds/lHkaRZ12sAV9W+qjoPWMbgbPZcBsMJPwicB+wAfrPbfdS4bs1QH/V566tqRVWtWLx48VF2L0n9msgsiKp6HvgCsKqqnu2C+XvA7wIXdLttB04fOmwZ8ExXXzaiLklTrc9ZEIuTnNBtHwe8FXi8G9Pd7+3AI932JmBNkmOTnAUsBx6sqh3AC0lWdrMfLgNu76tvSZqUPmdBLAE2JFnEIOg3VtVnkvxekvMYDCM8Cfw8QFVtTbIReBTYC1zVzYAAuBK4BTiOwewHZ0BImnq9BXBVPQy8aUT93TMcsw5YN6K+GTh3VhuUpMa8E06SGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJamRV/X1g5O8BvgicGz3OX9YVb+e5CTgU8CZwJPAO6vque6Y64ErgH3A+6rqrq5+PnALcBxwB3BNVVVfvUuaHlsf2cLFa95zUH3pySdw8403jPUz3nv1+3l69/Ov+PhXqrcABvYAF1XVd5McA3wpyZ3AvwTuqaoPJrkOuA741SRnA2uAc4DTgM8leUNV7QNuAtYC9zMI4FXAnT32LmlKvFiLeP1Faw+qP33v+rF/xtO7nz/oZxzJ8a9Ub0MQNfDd7uUx3aOA1cCGrr4BuLTbXg3cVlV7quoJYBtwQZIlwPFVdV931nvr0DGSNLV6HQNOsijJQ8BO4O6qegA4tap2AHTPp3S7LwW+NXT49q62tNs+sD7q89Ym2Zxk865du2b1zyJJs63XAK6qfVV1HrCMwdnsuTPsnlE/Yob6qM9bX1UrqmrF4sWLj7hfSZqkicyCqKrngS8wGLt9thtWoHve2e22HTh96LBlwDNdfdmIuiRNtd4COMniJCd028cBbwUeBzYBl3e7XQ7c3m1vAtYkOTbJWcBy4MFumOKFJCuTBLhs6BhJmlp9zoJYAmxIsohB0G+sqs8kuQ/YmOQK4CngHQBVtTXJRuBRYC9wVTcDAuBKXpqGdifOgJA0D/QWwFX1MPCmEfVvA285xDHrgHUj6puBmcaPJWnqeCecJDViAEtSIwawJDViAEtSIwawJDViAEtSIwawJDViAEtSIwawJDViAEtSIwawJDViAEtSIwawJDViAEtSIwawJDViAEtSIwawJDViAEtSIwawJDViAEtSIwawJDXS59fSS/Pee69+P0/vfv5ltaUnn8DNN97QpiFNFQNYOgpP736e11+09uW1e9c36kbTxiEISWrEAJakRgxgSWrEMWBpHvLi4HQwgKV5yIuD08EhCElqpLcATnJ6ks8neSzJ1iTXdPXfSPJ0koe6xyVDx1yfZFuSryd521D9/CRbuvc+kiR99S1Jk9LnEMRe4Jer6qtJXgd8Jcnd3Xs3VNWHhndOcjawBjgHOA34XJI3VNU+4CZgLXA/cAewCrizx94lqXe9BXBV7QB2dNsvJHkMWDrDIauB26pqD/BEkm3ABUmeBI6vqvsAktwKXIoBrHlq1AU0mOxFNC/iTcZELsIlORN4E/AAcCFwdZLLgM0MzpKfYxDO9w8dtr2rvdhtH1gf9TlrGZwpc8YZZ8zuH0KakFEX0GCyF9G8iDcZvV+ES/Ja4I+Aa6vqOwyGE34QOI/BGfJv7t91xOE1Q/3gYtX6qlpRVSsWL158tK1LUq96DeAkxzAI39+vqk8DVNWzVbWvqr4H/C5wQbf7duD0ocOXAc909WUj6pI01fqcBRHg48BjVfXhofqSod3eDjzSbW8C1iQ5NslZwHLgwW4s+YUkK7ufeRlwe199S9Kk9DkGfCHwbmBLkoe62geAdyU5j8EwwpPAzwNU1dYkG4FHGcyguKqbAQFwJXALcByDi29egJM09fqcBfElRo/f3jHDMeuAdSPqm4FzZ687SWrPO+EkqRHXgpA05xxqLvRjj3+DlRdNvp++GMCSxrL1kS1cvOY9B9X7uEHjUHOh92y5dlY/pzUDWNJYXqxFzW8QmW8cA5akRgxgSWrEAJakRgxgSWrEAJakRgxgSWrEaWiSJmbUXOKFvNC7ASxpYkbNJV7I84gNYEkLxly7xdkAlrRgzLVbnMe6CJfkwnFqkqTxjTsL4qNj1iRJY5pxCCLJPwJ+HFic5N8OvXU8sKjPxiRpvjvcGPCrgdd2+71uqP4d4Kf7akrS5Iy6MDXf1t2dq2YM4Kr6M+DPktxSVX81oZ4kTdCoC1Pzbd3duWrcWRDHJlkPnDl8TFX5/0hJeoXGDeD/Afw2cDOw7zD7SpLGMG4A762qm3rtRJIWmHGnof1Jkl9MsiTJSfsfvXYmSfPcuGfAl3fPvzJUK+AHZrcdSVo4xgrgqjqr70akPh1qDYCFvBKX2hsrgJNcNqpeVbfObjtSPw61BsBCXolL7Y07BPGjQ9uvAd4CfBUwgCXpFRp3COKXhl8neT3we710JEkLxCv9SqL/ByyfzUYkaaEZdwz4TxjMeoDBIjw/BGzsqylpPhp1IdCLgAvbuGPAHxra3gv8VVVtn+mAJKczGCP+fuB7wPqq+q/d/OFPMbit+UngnVX1XHfM9cAVDO62e19V3dXVzwduAY4D7gCuqapCmiKjLgR6EXBhG2sIoluU53EGK6KdCPztGIftBX65qn4IWAlcleRs4DrgnqpaDtzTvaZ7bw1wDrAK+FiS/Ute3gSsZTDssbx7X5Km2rjfiPFO4EHgHcA7gQeSzLgcZVXtqKqvdtsvAI8BS4HVwIZutw3Apd32auC2qtpTVU8A24ALkiwBjq+q+7qz3luHjpGkqTXuEMSvAT9aVTsBkiwGPgf84TgHJzkTeBPwAHBqVe2AQUgnOaXbbSlw/9Bh27vai932gfVRn7OWwZkyZ5xxxjitSVIz486C+L794dv59rjHJnkt8EfAtVX1nZl2HVGrGeoHF6vWV9WKqlqxePHicdqTpGbGPQP+bJK7gE92r/8Vg4thM0pyDIPw/f2q+nRXfjbJku7sdwmwP9i3A6cPHb4MeKarLxtRl6SpNuNZbJK/n+TCqvoV4HeAHwbeCNwHzHj5NkmAjwOPVdWHh97axEuL+1wO3D5UX5Pk2CRnMbjY9mA3XPFCkpXdz7xs6BhJmlqHOwP+LeADAN0Z7KcBkqzo3vvnMxx7IfBuYEuSh7raB4APAhuTXAE8xeDCHlW1NclG4FEGMyiuqqr9i79fyUvT0O7sHpI01Q4XwGdW1cMHFqtqc3dh7ZCq6kuMHr+FwVoSo45ZB6wb9XnAuYfpVZKmyuEupL1mhveOm81GJGmhOVwAfznJzx1Y7IYPvtJPS5K0MBxuCOJa4I+T/AwvBe4K4NXA23vsS5LmvRkDuKqeBX48yU/y0hjs/6qqe3vvTJLmuXHXA/488Pmee5GkBeWVrgcsSTpKBrAkNWIAS1Ij464FIWkOGvUtGwCPPf4NVl40+X50ZAxgaYqN+pYNgD1brp18MzpiDkFIUiMGsCQ1YgBLUiMGsCQ1YgBLUiMGsCQ1YgBLUiMGsCQ1YgBLUiMGsCQ14q3I0hhcc0F9MIClMbjmgvrgEIQkNWIAS1IjBrAkNWIAS1IjBrAkNWIAS1IjBrAkNdLbPOAknwB+CthZVed2td8Afg7Y1e32gaq6o3vveuAKYB/wvqq6q6ufD9wCHAfcAVxTVdVX31pYtj6yhYvXvOdltaUnn8DNN97QqCMtJH3eiHELcCNw6wH1G6rqQ8OFJGcDa4BzgNOAzyV5Q1XtA24C1gL3MwjgVcCdPfatBeTFWnTQDRZP37u+UTdaaHobgqiqLwJ/Pebuq4HbqmpPVT0BbAMuSLIEOL6q7uvOem8FLu2lYUmasBZjwFcneTjJJ5Kc2NWWAt8a2md7V1vabR9YHynJ2iSbk2zetWvXoXaTpDlh0mtB3AT8R6C6598E/g2QEfvWDPWRqmo9sB5gxYoVjhMvUKMWznHRHM1FEw3gqnp2/3aS3wU+073cDpw+tOsy4JmuvmxEXTqkUQvnuGiO5qKJDkF0Y7r7vR14pNveBKxJcmySs4DlwINVtQN4IcnKJAEuA26fZM+S1Jc+p6F9EngzcHKS7cCvA29Och6DYYQngZ8HqKqtSTYCjwJ7gau6GRAAV/LSNLQ7cQaEpHmitwCuqneNKH98hv3XAetG1DcD585ia5I0J3gnnCQ1YgBLUiMGsCQ1YgBLUiMGsCQ14rcia0451Ne/u0KZ5iMDWHPKob7+3RXKNB8ZwFJDo9Yjdt2KhcMAlhoatR6x61YsHF6Ek6RGDGBJasQAlqRGDGBJasQAlqRGDGBJasQAlqRGDGBJasQAlqRGvBNOs8rFdKTxGcCaVS6mI43PIQhJasQAlqRGDGBJasQAlqRGvAg3z4yaheAMBGluMoDnmVGzEJyBIM1NDkFIUiMGsCQ1YgBLUiOOAesVG3XBz2/0lcbXWwAn+QTwU8DOqjq3q50EfAo4E3gSeGdVPde9dz1wBbAPeF9V3dXVzwduAY4D7gCuqarqq2+Nb9QFP7/RVxpfn0MQtwCrDqhdB9xTVcuBe7rXJDkbWAOc0x3zsSSLumNuAtYCy7vHgT9TkqZSbwFcVV8E/vqA8mpgQ7e9Abh0qH5bVe2pqieAbcAFSZYAx1fVfd1Z761Dx0jSVJv0RbhTq2oHQPd8SldfCnxraL/tXW1pt31gfaQka5NsTrJ5165ds9q4JM22uTILIiNqNUN9pKpaX1UrqmrF4sWLZ605SerDpAP42W5Yge55Z1ffDpw+tN8y4JmuvmxEXZKm3qSnoW0CLgc+2D3fPlT/gyQfBk5jcLHtwaral+SFJCuBB4DLgI9OuGdJC9DWR7Zw8Zr3vKw22+uq9DkN7ZPAm4GTk2wHfp1B8G5McgXwFPAOgKrammQj8CiwF7iqqvZ1P+pKXpqGdmf3kKRevViLel9XpbcArqp3HeKttxxi/3XAuhH1zcC5s9iaGpjE2YQ0bbwTThMxibMJadrMlVkQkrTgeAa8gI1ay+HJb36DM5e/4WU1hwqkfhjAC9iotRye33KtQwXShDgEIUmNeAasqeAsCs1HBrCmgrMoNB85BCFJjRjAktSIQxCNjJoC5pimtLAYwI2MmgLmmKa0sDgEIUmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1Ii3Ik+pUWtJADz2+DdYedHk+5F05AzgKTVqLQmAPVuunXwzkl4RhyAkqREDWJIaMYAlqREDWJIaMYAlqRFnQcwho756HfyqImm+ahLASZ4EXgD2AXurakWSk4BPAWcCTwLvrKrnuv2vB67o9n9fVd3VoO3ejfrqdfCriqT5quUQxE9W1XlVtaJ7fR1wT1UtB+7pXpPkbGANcA6wCvhYkkUtGpak2TSXhiBWA2/utjcAXwB+tavfVlV7gCeSbAMuAO5r0OMrMuquNe9Yk9QqgAv40yQF/E5VrQdOraodAFW1I8kp3b5LgfuHjt3e1Q6SZC2wFuCMM87oq/cjNuquNe9Yk9QqgC+sqme6kL07yeMz7JsRtRq1Yxfk6wFWrFgxch9JmiuajAFX1TPd807gjxkMKTybZAlA97yz2307cPrQ4cuAZybXrST1Y+IBnOTvJnnd/m3gnwKPAJuAy7vdLgdu77Y3AWuSHJvkLGA58OBku5ak2ddiCOJU4I+T7P/8P6iqzyb5MrAxyRXAU8A7AKpqa5KNwKPAXuCqqtrXoG9JmlUTD+Cq+kvgjSPq3wbecohj1gHrem5NkibKW5ElqREDWJIaMYAlqREDWJIaMYAlqREDWJIamUuL8UyVUQvsuG6vpCNhAL9CoxbYcd1eSUfCIQhJasQz4MMYNdQArucr6egZwIcxaqgBXM9X0tFzCEKSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRF2QfMurbL/zmC0l9MYCHjPr2C7/5QlJfHIKQpEamJoCTrEry9STbklzXuh9JOlpTEcBJFgH/DbgYOBt4V5Kz23YlSUdnKgIYuADYVlV/WVV/C9wGrG7ckyQdlVRV6x4OK8lPA6uq6r3d63cDP1ZVVx+w31pg/1W0fwB8vce2TgZ29/jzZ4t9zr5p6XVa+oTp6fWV9rm7qlYdWJyWWRAZUTvo/xxVtR5Y3387kGRzVa2YxGcdDfucfdPS67T0CdPT62z3OS1DENuB04deLwOeadSLJM2KaQngLwPLk5yV5NXAGmBT454k6ahMxRBEVe1NcjVwF7AI+ERVbW3c1kSGOmaBfc6+ael1WvqE6el1VvuciotwkjQfTcsQhCTNOwawJDViAI8pyUlJ7k7yze75xBH7nJ7k80keS7I1yTUT7G/GW7Uz8JHu/YeT/MikejvCPn+m6+/hJH+e5I0t+ux6Gev29yQ/mmRfN1994sbpM8mbkzzU/Xf5Z5PusevhcP/uX5/kT5J8revzPY36/ESSnUkeOcT7s/e7VFU+xngA/xm4rtu+DvhPI/ZZAvxIt/064BvA2RPobRHwF8APAK8Gvnbg5wKXAHcymFO9EnigwT/Dcfr8ceDEbvviFn2O2+vQfvcCdwA/PRf7BE4AHgXO6F6fMkf7/MD+3ytgMfDXwKsb9PpPgB8BHjnE+7P2u+QZ8PhWAxu67Q3ApQfuUFU7quqr3fYLwGPA0gn0Ns6t2quBW2vgfuCEJEsm0NsR9VlVf15Vz3Uv72cw57uFcW9//yXgj4Cdk2xuyDh9/mvg01X1FEBVteh1nD4LeF2SAK9lEMB7J9smVNUXu88+lFn7XTKAx3dqVe2AQdACp8y0c5IzgTcBD/TfGkuBbw293s7BwT/OPn070h6uYHCm0cJhe02yFHg78NsT7OtA4/wzfQNwYpIvJPlKkssm1t1LxunzRuCHGNxktQW4pqq+N5n2jsis/S5NxTzgSUnyOeD7R7z1a0f4c17L4Kzo2qr6zmz0driPHFE7cH7hWLdz92zsHpL8JIMA/oleOzq0cXr9LeBXq2rf4KStiXH6fBVwPvAW4DjgviT3V9U3+m5uyDh9vg14CLgI+EHg7iT/e0K/Q0di1n6XDOAhVfXWQ72X5NkkS6pqR/fXjZF/jUtyDIPw/f2q+nRPrR5onFu158Lt3GP1kOSHgZuBi6vq2xPq7UDj9LoCuK0L35OBS5Lsrar/OZEOB8b9d7+7qv4G+JskXwTeyOAaxaSM0+d7gA/WYKB1W5IngH8IPDiZFsc2a79LDkGMbxNwebd9OXD7gTt0Y1cfBx6rqg9PsLdxbtXeBFzWXcFdCfzf/UMqc6nPJGcAnwbePeEztAMdtteqOquqzqyqM4E/BH5xwuE7Vp8M/lv9x0leleTvAD/G4PrEXOvzKQZn6SQ5lcGKhn850S7HM3u/S5O+wjitD+DvAfcA3+yeT+rqpwF3dNs/weCvIg8z+KvUQ8AlE+rvEgZnNH8B/FpX+wXgF7rtMFjU/i8YjK+taPTP8XB93gw8N/TPb3PDf+cz9nrAvrfQYBbEuH0Cv8JgJsQjDIbG5lyf3e/Sn3b/fT4C/GyjPj8J7ABeZHC2e0Vfv0veiixJjTgEIUmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmN/H8Z8DWgBMVBpwAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "sns.displot(train_corrs[0][selected_idxs]),sns.displot(train_corrs[0])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "epoch_num = train_corrs.shape[0]\n", "train_num = train_corrs.shape[1]\n", "test_num = test_corrs.shape[1]\n", "\n", "train_data = np.zeros((epoch_num*train_num, 2))\n", "for i in range(len(train_corrs)):\n", " train_data[i*train_num:(i+1)*train_num][:,0] = train_corrs[i]\n", " train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n", "test_data = np.zeros((epoch_num*test_num, 2))\n", "for i in range(len(test_corrs)):\n", " test_data[i*test_num:(i+1)*test_num][:,0] = test_corrs[i]\n", " test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n", "data = np.concatenate((train_data, test_data), axis=0)\n", "type = [\"Train\" for _ in range(len(train_data))] + [\"Test\" for _ in range(len(test_data))]\n", "method = [\"TimeVis\" for _ in range(len(data))]\n", "\n", "dvi_train_data = np.zeros((epoch_num*train_num, 2))\n", "for i in range(len(dvi_train_corrs)):\n", " dvi_train_data[i*train_num:(i+1)*train_num][:,0] = dvi_train_corrs[i]\n", " dvi_train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n", "dvi_test_data = np.zeros((epoch_num*test_num, 2))\n", "for i in range(len(dvi_test_corrs)):\n", " dvi_test_data[i*test_num:(i+1)*test_num][:,0] = dvi_test_corrs[i]\n", " dvi_test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n", "dvi_data = np.concatenate((dvi_train_data, dvi_test_data), axis=0)\n", "dvi_type = [\"Train\" for _ in range(len(dvi_train_data))]+[\"Test\" for _ in range(len(dvi_test_data))]\n", "dvi_method = [\"DVI\" for _ in range(len(dvi_data))]\n", "\n", "data = np.concatenate((data, dvi_data), axis=0)\n", "type = type + dvi_type\n", "method = method + dvi_method\n", "\n", "df = pd.DataFrame(data,columns=[\"corr\", \"epoch\"])\n", "df2 = df.assign(type = type)\n", "df3 = df2.assign(method = method)\n", "df3[[\"epoch\"]] = df[[\"epoch\"]].astype(int)\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.rcParams['figure.dpi'] = 100 # 图形分辨率\n", "sns.set_theme(style='darkgrid')\n", "plt.style.use('ggplot')\n", "plt.title(\"MNIST\")\n", "fg = sns.lineplot(x=\"epoch\", y=\"corr\", hue=\"method\", style=\"type\", markers=False, ci=95, data=df3)\n", "plt.savefig(\n", " \"./plot_results/corr_3_{}.{}\".format(\"mnist\", output_type),\n", " dpi=300,\n", " bbox_inches=\"tight\",\n", " pad_inches=0.0,\n", ")\n", "plt.clf()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "epoch_num = train_tnn.shape[0]\n", "train_num = train_tnn.shape[1]\n", "test_num = test_tnn.shape[1]\n", "\n", "train_data = np.zeros((epoch_num*train_num, 2))\n", "for i in range(len(train_tnn)):\n", " train_data[i*train_num:(i+1)*train_num][:,0] = train_tnn[i]\n", " train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n", "test_data = np.zeros((epoch_num*test_num, 2))\n", "for i in range(len(test_tnn)):\n", " test_data[i*test_num:(i+1)*test_num][:,0] = test_tnn[i]\n", " test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n", "data = np.concatenate((train_data, test_data), axis=0)\n", "type = [\"Train\" for _ in range(len(train_data))] + [\"Test\" for _ in range(len(test_data))]\n", "method = [\"TimeVis\" for _ in range(len(data))]\n", "\n", "dvi_train_data = np.zeros((epoch_num*train_num, 2))\n", "for i in range(len(dvi_train_tnn)):\n", " dvi_train_data[i*train_num:(i+1)*train_num][:,0] = dvi_train_tnn[i]\n", " dvi_train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n", "dvi_test_data = np.zeros((epoch_num*test_num, 2))\n", "for i in range(len(dvi_test_tnn)):\n", " dvi_test_data[i*test_num:(i+1)*test_num][:,0] = dvi_test_tnn[i]\n", " dvi_test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n", "dvi_data = np.concatenate((dvi_train_data, dvi_test_data), axis=0)\n", "dvi_type = [\"Train\" for _ in range(len(dvi_train_data))]+[\"Test\" for _ in range(len(dvi_test_data))]\n", "dvi_method = [\"DVI\" for _ in range(len(dvi_data))]\n", "\n", "data = np.concatenate((data, dvi_data), axis=0)\n", "type = type + dvi_type\n", "method = method + dvi_method\n", "\n", "df = pd.DataFrame(data,columns=[\"tnn\", \"epoch\"])\n", "df2 = df.assign(type = type)\n", "df3 = df2.assign(method = method)\n", "df3[[\"epoch\"]] = df[[\"epoch\"]].astype(int)\n", "plt.rcParams['figure.dpi'] = 100 # 图形分辨率\n", "sns.set_theme(style='darkgrid')\n", "plt.style.use('ggplot')\n", "plt.title(\"MNIST\")\n", "fg = sns.lineplot(x=\"epoch\", y=\"tnn\", hue=\"method\", style=\"type\", markers=False, ci=95, data=df3)\n", "plt.savefig(\n", " \"./plot_results/tnn_{}.{}\".format(\"mnist\", output_type),\n", " dpi=300,\n", " bbox_inches=\"tight\",\n", " pad_inches=0.0,\n", ")\n", "plt.clf()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "DATASET = \"fmnist\"\n", "CONTENT_PATH = \"/home/xianglin/projects/DVI_data/resnet18_{}\".format(DATASET)\n", "content_path = CONTENT_PATH" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "train_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_corrs.npy\".format(timevis)))\n", "train_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_ps.npy\".format(timevis)))\n", "train_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_5_tnn.npy\".format(timevis)))\n", "test_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_corrs.npy\".format(timevis)))\n", "test_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_ps.npy\".format(timevis)))\n", "test_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_5_tnn.npy\".format(timevis)))\n", "\n", "\n", "dvi_train_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_corrs.npy\".format(dvi)))\n", "dvi_train_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_ps.npy\".format(dvi)))\n", "dvi_train_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_5_tnn.npy\".format(dvi)))\n", "dvi_test_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_corrs.npy\".format(dvi)))\n", "dvi_test_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_ps.npy\".format(dvi)))\n", "dvi_test_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_5_tnn.npy\".format(dvi)))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "epoch_num = train_corrs.shape[0]\n", "train_num = train_corrs.shape[1]\n", "test_num = test_corrs.shape[1]\n", "\n", "train_data = np.zeros((epoch_num*train_num, 2))\n", "for i in range(len(train_corrs)):\n", " train_data[i*train_num:(i+1)*train_num][:,0] = train_corrs[i]\n", " train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n", "test_data = np.zeros((epoch_num*test_num, 2))\n", "for i in range(len(test_corrs)):\n", " test_data[i*test_num:(i+1)*test_num][:,0] = test_corrs[i]\n", " test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n", "data = np.concatenate((train_data, test_data), axis=0)\n", "type = [\"Train\" for _ in range(len(train_data))] + [\"Test\" for _ in range(len(test_data))]\n", "method = [\"TimeVis\" for _ in range(len(data))]\n", "\n", "dvi_train_data = np.zeros((epoch_num*train_num, 2))\n", "for i in range(len(dvi_train_corrs)):\n", " dvi_train_data[i*train_num:(i+1)*train_num][:,0] = dvi_train_corrs[i]\n", " dvi_train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n", "dvi_test_data = np.zeros((epoch_num*test_num, 2))\n", "for i in range(len(dvi_test_corrs)):\n", " dvi_test_data[i*test_num:(i+1)*test_num][:,0] = dvi_test_corrs[i]\n", " dvi_test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n", "dvi_data = np.concatenate((dvi_train_data, dvi_test_data), axis=0)\n", "dvi_type = [\"Train\" for _ in range(len(dvi_train_data))]+[\"Test\" for _ in range(len(dvi_test_data))]\n", "dvi_method = [\"DVI\" for _ in range(len(dvi_data))]\n", "\n", "data = np.concatenate((data, dvi_data), axis=0)\n", "type = type + dvi_type\n", "method = method + dvi_method\n", "\n", "df = pd.DataFrame(data,columns=[\"corr\", \"epoch\"])\n", "df2 = df.assign(type = type)\n", "df3 = df2.assign(method = method)\n", "df3[[\"epoch\"]] = df[[\"epoch\"]].astype(int)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.rcParams['figure.dpi'] = 100 # 图形分辨率\n", "sns.set_theme(style='darkgrid')\n", "plt.style.use('ggplot')\n", "plt.title(\"FMNIST\")\n", "sns.lineplot(x=\"epoch\", y=\"corr\", hue=\"method\", style=\"type\", markers=False, ci=95, data=df3)\n", "\n", "plt.savefig(\n", " \"./plot_results/corr_3_{}.{}\".format(\"fmnist\", output_type),\n", " dpi=300,\n", " bbox_inches=\"tight\",\n", " pad_inches=0.0,\n", ")\n", "plt.clf()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "epoch_num = train_tnn.shape[0]\n", "train_num = train_tnn.shape[1]\n", "test_num = test_tnn.shape[1]\n", "\n", "train_data = np.zeros((epoch_num*train_num, 2))\n", "for i in range(len(train_tnn)):\n", " train_data[i*train_num:(i+1)*train_num][:,0] = train_tnn[i]\n", " train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n", "test_data = np.zeros((epoch_num*test_num, 2))\n", "for i in range(len(test_tnn)):\n", " test_data[i*test_num:(i+1)*test_num][:,0] = test_tnn[i]\n", " test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n", "data = np.concatenate((train_data, test_data), axis=0)\n", "type = [\"Train\" for _ in range(len(train_data))] + [\"Test\" for _ in range(len(test_data))]\n", "method = [\"TimeVis\" for _ in range(len(data))]\n", "\n", "dvi_train_data = np.zeros((epoch_num*train_num, 2))\n", "for i in range(len(dvi_train_tnn)):\n", " dvi_train_data[i*train_num:(i+1)*train_num][:,0] = dvi_train_tnn[i]\n", " dvi_train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n", "dvi_test_data = np.zeros((epoch_num*test_num, 2))\n", "for i in range(len(dvi_test_tnn)):\n", " dvi_test_data[i*test_num:(i+1)*test_num][:,0] = dvi_test_tnn[i]\n", " dvi_test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n", "dvi_data = np.concatenate((dvi_train_data, dvi_test_data), axis=0)\n", "dvi_type = [\"Train\" for _ in range(len(dvi_train_data))]+[\"Test\" for _ in range(len(dvi_test_data))]\n", "dvi_method = [\"DVI\" for _ in range(len(dvi_data))]\n", "\n", "data = np.concatenate((data, dvi_data), axis=0)\n", "type = type + dvi_type\n", "method = method + dvi_method\n", "\n", "df = pd.DataFrame(data,columns=[\"tnn\", \"epoch\"])\n", "df2 = df.assign(type = type)\n", "df3 = df2.assign(method = method)\n", "df3[[\"epoch\"]] = df[[\"epoch\"]].astype(int)\n", "plt.rcParams['figure.dpi'] = 100 # 图形分辨率\n", "sns.set_theme(style='darkgrid')\n", "plt.style.use('ggplot')\n", "plt.title(\"FMNIST\")\n", "sns.lineplot(x=\"epoch\", y=\"tnn\", hue=\"method\", style=\"type\", markers=False, ci=95, data=df3)\n", "\n", "plt.savefig(\n", " \"./plot_results/tnn_{}.{}\".format(\"fmnist\", output_type),\n", " dpi=300,\n", " bbox_inches=\"tight\",\n", " pad_inches=0.0,\n", ")\n", "plt.clf()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "DATASET = \"cifar10\"\n", "CONTENT_PATH = \"/home/xianglin/projects/DVI_data/resnet18_{}\".format(DATASET)\n", "content_path = CONTENT_PATH" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "train_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_corrs.npy\".format(timevis)))\n", "train_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_ps.npy\".format(timevis)))\n", "train_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_5_tnn.npy\".format(timevis)))\n", "test_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_corrs.npy\".format(timevis)))\n", "test_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_ps.npy\".format(timevis)))\n", "test_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_5_tnn.npy\".format(timevis)))\n", "\n", "\n", "dvi_train_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_corrs.npy\".format(dvi)))\n", "dvi_train_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_ps.npy\".format(dvi)))\n", "dvi_train_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_5_tnn.npy\".format(dvi)))\n", "dvi_test_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_corrs.npy\".format(dvi)))\n", "dvi_test_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_ps.npy\".format(dvi)))\n", "dvi_test_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_5_tnn.npy\".format(dvi)))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "epoch_num = train_corrs.shape[0]\n", "train_num = train_corrs.shape[1]\n", "test_num = test_corrs.shape[1]\n", "\n", "train_data = np.zeros((epoch_num*train_num, 2))\n", "for i in range(len(train_corrs)):\n", " train_data[i*train_num:(i+1)*train_num][:,0] = train_corrs[i]\n", " train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n", "test_data = np.zeros((epoch_num*test_num, 2))\n", "for i in range(len(test_corrs)):\n", " test_data[i*test_num:(i+1)*test_num][:,0] = test_corrs[i]\n", " test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n", "data = np.concatenate((train_data, test_data), axis=0)\n", "type = [\"Train\" for _ in range(len(train_data))] + [\"Test\" for _ in range(len(test_data))]\n", "method = [\"TimeVis\" for _ in range(len(data))]\n", "\n", "dvi_train_data = np.zeros((epoch_num*train_num, 2))\n", "for i in range(len(dvi_train_corrs)):\n", " dvi_train_data[i*train_num:(i+1)*train_num][:,0] = dvi_train_corrs[i]\n", " dvi_train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n", "dvi_test_data = np.zeros((epoch_num*test_num, 2))\n", "for i in range(len(dvi_test_corrs)):\n", " dvi_test_data[i*test_num:(i+1)*test_num][:,0] = dvi_test_corrs[i]\n", " dvi_test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n", "dvi_data = np.concatenate((dvi_train_data, dvi_test_data), axis=0)\n", "dvi_type = [\"Train\" for _ in range(len(dvi_train_data))]+[\"Test\" for _ in range(len(dvi_test_data))]\n", "dvi_method = [\"DVI\" for _ in range(len(dvi_data))]\n", "\n", "data = np.concatenate((data, dvi_data), axis=0)\n", "type = type + dvi_type\n", "method = method + dvi_method\n", "\n", "df = pd.DataFrame(data,columns=[\"corr\", \"epoch\"])\n", "df2 = df.assign(type = type)\n", "df3 = df2.assign(method = method)\n", "df3[[\"epoch\"]] = df[[\"epoch\"]].astype(int)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.rcParams['figure.dpi'] = 100\n", "plt.style.use('ggplot')\n", "plt.title(\"CIFAR10\")\n", "sns.lineplot(x=\"epoch\", y=\"corr\", hue=\"method\", style=\"type\", markers=False, ci=95, data=df3)\n", "plt.savefig(\n", " \"./plot_results/corr_3_{}.{}\".format(\"cifar10\", output_type),\n", " dpi=300,\n", " bbox_inches=\"tight\",\n", " pad_inches=0.0,\n", ")\n", "plt.clf()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "epoch_num = train_tnn.shape[0]\n", "train_num = train_tnn.shape[1]\n", "test_num = test_tnn.shape[1]\n", "\n", "train_data = np.zeros((epoch_num*train_num, 2))\n", "for i in range(len(train_tnn)):\n", " train_data[i*train_num:(i+1)*train_num][:,0] = train_tnn[i]\n", " train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n", "test_data = np.zeros((epoch_num*test_num, 2))\n", "for i in range(len(test_tnn)):\n", " test_data[i*test_num:(i+1)*test_num][:,0] = test_tnn[i]\n", " test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n", "data = np.concatenate((train_data, test_data), axis=0)\n", "type = [\"Train\" for _ in range(len(train_data))] + [\"Test\" for _ in range(len(test_data))]\n", "method = [\"TimeVis\" for _ in range(len(data))]\n", "\n", "dvi_train_data = np.zeros((epoch_num*train_num, 2))\n", "for i in range(len(dvi_train_tnn)):\n", " dvi_train_data[i*train_num:(i+1)*train_num][:,0] = dvi_train_tnn[i]\n", " dvi_train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n", "dvi_test_data = np.zeros((epoch_num*test_num, 2))\n", "for i in range(len(dvi_test_tnn)):\n", " dvi_test_data[i*test_num:(i+1)*test_num][:,0] = dvi_test_tnn[i]\n", " dvi_test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n", "dvi_data = np.concatenate((dvi_train_data, dvi_test_data), axis=0)\n", "dvi_type = [\"Train\" for _ in range(len(dvi_train_data))]+[\"Test\" for _ in range(len(dvi_test_data))]\n", "dvi_method = [\"DVI\" for _ in range(len(dvi_data))]\n", "\n", "data = np.concatenate((data, dvi_data), axis=0)\n", "type = type + dvi_type\n", "method = method + dvi_method\n", "\n", "df = pd.DataFrame(data,columns=[\"tnn\", \"epoch\"])\n", "df2 = df.assign(type = type)\n", "df3 = df2.assign(method = method)\n", "df3[[\"epoch\"]] = df[[\"epoch\"]].astype(int)\n", "plt.rcParams['figure.dpi'] = 100\n", "plt.style.use('ggplot')\n", "plt.title(\"CIFAR10\")\n", "sns.lineplot(x=\"epoch\", y=\"tnn\", hue=\"method\", style=\"type\", markers=False, ci=95, data=df3)\n", "plt.savefig(\n", " \"./plot_results/tnn_{}.{}\".format(\"cifar10\", output_type),\n", " dpi=300,\n", " bbox_inches=\"tight\",\n", " pad_inches=0.0,\n", ")\n", "plt.clf()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# simple\n", "def draw(corrs, ps, corrs2, ps2, title):\n", " fig, axs = plt.subplots(2)\n", " fig.suptitle(title)\n", "\n", " epochs = [i for i in range(1, len(corrs)+1, 1)]\n", " mean_corr1 = np.mean(corrs, axis=1)\n", " var_corr1 = np.std(corrs, axis=1)\n", " mean_p1 = np.mean(ps, axis=1)\n", " var_p1 = np.std(ps, axis=1)\n", "\n", " mean_corr2 = np.mean(corrs2, axis=1)\n", " var_corr2 = np.std(corrs2, axis=1)\n", " mean_p2 = np.mean(ps2, axis=1)\n", " var_p2 = np.std(ps2, axis=1)\n", "\n", " a11 = axs[0].plot(epochs, mean_corr1, \"b.-\", epochs, mean_p1, \"r+-\")\n", " a12 = axs[0].fill_between(epochs, mean_corr1-var_corr1, mean_corr1+var_corr1)\n", " a13 = axs[0].fill_between(epochs, mean_p1-var_p1, mean_p1+var_p1)\n", "\n", " a21 = axs[1].plot(epochs, mean_corr2, \"b.-\", epochs, mean_p2, \"r+-\")\n", " a22 = axs[1].fill_between(epochs, mean_corr2-var_corr2, mean_corr2+var_corr2)\n", " a23 = axs[1].fill_between(epochs, mean_p2-var_p2, mean_p2+var_p2)\n", "\n", "\n", " plt.show()\n", " plt.clf()" ] } ], "metadata": { "interpreter": { "hash": "aa7a9f36e1a1e240450dbe9cc8f6d8df1d5301f36681fb271c44fdd883236b60" }, "kernelspec": { "display_name": "Python 3.7.11 ('SV': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }