|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
import os |
|
import numpy as np |
|
import pandas as pd |
|
from typing import Iterable |
|
|
|
import gradio as gr |
|
from gradio.themes.base import Base |
|
from gradio.themes.utils import colors, fonts, sizes |
|
import requests |
|
import torch |
|
import librosa |
|
import torch.nn.functional as F |
|
|
|
|
|
from audio_class_predictor import predict_class |
|
from bird_ast_model import birdast_preprocess, birdast_inference |
|
from bird_ast_seq_model import birdast_seq_preprocess, birdast_seq_inference |
|
|
|
from utils import plot_wave, plot_mel, download_model, bandpass_filter |
|
|
|
|
|
ASSET_DIR = "./assets" |
|
DEFUALT_SR = 16_000 |
|
DEFUALT_HIGH_CUT = 8_000 |
|
DEFUALT_LOW_CUT = 1_000 |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
print(f"Use Device: {DEVICE}") |
|
|
|
if not os.path.exists(ASSET_DIR): |
|
os.makedirs(ASSET_DIR) |
|
|
|
|
|
|
|
birdast_assets = { |
|
"model_weights": [ |
|
f"https://huggingface.co/shiyi-li/BirdAST/resolve/main/BirdAST_Baseline_GroupKFold_fold_{i}.pth" |
|
for i in range(5) |
|
], |
|
"label_mapping": "https://huggingface.co/shiyi-li/BirdAST/resolve/main/BirdAST_Baseline_GroupKFold_label_map.csv", |
|
"preprocess_fn": birdast_preprocess, |
|
"inference_fn": birdast_inference, |
|
} |
|
|
|
birdast_seq_assets = { |
|
"model_weights": [ |
|
f"https://huggingface.co/shiyi-li/BirdAST_Seq/resolve/main/BirdAST_SeqPool_GroupKFold_fold_{i}.pth" |
|
for i in range(5) |
|
], |
|
"label_mapping": "https://huggingface.co/shiyi-li/BirdAST_Seq/resolve/main/BirdAST_SeqPool_GroupKFold_label_map.csv", |
|
"preprocess_fn": birdast_seq_preprocess, |
|
"inference_fn": birdast_seq_inference, |
|
} |
|
|
|
|
|
ASSET_DICT = { |
|
"BirdAST": birdast_assets, |
|
"BirdAST_Seq": birdast_seq_assets, |
|
} |
|
|
|
|
|
def run_inference_with_model(audio_clip, sr, model_name): |
|
|
|
|
|
assets = ASSET_DICT[model_name] |
|
model_weights_url = assets["model_weights"] |
|
label_map_url = assets["label_mapping"] |
|
preprocess_fn = assets["preprocess_fn"] |
|
inference_fn = assets["inference_fn"] |
|
|
|
|
|
model_weights = [] |
|
for model_weight in model_weights_url: |
|
weight_file = os.path.join(ASSET_DIR, model_weight.split("/")[-1]) |
|
if not os.path.exists(weight_file): |
|
download_model(model_weight, weight_file) |
|
model_weights.append(weight_file) |
|
|
|
|
|
label_map_csv = os.path.join(ASSET_DIR, label_map_url.split("/")[-1]) |
|
if not os.path.exists(label_map_csv): |
|
download_model(label_map_url, label_map_csv) |
|
|
|
|
|
label_mapping = pd.read_csv(label_map_csv) |
|
species_id_to_name = {row["species_id"]: row["scientific_name"] for _, row in label_mapping.iterrows()} |
|
|
|
|
|
spectrogram = preprocess_fn(audio_clip, sr=sr) |
|
|
|
|
|
predictions = inference_fn(model_weights, spectrogram, device=DEVICE) |
|
|
|
|
|
final_predicts = predictions.mean(axis=0) |
|
topk_values, topk_indices = torch.topk(torch.from_numpy(final_predicts), 10) |
|
|
|
results = [] |
|
for idx, scores in zip(topk_indices, topk_values): |
|
species_name = species_id_to_name[idx.item()] |
|
probability = scores.item() * 100 |
|
results.append([species_name, probability]) |
|
|
|
return results |
|
|
|
|
|
|
|
def load_markdown_from_url(url): |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
return response.text |
|
|
|
markdown_url = 'https://github.com/AmroAbdrabo/amroa/raw/main/img/desc.md' |
|
markdown_content = load_markdown_from_url(markdown_url) |
|
|
|
def predict(audio, start, end, model_name="BirdAST_Seq"): |
|
|
|
raw_sr, audio_array = audio |
|
|
|
if audio_array.ndim > 1: |
|
audio_array = audio_array.mean(axis=1) |
|
|
|
print(f"Audio shape raw: {audio_array.shape}, sr: {raw_sr}") |
|
|
|
|
|
len_audio = audio_array.shape[0] / raw_sr |
|
if start >= end: |
|
raise gr.Error(f"`start` ({start}) must be smaller than end ({end}s)") |
|
|
|
if audio_array.shape[0] < start * raw_sr: |
|
raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({len_audio:.0f}s)") |
|
|
|
if audio_array.shape[0] < end * raw_sr: |
|
end = audio_array.shape[0] / (1.0*raw_sr) |
|
|
|
audio_array = np.array(audio_array, dtype=np.float32) / 32768.0 |
|
audio_array = audio_array[int(start*raw_sr) : int(end*raw_sr)] |
|
|
|
if raw_sr != DEFUALT_SR: |
|
|
|
audio_array = bandpass_filter(audio_array, DEFUALT_LOW_CUT, DEFUALT_HIGH_CUT, raw_sr) |
|
audio_array = librosa.resample(audio_array, orig_sr=raw_sr, target_sr=DEFUALT_SR) |
|
print(f"Resampled Audio shape: {audio_array.shape}") |
|
|
|
audio_array = audio_array.astype(np.float32) |
|
|
|
|
|
audio_class = predict_class(audio_array) |
|
|
|
fig_spectrogram = plot_mel(DEFUALT_SR, audio_array) |
|
fig_waveform = plot_wave(DEFUALT_SR, audio_array) |
|
|
|
|
|
print(f"Running inference with model: {model_name}") |
|
species_class = run_inference_with_model(audio_array, DEFUALT_SR, model_name) |
|
|
|
return audio_class, species_class, fig_waveform, fig_spectrogram |
|
|
|
|
|
|
|
|
|
DESCRIPTION = markdown_content |
|
|
|
css = """ |
|
#gradio-animation { |
|
font-size: 2em; |
|
font-weight: bold; |
|
text-align: center; |
|
margin-bottom: 20px; |
|
} |
|
|
|
.logo-container img { |
|
width: 14%; /* Adjust width as necessary */ |
|
display: block; |
|
margin: auto; |
|
} |
|
|
|
.number-input { |
|
height: 100%; |
|
padding-bottom: 60px; /* Adust the value as needed for more or less space */ |
|
} |
|
.full-height { |
|
height: 100%; |
|
} |
|
.column-container { |
|
height: 100%; |
|
} |
|
""" |
|
|
|
|
|
|
|
class Seafoam(Base): |
|
def __init__( |
|
self, |
|
*, |
|
primary_hue: colors.Color | str = colors.emerald, |
|
secondary_hue: colors.Color | str = colors.blue, |
|
neutral_hue: colors.Color | str = colors.gray, |
|
spacing_size: sizes.Size | str = sizes.spacing_md, |
|
radius_size: sizes.Size | str = sizes.radius_md, |
|
text_size: sizes.Size | str = sizes.text_lg, |
|
font: fonts.Font |
|
| str |
|
| Iterable[fonts.Font | str] = ( |
|
fonts.GoogleFont("Quicksand"), |
|
"ui-sans-serif", |
|
"sans-serif", |
|
), |
|
font_mono: fonts.Font |
|
| str |
|
| Iterable[fonts.Font | str] = ( |
|
fonts.GoogleFont("IBM Plex Mono"), |
|
"ui-monospace", |
|
"monospace", |
|
), |
|
): |
|
super().__init__( |
|
primary_hue=primary_hue, |
|
secondary_hue=secondary_hue, |
|
neutral_hue=neutral_hue, |
|
spacing_size=spacing_size, |
|
radius_size=radius_size, |
|
text_size=text_size, |
|
font=font, |
|
font_mono=font_mono, |
|
) |
|
|
|
|
|
seafoam = Seafoam() |
|
|
|
|
|
js = """ |
|
function createGradioAnimation() { |
|
var container = document.getElementById('gradio-animation'); |
|
var text = 'Voice of Jungle'; |
|
for (var i = 0; i < text.length; i++) { |
|
(function(i){ |
|
setTimeout(function(){ |
|
var letter = document.createElement('span'); |
|
letter.style.opacity = '0'; |
|
letter.style.transition = 'opacity 0.5s'; |
|
letter.innerText = text[i]; |
|
container.appendChild(letter); |
|
setTimeout(function() { |
|
letter.style.opacity = '1'; |
|
}, 50); |
|
}, i * 250); |
|
})(i); |
|
} |
|
} |
|
""" |
|
|
|
REFERENCES = """ |
|
# Appendix |
|
|
|
We have applied the AudioMAE model to pre-classify the 23000+ unlabelled audio clips collected from the Greater Manaus region in the Amazon rainforest. The results of the audio type classification can be found in the following [link](https://drive.google.com/file/d/1uOT88LDnBD-Z3YcFz1e9XjvW2ugCo6EI/view?usp=drive_link). We hope that the pre-classification results can help researchers better exploring the vast collection of audio recordings and facilitate the study of biodiversity in the Amazon rainforest. |
|
|
|
# References |
|
|
|
[1] Torkington, S. (2023, February 7). 50% of the global economy is under threat from biodiversity loss. World Economic Forum. Retrieved from https://www.weforum.org/agenda/2023/02/biodiversity-nature-loss-cop15/. |
|
|
|
[2] Huang, P.-Y., Xu, H., Li, J., Baevski, A., Auli, M., Galuba, W., Metze, F., & Feichtenhofer, C. (2022). Masked Autoencoders that Listen. In NeurIPS. |
|
|
|
[3] https://www.kaggle.com/code/dima806/bird-species-by-sound-detection |
|
|
|
# Acknowledgements |
|
|
|
We would like to thank all organizers, mentors and participants of the AI+Environment EcoHackathon 2024 event for their unwavering support and collaboration. We extend our gratitude to ETH BiodivX, GainForest and ETH AI Center for providing data, facilities and resources that enabled us to analyse the rich data in different ways. Our special thanks to David Dao, Sarah Tariq, Alessandro Amodio for always being there to help us! πππ |
|
""" |
|
|
|
|
|
def handle_model_selection(model_name, download_status): |
|
|
|
|
|
print(f"Downloading model weights for {model_name}...") |
|
|
|
if model_name is None: |
|
model_name = "BirdAST" |
|
|
|
assets = ASSET_DICT[model_name] |
|
model_weights_url = assets["model_weights"] |
|
download_flag = True |
|
try: |
|
total_files = len(model_weights_url) |
|
for idx, model_weight in enumerate(model_weights_url): |
|
weight_file = os.path.join(ASSET_DIR, model_weight.split("/")[-1]) |
|
print(weight_file) |
|
if not os.path.exists(weight_file): |
|
download_status = f"Downloading {idx + 1} of {total_files}" |
|
download_model(model_weight, weight_file) |
|
|
|
if not os.path.exists(weight_file): |
|
download_flag = False |
|
break |
|
|
|
if download_flag: |
|
download_status = f"Model <{model_name}> is ready! πππ\nUsing Device: {DEVICE.upper()}" |
|
else: |
|
download_status = f"An error occurred while downloading model weights." |
|
|
|
except Exception as e: |
|
download_status = f"An error occurred while downloading model weights." |
|
|
|
return download_status |
|
|
|
|
|
with gr.Blocks(theme = seafoam, css = css, js = js) as demo: |
|
|
|
gr.Markdown('<div class="logo-container"><img src="https://i.ibb.co/vcG9kr0/vojlogo.jpg" width="50px" alt="vojlogo"></div>') |
|
gr.Markdown('<div id="gradio-animation"></div>') |
|
gr.Markdown(DESCRIPTION) |
|
|
|
|
|
model_names = ['BirdAST', 'BirdAST_Seq'] |
|
model_dropdown = gr.Dropdown(label="Choose a model", choices=model_names) |
|
download_status = gr.Textbox(label="Model Status", lines=3, value='', interactive=False) |
|
|
|
model_dropdown.change(handle_model_selection, inputs=[model_dropdown, download_status], outputs=download_status) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(elem_classes="column-container"): |
|
start_time_input = gr.Number(label="Start Time", value=0, elem_classes="number-input full-height") |
|
end_time_input = gr.Number(label="End Time", value=10, elem_classes="number-input full-height") |
|
with gr.Column(): |
|
audio_input = gr.Audio(label="Input Audio", elem_classes="full-height") |
|
|
|
with gr.Row(): |
|
raw_class_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Class Prediction") |
|
species_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Species Prediction") |
|
|
|
with gr.Row(): |
|
waveform_output = gr.Plot(label="Waveform") |
|
spectrogram_output = gr.Plot(label="Spectrogram") |
|
|
|
gr.Examples( |
|
examples=[ |
|
["XC226833-Chestnut-belted_20Chat-Tyrant_20A_2010989.mp3", 0, 10], |
|
["XC812290-Many-striped-Canastero_Teaben_Pe_1jul2022_FSchmitt_1.mp3", 0, 10], |
|
["XC763511-Synallaxis-maronica_Bagua-grande_MixPre-1746.mp3", 0, 10] |
|
], |
|
inputs=[audio_input, start_time_input, end_time_input] |
|
) |
|
|
|
gr.Button("Predict").click(predict, [audio_input, start_time_input, end_time_input, model_dropdown], [raw_class_output, species_output, waveform_output, spectrogram_output]) |
|
|
|
gr.Markdown(REFERENCES) |
|
|
|
demo.launch(share = True) |
|
|
|
|
|
|