Spaces:
Runtime error
Runtime error
BEATs
BEATs: Audio Pre-Training with Acoustic Tokenizers
Official PyTorch implementation and pretrained models of BEATs
Pre-Trained and Fine-Tuned Tokenizers and Models
Iterations | Tokenizer | Pre-Trained Model | AudioSet Fine-Tuned Model 1 | AudioSet Fine-Tuned Model 2 |
---|---|---|---|---|
Iter1 | Random Projection | BEATs_iter1 | Fine-tuned BEATs_iter1 (cpt1) | Fine-tuned BEATs_iter1 (cpt2) |
Iter2 | Tokenizer_iter2 | BEATs_iter2 | Fine-tuned BEATs_iter2 (cpt1) | Fine-tuned BEATs_iter2 (cpt2) |
Iter3 | Tokenizer_iter3 | BEATs_iter3 | Fine-tuned BEATs_iter3 (cpt1) | Fine-tuned BEATs_iter3 (cpt2) |
Iter3+ | Tokenizer_iter3+ (AS20K) | BEATs_iter3+ (AS20K) | Fine-tuned BEATs_iter3+ (AS20K) (cpt1) | Fine-tuned BEATs_iter3+ (AS20K) (cpt2) |
Iter3+ | Tokenizer_iter3+ (AS2M) | BEATs_iter3+ (AS2M) | Fine-tuned BEATs_iter3+ (AS2M) (cpt1) | Fine-tuned BEATs_iter3+ (AS2M) (cpt2) |
Load Tokenizers
import torch
from Tokenizers import TokenizersConfig, Tokenizers
# load the pre-trained checkpoints
checkpoint = torch.load('/path/to/tokenizer.pt')
cfg = TokenizersConfig(checkpoint['cfg'])
BEATs_tokenizer = Tokenizers(cfg)
BEATs_tokenizer.load_state_dict(checkpoint['model'])
BEATs_tokenizer.eval()
# tokenize the audio and generate the labels
audio_input_16khz = torch.randn(1, 10000)
padding_mask = torch.zeros(1, 10000).bool()
labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=padding_mask)
Load Pre-Trained Models
import torch
from BEATs import BEATs, BEATsConfig
# load the pre-trained checkpoints
checkpoint = torch.load('/path/to/model.pt')
cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()
# extract the the audio representation
audio_input_16khz = torch.randn(1, 10000)
padding_mask = torch.zeros(1, 10000).bool()
representation = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]
Load Fine-tuned Models
import torch
from BEATs import BEATs, BEATsConfig
# load the fine-tuned checkpoints
checkpoint = torch.load('/path/to/model.pt')
cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()
# predict the classification probability of each class
audio_input_16khz = torch.randn(3, 10000)
padding_mask = torch.zeros(3, 10000).bool()
probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]
for i, (top5_label_prob, top5_label_idx) in enumerate(zip(*probs.topk(k=5))):
top5_label = [checkpoint['label_dict'][label_idx.item()] for label_idx in top5_label_idx]
print(f'Top 5 predicted labels of the {i}th audio are {top5_label} with probability of {top5_label_prob}')
Evaluation Results
Comparing with the SOTA Single Models
Comparing with the SOTA Ensemble Models
Comparing Different BEATS Tokenizers
Comparing Different Pre-Training Targets
License
This project is licensed under the license found in the LICENSE file in the root directory of this source tree. Portions of the source code are based on the FAIRSEQ and VQGAN project.
Microsoft Open Source Code of Conduct
Reference
If you find our work is useful in your research, please cite the following paper:
@article{Chen2022beats,
title = {BEATs: Audio Pre-Training with Acoustic Tokenizers},
author = {Sanyuan Chen and Yu Wu and Chengyi Wang and Shujie Liu and Daniel Tompkins and Zhuo Chen and Furu Wei},
eprint={2212.09058},
archivePrefix={arXiv},
year={2022}
}
Contact Information
For help or issues using BEATs models, please submit a GitHub issue.
For other communications related to BEATs, please contact Yu Wu ([email protected]
).