Upload folder using huggingface_hub
Browse files- README.md +23 -7
- __pycache__/gradio_demo.cpython-311.pyc +0 -0
- __pycache__/model_inference.cpython-311.pyc +0 -0
- analyze_checklist.ipynb +155 -0
- gbif_data_stats.sh +16 -0
- gradio_demo.py +40 -0
- job_clean_dataset.sh +31 -0
- job_copy_data_to_object_storage.sh +28 -0
- job_create_webdataset.sh +40 -0
- job_delete_images.sh +30 -0
- job_fetch_images.sh +31 -0
- job_gradio_demo.sh +17 -0
- job_predict_lifestage.sh +40 -0
- job_split_dataset.sh +32 -0
- job_verify_images.sh +34 -0
- key_to_name_map.py +31 -0
- model_inference.py +153 -0
- prepare_gbif_checklist.py +69 -0
- split_verification_list.py +32 -0
- test.py +28 -0
README.md
CHANGED
@@ -1,12 +1,28 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: green
|
5 |
-
colorTo: green
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.42.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Mila_Global_Moth_Classifier
|
3 |
+
app_file: gradio_demo.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 4.42.0
|
|
|
|
|
6 |
---
|
7 |
+
# Global Moth Model
|
8 |
+
Research related to the development of a global moth species classification model for automated moth monitoring.
|
9 |
|
10 |
+
## Process
|
11 |
+
The below steps are carrried out to train a global model.
|
12 |
+
|
13 |
+
### Checklist preparation
|
14 |
+
1. **Fetch Leps Checklist**: Download the Lepidoptera taxonomy from GBIF ([DOI](https://www.gbif.org/occurrence/download/)).
|
15 |
+
2. **Fetch DwC-A**: Fetch the Darwin Core Archive from GBIF for the order Lepidoptera ([DOI](https://doi.org/10.15468/dl.6j5bzj)).
|
16 |
+
3. **Curate Moth Checklist** (`prepare_gbif_checklist.py`): Clean and curate the Lepidoptera checklist to have only moth species. Remove all non-species taxa and butterfly families. A curated list is [here](https://docs.google.com/spreadsheets/d/1E6Zn2hXbHGMMAiPhtDXFO9_hDtl68lG5fx2vg0jyBvg/edit?usp=sharing).
|
17 |
+
|
18 |
+
### Dataset download and curation
|
19 |
+
The next steps to download and curate data are followed from [here](https://github.com/RolnickLab/ami-ml/tree/main/src/dataset_tools).
|
20 |
+
|
21 |
+
1. **Fetch GBIF images**: Download the images from GBIF using the command `ami-dataset fetch-images`. An example slurm script with the argument options is provided (`job_fetch_images.sh`). The DwC-A file requires about 300GB of RAM to be loaded. There should be smarter ways to load the archive file in (multiple?) smaller memory but we haven't explored it ourselves.
|
22 |
+
2. **Verify images**: Verify the downloaded images for corruption (`job_verify_images.sh`).
|
23 |
+
3. **Delete corrupted images**: `job_delete_images.sh`
|
24 |
+
4. **Lifestage prediction:** Run the lifestage prediction model on images without the lifestage tag. The purpose is to remove non-adult moth images from the dataset (`job_predict_lifestage.sh`).
|
25 |
+
5. **Final clean dataset:** Create the final list of images cleaned after image verification and lifestage prediction (`job_clean_dataset.sh`).
|
26 |
+
6. **Dataset splits:** Create dataset splits for model training (`job_split_dataset.sh`).
|
27 |
+
|
28 |
+
### Model training
|
__pycache__/gradio_demo.cpython-311.pyc
ADDED
Binary file (1.74 kB). View file
|
|
__pycache__/model_inference.cpython-311.pyc
ADDED
Binary file (7.93 kB). View file
|
|
analyze_checklist.ipynb
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"data": {
|
10 |
+
"text/plain": [
|
11 |
+
"True"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
"execution_count": 1,
|
15 |
+
"metadata": {},
|
16 |
+
"output_type": "execute_result"
|
17 |
+
}
|
18 |
+
],
|
19 |
+
"source": [
|
20 |
+
"# System packages\n",
|
21 |
+
"import sys\n",
|
22 |
+
"import os\n",
|
23 |
+
"\n",
|
24 |
+
"# 3rd party packages\n",
|
25 |
+
"import pandas as pd\n",
|
26 |
+
"import dotenv\n",
|
27 |
+
"import json\n",
|
28 |
+
"\n",
|
29 |
+
"# Our main package (coming soon!)\n",
|
30 |
+
"# import ami_ml\n",
|
31 |
+
"\n",
|
32 |
+
"# Local development packages not yet in the main package\n",
|
33 |
+
"sys.path.append(\"./\")\n",
|
34 |
+
"\n",
|
35 |
+
"# Auto reload your development packages\n",
|
36 |
+
"%load_ext autoreload\n",
|
37 |
+
"%autoreload 2\n",
|
38 |
+
"\n",
|
39 |
+
"# Load secrets and config from optional .env file\n",
|
40 |
+
"dotenv.load_dotenv()"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": 2,
|
46 |
+
"metadata": {},
|
47 |
+
"outputs": [
|
48 |
+
{
|
49 |
+
"name": "stdout",
|
50 |
+
"output_type": "stream",
|
51 |
+
"text": [
|
52 |
+
"No. of accepted moth species: 46983\n",
|
53 |
+
"No. of unique genera: 9413\n",
|
54 |
+
"No. of unique families: 124\n"
|
55 |
+
]
|
56 |
+
}
|
57 |
+
],
|
58 |
+
"source": [
|
59 |
+
"# Read the global moth checklist\n",
|
60 |
+
"moth_checklist_df = pd.read_csv(os.getenv(\"GLOBAL_MOTH_CHECKLIST\"))\n",
|
61 |
+
"\n",
|
62 |
+
"# Get statistics regarding accepted moth species\n",
|
63 |
+
"accepted_moths = moth_checklist_df[moth_checklist_df[\"taxonomicStatus\"] == \"ACCEPTED\"]\n",
|
64 |
+
"num_genus = set(accepted_moths[\"genus\"])\n",
|
65 |
+
"num_family = set(accepted_moths[\"family\"])\n",
|
66 |
+
"print(f\"No. of accepted moth species: {accepted_moths.shape[0]}\")\n",
|
67 |
+
"print(f\"No. of unique genera: {len(num_genus)}\")\n",
|
68 |
+
"print(f\"No. of unique families: {len(num_family)}\")\n"
|
69 |
+
]
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"cell_type": "code",
|
73 |
+
"execution_count": 29,
|
74 |
+
"metadata": {},
|
75 |
+
"outputs": [],
|
76 |
+
"source": [
|
77 |
+
"# Save the accepted taxon keys to json file\n",
|
78 |
+
"unique_accepted_keys = list(accepted_moths[\"acceptedTaxonKey\"])\n",
|
79 |
+
"file_path = os.getenv(\"ACCEPTED_KEY_LIST\")\n",
|
80 |
+
"with open(file_path, \"w\") as file:\n",
|
81 |
+
" json.dump(unique_accepted_keys, file)"
|
82 |
+
]
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"cell_type": "code",
|
86 |
+
"execution_count": 30,
|
87 |
+
"metadata": {},
|
88 |
+
"outputs": [],
|
89 |
+
"source": [
|
90 |
+
"# Test the json file read\n",
|
91 |
+
"with open(os.getenv(\"ACCEPTED_KEY_LIST\")) as f:\n",
|
92 |
+
" keys_list = json.load(f)\n",
|
93 |
+
" keys_list = [int(x) for x in keys_list]"
|
94 |
+
]
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"cell_type": "markdown",
|
98 |
+
"metadata": {},
|
99 |
+
"source": [
|
100 |
+
"Calculate the total occurrences for all accepted taxon keys, with a cap of 1000"
|
101 |
+
]
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"cell_type": "code",
|
105 |
+
"execution_count": 8,
|
106 |
+
"metadata": {},
|
107 |
+
"outputs": [
|
108 |
+
{
|
109 |
+
"name": "stdout",
|
110 |
+
"output_type": "stream",
|
111 |
+
"text": [
|
112 |
+
"The total occurrences with a cap of thousand images is 3898528.\n"
|
113 |
+
]
|
114 |
+
}
|
115 |
+
],
|
116 |
+
"source": [
|
117 |
+
"num_occ = list(accepted_moths[\"numberOfOccurrences\"])\n",
|
118 |
+
"num_occ_limit = [] \n",
|
119 |
+
"for count in num_occ:\n",
|
120 |
+
" if count <= 1000: num_occ_limit.append(count)\n",
|
121 |
+
" else: num_occ_limit.append(1000)\n",
|
122 |
+
"\n",
|
123 |
+
"print(f\"The total occurrences with a cap of thousand images is {sum(num_occ_limit)}.\")"
|
124 |
+
]
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"cell_type": "code",
|
128 |
+
"execution_count": null,
|
129 |
+
"metadata": {},
|
130 |
+
"outputs": [],
|
131 |
+
"source": []
|
132 |
+
}
|
133 |
+
],
|
134 |
+
"metadata": {
|
135 |
+
"kernelspec": {
|
136 |
+
"display_name": "Python 3",
|
137 |
+
"language": "python",
|
138 |
+
"name": "python3"
|
139 |
+
},
|
140 |
+
"language_info": {
|
141 |
+
"codemirror_mode": {
|
142 |
+
"name": "ipython",
|
143 |
+
"version": 3
|
144 |
+
},
|
145 |
+
"file_extension": ".py",
|
146 |
+
"mimetype": "text/x-python",
|
147 |
+
"name": "python",
|
148 |
+
"nbconvert_exporter": "python",
|
149 |
+
"pygments_lexer": "ipython3",
|
150 |
+
"version": "3.11.9"
|
151 |
+
}
|
152 |
+
},
|
153 |
+
"nbformat": 4,
|
154 |
+
"nbformat_minor": 2
|
155 |
+
}
|
gbif_data_stats.sh
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Load absolute data paths
|
4 |
+
set -o allexport
|
5 |
+
source .env
|
6 |
+
set +o allexport
|
7 |
+
|
8 |
+
# Calculate data statistics
|
9 |
+
datasets_count=$(ls $GLOBAL_MODEL_DATASET_PATH | wc -l)
|
10 |
+
num_images=$(find $GLOBAL_MODEL_DATASET_PATH -type f | wc -l)
|
11 |
+
dataset_size=$(du -sh $GLOBAL_MODEL_DATASET_PATH)
|
12 |
+
|
13 |
+
# Print statistics
|
14 |
+
echo "Number of dataset sources: $datasets_count"
|
15 |
+
echo "Number of images: $num_images"
|
16 |
+
echo "Dataset size: $dataset_size"
|
gradio_demo.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import PIL
|
5 |
+
import torch
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
from model_inference import ModelInference
|
8 |
+
|
9 |
+
# Load secrets and config from optional .env file
|
10 |
+
load_dotenv()
|
11 |
+
GLOBAL_MODEL = os.getenv("GLOBAL_MODEL")
|
12 |
+
CATEGORY_MAP = os.getenv("CATEGORY_MAP_JSON")
|
13 |
+
CATEG_TO_NAME_MAP = os.getenv("CATEG_TO_NAME_MAP")
|
14 |
+
|
15 |
+
|
16 |
+
# Model prediction function
|
17 |
+
def predict_species(image: PIL.Image.Image) -> dict[str, float]:
|
18 |
+
"""Moth species prediction"""
|
19 |
+
|
20 |
+
# Build the model class
|
21 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
+
fgrained_classifier = ModelInference(
|
23 |
+
GLOBAL_MODEL, "timm_resnet50", CATEGORY_MAP, CATEG_TO_NAME_MAP, device, topk=5
|
24 |
+
)
|
25 |
+
|
26 |
+
# Predict on image
|
27 |
+
sp_pred = fgrained_classifier.predict(image)
|
28 |
+
|
29 |
+
return sp_pred
|
30 |
+
|
31 |
+
|
32 |
+
demo = gr.Interface(
|
33 |
+
fn=predict_species,
|
34 |
+
inputs=gr.Image(type="pil"),
|
35 |
+
outputs=gr.Label(),
|
36 |
+
title="Mila Global Moth Species Classifier",
|
37 |
+
)
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
demo.launch(share=True)
|
job_clean_dataset.sh
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --job-name=clean_dataset
|
3 |
+
#SBATCH --ntasks=1
|
4 |
+
#SBATCH --time=3:00:00
|
5 |
+
#SBATCH --partition=long-cpu # Ask for long-cpu job
|
6 |
+
#SBATCH --cpus-per-task=2 # Ask for 2 CPUs
|
7 |
+
#SBATCH --mem=300G # Ask for 300 GB of RAM
|
8 |
+
#SBATCH --output=clean_dataset_%j.out
|
9 |
+
|
10 |
+
# 1. Load the required modules
|
11 |
+
module load miniconda/3
|
12 |
+
|
13 |
+
# 2. Load your environment
|
14 |
+
conda activate ami-ml
|
15 |
+
|
16 |
+
# 3. Load the environment variables outside of python script
|
17 |
+
set -o allexport
|
18 |
+
source .env
|
19 |
+
set +o allexport
|
20 |
+
|
21 |
+
# Keep track of time
|
22 |
+
SECONDS=0
|
23 |
+
|
24 |
+
# 4. Launch your script
|
25 |
+
ami-dataset clean-dataset \
|
26 |
+
--dwca-file $DWCA_FILE \
|
27 |
+
--verified-data-csv $VERIFICATION_RESULTS \
|
28 |
+
--life-stage-predictions $LIFESTAGE_RESULTS
|
29 |
+
|
30 |
+
# Print time taken to execute the script
|
31 |
+
echo "Time taken to clean the dataset: $SECONDS seconds"
|
job_copy_data_to_object_storage.sh
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --job-name=upload_dataset
|
3 |
+
#SBATCH --ntasks=1
|
4 |
+
#SBATCH --time=96:00:00
|
5 |
+
#SBATCH --partition=long-cpu # Ask for long cpu job
|
6 |
+
#SBATCH --cpus-per-task=2 # Ask for 2 CPUs
|
7 |
+
#SBATCH --mem=4G # Ask for 4 GB of RAM
|
8 |
+
#SBATCH --output=upload_dataset_%j.out
|
9 |
+
|
10 |
+
# 1. Load the required modules
|
11 |
+
module load miniconda/3
|
12 |
+
|
13 |
+
# 2. Load your environment
|
14 |
+
conda activate ami-ml
|
15 |
+
|
16 |
+
# 3. Load the environment variables outside of python script
|
17 |
+
set -o allexport
|
18 |
+
source .env
|
19 |
+
set +o allexport
|
20 |
+
|
21 |
+
# Keep track of time
|
22 |
+
SECONDS=0
|
23 |
+
|
24 |
+
# 4. Launch your script
|
25 |
+
aws s3 sync $GLOBAL_MODEL_DIR $GLOBAL_MODEL_OBJECT_STORE
|
26 |
+
|
27 |
+
# Print time taken to execute the script
|
28 |
+
echo "Time taken to upload the dataset: $SECONDS seconds"
|
job_create_webdataset.sh
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --job-name=create_webdataset
|
3 |
+
#SBATCH --ntasks=1
|
4 |
+
#SBATCH --time=72:00:00
|
5 |
+
#SBATCH --partition=long-cpu # Ask for long-cpu job
|
6 |
+
#SBATCH --cpus-per-task=4 # Ask for 4 CPUs
|
7 |
+
#SBATCH --mem=10G # Ask for 10 GB of RAM
|
8 |
+
#SBATCH --output=create_webdataset_%j.out
|
9 |
+
|
10 |
+
# 1. Load the required modules
|
11 |
+
module load miniconda/3
|
12 |
+
|
13 |
+
# 2. Load your environment
|
14 |
+
conda activate ami-ml
|
15 |
+
|
16 |
+
# 3. Load the environment variables outside of python script
|
17 |
+
set -o allexport
|
18 |
+
source .env
|
19 |
+
set +o allexport
|
20 |
+
|
21 |
+
# Keep track of time
|
22 |
+
SECONDS=0
|
23 |
+
|
24 |
+
# 4. Launch your script
|
25 |
+
ami-dataset create-webdataset \
|
26 |
+
--annotations-csv $SAMPLE_TRAIN_CSV \
|
27 |
+
--webdataset-pattern $SAMPLE_TRAIN_WBDS \
|
28 |
+
--wandb-run wbds_train_sample \
|
29 |
+
--dataset-path $GLOBAL_MODEL_DATASET_PATH \
|
30 |
+
--image-path-column image_path \
|
31 |
+
--label-column acceptedTaxonKey \
|
32 |
+
--columns-to-json $COLUMNS_TO_JSON \
|
33 |
+
--resize-min-size 450 \
|
34 |
+
--wandb-entity $WANDB_ENTITY \
|
35 |
+
--wandb-project $WANDB_PROJECT
|
36 |
+
# --save-category-map-json $CATEGORY_MAP_JSON
|
37 |
+
|
38 |
+
|
39 |
+
# Print time taken to execute the script
|
40 |
+
echo "Time taken to create the webdataset: $SECONDS seconds"
|
job_delete_images.sh
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --job-name=delete_corrupted_images
|
3 |
+
#SBATCH --ntasks=1
|
4 |
+
#SBATCH --time=4:00:00
|
5 |
+
#SBATCH --partition=long-cpu # Ask for long-cpu job
|
6 |
+
#SBATCH --cpus-per-task=2 # Ask for 2 CPUs
|
7 |
+
#SBATCH --mem=4G # Ask for 4 GB of RAM
|
8 |
+
#SBATCH --output=delete_corrupted_images_%j.out
|
9 |
+
|
10 |
+
# 1. Load the required modules
|
11 |
+
module load miniconda/3
|
12 |
+
|
13 |
+
# 2. Load your environment
|
14 |
+
conda activate ami-ml
|
15 |
+
|
16 |
+
# 3. Load the environment variables outside of python script
|
17 |
+
set -o allexport
|
18 |
+
source .env
|
19 |
+
set +o allexport
|
20 |
+
|
21 |
+
# Keep track of time
|
22 |
+
SECONDS=0
|
23 |
+
|
24 |
+
# 4. Launch your script
|
25 |
+
ami-dataset delete-images \
|
26 |
+
--error-images-csv $VERIFICATION_ERROR_RESULTS \
|
27 |
+
--base-path $GLOBAL_MODEL_DATASET_PATH
|
28 |
+
|
29 |
+
# Print time taken to execute the script
|
30 |
+
echo "Time taken to delete the corrupted images: $SECONDS seconds"
|
job_fetch_images.sh
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --job-name=fetch_gbif_images
|
3 |
+
#SBATCH --partition=long-cpu # Ask for long-cpu job
|
4 |
+
#SBATCH --cpus-per-task=1 # Ask for 1 CPUs
|
5 |
+
#SBATCH --mem=300G # Ask for 300 GB of RAM
|
6 |
+
#SBATCH --output=fetch_gbif_images_%j.out
|
7 |
+
|
8 |
+
# 1. Load the required modules
|
9 |
+
module load miniconda/3
|
10 |
+
|
11 |
+
# 2. Load your environment
|
12 |
+
conda activate ami-ml
|
13 |
+
|
14 |
+
# 3. Load the environment variables outside of python script
|
15 |
+
set -o allexport
|
16 |
+
source .env
|
17 |
+
set +o allexport
|
18 |
+
|
19 |
+
# Keep track of time
|
20 |
+
SECONDS=0
|
21 |
+
|
22 |
+
# 4. Launch your script
|
23 |
+
ami-dataset fetch-images \
|
24 |
+
--dataset-path $GLOBAL_MODEL_DATASET_PATH \
|
25 |
+
--dwca-file $DWCA_FILE \
|
26 |
+
--num-images-per-category 1000 \
|
27 |
+
--num-workers 4 \
|
28 |
+
--subset-list $ACCEPTED_KEY_LIST
|
29 |
+
|
30 |
+
# Print time taken to execute the script
|
31 |
+
echo "Time taken: $SECONDS seconds"
|
job_gradio_demo.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --job-name=gradio_demo
|
3 |
+
#SBATCH --ntasks=1
|
4 |
+
#SBATCH --time=120:00:00
|
5 |
+
#SBATCH --partition=long-cpu # Ask for long-cpu job
|
6 |
+
#SBATCH --cpus-per-task=1 # Ask for 1 CPUs
|
7 |
+
#SBATCH --mem=5G # Ask for 5 GB of RAM
|
8 |
+
#SBATCH --output=gradio_demo_%j.out
|
9 |
+
|
10 |
+
# 1. Load the required modules
|
11 |
+
module load miniconda/3
|
12 |
+
|
13 |
+
# 2. Load your environment
|
14 |
+
conda activate ami-ml
|
15 |
+
|
16 |
+
# 3. Run the demo
|
17 |
+
gradio global_moth_model/gradio_demo.py
|
job_predict_lifestage.sh
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --job-name=lifestage_prediction
|
3 |
+
#SBATCH --ntasks=1
|
4 |
+
#SBATCH --time=24:00:00
|
5 |
+
#SBATCH --mem=16G
|
6 |
+
#SBATCH --partition=long # Ask for long job
|
7 |
+
#SBATCH --cpus-per-task=4 # Ask for 4 CPUs
|
8 |
+
#SBATCH --gres=gpu:1 # Ask or 1 GPU
|
9 |
+
#SBATCH --output=lifestage_prediction_%j.out
|
10 |
+
|
11 |
+
# 1. Load the required modules
|
12 |
+
module load miniconda/3
|
13 |
+
|
14 |
+
# 2. Load your environment
|
15 |
+
conda activate ami-ml
|
16 |
+
|
17 |
+
# 3. Load the environment variables outside of python script
|
18 |
+
set -o allexport
|
19 |
+
source .env
|
20 |
+
set +o allexport
|
21 |
+
|
22 |
+
# Keep track of time
|
23 |
+
SECONDS=0
|
24 |
+
|
25 |
+
# 4. Launch your script
|
26 |
+
ami-dataset predict-lifestage \
|
27 |
+
--verified-data-csv $VERIFICATION_RESULTS_P2 \
|
28 |
+
--results-csv $LIFESTAGE_RESULTS_P2 \
|
29 |
+
--wandb-run lifestage_prediction_p2 \
|
30 |
+
--dataset-path $GLOBAL_MODEL_DATASET_PATH \
|
31 |
+
--model-path $LIFESTAGE_MODEL \
|
32 |
+
--category-map-json $LIFESTAGE_CATEGORY_MAP \
|
33 |
+
--wandb-entity $WANDB_ENTITY \
|
34 |
+
--wandb-project $WANDB_PROJECT \
|
35 |
+
--log-frequence 25 \
|
36 |
+
--batch-size 1024 \
|
37 |
+
--num-classes 2
|
38 |
+
|
39 |
+
# Print time taken to execute the script
|
40 |
+
echo "Time taken to run life stage prediction: $SECONDS seconds"
|
job_split_dataset.sh
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --job-name=split_dataset
|
3 |
+
#SBATCH --ntasks=1
|
4 |
+
#SBATCH --time=2:00:00
|
5 |
+
#SBATCH --partition=long-cpu # Ask for long-cpu job
|
6 |
+
#SBATCH --cpus-per-task=2 # Ask for 2 CPUs
|
7 |
+
#SBATCH --mem=6G # Ask for 6 GB of RAM
|
8 |
+
#SBATCH --output=split_dataset_%j.out
|
9 |
+
|
10 |
+
# 1. Load the required modules
|
11 |
+
module load miniconda/3
|
12 |
+
|
13 |
+
# 2. Load your environment
|
14 |
+
conda activate ami-ml
|
15 |
+
|
16 |
+
# 3. Load the environment variables outside of python script
|
17 |
+
set -o allexport
|
18 |
+
source .env
|
19 |
+
set +o allexport
|
20 |
+
|
21 |
+
# Keep track of time
|
22 |
+
SECONDS=0
|
23 |
+
|
24 |
+
# 4. Launch your script
|
25 |
+
ami-dataset split-dataset \
|
26 |
+
--dataset-csv $FINAL_CLEAN_DATASET \
|
27 |
+
--split-prefix $SPLIT_PREFIX \
|
28 |
+
--max-instances 1000 \
|
29 |
+
--min-instances 4
|
30 |
+
|
31 |
+
# Print time taken to execute the script
|
32 |
+
echo "Time taken to split the dataset: $SECONDS seconds"
|
job_verify_images.sh
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --job-name=verify_gbif_images
|
3 |
+
#SBATCH --ntasks=1
|
4 |
+
#SBATCH --time=24:00:00
|
5 |
+
#SBATCH --partition=long-cpu # Ask for long-cpu job
|
6 |
+
#SBATCH --cpus-per-task=16 # Ask for 16 CPUs
|
7 |
+
#SBATCH --mem=300G # Ask for 300 GB of RAM
|
8 |
+
#SBATCH --output=verify_gbif_images_%j.out
|
9 |
+
|
10 |
+
# 1. Load the required modules
|
11 |
+
module load miniconda/3
|
12 |
+
|
13 |
+
# 2. Load your environment
|
14 |
+
conda activate ami-ml
|
15 |
+
|
16 |
+
# 3. Load the environment variables outside of python script
|
17 |
+
set -o allexport
|
18 |
+
source .env
|
19 |
+
set +o allexport
|
20 |
+
|
21 |
+
# Keep track of time
|
22 |
+
SECONDS=0
|
23 |
+
|
24 |
+
# 4. Launch your script
|
25 |
+
ami-dataset verify-images \
|
26 |
+
--dataset-path $GLOBAL_MODEL_DATASET_PATH \
|
27 |
+
--dwca-file $DWCA_FILE \
|
28 |
+
--num-workers 16 \
|
29 |
+
--results-csv $VERIFICATION_RESULTS \
|
30 |
+
--resume-from-ckpt $VERIFICATION_RESULTS \
|
31 |
+
--subset-list $ACCEPTED_KEY_LIST
|
32 |
+
|
33 |
+
# Print time taken to execute the script
|
34 |
+
echo "Time taken to verify images: $SECONDS seconds"
|
key_to_name_map.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
"""Create a mapping from taxon keys to species names"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
# System packages
|
11 |
+
import pandas as pd
|
12 |
+
|
13 |
+
# 3rd party packages
|
14 |
+
from dotenv import load_dotenv
|
15 |
+
|
16 |
+
# Load secrets and config from optional .env file
|
17 |
+
load_dotenv()
|
18 |
+
|
19 |
+
# Variable definitions
|
20 |
+
GLOBAL_MODEL_DIR = os.getenv("GLOBAL_MODEL_DIR")
|
21 |
+
moth_list = pd.read_csv(Path(GLOBAL_MODEL_DIR) / "gbif_moth_checklist_07242024.csv")
|
22 |
+
map_dict = {}
|
23 |
+
map_file = Path(GLOBAL_MODEL_DIR) / "categ_to_name_map.json"
|
24 |
+
|
25 |
+
# Build the dict
|
26 |
+
for _, row in moth_list.iterrows():
|
27 |
+
map_dict[int(row["acceptedTaxonKey"])] = row["species"]
|
28 |
+
|
29 |
+
# Save the dict
|
30 |
+
with open(map_file, "w") as file:
|
31 |
+
json.dump(map_dict, file, indent=2)
|
model_inference.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import PIL
|
4 |
+
import timm
|
5 |
+
import torch
|
6 |
+
from torchvision import transforms
|
7 |
+
|
8 |
+
|
9 |
+
class ModelInference:
|
10 |
+
"""Model inference class definition"""
|
11 |
+
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
model_path: str,
|
15 |
+
model_type: str,
|
16 |
+
category_map_json: str,
|
17 |
+
categ_to_name_map_json: str,
|
18 |
+
device: str,
|
19 |
+
input_size: int = 128,
|
20 |
+
topk: int = 10,
|
21 |
+
):
|
22 |
+
self.device = device
|
23 |
+
self.topk = topk
|
24 |
+
self.input_size = input_size
|
25 |
+
self.model_type = model_type
|
26 |
+
self.image = None
|
27 |
+
self.id2categ = self._load_category_map(category_map_json)
|
28 |
+
self.categ2name = self._load_categ_to_name_map(categ_to_name_map_json)
|
29 |
+
self.model = self._load_model(model_path, num_classes=len(self.id2categ))
|
30 |
+
self.model.eval()
|
31 |
+
|
32 |
+
def _load_categ_to_name_map(self, categ_to_name_map_json: str):
|
33 |
+
with open(categ_to_name_map_json, "r") as f:
|
34 |
+
categ_to_name_map = json.load(f)
|
35 |
+
|
36 |
+
return categ_to_name_map
|
37 |
+
|
38 |
+
def _load_category_map(self, category_map_json: str):
|
39 |
+
with open(category_map_json, "r") as f:
|
40 |
+
categories_map = json.load(f)
|
41 |
+
|
42 |
+
id2categ = {categories_map[categ]: categ for categ in categories_map}
|
43 |
+
return id2categ
|
44 |
+
|
45 |
+
def _pad_to_square(self):
|
46 |
+
"""Padding transformation to make the image square"""
|
47 |
+
width, height = self.image.size
|
48 |
+
if height < width:
|
49 |
+
return transforms.Pad(padding=[0, 0, 0, width - height])
|
50 |
+
elif height > width:
|
51 |
+
return transforms.Pad(padding=[0, 0, height - width, 0])
|
52 |
+
else:
|
53 |
+
return transforms.Pad(padding=[0, 0, 0, 0])
|
54 |
+
|
55 |
+
def get_transforms(self):
|
56 |
+
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
57 |
+
return transforms.Compose(
|
58 |
+
[
|
59 |
+
self._pad_to_square(),
|
60 |
+
transforms.ToTensor(),
|
61 |
+
transforms.Resize((self.input_size, self.input_size), antialias=True),
|
62 |
+
transforms.Normalize(mean, std),
|
63 |
+
]
|
64 |
+
)
|
65 |
+
|
66 |
+
def _load_model(self, model_path: str, num_classes: int, pretrained: bool = True):
|
67 |
+
if self.model_type == "resnet50":
|
68 |
+
model = timm.create_model(
|
69 |
+
"resnet50", pretrained=pretrained, num_classes=num_classes
|
70 |
+
)
|
71 |
+
|
72 |
+
elif self.model_type == "timm_resnet50":
|
73 |
+
model = timm.create_model(
|
74 |
+
"resnet50", pretrained=pretrained, num_classes=num_classes
|
75 |
+
)
|
76 |
+
|
77 |
+
elif self.model_type == "timm_convnext-t":
|
78 |
+
model = timm.create_model(
|
79 |
+
"convnext_tiny_in22k", pretrained=pretrained, num_classes=num_classes
|
80 |
+
)
|
81 |
+
|
82 |
+
elif self.model_type == "timm_convnext-b":
|
83 |
+
model = timm.create_model(
|
84 |
+
"convnext_base_in22k", pretrained=pretrained, num_classes=num_classes
|
85 |
+
)
|
86 |
+
|
87 |
+
elif self.model_type == "efficientnetv2-b3":
|
88 |
+
model = timm.create_model(
|
89 |
+
"tf_efficientnetv2_b3", pretrained=pretrained, num_classes=num_classes
|
90 |
+
)
|
91 |
+
|
92 |
+
elif self.model_type == "timm_mobilenetv3large":
|
93 |
+
model = timm.create_model(
|
94 |
+
"mobilenetv3_large_100", pretrained=pretrained, num_classes=num_classes
|
95 |
+
)
|
96 |
+
|
97 |
+
elif self.model_type == "timm_vit-b16-128":
|
98 |
+
model = timm.create_model(
|
99 |
+
"vit_base_patch16_224_in21k",
|
100 |
+
pretrained=pretrained,
|
101 |
+
img_size=128,
|
102 |
+
num_classes=num_classes,
|
103 |
+
)
|
104 |
+
|
105 |
+
else:
|
106 |
+
raise RuntimeError(f"Model {self.model_type} not implemented")
|
107 |
+
|
108 |
+
# Load model weights
|
109 |
+
model.load_state_dict(
|
110 |
+
torch.load(model_path, map_location=torch.device(self.device))
|
111 |
+
)
|
112 |
+
# Parallelize inference if multiple GPUs available
|
113 |
+
if torch.cuda.device_count() > 1:
|
114 |
+
model = torch.nn.DataParallel(model)
|
115 |
+
|
116 |
+
model = model.to(self.device)
|
117 |
+
return model
|
118 |
+
|
119 |
+
def predict(self, image: PIL.Image.Image):
|
120 |
+
with torch.no_grad():
|
121 |
+
# Process the image for prediction
|
122 |
+
self.image = image
|
123 |
+
transforms = self.get_transforms()
|
124 |
+
image = transforms(image)
|
125 |
+
image = image.to(self.device)
|
126 |
+
image = image.unsqueeze_(0)
|
127 |
+
|
128 |
+
# Model prediction on the image
|
129 |
+
predictions = self.model(image)
|
130 |
+
predictions = torch.nn.functional.softmax(predictions, dim=1)
|
131 |
+
predictions = predictions.cpu()
|
132 |
+
if self.topk == 0 or self.topk > len(
|
133 |
+
predictions[0]
|
134 |
+
): # topk=0 means get all predictions
|
135 |
+
predictions = torch.topk(predictions, len(predictions[0]))
|
136 |
+
else:
|
137 |
+
predictions = torch.topk(predictions, self.topk)
|
138 |
+
|
139 |
+
# Process the results
|
140 |
+
values, indices = (
|
141 |
+
predictions.values.numpy()[0],
|
142 |
+
predictions.indices.numpy()[0],
|
143 |
+
)
|
144 |
+
pred_results = {}
|
145 |
+
|
146 |
+
for i in range(len(indices)):
|
147 |
+
idx, value = indices[i], values[i]
|
148 |
+
categ = self.id2categ[idx]
|
149 |
+
sp_name = self.categ2name[categ]
|
150 |
+
pred_results[sp_name] = value
|
151 |
+
# pred_results.append([sp_name, round(value*100, 2)])
|
152 |
+
|
153 |
+
return pred_results
|
prepare_gbif_checklist.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
"""Prepare the GBIF checklist for the global moth model"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
# System packages
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
# 3rd party packages
|
13 |
+
from dotenv import load_dotenv
|
14 |
+
|
15 |
+
# Load secrets and config from optional .env file
|
16 |
+
load_dotenv()
|
17 |
+
|
18 |
+
|
19 |
+
def remove_non_species_taxon(checklist: pd.DataFrame) -> pd.DataFrame:
|
20 |
+
"""
|
21 |
+
Remove all non-species taxa from the checklist
|
22 |
+
"""
|
23 |
+
|
24 |
+
# Keep only rows where the taxa rank is "SPECIES"
|
25 |
+
checklist = checklist.loc[checklist["taxonRank"] == "SPECIES"]
|
26 |
+
|
27 |
+
return checklist
|
28 |
+
|
29 |
+
|
30 |
+
def remove_butterflies(checklist: pd.DataFrame) -> pd.DataFrame:
|
31 |
+
"""
|
32 |
+
Remove all butterflies from the checklist
|
33 |
+
"""
|
34 |
+
|
35 |
+
# List of butterfly families
|
36 |
+
butterfly_fm = [
|
37 |
+
"Hesperiidae",
|
38 |
+
"Lycaenidae",
|
39 |
+
"Nymphalidae",
|
40 |
+
"Papilionidae",
|
41 |
+
"Pieridae",
|
42 |
+
"Riodinidae",
|
43 |
+
"Hedylidae",
|
44 |
+
]
|
45 |
+
|
46 |
+
# Remove butterfly families
|
47 |
+
checklist = checklist.loc[~checklist["family"].isin(butterfly_fm)]
|
48 |
+
|
49 |
+
return checklist
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
GLOBAL_MODEL_DIR = os.getenv("GLOBAL_MODEL_DIR")
|
54 |
+
|
55 |
+
# Remove non-species taxa
|
56 |
+
checklist = "gbif_leps_checklist_07242024_original.csv"
|
57 |
+
checklist_pd = pd.read_csv(Path(GLOBAL_MODEL_DIR) / checklist)
|
58 |
+
leps_checklist_pd = remove_non_species_taxon(checklist_pd)
|
59 |
+
leps_checklist_pd.to_csv(
|
60 |
+
Path(GLOBAL_MODEL_DIR) / "gbif_leps_checklist_07242024.csv", index=False
|
61 |
+
)
|
62 |
+
|
63 |
+
# Remove butterflies
|
64 |
+
checklist = "gbif_leps_checklist_07242024.csv"
|
65 |
+
checklist_pd = pd.read_csv(Path(GLOBAL_MODEL_DIR) / checklist)
|
66 |
+
moth_checklist_pd = remove_butterflies(checklist_pd)
|
67 |
+
moth_checklist_pd.to_csv(
|
68 |
+
Path(GLOBAL_MODEL_DIR) / "gbif_moth_checklist_07242024.csv", index=False
|
69 |
+
)
|
split_verification_list.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
"""Split the image verification list to multiple parts"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
# System packages
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
# 3rd party packages
|
13 |
+
from dotenv import load_dotenv
|
14 |
+
|
15 |
+
# Load secrets and config from optional .env file
|
16 |
+
load_dotenv()
|
17 |
+
|
18 |
+
# Load the list
|
19 |
+
img_verf_df = pd.read_csv(os.getenv("VERIFICATION_RESULTS"))
|
20 |
+
img_verf_lstage_nan_df = img_verf_df[img_verf_df.lifeStage.isnull()].copy()
|
21 |
+
|
22 |
+
# Slice the list
|
23 |
+
num_entries = img_verf_lstage_nan_df.shape[0]
|
24 |
+
half = int(num_entries / 2)
|
25 |
+
img_verf_lstage_nan_p1 = img_verf_lstage_nan_df.iloc[:half, :].copy()
|
26 |
+
img_verf_lstage_nan_p2 = img_verf_lstage_nan_df.iloc[half:, :].copy()
|
27 |
+
|
28 |
+
# Save the scripts
|
29 |
+
save_dir = os.getenv("GLOBAL_MODEL_DIR")
|
30 |
+
fname = Path(os.getenv("VERIFICATION_RESULTS")).stem
|
31 |
+
img_verf_lstage_nan_p1.to_csv(Path(save_dir) / str(fname + "_p1" + ".csv"), index=False)
|
32 |
+
img_verf_lstage_nan_p2.to_csv(Path(save_dir) / str(fname + "_p2" + ".csv"), index=False)
|
test.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import webdataset as wds
|
2 |
+
|
3 |
+
# dataset_path = "/home/mila/a/aditya.jain/scratch/global_model/webdataset/train/train450-000000.tar"
|
4 |
+
|
5 |
+
# # Create a WebDataset reader
|
6 |
+
# dataset = wds.WebDataset(dataset_path)
|
7 |
+
|
8 |
+
# for sample in dataset:
|
9 |
+
# a = 2
|
10 |
+
# for key, value in sample.items():
|
11 |
+
# print(f"{key}: {type(value)}")
|
12 |
+
|
13 |
+
import json
|
14 |
+
|
15 |
+
categ_map_f = "/home/mila/a/aditya.jain/scratch/global_model/category_map.json"
|
16 |
+
new_categ_map = {}
|
17 |
+
|
18 |
+
with open(categ_map_f, "r") as f:
|
19 |
+
category_map = json.load(f)
|
20 |
+
|
21 |
+
for key in category_map.keys():
|
22 |
+
new_key = str(int(float(key)))
|
23 |
+
new_categ_map[new_key] = category_map[key]
|
24 |
+
|
25 |
+
|
26 |
+
with open("/home/mila/a/aditya.jain/scratch/global_model/category_map_v2.json", "w") as f:
|
27 |
+
json.dump(new_categ_map, f)
|
28 |
+
|