adityajain07 commited on
Commit
d6c6696
·
verified ·
1 Parent(s): 1655813

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,12 +1,28 @@
1
  ---
2
- title: Mila Global Moth Classifier
3
- emoji: 👁
4
- colorFrom: green
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.42.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Mila_Global_Moth_Classifier
3
+ app_file: gradio_demo.py
 
 
4
  sdk: gradio
5
  sdk_version: 4.42.0
 
 
6
  ---
7
+ # Global Moth Model
8
+ Research related to the development of a global moth species classification model for automated moth monitoring.
9
 
10
+ ## Process
11
+ The below steps are carrried out to train a global model.
12
+
13
+ ### Checklist preparation
14
+ 1. **Fetch Leps Checklist**: Download the Lepidoptera taxonomy from GBIF ([DOI](https://www.gbif.org/occurrence/download/)).
15
+ 2. **Fetch DwC-A**: Fetch the Darwin Core Archive from GBIF for the order Lepidoptera ([DOI](https://doi.org/10.15468/dl.6j5bzj)).
16
+ 3. **Curate Moth Checklist** (`prepare_gbif_checklist.py`): Clean and curate the Lepidoptera checklist to have only moth species. Remove all non-species taxa and butterfly families. A curated list is [here](https://docs.google.com/spreadsheets/d/1E6Zn2hXbHGMMAiPhtDXFO9_hDtl68lG5fx2vg0jyBvg/edit?usp=sharing).
17
+
18
+ ### Dataset download and curation
19
+ The next steps to download and curate data are followed from [here](https://github.com/RolnickLab/ami-ml/tree/main/src/dataset_tools).
20
+
21
+ 1. **Fetch GBIF images**: Download the images from GBIF using the command `ami-dataset fetch-images`. An example slurm script with the argument options is provided (`job_fetch_images.sh`). The DwC-A file requires about 300GB of RAM to be loaded. There should be smarter ways to load the archive file in (multiple?) smaller memory but we haven't explored it ourselves.
22
+ 2. **Verify images**: Verify the downloaded images for corruption (`job_verify_images.sh`).
23
+ 3. **Delete corrupted images**: `job_delete_images.sh`
24
+ 4. **Lifestage prediction:** Run the lifestage prediction model on images without the lifestage tag. The purpose is to remove non-adult moth images from the dataset (`job_predict_lifestage.sh`).
25
+ 5. **Final clean dataset:** Create the final list of images cleaned after image verification and lifestage prediction (`job_clean_dataset.sh`).
26
+ 6. **Dataset splits:** Create dataset splits for model training (`job_split_dataset.sh`).
27
+
28
+ ### Model training
__pycache__/gradio_demo.cpython-311.pyc ADDED
Binary file (1.74 kB). View file
 
__pycache__/model_inference.cpython-311.pyc ADDED
Binary file (7.93 kB). View file
 
analyze_checklist.ipynb ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "data": {
10
+ "text/plain": [
11
+ "True"
12
+ ]
13
+ },
14
+ "execution_count": 1,
15
+ "metadata": {},
16
+ "output_type": "execute_result"
17
+ }
18
+ ],
19
+ "source": [
20
+ "# System packages\n",
21
+ "import sys\n",
22
+ "import os\n",
23
+ "\n",
24
+ "# 3rd party packages\n",
25
+ "import pandas as pd\n",
26
+ "import dotenv\n",
27
+ "import json\n",
28
+ "\n",
29
+ "# Our main package (coming soon!)\n",
30
+ "# import ami_ml\n",
31
+ "\n",
32
+ "# Local development packages not yet in the main package\n",
33
+ "sys.path.append(\"./\")\n",
34
+ "\n",
35
+ "# Auto reload your development packages\n",
36
+ "%load_ext autoreload\n",
37
+ "%autoreload 2\n",
38
+ "\n",
39
+ "# Load secrets and config from optional .env file\n",
40
+ "dotenv.load_dotenv()"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 2,
46
+ "metadata": {},
47
+ "outputs": [
48
+ {
49
+ "name": "stdout",
50
+ "output_type": "stream",
51
+ "text": [
52
+ "No. of accepted moth species: 46983\n",
53
+ "No. of unique genera: 9413\n",
54
+ "No. of unique families: 124\n"
55
+ ]
56
+ }
57
+ ],
58
+ "source": [
59
+ "# Read the global moth checklist\n",
60
+ "moth_checklist_df = pd.read_csv(os.getenv(\"GLOBAL_MOTH_CHECKLIST\"))\n",
61
+ "\n",
62
+ "# Get statistics regarding accepted moth species\n",
63
+ "accepted_moths = moth_checklist_df[moth_checklist_df[\"taxonomicStatus\"] == \"ACCEPTED\"]\n",
64
+ "num_genus = set(accepted_moths[\"genus\"])\n",
65
+ "num_family = set(accepted_moths[\"family\"])\n",
66
+ "print(f\"No. of accepted moth species: {accepted_moths.shape[0]}\")\n",
67
+ "print(f\"No. of unique genera: {len(num_genus)}\")\n",
68
+ "print(f\"No. of unique families: {len(num_family)}\")\n"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": 29,
74
+ "metadata": {},
75
+ "outputs": [],
76
+ "source": [
77
+ "# Save the accepted taxon keys to json file\n",
78
+ "unique_accepted_keys = list(accepted_moths[\"acceptedTaxonKey\"])\n",
79
+ "file_path = os.getenv(\"ACCEPTED_KEY_LIST\")\n",
80
+ "with open(file_path, \"w\") as file:\n",
81
+ " json.dump(unique_accepted_keys, file)"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": 30,
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": [
90
+ "# Test the json file read\n",
91
+ "with open(os.getenv(\"ACCEPTED_KEY_LIST\")) as f:\n",
92
+ " keys_list = json.load(f)\n",
93
+ " keys_list = [int(x) for x in keys_list]"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "markdown",
98
+ "metadata": {},
99
+ "source": [
100
+ "Calculate the total occurrences for all accepted taxon keys, with a cap of 1000"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": 8,
106
+ "metadata": {},
107
+ "outputs": [
108
+ {
109
+ "name": "stdout",
110
+ "output_type": "stream",
111
+ "text": [
112
+ "The total occurrences with a cap of thousand images is 3898528.\n"
113
+ ]
114
+ }
115
+ ],
116
+ "source": [
117
+ "num_occ = list(accepted_moths[\"numberOfOccurrences\"])\n",
118
+ "num_occ_limit = [] \n",
119
+ "for count in num_occ:\n",
120
+ " if count <= 1000: num_occ_limit.append(count)\n",
121
+ " else: num_occ_limit.append(1000)\n",
122
+ "\n",
123
+ "print(f\"The total occurrences with a cap of thousand images is {sum(num_occ_limit)}.\")"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": null,
129
+ "metadata": {},
130
+ "outputs": [],
131
+ "source": []
132
+ }
133
+ ],
134
+ "metadata": {
135
+ "kernelspec": {
136
+ "display_name": "Python 3",
137
+ "language": "python",
138
+ "name": "python3"
139
+ },
140
+ "language_info": {
141
+ "codemirror_mode": {
142
+ "name": "ipython",
143
+ "version": 3
144
+ },
145
+ "file_extension": ".py",
146
+ "mimetype": "text/x-python",
147
+ "name": "python",
148
+ "nbconvert_exporter": "python",
149
+ "pygments_lexer": "ipython3",
150
+ "version": "3.11.9"
151
+ }
152
+ },
153
+ "nbformat": 4,
154
+ "nbformat_minor": 2
155
+ }
gbif_data_stats.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Load absolute data paths
4
+ set -o allexport
5
+ source .env
6
+ set +o allexport
7
+
8
+ # Calculate data statistics
9
+ datasets_count=$(ls $GLOBAL_MODEL_DATASET_PATH | wc -l)
10
+ num_images=$(find $GLOBAL_MODEL_DATASET_PATH -type f | wc -l)
11
+ dataset_size=$(du -sh $GLOBAL_MODEL_DATASET_PATH)
12
+
13
+ # Print statistics
14
+ echo "Number of dataset sources: $datasets_count"
15
+ echo "Number of images: $num_images"
16
+ echo "Dataset size: $dataset_size"
gradio_demo.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import PIL
5
+ import torch
6
+ from dotenv import load_dotenv
7
+ from model_inference import ModelInference
8
+
9
+ # Load secrets and config from optional .env file
10
+ load_dotenv()
11
+ GLOBAL_MODEL = os.getenv("GLOBAL_MODEL")
12
+ CATEGORY_MAP = os.getenv("CATEGORY_MAP_JSON")
13
+ CATEG_TO_NAME_MAP = os.getenv("CATEG_TO_NAME_MAP")
14
+
15
+
16
+ # Model prediction function
17
+ def predict_species(image: PIL.Image.Image) -> dict[str, float]:
18
+ """Moth species prediction"""
19
+
20
+ # Build the model class
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ fgrained_classifier = ModelInference(
23
+ GLOBAL_MODEL, "timm_resnet50", CATEGORY_MAP, CATEG_TO_NAME_MAP, device, topk=5
24
+ )
25
+
26
+ # Predict on image
27
+ sp_pred = fgrained_classifier.predict(image)
28
+
29
+ return sp_pred
30
+
31
+
32
+ demo = gr.Interface(
33
+ fn=predict_species,
34
+ inputs=gr.Image(type="pil"),
35
+ outputs=gr.Label(),
36
+ title="Mila Global Moth Species Classifier",
37
+ )
38
+
39
+ if __name__ == "__main__":
40
+ demo.launch(share=True)
job_clean_dataset.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=clean_dataset
3
+ #SBATCH --ntasks=1
4
+ #SBATCH --time=3:00:00
5
+ #SBATCH --partition=long-cpu # Ask for long-cpu job
6
+ #SBATCH --cpus-per-task=2 # Ask for 2 CPUs
7
+ #SBATCH --mem=300G # Ask for 300 GB of RAM
8
+ #SBATCH --output=clean_dataset_%j.out
9
+
10
+ # 1. Load the required modules
11
+ module load miniconda/3
12
+
13
+ # 2. Load your environment
14
+ conda activate ami-ml
15
+
16
+ # 3. Load the environment variables outside of python script
17
+ set -o allexport
18
+ source .env
19
+ set +o allexport
20
+
21
+ # Keep track of time
22
+ SECONDS=0
23
+
24
+ # 4. Launch your script
25
+ ami-dataset clean-dataset \
26
+ --dwca-file $DWCA_FILE \
27
+ --verified-data-csv $VERIFICATION_RESULTS \
28
+ --life-stage-predictions $LIFESTAGE_RESULTS
29
+
30
+ # Print time taken to execute the script
31
+ echo "Time taken to clean the dataset: $SECONDS seconds"
job_copy_data_to_object_storage.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=upload_dataset
3
+ #SBATCH --ntasks=1
4
+ #SBATCH --time=96:00:00
5
+ #SBATCH --partition=long-cpu # Ask for long cpu job
6
+ #SBATCH --cpus-per-task=2 # Ask for 2 CPUs
7
+ #SBATCH --mem=4G # Ask for 4 GB of RAM
8
+ #SBATCH --output=upload_dataset_%j.out
9
+
10
+ # 1. Load the required modules
11
+ module load miniconda/3
12
+
13
+ # 2. Load your environment
14
+ conda activate ami-ml
15
+
16
+ # 3. Load the environment variables outside of python script
17
+ set -o allexport
18
+ source .env
19
+ set +o allexport
20
+
21
+ # Keep track of time
22
+ SECONDS=0
23
+
24
+ # 4. Launch your script
25
+ aws s3 sync $GLOBAL_MODEL_DIR $GLOBAL_MODEL_OBJECT_STORE
26
+
27
+ # Print time taken to execute the script
28
+ echo "Time taken to upload the dataset: $SECONDS seconds"
job_create_webdataset.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=create_webdataset
3
+ #SBATCH --ntasks=1
4
+ #SBATCH --time=72:00:00
5
+ #SBATCH --partition=long-cpu # Ask for long-cpu job
6
+ #SBATCH --cpus-per-task=4 # Ask for 4 CPUs
7
+ #SBATCH --mem=10G # Ask for 10 GB of RAM
8
+ #SBATCH --output=create_webdataset_%j.out
9
+
10
+ # 1. Load the required modules
11
+ module load miniconda/3
12
+
13
+ # 2. Load your environment
14
+ conda activate ami-ml
15
+
16
+ # 3. Load the environment variables outside of python script
17
+ set -o allexport
18
+ source .env
19
+ set +o allexport
20
+
21
+ # Keep track of time
22
+ SECONDS=0
23
+
24
+ # 4. Launch your script
25
+ ami-dataset create-webdataset \
26
+ --annotations-csv $SAMPLE_TRAIN_CSV \
27
+ --webdataset-pattern $SAMPLE_TRAIN_WBDS \
28
+ --wandb-run wbds_train_sample \
29
+ --dataset-path $GLOBAL_MODEL_DATASET_PATH \
30
+ --image-path-column image_path \
31
+ --label-column acceptedTaxonKey \
32
+ --columns-to-json $COLUMNS_TO_JSON \
33
+ --resize-min-size 450 \
34
+ --wandb-entity $WANDB_ENTITY \
35
+ --wandb-project $WANDB_PROJECT
36
+ # --save-category-map-json $CATEGORY_MAP_JSON
37
+
38
+
39
+ # Print time taken to execute the script
40
+ echo "Time taken to create the webdataset: $SECONDS seconds"
job_delete_images.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=delete_corrupted_images
3
+ #SBATCH --ntasks=1
4
+ #SBATCH --time=4:00:00
5
+ #SBATCH --partition=long-cpu # Ask for long-cpu job
6
+ #SBATCH --cpus-per-task=2 # Ask for 2 CPUs
7
+ #SBATCH --mem=4G # Ask for 4 GB of RAM
8
+ #SBATCH --output=delete_corrupted_images_%j.out
9
+
10
+ # 1. Load the required modules
11
+ module load miniconda/3
12
+
13
+ # 2. Load your environment
14
+ conda activate ami-ml
15
+
16
+ # 3. Load the environment variables outside of python script
17
+ set -o allexport
18
+ source .env
19
+ set +o allexport
20
+
21
+ # Keep track of time
22
+ SECONDS=0
23
+
24
+ # 4. Launch your script
25
+ ami-dataset delete-images \
26
+ --error-images-csv $VERIFICATION_ERROR_RESULTS \
27
+ --base-path $GLOBAL_MODEL_DATASET_PATH
28
+
29
+ # Print time taken to execute the script
30
+ echo "Time taken to delete the corrupted images: $SECONDS seconds"
job_fetch_images.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=fetch_gbif_images
3
+ #SBATCH --partition=long-cpu # Ask for long-cpu job
4
+ #SBATCH --cpus-per-task=1 # Ask for 1 CPUs
5
+ #SBATCH --mem=300G # Ask for 300 GB of RAM
6
+ #SBATCH --output=fetch_gbif_images_%j.out
7
+
8
+ # 1. Load the required modules
9
+ module load miniconda/3
10
+
11
+ # 2. Load your environment
12
+ conda activate ami-ml
13
+
14
+ # 3. Load the environment variables outside of python script
15
+ set -o allexport
16
+ source .env
17
+ set +o allexport
18
+
19
+ # Keep track of time
20
+ SECONDS=0
21
+
22
+ # 4. Launch your script
23
+ ami-dataset fetch-images \
24
+ --dataset-path $GLOBAL_MODEL_DATASET_PATH \
25
+ --dwca-file $DWCA_FILE \
26
+ --num-images-per-category 1000 \
27
+ --num-workers 4 \
28
+ --subset-list $ACCEPTED_KEY_LIST
29
+
30
+ # Print time taken to execute the script
31
+ echo "Time taken: $SECONDS seconds"
job_gradio_demo.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=gradio_demo
3
+ #SBATCH --ntasks=1
4
+ #SBATCH --time=120:00:00
5
+ #SBATCH --partition=long-cpu # Ask for long-cpu job
6
+ #SBATCH --cpus-per-task=1 # Ask for 1 CPUs
7
+ #SBATCH --mem=5G # Ask for 5 GB of RAM
8
+ #SBATCH --output=gradio_demo_%j.out
9
+
10
+ # 1. Load the required modules
11
+ module load miniconda/3
12
+
13
+ # 2. Load your environment
14
+ conda activate ami-ml
15
+
16
+ # 3. Run the demo
17
+ gradio global_moth_model/gradio_demo.py
job_predict_lifestage.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=lifestage_prediction
3
+ #SBATCH --ntasks=1
4
+ #SBATCH --time=24:00:00
5
+ #SBATCH --mem=16G
6
+ #SBATCH --partition=long # Ask for long job
7
+ #SBATCH --cpus-per-task=4 # Ask for 4 CPUs
8
+ #SBATCH --gres=gpu:1 # Ask or 1 GPU
9
+ #SBATCH --output=lifestage_prediction_%j.out
10
+
11
+ # 1. Load the required modules
12
+ module load miniconda/3
13
+
14
+ # 2. Load your environment
15
+ conda activate ami-ml
16
+
17
+ # 3. Load the environment variables outside of python script
18
+ set -o allexport
19
+ source .env
20
+ set +o allexport
21
+
22
+ # Keep track of time
23
+ SECONDS=0
24
+
25
+ # 4. Launch your script
26
+ ami-dataset predict-lifestage \
27
+ --verified-data-csv $VERIFICATION_RESULTS_P2 \
28
+ --results-csv $LIFESTAGE_RESULTS_P2 \
29
+ --wandb-run lifestage_prediction_p2 \
30
+ --dataset-path $GLOBAL_MODEL_DATASET_PATH \
31
+ --model-path $LIFESTAGE_MODEL \
32
+ --category-map-json $LIFESTAGE_CATEGORY_MAP \
33
+ --wandb-entity $WANDB_ENTITY \
34
+ --wandb-project $WANDB_PROJECT \
35
+ --log-frequence 25 \
36
+ --batch-size 1024 \
37
+ --num-classes 2
38
+
39
+ # Print time taken to execute the script
40
+ echo "Time taken to run life stage prediction: $SECONDS seconds"
job_split_dataset.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=split_dataset
3
+ #SBATCH --ntasks=1
4
+ #SBATCH --time=2:00:00
5
+ #SBATCH --partition=long-cpu # Ask for long-cpu job
6
+ #SBATCH --cpus-per-task=2 # Ask for 2 CPUs
7
+ #SBATCH --mem=6G # Ask for 6 GB of RAM
8
+ #SBATCH --output=split_dataset_%j.out
9
+
10
+ # 1. Load the required modules
11
+ module load miniconda/3
12
+
13
+ # 2. Load your environment
14
+ conda activate ami-ml
15
+
16
+ # 3. Load the environment variables outside of python script
17
+ set -o allexport
18
+ source .env
19
+ set +o allexport
20
+
21
+ # Keep track of time
22
+ SECONDS=0
23
+
24
+ # 4. Launch your script
25
+ ami-dataset split-dataset \
26
+ --dataset-csv $FINAL_CLEAN_DATASET \
27
+ --split-prefix $SPLIT_PREFIX \
28
+ --max-instances 1000 \
29
+ --min-instances 4
30
+
31
+ # Print time taken to execute the script
32
+ echo "Time taken to split the dataset: $SECONDS seconds"
job_verify_images.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=verify_gbif_images
3
+ #SBATCH --ntasks=1
4
+ #SBATCH --time=24:00:00
5
+ #SBATCH --partition=long-cpu # Ask for long-cpu job
6
+ #SBATCH --cpus-per-task=16 # Ask for 16 CPUs
7
+ #SBATCH --mem=300G # Ask for 300 GB of RAM
8
+ #SBATCH --output=verify_gbif_images_%j.out
9
+
10
+ # 1. Load the required modules
11
+ module load miniconda/3
12
+
13
+ # 2. Load your environment
14
+ conda activate ami-ml
15
+
16
+ # 3. Load the environment variables outside of python script
17
+ set -o allexport
18
+ source .env
19
+ set +o allexport
20
+
21
+ # Keep track of time
22
+ SECONDS=0
23
+
24
+ # 4. Launch your script
25
+ ami-dataset verify-images \
26
+ --dataset-path $GLOBAL_MODEL_DATASET_PATH \
27
+ --dwca-file $DWCA_FILE \
28
+ --num-workers 16 \
29
+ --results-csv $VERIFICATION_RESULTS \
30
+ --resume-from-ckpt $VERIFICATION_RESULTS \
31
+ --subset-list $ACCEPTED_KEY_LIST
32
+
33
+ # Print time taken to execute the script
34
+ echo "Time taken to verify images: $SECONDS seconds"
key_to_name_map.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ """Create a mapping from taxon keys to species names"""
5
+
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+
10
+ # System packages
11
+ import pandas as pd
12
+
13
+ # 3rd party packages
14
+ from dotenv import load_dotenv
15
+
16
+ # Load secrets and config from optional .env file
17
+ load_dotenv()
18
+
19
+ # Variable definitions
20
+ GLOBAL_MODEL_DIR = os.getenv("GLOBAL_MODEL_DIR")
21
+ moth_list = pd.read_csv(Path(GLOBAL_MODEL_DIR) / "gbif_moth_checklist_07242024.csv")
22
+ map_dict = {}
23
+ map_file = Path(GLOBAL_MODEL_DIR) / "categ_to_name_map.json"
24
+
25
+ # Build the dict
26
+ for _, row in moth_list.iterrows():
27
+ map_dict[int(row["acceptedTaxonKey"])] = row["species"]
28
+
29
+ # Save the dict
30
+ with open(map_file, "w") as file:
31
+ json.dump(map_dict, file, indent=2)
model_inference.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import PIL
4
+ import timm
5
+ import torch
6
+ from torchvision import transforms
7
+
8
+
9
+ class ModelInference:
10
+ """Model inference class definition"""
11
+
12
+ def __init__(
13
+ self,
14
+ model_path: str,
15
+ model_type: str,
16
+ category_map_json: str,
17
+ categ_to_name_map_json: str,
18
+ device: str,
19
+ input_size: int = 128,
20
+ topk: int = 10,
21
+ ):
22
+ self.device = device
23
+ self.topk = topk
24
+ self.input_size = input_size
25
+ self.model_type = model_type
26
+ self.image = None
27
+ self.id2categ = self._load_category_map(category_map_json)
28
+ self.categ2name = self._load_categ_to_name_map(categ_to_name_map_json)
29
+ self.model = self._load_model(model_path, num_classes=len(self.id2categ))
30
+ self.model.eval()
31
+
32
+ def _load_categ_to_name_map(self, categ_to_name_map_json: str):
33
+ with open(categ_to_name_map_json, "r") as f:
34
+ categ_to_name_map = json.load(f)
35
+
36
+ return categ_to_name_map
37
+
38
+ def _load_category_map(self, category_map_json: str):
39
+ with open(category_map_json, "r") as f:
40
+ categories_map = json.load(f)
41
+
42
+ id2categ = {categories_map[categ]: categ for categ in categories_map}
43
+ return id2categ
44
+
45
+ def _pad_to_square(self):
46
+ """Padding transformation to make the image square"""
47
+ width, height = self.image.size
48
+ if height < width:
49
+ return transforms.Pad(padding=[0, 0, 0, width - height])
50
+ elif height > width:
51
+ return transforms.Pad(padding=[0, 0, height - width, 0])
52
+ else:
53
+ return transforms.Pad(padding=[0, 0, 0, 0])
54
+
55
+ def get_transforms(self):
56
+ mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
57
+ return transforms.Compose(
58
+ [
59
+ self._pad_to_square(),
60
+ transforms.ToTensor(),
61
+ transforms.Resize((self.input_size, self.input_size), antialias=True),
62
+ transforms.Normalize(mean, std),
63
+ ]
64
+ )
65
+
66
+ def _load_model(self, model_path: str, num_classes: int, pretrained: bool = True):
67
+ if self.model_type == "resnet50":
68
+ model = timm.create_model(
69
+ "resnet50", pretrained=pretrained, num_classes=num_classes
70
+ )
71
+
72
+ elif self.model_type == "timm_resnet50":
73
+ model = timm.create_model(
74
+ "resnet50", pretrained=pretrained, num_classes=num_classes
75
+ )
76
+
77
+ elif self.model_type == "timm_convnext-t":
78
+ model = timm.create_model(
79
+ "convnext_tiny_in22k", pretrained=pretrained, num_classes=num_classes
80
+ )
81
+
82
+ elif self.model_type == "timm_convnext-b":
83
+ model = timm.create_model(
84
+ "convnext_base_in22k", pretrained=pretrained, num_classes=num_classes
85
+ )
86
+
87
+ elif self.model_type == "efficientnetv2-b3":
88
+ model = timm.create_model(
89
+ "tf_efficientnetv2_b3", pretrained=pretrained, num_classes=num_classes
90
+ )
91
+
92
+ elif self.model_type == "timm_mobilenetv3large":
93
+ model = timm.create_model(
94
+ "mobilenetv3_large_100", pretrained=pretrained, num_classes=num_classes
95
+ )
96
+
97
+ elif self.model_type == "timm_vit-b16-128":
98
+ model = timm.create_model(
99
+ "vit_base_patch16_224_in21k",
100
+ pretrained=pretrained,
101
+ img_size=128,
102
+ num_classes=num_classes,
103
+ )
104
+
105
+ else:
106
+ raise RuntimeError(f"Model {self.model_type} not implemented")
107
+
108
+ # Load model weights
109
+ model.load_state_dict(
110
+ torch.load(model_path, map_location=torch.device(self.device))
111
+ )
112
+ # Parallelize inference if multiple GPUs available
113
+ if torch.cuda.device_count() > 1:
114
+ model = torch.nn.DataParallel(model)
115
+
116
+ model = model.to(self.device)
117
+ return model
118
+
119
+ def predict(self, image: PIL.Image.Image):
120
+ with torch.no_grad():
121
+ # Process the image for prediction
122
+ self.image = image
123
+ transforms = self.get_transforms()
124
+ image = transforms(image)
125
+ image = image.to(self.device)
126
+ image = image.unsqueeze_(0)
127
+
128
+ # Model prediction on the image
129
+ predictions = self.model(image)
130
+ predictions = torch.nn.functional.softmax(predictions, dim=1)
131
+ predictions = predictions.cpu()
132
+ if self.topk == 0 or self.topk > len(
133
+ predictions[0]
134
+ ): # topk=0 means get all predictions
135
+ predictions = torch.topk(predictions, len(predictions[0]))
136
+ else:
137
+ predictions = torch.topk(predictions, self.topk)
138
+
139
+ # Process the results
140
+ values, indices = (
141
+ predictions.values.numpy()[0],
142
+ predictions.indices.numpy()[0],
143
+ )
144
+ pred_results = {}
145
+
146
+ for i in range(len(indices)):
147
+ idx, value = indices[i], values[i]
148
+ categ = self.id2categ[idx]
149
+ sp_name = self.categ2name[categ]
150
+ pred_results[sp_name] = value
151
+ # pred_results.append([sp_name, round(value*100, 2)])
152
+
153
+ return pred_results
prepare_gbif_checklist.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ """Prepare the GBIF checklist for the global moth model"""
5
+
6
+ import os
7
+ from pathlib import Path
8
+
9
+ # System packages
10
+ import pandas as pd
11
+
12
+ # 3rd party packages
13
+ from dotenv import load_dotenv
14
+
15
+ # Load secrets and config from optional .env file
16
+ load_dotenv()
17
+
18
+
19
+ def remove_non_species_taxon(checklist: pd.DataFrame) -> pd.DataFrame:
20
+ """
21
+ Remove all non-species taxa from the checklist
22
+ """
23
+
24
+ # Keep only rows where the taxa rank is "SPECIES"
25
+ checklist = checklist.loc[checklist["taxonRank"] == "SPECIES"]
26
+
27
+ return checklist
28
+
29
+
30
+ def remove_butterflies(checklist: pd.DataFrame) -> pd.DataFrame:
31
+ """
32
+ Remove all butterflies from the checklist
33
+ """
34
+
35
+ # List of butterfly families
36
+ butterfly_fm = [
37
+ "Hesperiidae",
38
+ "Lycaenidae",
39
+ "Nymphalidae",
40
+ "Papilionidae",
41
+ "Pieridae",
42
+ "Riodinidae",
43
+ "Hedylidae",
44
+ ]
45
+
46
+ # Remove butterfly families
47
+ checklist = checklist.loc[~checklist["family"].isin(butterfly_fm)]
48
+
49
+ return checklist
50
+
51
+
52
+ if __name__ == "__main__":
53
+ GLOBAL_MODEL_DIR = os.getenv("GLOBAL_MODEL_DIR")
54
+
55
+ # Remove non-species taxa
56
+ checklist = "gbif_leps_checklist_07242024_original.csv"
57
+ checklist_pd = pd.read_csv(Path(GLOBAL_MODEL_DIR) / checklist)
58
+ leps_checklist_pd = remove_non_species_taxon(checklist_pd)
59
+ leps_checklist_pd.to_csv(
60
+ Path(GLOBAL_MODEL_DIR) / "gbif_leps_checklist_07242024.csv", index=False
61
+ )
62
+
63
+ # Remove butterflies
64
+ checklist = "gbif_leps_checklist_07242024.csv"
65
+ checklist_pd = pd.read_csv(Path(GLOBAL_MODEL_DIR) / checklist)
66
+ moth_checklist_pd = remove_butterflies(checklist_pd)
67
+ moth_checklist_pd.to_csv(
68
+ Path(GLOBAL_MODEL_DIR) / "gbif_moth_checklist_07242024.csv", index=False
69
+ )
split_verification_list.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ """Split the image verification list to multiple parts"""
5
+
6
+ import os
7
+ from pathlib import Path
8
+
9
+ # System packages
10
+ import pandas as pd
11
+
12
+ # 3rd party packages
13
+ from dotenv import load_dotenv
14
+
15
+ # Load secrets and config from optional .env file
16
+ load_dotenv()
17
+
18
+ # Load the list
19
+ img_verf_df = pd.read_csv(os.getenv("VERIFICATION_RESULTS"))
20
+ img_verf_lstage_nan_df = img_verf_df[img_verf_df.lifeStage.isnull()].copy()
21
+
22
+ # Slice the list
23
+ num_entries = img_verf_lstage_nan_df.shape[0]
24
+ half = int(num_entries / 2)
25
+ img_verf_lstage_nan_p1 = img_verf_lstage_nan_df.iloc[:half, :].copy()
26
+ img_verf_lstage_nan_p2 = img_verf_lstage_nan_df.iloc[half:, :].copy()
27
+
28
+ # Save the scripts
29
+ save_dir = os.getenv("GLOBAL_MODEL_DIR")
30
+ fname = Path(os.getenv("VERIFICATION_RESULTS")).stem
31
+ img_verf_lstage_nan_p1.to_csv(Path(save_dir) / str(fname + "_p1" + ".csv"), index=False)
32
+ img_verf_lstage_nan_p2.to_csv(Path(save_dir) / str(fname + "_p2" + ".csv"), index=False)
test.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import webdataset as wds
2
+
3
+ # dataset_path = "/home/mila/a/aditya.jain/scratch/global_model/webdataset/train/train450-000000.tar"
4
+
5
+ # # Create a WebDataset reader
6
+ # dataset = wds.WebDataset(dataset_path)
7
+
8
+ # for sample in dataset:
9
+ # a = 2
10
+ # for key, value in sample.items():
11
+ # print(f"{key}: {type(value)}")
12
+
13
+ import json
14
+
15
+ categ_map_f = "/home/mila/a/aditya.jain/scratch/global_model/category_map.json"
16
+ new_categ_map = {}
17
+
18
+ with open(categ_map_f, "r") as f:
19
+ category_map = json.load(f)
20
+
21
+ for key in category_map.keys():
22
+ new_key = str(int(float(key)))
23
+ new_categ_map[new_key] = category_map[key]
24
+
25
+
26
+ with open("/home/mila/a/aditya.jain/scratch/global_model/category_map_v2.json", "w") as f:
27
+ json.dump(new_categ_map, f)
28
+