{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "\n", "! pip install datasets transformers evaluate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load a dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start by loading a small image classification dataset and taking a look at its structure.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "dataset_name = \"jonathan-roberts1/Satellite-Images-of-Hurricane-Damage\"\n", "\n", "def get_ds():\n", " ds = load_dataset(dataset_name)\n", " ds = ds[\"train\"].train_test_split(test_size=0.5)\n", " ds[\"train\"][\"label\"].count(1), ds[\"test\"][\"label\"].count(0)\n", " ds_ = ds[\"test\"].train_test_split(test_size=0.5)\n", " ds[\"validation\"] = ds_[\"train\"]\n", " ds[\"test\"] = ds_[\"test\"]\n", " return ds\n", "\n", "ds = get_ds()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's take a look at the 400th example from the `'train'` split from the dataset. You'll notice each example from the dataset has 2 features:\n", "\n", "1. `image`: A PIL Image\n", "1. `image_file_path`: The `str` path to the image file that was loaded as `image`\n", "1. `labels`: A [`datasets.ClassLabel`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=classlabel#datasets.ClassLabel) feature, which we'll see as an integer representation of the label for a given example. (Later we'll see how to get the string class names, don't worry)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'image': ,\n", " 'label': 0}" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ex = ds['train'][400]\n", "ex" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's take a look at the image 👀" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCACAAIADASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDjIZ2VwN3I9atvELlcHiqCoFPXmrcdyiMEJ+auZNszTIZtNmjH7ol1NVhFj7ycj2rVXfI+I2YN3waWSxdlO4jce9VZgZBZV6E1PDcNkbeakNiR8pGT9Knt7MRnnAFFiVctQF3lU81YYneTuNNj2oPlHPrTs4JbPWqsaIjGdp5OPSjLBOuKdxjH60nGAOPxoHcQswK/M1GWLk7jS9+maQHBzj86ABSwBO5vwNBc7ByR+NL0BBNIBxjGaLAOLNuHJ/GlEjbj8xpuMkZ5xS9GJ55oAxLpGjkKFTn1quFcEAnPr610dxarKM7sEdCKzZ7MKODlvSpTsS1YS3nELAhuO9akVwkoBBFc7Kjq3PI7irlm+0g5quYVzaKqRzURC9j+lSRgNHkGmsuFGBzTuUMz2FLnOO9BA3YPSkwMnjge9Fxi5BODijBpByucVJtBUEE80AM2knimMwTqaljJEuAOnr3qvfXUcTbZI2x/eHalcCQMO+PxpSMnFULWXLFVceX2JPNXsjIAOR9aAF60DFICGbGKUY55pgTkFRxj8aWzhhuJT5/yehp8oGMdKpySlAQKybE3Ytz6BuYsHiVT2LVRl0v7IQDIh9ADms7UZG2qxZx7g1UG9l/du34Ek0KLFe50CZA27hSiZA2xnXd6ZrmnguyeHmojguoXLqhdv9oVauJXOnPJzlc/WowvJIZcn3rnnju5H3lpFz/CvApUgulBAVz755oRWp0OCq8sMjuTTh+8UbXVh1yDXMvZ3bc5lI9C1SeTeNGIyCqjoF4oA6IEh8/Lj1zU15aLe6SwQpvXOCTXLi1uTwN+P96nfZLkA5Mn/fZqWNDfIEIPmBsjstLbXKh8LIQT2bjNNNrdZzzn69ab9kut25oulWhO5tws7cMpPoasKrc/L+VYDW12wHzOp7bTxSiC8DDLy/g1INjoDdpICyniqkj4VnC5I7VgC8nCkAkjuBU0N83llCMk9yalxE0OuriW8wDCBtPQHrUsDbCGC4YdvSqqOArYYg1LC24EgknuTTsNI2Vk3KCadn3qraEHK5yRzVoLkcDAoLEzkUcZBpdtG3NIBMAnNNJB/CpCpxTGXnFMLCbgT1pcj34pgXIyKULjnkiiwEkYDnaeD2pZ1EWAXGfSogOcjg9qrS7/ADiztuI6A0gLeR16UEd6amSARTs560wOZEpUYVht74FAGzJIGDUQBcZVSKlChuD19KdrGVxUY7iTjmpod4ABAPPUd6jEfrUsZ2dOKQJ2L9tKwlXgADrWyI2OODisvTwjSBmwMVvxnjpQzVFXyG/uNSeQ3pWgVzz1/GmEelIdin5Deh/CgwN/dq4M0o/SgCibVicgGg2zd1OfatAYzS4zSAzfsr/3arXduUUEjBrZKenSsy/R3Ybeg4NNCZBaxtJ8qgmrP2OXsOfrVq1hEUI4wTVgAdaVx2PPgrAYJ5HWpUBJ9a0Liye3flOvtVd49p6EU73MbPqMCgd808ICR3puMngVZhjJ4NDBFq0jORjA9q24M4GCPzrNggz2x/WraxbRwcfjTSLTL2GBzkD6mja3XH45qkRg4Jpdp/vH86dh3LmxvT8zTdrg/wBc1XIAGA7E0nOM7yaXKO5ZCk9f51PHDJJkAqPqazMnuxoDdst9M0WC5oujKcMR+BqMRgnO1f61V3g/xfrTCATncfzosK5dMcnoKUBxweD9aprt7uw9Dmo5yI03KGb1yaOULst3MCupzWVLZqHz61syTKucn9KqSsrc1KRLKqaWjLlSPelWyCZI/GrUEwAIb7vY0NMuSauwaDEXaBzxUjNjpUBlx0U4zSlySDsP40DZIM5oYkA+ooLpuyucd89qa7q2cZz607iG7iRknmmOzDoetKxPl/L+dQlujMcY7UmwuSK4XlzxSNOhPCiqksu4+1R5PrWbmRKVi+JlPVRila4TH3B+dUQT60c9annZHOy75y+n40plQrgoDVMfjS4JHBo52NTZWOoPISGMrA9iKYbu4ZBHHHIoH8XeoHvkIDIpJ/u1oW7GRQQ3XsK0bNWjPlkuoxuXzwp/jPSmNNJKmfNmLjsK2Lnd5JjlYgf3TWWJ4o+G4/DrRzMh3Kf266QhRNID71ML+VMF7iVPYioJmDS5jGT60OJJsMwDAegqtwWpNJqDNwbqQD6VEt9KG2JPIV96jcKPvKQ3amIxDDvz0A6UirFp7u66CWTFIDcyj5pHbHTmkM+GCsoA9c1IN6MGSTilZkXIJGuMjcXyOlOiu5futIV+oqSR9+W3Ek96rsMkMfvUcqHuaCNOVyJCffbUm+4C92J/2agt55IF6k1oRXXmnAHNS00DiZ++553uw9PlpDJcfd3P9cVqNuPXFVJmKtg84ppoSP/Z", "image/png": "", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "image = ex['image']\n", "image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since the `'labels'` feature of this dataset is a `datasets.features.ClassLabel`, we can use it to lookup the corresponding name for this example's label ID.\n", "\n", "First, lets access the feature definition for the `'label'`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['image', 'label'],\n", " num_rows: 5000\n", " })\n", " test: Dataset({\n", " features: ['image', 'label'],\n", " num_rows: 2500\n", " })\n", " validation: Dataset({\n", " features: ['image', 'label'],\n", " num_rows: 2500\n", " })\n", "})" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ds" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ClassLabel(names=['flooded or damaged buildings', 'undamaged buildings'], id=None)" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "labels = ds['train'].features['label']\n", "labels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, lets print out the class label for our example. We'll do that by using the [`int2str`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=classlabel#datasets.ClassLabel.int2str) function of `ClassLabel`, which, as the name implies, lets us pass the int representation of the class to look up the string label." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'flooded or damaged buildings'" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "labels.int2str(ex['label'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's write a function that'll display a grid of examples from each class so we can get a better idea of what we're working with." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ae3ce1bae7f84edea0a27b833b1e2d1f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Filter: 0%| | 0/5000 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from transformers.utils.dummy_vision_objects import ImageGPTFeatureExtractor\n", "import random\n", "from PIL import ImageDraw, ImageFont, Image\n", "\n", "def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):\n", "\n", " w, h = size\n", " labels = ds['train'].features['label'].names\n", " grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))\n", " draw = ImageDraw.Draw(grid)\n", " font = ImageFont.truetype(\"/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf\", 24)\n", "\n", " for label_id, label in enumerate(labels):\n", "\n", " # Filter the dataset by a single label, shuffle it, and grab a few samples\n", " ds_slice = ds['train'].filter(lambda ex: ex['label'] == label_id).shuffle(seed).select(range(examples_per_class))\n", "\n", " # Plot this label's examples along a row\n", " for i, example in enumerate(ds_slice):\n", " image = example['image']\n", " idx = examples_per_class * label_id + i\n", " box = (idx % examples_per_class * w, idx // examples_per_class * h)\n", " grid.paste(image.resize(size), box=box)\n", " draw.text(box, label, (255, 255, 255), font=font)\n", "\n", " return grid\n", "\n", "show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 }