diff --git "a/aclanthology_visualization.ipynb" "b/aclanthology_visualization.ipynb"
deleted file mode 100644--- "a/aclanthology_visualization.ipynb"
+++ /dev/null
@@ -1,1570 +0,0 @@
-{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "provenance": [],
- "gpuType": "T4"
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "language_info": {
- "name": "python"
- },
- "accelerator": "GPU",
- "widgets": {
- "application/vnd.jupyter.widget-state+json": {
- "1619b254fcbb4cb880d1be5685c74dbc": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HBoxModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HBoxModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HBoxView",
- "box_style": "",
- "children": [
- "IPY_MODEL_607c048fb1634a7689e355036c144984",
- "IPY_MODEL_869501d4d38e46f184a66423d93a2745",
- "IPY_MODEL_c03dc1381a0c430182fe86d8a100b249"
- ],
- "layout": "IPY_MODEL_b6d31f4cebc84ef0a563d41482b14cc2"
- }
- },
- "607c048fb1634a7689e355036c144984": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HTMLModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_9ebec18dbdff4913a4902429a726b9e0",
- "placeholder": "",
- "style": "IPY_MODEL_c9dc6fbcf53a4c9fb53716a18db6ffbe",
- "value": "Map: 100%"
- }
- },
- "869501d4d38e46f184a66423d93a2745": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "FloatProgressModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "FloatProgressModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "ProgressView",
- "bar_style": "success",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_c9ef5bf8ff3e44358c4557f74c3e379e",
- "max": 1249,
- "min": 0,
- "orientation": "horizontal",
- "style": "IPY_MODEL_0cc5f439950e49eaa4d417396e21e2c4",
- "value": 1249
- }
- },
- "c03dc1381a0c430182fe86d8a100b249": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HTMLModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_1479cc60b4ac4864b46b592dc1050157",
- "placeholder": "",
- "style": "IPY_MODEL_a9e75caedfbf46e0bd0effe1e60065cd",
- "value": " 1249/1249 [00:02<00:00, 454.98 examples/s]"
- }
- },
- "b6d31f4cebc84ef0a563d41482b14cc2": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "9ebec18dbdff4913a4902429a726b9e0": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "c9dc6fbcf53a4c9fb53716a18db6ffbe": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- },
- "c9ef5bf8ff3e44358c4557f74c3e379e": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "0cc5f439950e49eaa4d417396e21e2c4": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "ProgressStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "ProgressStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "bar_color": null,
- "description_width": ""
- }
- },
- "1479cc60b4ac4864b46b592dc1050157": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "a9e75caedfbf46e0bd0effe1e60065cd": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- },
- "4c7b67b7151e4c9fb47eaae2f39a21b8": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HBoxModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HBoxModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HBoxView",
- "box_style": "",
- "children": [
- "IPY_MODEL_5bd552c8824e407c934978e35e7de980",
- "IPY_MODEL_d052c01440db4dafb5d699eb57a9d613",
- "IPY_MODEL_e1bbc114c9054a28a48762831a44ef11"
- ],
- "layout": "IPY_MODEL_591bf12de23c41c6aa510f6d6702b30e"
- }
- },
- "5bd552c8824e407c934978e35e7de980": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HTMLModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_5b379ade011143b9bf21c2aedaaf9149",
- "placeholder": "",
- "style": "IPY_MODEL_26062d5edbee4879a66829962199ca43",
- "value": "encoding: 100%"
- }
- },
- "d052c01440db4dafb5d699eb57a9d613": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "FloatProgressModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "FloatProgressModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "ProgressView",
- "bar_style": "success",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_c65d6c4a6d0a44d2a9fb8ca75cc5f790",
- "max": 20,
- "min": 0,
- "orientation": "horizontal",
- "style": "IPY_MODEL_3bb8adf35cf74c3cbd3d2c58912041a3",
- "value": 20
- }
- },
- "e1bbc114c9054a28a48762831a44ef11": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HTMLModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_cc2125fcf9ab49eb9e2be054a4c3fc18",
- "placeholder": "",
- "style": "IPY_MODEL_2ff93c21f097436f9ccd61a8c9c8010d",
- "value": " 20/20 [00:32<00:00, 1.47s/it]"
- }
- },
- "591bf12de23c41c6aa510f6d6702b30e": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "5b379ade011143b9bf21c2aedaaf9149": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "26062d5edbee4879a66829962199ca43": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- },
- "c65d6c4a6d0a44d2a9fb8ca75cc5f790": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "3bb8adf35cf74c3cbd3d2c58912041a3": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "ProgressStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "ProgressStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "bar_color": null,
- "description_width": ""
- }
- },
- "cc2125fcf9ab49eb9e2be054a4c3fc18": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "2ff93c21f097436f9ccd61a8c9c8010d": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- }
- }
- }
- },
- "cells": [
- {
- "cell_type": "markdown",
- "source": [
- "In this notebook, we provide the steps to reproduce a plot similar to https://huggingface.co/spaces/gwf-uwaterloo/aclscatter2d\n",
- "\n",
- "**Before running this colab, make sure the runtime type is set to GPU.** You can double check this in the \"Checks\" section.\n",
- "\n",
- "The plot will be generated using [plotly](https://plotly.com/python/getting-started/)."
- ],
- "metadata": {
- "id": "AeaHYgzwgyOF"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "# @title XML file name to download from acl-anthology github page\n",
- "FILE_NAME = '2023.acl.xml' # @param {type:\"string\"}"
- ],
- "metadata": {
- "cellView": "form",
- "id": "mQ31dArhTOmd"
- },
- "execution_count": 1,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "# @title Model name from huggingface\n",
- "MODEL_NAME = 'allenai/specter2_base' # @param {type:\"string\"}\n",
- "\n",
- "ADAPTER_NAME = \"\" # @param {type:\"string\"}"
- ],
- "metadata": {
- "cellView": "form",
- "id": "jSt0Jpueanvn"
- },
- "execution_count": 2,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "# @title Inference args\n",
- "BATCH_SIZE = 64 # @param {type:\"integer\"}"
- ],
- "metadata": {
- "cellView": "form",
- "id": "HryCbmPBcw5V"
- },
- "execution_count": 3,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "# @title Visualization args\n",
- "NUM_CLUSTERS = 50 # @param {type:\"integer\"}"
- ],
- "metadata": {
- "cellView": "form",
- "id": "qyedQTz5ezl4"
- },
- "execution_count": 4,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Setup"
- ],
- "metadata": {
- "id": "jXbz3X1sUHcr"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Install dependencies"
- ],
- "metadata": {
- "id": "O9n1VhtvUQxS"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "!pip install datasets\n",
- "!pip install transformers\n",
- "!pip install adapter-transformers==3.0.1"
- ],
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "d0XchP9jUOhb",
- "outputId": "133dcd54-f647-44bf-e5d4-f91383be6640"
- },
- "execution_count": 5,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\u001b[33mWARNING: Ignoring invalid distribution -lotly (/usr/local/lib/python3.10/dist-packages)\u001b[0m\u001b[33m\n",
- "\u001b[0mRequirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.14.5)\n",
- "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.23.5)\n",
- "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)\n",
- "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.7)\n",
- "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n",
- "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n",
- "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.1)\n",
- "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.3.0)\n",
- "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.15)\n",
- "Requirement already satisfied: fsspec[http]<2023.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n",
- "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.8.5)\n",
- "Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.17.2)\n",
- "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.1)\n",
- "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n",
- "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n",
- "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (3.2.0)\n",
- "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n",
- "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
- "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.2)\n",
- "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.0)\n",
- "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
- "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (3.12.2)\n",
- "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (4.5.0)\n",
- "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.4)\n",
- "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.0.4)\n",
- "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.7.22)\n",
- "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
- "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3.post1)\n",
- "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n",
- "\u001b[33mWARNING: Ignoring invalid distribution -lotly (/usr/local/lib/python3.10/dist-packages)\u001b[0m\u001b[33m\n",
- "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution -lotly (/usr/local/lib/python3.10/dist-packages)\u001b[0m\u001b[33m\n",
- "\u001b[0mRequirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.33.2)\n",
- "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.2)\n",
- "Requirement already satisfied: huggingface-hub<1.0,>=0.15.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.17.2)\n",
- "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n",
- "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n",
- "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
- "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n",
- "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n",
- "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.13.3)\n",
- "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.3.3)\n",
- "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n",
- "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.15.1->transformers) (2023.6.0)\n",
- "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.15.1->transformers) (4.5.0)\n",
- "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.2.0)\n",
- "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n",
- "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.4)\n",
- "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)\n",
- "\u001b[33mWARNING: Ignoring invalid distribution -lotly (/usr/local/lib/python3.10/dist-packages)\u001b[0m\u001b[33m\n",
- "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution -lotly (/usr/local/lib/python3.10/dist-packages)\u001b[0m\u001b[33m\n",
- "\u001b[0mRequirement already satisfied: adapter-transformers==3.0.1 in /usr/local/lib/python3.10/dist-packages (3.0.1)\n",
- "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (3.12.2)\n",
- "Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (0.17.2)\n",
- "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (1.23.5)\n",
- "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (23.1)\n",
- "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (6.0.1)\n",
- "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (2023.6.3)\n",
- "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (2.31.0)\n",
- "Requirement already satisfied: sacremoses in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (0.0.53)\n",
- "Requirement already satisfied: tokenizers!=0.11.3,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (0.13.3)\n",
- "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (4.66.1)\n",
- "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.1.0->adapter-transformers==3.0.1) (2023.6.0)\n",
- "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.1.0->adapter-transformers==3.0.1) (4.5.0)\n",
- "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->adapter-transformers==3.0.1) (3.2.0)\n",
- "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->adapter-transformers==3.0.1) (3.4)\n",
- "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->adapter-transformers==3.0.1) (2.0.4)\n",
- "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->adapter-transformers==3.0.1) (2023.7.22)\n",
- "Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from sacremoses->adapter-transformers==3.0.1) (1.16.0)\n",
- "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from sacremoses->adapter-transformers==3.0.1) (8.1.7)\n",
- "Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from sacremoses->adapter-transformers==3.0.1) (1.3.2)\n",
- "\u001b[33mWARNING: Ignoring invalid distribution -lotly (/usr/local/lib/python3.10/dist-packages)\u001b[0m\u001b[33m\n",
- "\u001b[0m"
- ]
- }
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Imports"
- ],
- "metadata": {
- "id": "c0MMhYc_UKfG"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "import json\n",
- "import os\n",
- "import re\n",
- "from functools import partial\n",
- "from tqdm.auto import tqdm\n",
- "from typing import Any, Iterable, Mapping\n",
- "\n",
- "import datasets\n",
- "import numpy as np\n",
- "import pandas as pd\n",
- "import torch\n",
- "from torch.utils.data import DataLoader\n",
- "from transformers import DataCollatorWithPadding, AutoModel, AutoTokenizer, AutoConfig\n",
- "from sklearn.cluster import KMeans\n",
- "from sklearn.manifold import TSNE\n",
- "\n",
- "import plotly.express as px"
- ],
- "metadata": {
- "id": "AJULv3wPUG0z"
- },
- "execution_count": 6,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Checks"
- ],
- "metadata": {
- "id": "BY2W1tBTUVWN"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@markdown **Check GPU type**\n",
- "!nvidia-smi -L\n",
- "\n",
- "#@markdown **Check PyTorch version**\n",
- "print(\"PyTorch version:\", torch.__version__)\n",
- "print(\"CUDA version:\", torch.version.cuda)\n",
- "print(\"#GPUs:\", torch.cuda.device_count())"
- ],
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "cellView": "form",
- "id": "jtYjxTfuUXUb",
- "outputId": "4f62a4ba-8b8b-462d-caa6-e002ec2d7b1b"
- },
- "execution_count": 7,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "GPU 0: Tesla T4 (UUID: GPU-5e2802f0-3a72-ee6b-56ce-fc17d7e725c4)\n",
- "PyTorch version: 2.0.1+cu118\n",
- "CUDA version: 11.8\n",
- "#GPUs: 1\n"
- ]
- }
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Load Huggingface Stuff"
- ],
- "metadata": {
- "id": "osH8mbM4aCw0"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\n",
- "\n",
- "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
- "\n",
- "config = AutoConfig.from_pretrained(MODEL_NAME, return_dict=True, output_hidden_states=True)\n",
- "\n",
- "model = AutoModel.from_pretrained(MODEL_NAME, config=config)\n",
- "if ADAPTER_NAME:\n",
- " model.load_adapter(\n",
- " ADAPTER_NAME,\n",
- " source=\"hf\",\n",
- " set_active=True,\n",
- " )\n",
- "\n",
- "model.eval()\n",
- "model.to(\"cuda\")"
- ],
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "6j9EGcCSZ8Z_",
- "outputId": "1edfabc5-35b0-47d6-8c58-8cf1e35ca5fe"
- },
- "execution_count": 8,
- "outputs": [
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "BertModel(\n",
- " (shared_parameters): ModuleDict()\n",
- " (invertible_adapters): ModuleDict()\n",
- " (embeddings): BertEmbeddings(\n",
- " (word_embeddings): Embedding(31090, 768, padding_idx=0)\n",
- " (position_embeddings): Embedding(512, 768)\n",
- " (token_type_embeddings): Embedding(2, 768)\n",
- " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " (encoder): BertEncoder(\n",
- " (layer): ModuleList(\n",
- " (0-11): 12 x BertLayer(\n",
- " (attention): BertAttention(\n",
- " (self): BertSelfAttention(\n",
- " (query): Linear(in_features=768, out_features=768, bias=True)\n",
- " (key): Linear(in_features=768, out_features=768, bias=True)\n",
- " (value): Linear(in_features=768, out_features=768, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (prefix_tuning): PrefixTuningShim(\n",
- " (pool): PrefixTuningPool(\n",
- " (prefix_tunings): ModuleDict()\n",
- " )\n",
- " )\n",
- " )\n",
- " (output): BertSelfOutput(\n",
- " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
- " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (adapters): ModuleDict()\n",
- " (adapter_fusion_layer): ModuleDict()\n",
- " )\n",
- " )\n",
- " (intermediate): BertIntermediate(\n",
- " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
- " (intermediate_act_fn): GELUActivation()\n",
- " )\n",
- " (output): BertOutput(\n",
- " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
- " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (adapters): ModuleDict()\n",
- " (adapter_fusion_layer): ModuleDict()\n",
- " )\n",
- " )\n",
- " )\n",
- " )\n",
- " (pooler): BertPooler(\n",
- " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
- " (activation): Tanh()\n",
- " )\n",
- " (prefix_tuning): PrefixTuningPool(\n",
- " (prefix_tunings): ModuleDict()\n",
- " )\n",
- ")"
- ]
- },
- "metadata": {},
- "execution_count": 8
- }
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Preparing Data"
- ],
- "metadata": {
- "id": "v9olGFFaP6Un"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Downloading from acl-anthology github"
- ],
- "metadata": {
- "id": "YvFxyYEpP_wj"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "The paper information can be downloaded from `acl-anthology` github page in the XML format: https://github.com/acl-org/acl-anthology/tree/master/data/xml/"
- ],
- "metadata": {
- "id": "Vm022cIzSorc"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "!rm -f $FILE_NAME\n",
- "!wget \"https://raw.githubusercontent.com/acl-org/acl-anthology/master/data/xml/$FILE_NAME\"\n",
- "\n",
- "assert os.path.exists(FILE_NAME), \"Downloaded file exists\""
- ],
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "knMDRgK8Sfl_",
- "outputId": "ea0abab7-fe9f-4ffa-e627-1b3d4f5a8953"
- },
- "execution_count": 9,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "--2023-09-20 03:28:48-- https://raw.githubusercontent.com/acl-org/acl-anthology/master/data/xml/2023.acl.xml\n",
- "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
- "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
- "HTTP request sent, awaiting response... 200 OK\n",
- "Length: 2597735 (2.5M) [text/plain]\n",
- "Saving to: ‘2023.acl.xml’\n",
- "\n",
- "2023.acl.xml 100%[===================>] 2.48M --.-KB/s in 0.02s \n",
- "\n",
- "2023-09-20 03:28:49 (142 MB/s) - ‘2023.acl.xml’ saved [2597735/2597735]\n",
- "\n"
- ]
- }
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "download the xml file from this [link](https://github.com/acl-org/acl-anthology/tree/006c7247a6bf0ff859bfd3aab6ea6a19452580ad/data/xml). \n",
- "Convert the xml files to jsonl files by running the following code"
- ],
- "metadata": {
- "id": "2KFobPmUbu7j"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Parsing"
- ],
- "metadata": {
- "id": "CUD4LOJlUmMj"
- }
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "id": "WXQgTZQ103g7",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "4edc4fd1-0a7f-4419-ffa3-e1d9f259a139"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "#papers founds in 2023.acl.xml: 1249\n"
- ]
- }
- ],
- "source": [
- "import xml.etree.ElementTree as ET\n",
- "\n",
- "URL_MAPPINGS = dict(\n",
- " D=\"emnlp\",\n",
- " N=\"naacl\",\n",
- " P=\"acl\",\n",
- " Q=\"tacl\",\n",
- ")\n",
- "\n",
- "def xml_to_jsonl(xml_file: os.PathLike) -> Iterable[Mapping[str, Any]]:\n",
- " tree = ET.parse(xml_file)\n",
- " root = tree.getroot()\n",
- " papers = root.findall(\".//paper\")\n",
- "\n",
- " for paper in papers:\n",
- " paper_dict = {}\n",
- " paper_dict[\"title\"] = \"\".join(paper.find(\"title\").itertext())\n",
- "\n",
- " authors = []\n",
- " for author in paper.findall(\"author\"):\n",
- " first_name = author.findtext(\"first\")\n",
- " last_name = author.findtext(\"last\")\n",
- " authors.append(f\"{first_name} {last_name}\")\n",
- " paper_dict[\"authors\"] = authors\n",
- "\n",
- " paper_dict[\"abstract\"] = \"\" if paper.find(\"abstract\")==None else \"\".join(paper.find(\"abstract\").itertext())\n",
- " paper_dict[\"pages\"] = paper.findtext(\"pages\")\n",
- " paper_dict[\"url\"] = paper.findtext(\"url\")\n",
- " paper_dict[\"bibkey\"] = paper.findtext(\"bibkey\")\n",
- " paper_dict[\"doi\"] = paper.findtext(\"doi\")\n",
- "\n",
- " conference, paper_type = None, None\n",
- " matched = re.match(r\"(\\d+)\\.(\\w+)-(\\w+)\\.\\d+\", paper_dict[\"url\"])\n",
- " if matched:\n",
- " year = int(matched.group(1))\n",
- " conference = matched.group(2)\n",
- " paper_type = matched.group(3)\n",
- " else:\n",
- " bibs = paper_dict[\"bibkey\"].split(\"-\")\n",
- " for b in range(len(bibs) - 1, -1, -1):\n",
- " try:\n",
- " year = int(bibs[b])\n",
- " break\n",
- " except ValueError:\n",
- " pass\n",
- "\n",
- " conference = URL_MAPPINGS.get(paper_dict[\"url\"][0], None)\n",
- "\n",
- " paper_dict[\"source\"] = conference\n",
- " paper_dict[\"year\"] = year\n",
- " paper_dict[\"publication_type\"] = paper_type\n",
- "\n",
- " yield paper_dict\n",
- "\n",
- "papers = list(xml_to_jsonl(FILE_NAME))\n",
- "\n",
- "print(f\"#papers founds in {FILE_NAME}: {len(papers)}\")"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Encode"
- ],
- "metadata": {
- "id": "3yXoFyHhdd25"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Creating DataLoader"
- ],
- "metadata": {
- "id": "ml0g17tYX2jP"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "dataset = datasets.Dataset.from_list(\n",
- " [{\"text\": p[\"title\"] + tokenizer.sep_token + (p[\"abstract\"] or \"\"), \"idx\": i + 1} for i, p in enumerate(papers)]\n",
- ")\n",
- "\n",
- "tokenize_fn = lambda batch: tokenizer(batch[\"text\"], padding=True, truncation=True, max_length=512)\n",
- "dataset = dataset.map(tokenize_fn, batched=True)\n",
- "\n",
- "columns = [\"idx\", \"input_ids\", \"attention_mask\"]\n",
- "if \"token_type_ids\" in dataset.column_names:\n",
- " columns.append(\"token_type_ids\")\n",
- "\n",
- "data_loader = DataLoader(\n",
- " dataset.with_format(\"torch\", columns=columns),\n",
- " collate_fn=DataCollatorWithPadding(tokenizer),\n",
- " batch_size=BATCH_SIZE,\n",
- ")"
- ],
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 153,
- "referenced_widgets": [
- "1619b254fcbb4cb880d1be5685c74dbc",
- "607c048fb1634a7689e355036c144984",
- "869501d4d38e46f184a66423d93a2745",
- "c03dc1381a0c430182fe86d8a100b249",
- "b6d31f4cebc84ef0a563d41482b14cc2",
- "9ebec18dbdff4913a4902429a726b9e0",
- "c9dc6fbcf53a4c9fb53716a18db6ffbe",
- "c9ef5bf8ff3e44358c4557f74c3e379e",
- "0cc5f439950e49eaa4d417396e21e2c4",
- "1479cc60b4ac4864b46b592dc1050157",
- "a9e75caedfbf46e0bd0effe1e60065cd"
- ]
- },
- "id": "sCG1iVa4X7ye",
- "outputId": "a287df82-3448-4b24-9e26-582bd7b4b180"
- },
- "execution_count": 11,
- "outputs": [
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- "Map: 0%| | 0/1249 [00:00, ? examples/s]"
- ],
- "application/vnd.jupyter.widget-view+json": {
- "version_major": 2,
- "version_minor": 0,
- "model_id": "1619b254fcbb4cb880d1be5685c74dbc"
- }
- },
- "metadata": {}
- }
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Running Inference"
- ],
- "metadata": {
- "id": "1KtBlNdMdnQQ"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "embeds = []\n",
- "for batch in tqdm(data_loader, desc=\"encoding\"):\n",
- " indices = batch.pop(\"idx\", None)\n",
- " if isinstance(indices, torch.Tensor):\n",
- " indices = indices.cpu().tolist()\n",
- "\n",
- " batch = {k: v.to(\"cuda\") if v is not None else v for k, v in batch.items()}\n",
- "\n",
- " with torch.no_grad():\n",
- " output = model(**batch)\n",
- " encoded = output.last_hidden_state[:, 0].cpu().numpy()\n",
- "\n",
- " embeds.append(encoded)\n",
- "\n",
- "embeds = np.concatenate(embeds, axis=0)\n",
- "\n",
- "print(f\"Embeddings size:\", embeds.shape)"
- ],
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 143,
- "referenced_widgets": [
- "4c7b67b7151e4c9fb47eaae2f39a21b8",
- "5bd552c8824e407c934978e35e7de980",
- "d052c01440db4dafb5d699eb57a9d613",
- "e1bbc114c9054a28a48762831a44ef11",
- "591bf12de23c41c6aa510f6d6702b30e",
- "5b379ade011143b9bf21c2aedaaf9149",
- "26062d5edbee4879a66829962199ca43",
- "c65d6c4a6d0a44d2a9fb8ca75cc5f790",
- "3bb8adf35cf74c3cbd3d2c58912041a3",
- "cc2125fcf9ab49eb9e2be054a4c3fc18",
- "2ff93c21f097436f9ccd61a8c9c8010d"
- ]
- },
- "id": "QbqEJIgWdr2o",
- "outputId": "644df647-2880-417e-be90-e12492a5c3b7"
- },
- "execution_count": 12,
- "outputs": [
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- "encoding: 0%| | 0/20 [00:00, ?it/s]"
- ],
- "application/vnd.jupyter.widget-view+json": {
- "version_major": 2,
- "version_minor": 0,
- "model_id": "4c7b67b7151e4c9fb47eaae2f39a21b8"
- }
- },
- "metadata": {}
- },
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Embeddings size: (1249, 768)\n"
- ]
- }
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Housekeeping prior to Visualization\n",
- "\n",
- "To plot the embeddings, we first cluster the points and then reduce the number of dimensions to 2-d using t-SNE."
- ],
- "metadata": {
- "id": "agDmw5DPefij"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Clustering"
- ],
- "metadata": {
- "id": "bbLVfaIufuSu"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "clusterer = KMeans(n_clusters=NUM_CLUSTERS, n_init=\"auto\")\n",
- "clusters = clusterer.fit(embeds).labels_\n",
- "\n",
- "print(\"Clustering done\")"
- ],
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "iUEvI_OaeoIf",
- "outputId": "6bc3c072-7fdc-4349-cd57-7e9205f77c01"
- },
- "execution_count": 13,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Clustering done\n"
- ]
- }
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Applying t-SNE\n",
- "\n",
- "We changed perplexity and number of iterations from their default value because the scatter plot would look nicer."
- ],
- "metadata": {
- "id": "sqfCHjTAfwXF"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "reducer = TSNE(n_jobs=12, perplexity=10, n_iter=3000)\n",
- "reduced_embeds = reducer.fit_transform(embeds)"
- ],
- "metadata": {
- "id": "AT2fWBc4fyFl"
- },
- "execution_count": 14,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Visualize"
- ],
- "metadata": {
- "id": "t28XwXvNgrBo"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "# @title\n",
- "def to_string_authors(list_of_authors):\n",
- " if len(list_of_authors) > 5:\n",
- " return \", \".join(list_of_authors[:5]) + \", et al.\"\n",
- " elif len(list_of_authors) > 2:\n",
- " return \", \".join(list_of_authors[:-1]) + \", and \" + list_of_authors[-1]\n",
- " else:\n",
- " return \" and \".join(list_of_authors)\n",
- "\n",
- "\n",
- "for i, (point, c, p) in enumerate(zip(reduced_embeds, clusters, papers)):\n",
- " p[\"x\"] = point[0]\n",
- " p[\"y\"] = point[1]\n",
- " p[\"cluster\"] = c\n",
- " p[\"authors_trimmed\"] = [(x[x.index(\",\") + 1 :].strip() + \" \" + x.split(\",\")[0].strip()) if \",\" in x else x for x in p[\"authors\"]]\n",
- " if \"publication_type\" in p:\n",
- " p[\"type\"] = p.pop(\"publication_type\")\n",
- "\n",
- "df = pd.DataFrame(papers)\n",
- "\n",
- "fig = px.scatter(\n",
- " df,\n",
- " x=\"x\",\n",
- " y=\"y\",\n",
- " color=\"cluster\",\n",
- " width=1000,\n",
- " height=800,\n",
- " custom_data=(\"title\", \"authors_trimmed\", \"year\", \"source\", \"type\"),\n",
- " color_continuous_scale=\"fall\",\n",
- ")\n",
- "fig.update_traces(\n",
- " hovertemplate=\"%{customdata[0]}
%{customdata[1]}
%{customdata[2]}
%{customdata[3]}\"\n",
- ")\n",
- "fig.update_layout(\n",
- " showlegend=False,\n",
- " font=dict(\n",
- " family=\"Times New Roman\",\n",
- " size=30,\n",
- " ),\n",
- " hoverlabel=dict(\n",
- " align=\"left\",\n",
- " font_size=14,\n",
- " font_family=\"Rockwell\",\n",
- " namelength=-1,\n",
- " ),\n",
- ")\n",
- "fig.update_xaxes(title=\"\")\n",
- "fig.update_yaxes(title=\"\")\n",
- "\n",
- "a = fig.show()"
- ],
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 817
- },
- "cellView": "form",
- "id": "B-TwYJM5gtF-",
- "outputId": "99a5d7d7-2e49-43af-be93-7677c50effba"
- },
- "execution_count": 15,
- "outputs": [
- {
- "output_type": "display_data",
- "data": {
- "text/html": [
- "\n",
- "