{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "collapsed_sections": [ "GirPusJtYPsP", "C-7gft4ddTzo" ], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU", "widgets": { "application/vnd.jupyter.widget-state+json": { "c46dc091acd34be2887c59bf95838529": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_af6cce709f8a478a87c4c89222193d8b", "IPY_MODEL_4c32916fe36e4466b9c8a96bbd7db71b", "IPY_MODEL_48d80beb4cc646328036336431b01278" ], "layout": "IPY_MODEL_755b9a30f525493382f771840e4b04f4" } }, "af6cce709f8a478a87c4c89222193d8b": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_481a5ffb8c394dfe88d9df74b3edd372", "placeholder": "​", "style": "IPY_MODEL_7ddb099532f44b8fbc77bfd94232ff8f", "value": "README.md: 100%" } }, "4c32916fe36e4466b9c8a96bbd7db71b": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_ba7928e2ec0e459e870e7f8420fc26f6", "max": 360, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_2422979a81b0425fa82df01a1b4b170a", "value": 360 } }, "48d80beb4cc646328036336431b01278": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_93694c9083bc4a19a51b854829180dc9", "placeholder": "​", "style": "IPY_MODEL_8d3e47bf8478457d9db128a9927fc568", "value": " 360/360 [00:00<00:00, 16.6kB/s]" } }, "755b9a30f525493382f771840e4b04f4": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "481a5ffb8c394dfe88d9df74b3edd372": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7ddb099532f44b8fbc77bfd94232ff8f": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "ba7928e2ec0e459e870e7f8420fc26f6": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2422979a81b0425fa82df01a1b4b170a": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "93694c9083bc4a19a51b854829180dc9": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "8d3e47bf8478457d9db128a9927fc568": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "6d93d4f16f7f409d895c928f4c091619": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_bdd1de9927bf4183a39c4eb417b4ee65", "IPY_MODEL_ed4e47b387a347f98567e87f4dce2dff", "IPY_MODEL_95b6f6c28acc4ff3a7da7d1ac5d1fc2d" ], "layout": "IPY_MODEL_dcda5b897d8c482f8ce32387af5fdb2b" } }, "bdd1de9927bf4183a39c4eb417b4ee65": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_b2da78bd47b144f49a8202289cc6745a", "placeholder": "​", "style": "IPY_MODEL_1ed7259217474bcfa1e5f80071fb708e", "value": "train-00000-of-00001.parquet: 100%" } }, "ed4e47b387a347f98567e87f4dce2dff": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_2589e545cc864d4095becc8d1f75f263", "max": 306697640, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_c1584c81502c471da4c9d89c3e922813", "value": 306697640 } }, "95b6f6c28acc4ff3a7da7d1ac5d1fc2d": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_fd455eaad05b4614acaee95e03a44fa0", "placeholder": "​", "style": "IPY_MODEL_2df5051a2bc743258ef138f14173ccc2", "value": " 307M/307M [00:12<00:00, 22.8MB/s]" } }, "dcda5b897d8c482f8ce32387af5fdb2b": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "b2da78bd47b144f49a8202289cc6745a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "1ed7259217474bcfa1e5f80071fb708e": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "2589e545cc864d4095becc8d1f75f263": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c1584c81502c471da4c9d89c3e922813": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "fd455eaad05b4614acaee95e03a44fa0": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2df5051a2bc743258ef138f14173ccc2": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "e2ef3cf0e3ff4ea3a8a0dff3dd73a5f1": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_7bac50c73a644c9f9e3369b763cb5db7", "IPY_MODEL_24206922f4c64c8aadbaec122804aadf", "IPY_MODEL_c1f30aa01b434b0d8f9799503d9601f9" ], "layout": "IPY_MODEL_7b671494c6754864931f43c546578dcb" } }, "7bac50c73a644c9f9e3369b763cb5db7": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_698b1bbf0fdc47e389e9d8eb5aca93d6", "placeholder": "​", "style": "IPY_MODEL_82a61b5594dd49f2b1ca5dea552b8d87", "value": "Generating train split: 100%" } }, "24206922f4c64c8aadbaec122804aadf": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_be75f5be99c246d5a01186a17181a3c3", "max": 100, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_80300b7040c349ed92ecefb4d3402a7b", "value": 100 } }, "c1f30aa01b434b0d8f9799503d9601f9": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_0a1996f8fe29482aa0b972c07040d97d", "placeholder": "​", "style": "IPY_MODEL_fa4c605349df4638a8c71e7aa52db1ad", "value": " 100/100 [00:01<00:00, 71.49 examples/s]" } }, "7b671494c6754864931f43c546578dcb": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "698b1bbf0fdc47e389e9d8eb5aca93d6": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "82a61b5594dd49f2b1ca5dea552b8d87": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "be75f5be99c246d5a01186a17181a3c3": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "80300b7040c349ed92ecefb4d3402a7b": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "0a1996f8fe29482aa0b972c07040d97d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "fa4c605349df4638a8c71e7aa52db1ad": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } } } } }, "cells": [ { "cell_type": "markdown", "source": [ "# Imports and Hugging Face Login" ], "metadata": { "id": "GirPusJtYPsP" } }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "hKdN-6CXXV12", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "a1d9a131-8a76-436c-ac24-c3467b5dcc01" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (0.26.5)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (3.16.1)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (2024.10.0)\n", "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (24.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (6.0.2)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (4.66.6)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub) (3.4.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub) (2.2.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub) (2024.8.30)\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0m" ] } ], "source": [ "!pip install huggingface-hub\n", "!pip install datasets > delete.txt" ] }, { "cell_type": "code", "source": [ "import torch\n", "import pickle\n", "from huggingface_hub import hf_hub_download\n", "from datasets import load_dataset, Image\n", "import torch\n", "from torch import nn, optim\n", "from torch.utils.data import DataLoader, Dataset\n", "import numpy as np\n", "from geopy.distance import geodesic\n", "import matplotlib.pyplot as plt\n", "from torchvision import models" ], "metadata": { "id": "SPzgZOzxYYiT" }, "execution_count": 2, "outputs": [] }, { "cell_type": "code", "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ], "metadata": { "id": "PJquO0g1YaMU", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "d82f4fdc-32ee-4f91-e6ce-558ad3e3c837" }, "execution_count": 3, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "cuda\n" ] } ] }, { "cell_type": "code", "source": [ "!huggingface-cli login\n", "# use appropiate token" ], "metadata": { "id": "IcGfZSsoZgau", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "436dcc6f-a924-4be8-e9a8-39c197e5e1e1" }, "execution_count": 4, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", " _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|\n", " _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|\n", " _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|\n", " _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|\n", " _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|\n", "\n", " To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .\n", "Enter your token (input will not be visible): \n", "Add token as git credential? (Y/n) y\n", "Token is valid (permission: fineGrained).\n", "The token `CIS 5190 Project 3` has been saved to /root/.cache/huggingface/stored_tokens\n", "\u001b[1m\u001b[31mCannot authenticate through git-credential as no helper is defined on your machine.\n", "You might have to re-authenticate when pushing to the Hugging Face Hub.\n", "Run the following command in your terminal in case you want to set the 'store' credential helper as default.\n", "\n", "git config --global credential.helper store\n", "\n", "Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.\u001b[0m\n", "Token has not been saved to git credential helper.\n", "Your token has been saved to /root/.cache/huggingface/token\n", "Login successful.\n", "The current active token is: `CIS 5190 Project 3`\n" ] } ] }, { "cell_type": "markdown", "source": [ "# Models and Classes" ], "metadata": { "id": "LplsJ-PXXbtm" } }, { "cell_type": "code", "source": [ "class EnsembleModel(nn.Module):\n", " def __init__(self, models, num_models):\n", " super(EnsembleModel, self).__init__()\n", " self.models = nn.ModuleList(models)\n", " self.weights = nn.Parameter(torch.ones(num_models) / num_models)\n", "\n", " def forward(self, x):\n", " outputs = torch.stack([model(x) for model in self.models], dim=-1)\n", " weighted_output = torch.einsum('bij,j->bi', outputs, self.weights)\n", " return weighted_output" ], "metadata": { "id": "ofOTpLIPcylC" }, "execution_count": 9, "outputs": [] }, { "cell_type": "code", "source": [ "class Model1(nn.Module):\n", " def __init__(self, dropout):\n", " super(Model1, self).__init__()\n", " self.features = nn.Sequential(\n", " nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d(kernel_size=3, stride=2),\n", " nn.Conv2d(64, 192, kernel_size=5, padding=2),\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d(kernel_size=3, stride=2),\n", " nn.Conv2d(192, 384, kernel_size=3, padding=1),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(384, 256, kernel_size=3, padding=1),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(256, 256, kernel_size=3, padding=1),\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d(kernel_size=3, stride=2),\n", " )\n", " self.classifier = nn.Sequential(\n", " nn.Dropout(p=dropout),\n", " nn.Linear(256 * 6 * 6, 1024),\n", " nn.ReLU(inplace=True),\n", " nn.Dropout(p=dropout),\n", " nn.Linear(1024, 512),\n", " nn.ReLU(inplace=True),\n", " nn.Linear(512, 2),\n", " )\n", "\n", " def forward(self, x):\n", " x = self.features(x)\n", " x = torch.flatten(x, 1)\n", " x = self.classifier(x)\n", " return x\n", "\n", "\n", "def model_fn(dropout):\n", " return Model1(dropout)" ], "metadata": { "id": "fbtZvQrlYGfU" }, "execution_count": 10, "outputs": [] }, { "cell_type": "code", "source": [ "class Model2(nn.Module):\n", " def __init__(self, num_blocks=3, dropout_rate=0.5):\n", " super(Model2, self).__init__()\n", "\n", " resnet = models.resnet34(pretrained=True)\n", "\n", " for param in list(resnet.parameters())[:num_blocks]:\n", " param.requires_grad = False\n", "\n", " self.features = nn.Sequential(*list(resnet.children())[:-2])\n", " self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n", "\n", " self.classifier = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Dropout(p=dropout_rate),\n", " nn.Linear(resnet.fc.in_features, 512),\n", " nn.ReLU(inplace=True),\n", " nn.Dropout(p=dropout_rate),\n", " nn.Linear(512, 2)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.features(x)\n", " x = self.avgpool(x)\n", " x = self.classifier(x)\n", " return x" ], "metadata": { "id": "iBssHEtGXdWi" }, "execution_count": 11, "outputs": [] }, { "cell_type": "code", "source": [ "class InceptionModule(nn.Module):\n", " def __init__(self, in_channels, ch1x1, ch3x3_reduce, ch3x3, ch5x5_reduce, ch5x5, pool_proj):\n", " super(InceptionModule, self).__init__()\n", "\n", " self.branch1 = nn.Sequential(\n", " nn.Conv2d(in_channels, ch1x1, kernel_size=1),\n", " nn.ReLU(inplace=True)\n", " )\n", " self.branch2 = nn.Sequential(\n", " nn.Conv2d(in_channels, ch3x3_reduce, kernel_size=1),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(ch3x3_reduce, ch3x3, kernel_size=3, padding=1),\n", " nn.ReLU(inplace=True)\n", " )\n", "\n", " self.branch3 = nn.Sequential(\n", " nn.Conv2d(in_channels, ch5x5_reduce, kernel_size=1),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(ch5x5_reduce, ch5x5, kernel_size=5, padding=2),\n", " nn.ReLU(inplace=True)\n", " )\n", "\n", " self.branch4 = nn.Sequential(\n", " nn.MaxPool2d(kernel_size=3, stride=1, padding=1),\n", " nn.Conv2d(in_channels, pool_proj, kernel_size=1),\n", " nn.ReLU(inplace=True)\n", " )\n", "\n", " def forward(self, x):\n", " branch1 = self.branch1(x)\n", " branch2 = self.branch2(x)\n", " branch3 = self.branch3(x)\n", " branch4 = self.branch4(x)\n", " outputs = torch.cat([branch1, branch2, branch3, branch4], 1)\n", " return outputs\n", "\n", "class Model4(nn.Module):\n", " def __init__(self, dropout_rate=0.5):\n", " super(Model4, self).__init__()\n", "\n", " self.pre_layers = nn.Sequential(\n", " nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1),\n", " nn.Conv2d(64, 192, kernel_size=3, padding=1),\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n", " )\n", "\n", "\n", " self.inception1 = InceptionModule(192, 64, 96, 128, 16, 32, 32)\n", " self.inception2 = InceptionModule(256, 128, 128, 192, 32, 96, 64)\n", "\n", " self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n", "\n", " self.inception3 = InceptionModule(480, 192, 96, 208, 16, 48, 64)\n", " self.inception4 = InceptionModule(512, 160, 112, 224, 24, 64, 64)\n", "\n", " self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n", " self.classifier = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Dropout(p=dropout_rate),\n", " nn.Linear(512, 1024),\n", " nn.ReLU(inplace=True),\n", " nn.Dropout(p=dropout_rate),\n", " nn.Linear(1024, 512),\n", " nn.ReLU(inplace=True),\n", " nn.Linear(512, 2)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.pre_layers(x)\n", " x = self.inception1(x)\n", " x = self.inception2(x)\n", " x = self.maxpool(x)\n", " x = self.inception3(x)\n", " x = self.inception4(x)\n", " x = self.avgpool(x)\n", " x = self.classifier(x)\n", " return x" ], "metadata": { "id": "c4y6R0A3XjcI" }, "execution_count": 12, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Load Test Dataset" ], "metadata": { "id": "ybwRXm3zYg_I" } }, { "cell_type": "code", "source": [ "from torch.utils.data import Dataset\n", "class GPSImageDataset(Dataset):\n", " def __init__(self, hf_dataset, transform, lat_mean=None, lat_std=None, lon_mean=None, lon_std=None):\n", " self.hf_dataset = hf_dataset\n", " self.transform = transform\n", "\n", " # Normalize the latitude and longitude\n", " self.latitudes = np.array(hf_dataset['Latitude'])\n", " self.longitudes = np.array(hf_dataset['Longitude'])\n", " self.latitude_mean = lat_mean if lat_mean is not None else self.latitudes.mean()\n", " self.latitude_std = lat_std if lat_std is not None else self.latitudes.std()\n", " self.longitude_mean = lon_mean if lon_mean is not None else self.longitudes.mean()\n", " self.longitude_std = lon_std if lon_std is not None else self.longitudes.std()\n", "\n", " self.normalized_latitudes = (self.latitudes - self.latitude_mean) / self.latitude_std\n", " self.normalized_longitudes = (self.longitudes - self.longitude_mean) / self.longitude_std\n", "\n", " def __len__(self):\n", " return len(self.hf_dataset)\n", "\n", " def __getitem__(self, idx):\n", " image = self.hf_dataset[idx]['image']\n", " latitude = self.normalized_latitudes[idx]\n", " longitude = self.normalized_longitudes[idx]\n", "\n", " if self.transform:\n", " image = self.transform(image)\n", "\n", " return image, torch.tensor([latitude, longitude], dtype=torch.float)" ], "metadata": { "id": "EfCxgZxMY7b6" }, "execution_count": 14, "outputs": [] }, { "cell_type": "code", "source": [ "from torchvision import transforms, models\n", "transform = transforms.Compose([\n", " transforms.RandomResizedCrop(224),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", "])\n", "\n", "inference_transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", "])" ], "metadata": { "id": "P4Gx6KLQXz4E" }, "execution_count": 15, "outputs": [] }, { "cell_type": "code", "source": [ "dataset_test = load_dataset(\"gydou/released_img\")" ], "metadata": { "id": "NTFvFWpRYgcM", "colab": { "base_uri": "https://localhost:8080/", "height": 217, "referenced_widgets": [ "c46dc091acd34be2887c59bf95838529", "af6cce709f8a478a87c4c89222193d8b", "4c32916fe36e4466b9c8a96bbd7db71b", "48d80beb4cc646328036336431b01278", "755b9a30f525493382f771840e4b04f4", "481a5ffb8c394dfe88d9df74b3edd372", "7ddb099532f44b8fbc77bfd94232ff8f", "ba7928e2ec0e459e870e7f8420fc26f6", "2422979a81b0425fa82df01a1b4b170a", "93694c9083bc4a19a51b854829180dc9", "8d3e47bf8478457d9db128a9927fc568", "6d93d4f16f7f409d895c928f4c091619", "bdd1de9927bf4183a39c4eb417b4ee65", "ed4e47b387a347f98567e87f4dce2dff", "95b6f6c28acc4ff3a7da7d1ac5d1fc2d", "dcda5b897d8c482f8ce32387af5fdb2b", "b2da78bd47b144f49a8202289cc6745a", "1ed7259217474bcfa1e5f80071fb708e", "2589e545cc864d4095becc8d1f75f263", "c1584c81502c471da4c9d89c3e922813", "fd455eaad05b4614acaee95e03a44fa0", "2df5051a2bc743258ef138f14173ccc2", "e2ef3cf0e3ff4ea3a8a0dff3dd73a5f1", "7bac50c73a644c9f9e3369b763cb5db7", "24206922f4c64c8aadbaec122804aadf", "c1f30aa01b434b0d8f9799503d9601f9", "7b671494c6754864931f43c546578dcb", "698b1bbf0fdc47e389e9d8eb5aca93d6", "82a61b5594dd49f2b1ca5dea552b8d87", "be75f5be99c246d5a01186a17181a3c3", "80300b7040c349ed92ecefb4d3402a7b", "0a1996f8fe29482aa0b972c07040d97d", "fa4c605349df4638a8c71e7aa52db1ad" ] }, "outputId": "877c8003-7541-4eb2-bfd5-92540f2d2381" }, "execution_count": 16, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "README.md: 0%| | 0.00/360 [00:00 list of models for each type).\n", " ensemble_weights: Numpy array of ensemble weights.\n", " \"\"\"\n", " # Load the pickle file\n", " with open(file_name, \"rb\") as f:\n", " ensemble_data = pickle.load(f)\n", "\n", " # Extract the ensemble weights\n", " ensemble_weights = ensemble_data[\"ensemble_weights\"]\n", "\n", " # Reload the individual models\n", " trained_models = {}\n", " for model_name, state_dicts in ensemble_data[\"models\"].items():\n", " trained_models[model_name] = []\n", " for state_dict in state_dicts:\n", " model = model_classes[model_name]()\n", " model.load_state_dict(state_dict)\n", " model = model.to(device)\n", " trained_models[model_name].append(model)\n", "\n", " return trained_models, ensemble_weights" ], "metadata": { "id": "1PygE9aMZ4xm" }, "execution_count": 43, "outputs": [] }, { "cell_type": "code", "source": [ "model_classes = {\n", " \"Model1\": lambda: Model1(dropout=0.5),\n", " \"Model2\": lambda: Model2(num_blocks=3, dropout_rate=0.5),\n", " \"Model4\": lambda: Model4(dropout_rate=0.5)\n", "}\n", "\n", "# Load the ensemble\n", "trained_models, ensemble_weights = load_ensemble(pickle_file_path, model_classes, device=\"cuda\")\n", "models_ensemble = []\n", "for model_list in trained_models.values():\n", " models_ensemble.extend(model_list)\n", "\n", "# ensemble model\n", "ensemble_model = EnsembleModel(models=models_ensemble, num_models=len(models_ensemble))\n", "ensemble_model.weights.data = torch.tensor(ensemble_weights, dtype=torch.float32, device=\"cuda\")\n", "ensemble_model = ensemble_model.to(\"cuda\")" ], "metadata": { "id": "WpGJ4SIrZ9G2" }, "execution_count": 44, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Evaluation" ], "metadata": { "id": "PN94YVq0dMX1" } }, { "cell_type": "code", "source": [ "def evaluate_final_rmse(ensemble_model, data_loader, lat_mean, lon_mean, lat_std, lon_std):\n", " \"\"\"\n", " Evaluate the ensemble model on a given dataset and compute final RMSE in meters.\n", " \"\"\"\n", " ensemble_model.eval()\n", " total_loss = 0.0\n", " total_samples = 0\n", "\n", " with torch.no_grad():\n", " for images, targets in data_loader:\n", " images, targets = images.to(device), targets.to(device)\n", " outputs = ensemble_model(images)\n", " preds_denorm = outputs.cpu().numpy() * np.array([lat_std, lon_std]) + np.array([lat_mean, lon_mean])\n", " actuals_denorm = targets.cpu().numpy() * np.array([lat_std, lon_std]) + np.array([lat_mean, lon_mean])\n", "\n", " for pred, actual in zip(preds_denorm, actuals_denorm):\n", " distance = geodesic((actual[0], actual[1]), (pred[0], pred[1])).meters\n", " total_loss += distance ** 2\n", " total_samples += targets.size(0)\n", "\n", " final_loss = total_loss / total_samples\n", " final_rmse = np.sqrt(final_loss)\n", "\n", " return final_loss, final_rmse" ], "metadata": { "id": "zUhrqOv5cNag" }, "execution_count": 47, "outputs": [] }, { "cell_type": "code", "source": [ "final_test_loss, final_test_rmse = evaluate_final_rmse(\n", " ensemble_model=ensemble_model,\n", " data_loader=test_dataloader,\n", " lat_mean=lat_mean,\n", " lon_mean=lon_mean,\n", " lat_std=lat_std,\n", " lon_std=lon_std\n", ")\n", "\n", "print(f\"Test Loss (meters^2): {final_test_loss:.2f}\")\n", "print(f\"Test RMSE (meters): {final_test_rmse:.2f}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-UZcLgmBcM-q", "outputId": "5ed71053-5017-48e5-d9ec-825ca01a8124" }, "execution_count": 48, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Test Loss (meters^2): 8089.13\n", "Test RMSE (meters): 89.94\n" ] } ] }, { "cell_type": "markdown", "source": [ "# Visualizatoin" ], "metadata": { "id": "C-7gft4ddTzo" } }, { "cell_type": "code", "source": [ "def visualize_predictions(all_preds, all_actuals, lat_mean, lon_mean, lat_std, lon_std):\n", " \"\"\"\n", " Visualizes actual and predicted GPS coordinates on a scatter plot,\n", " including error lines connecting each prediction to its corresponding actual point.\n", " \"\"\"\n", "\n", " all_preds_denorm = all_preds * np.array([lat_std, lon_std]) + np.array([lat_mean, lon_mean])\n", " all_actuals_denorm = all_actuals * np.array([lat_std, lon_std]) + np.array([lat_mean, lon_mean])\n", "\n", " plt.figure(figsize=(10, 5))\n", "\n", " plt.scatter(all_actuals_denorm[:, 1], all_actuals_denorm[:, 0], label='Actual', color='blue', alpha=0.6)\n", " plt.scatter(all_preds_denorm[:, 1], all_preds_denorm[:, 0], label='Predicted', color='red', alpha=0.6)\n", " for i in range(len(all_actuals_denorm)):\n", " plt.plot(\n", " [all_actuals_denorm[i, 1], all_preds_denorm[i, 1]],\n", " [all_actuals_denorm[i, 0], all_preds_denorm[i, 0]],\n", " color='gray', linewidth=0.5\n", " )\n", "\n", " plt.legend()\n", " plt.xlabel('Longitude')\n", " plt.ylabel('Latitude')\n", " plt.title('Actual vs. Predicted GPS Coordinates with Error Lines')\n", " plt.grid(True)\n", " plt.show()" ], "metadata": { "id": "W1O4anKmd1o7" }, "execution_count": 49, "outputs": [] }, { "cell_type": "code", "source": [ "ensemble_model.eval()\n", "\n", "all_preds = []\n", "all_actuals = []\n", "\n", "with torch.no_grad():\n", " for images, targets in test_dataloader:\n", " images = images.to(\"cuda\")\n", " targets = targets.to(\"cuda\")\n", "\n", " preds = ensemble_model(images)\n", "\n", " all_preds.append(preds.cpu().numpy())\n", " all_actuals.append(targets.cpu().numpy())\n", "\n", "all_preds = np.concatenate(all_preds, axis=0)\n", "all_actuals = np.concatenate(all_actuals, axis=0)\n", "\n", "visualize_predictions(\n", " all_preds=all_preds,\n", " all_actuals=all_actuals,\n", " lat_mean=lat_mean,\n", " lon_mean=lon_mean,\n", " lat_std=lat_std,\n", " lon_std=lon_std\n", ")" ], "metadata": { "id": "m8IiYdxJdYy_" }, "execution_count": null, "outputs": [] } ] }