PyTorch
ssl-aasist
custom_code
ssl-aasist / README.md
ash56's picture
Update README.md
539ffe8 verified
---
license: apache-2.0
---
## This repository contains the model checkpoints related to the paper: *[Less is More for Synthetic Speech Detection in the Wild](https://arxiv.org/abs/2502.05674)*
Dataset can be downloaded from [here](https://huggingface.co/datasets/ash56/ShiftySpeech/tree/main)
Model Architecture : [SSL-AASIST](https://arxiv.org/pdf/2202.12233)
**Note: Model is trained on audio samples generated using HiFiGAN vocoder with source dataset as LJSpeech. Both real and spoof samples are derived from [WaveFake](https://arxiv.org/abs/2111.02813) dataset**
## ⚙️ Usage
#### Install libraries
```bash
conda create -n ssl-aasist python=3.10.14
conda activate ssl-aasist
pip install pip==23
pip install omegaconf==2.0.6 pyarrow==19.0
```
*Note: pip version < 24.1 *
```bash
pip install torch datasets transformers librosa numpy scikit-learn huggingface_hub
```
```bash
pip install git+https://github.com/facebookresearch/fairseq.git@920a548ca770fb1a951f7f4289b4d3a0c1bc226f
```
#### Load Model and Dataset
```bash
from transformers import AutoConfig, AutoModel
import torch
import librosa
from datasets import load_dataset
import numpy as np
from torch import Tensor
from sklearn.metrics import roc_auc_score
config = AutoConfig.from_pretrained("ash56/ssl-aasist", trust_remote_code=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AutoModel.from_pretrained("ash56/ssl-aasist", config=config,trust_remote_code=True, force_download=True).to(device)
#Load ShiftySpeech dataset
spoof_data= load_dataset("ash56/ShiftySpeech", data_files={"data": "Vocoders/apnet2/apnet2_aishell_flac.tar.gz"})["data"]
real_data = load_dataset("ash56/ShiftySpeech", data_files={"data": "real_data_flac/real_data_aishell_flac.tar.gz"})["data"]
model.eval()
```
#### Inference
For batch inference:
```bash
def pad(x, max_len=64600):
x_len = x.shape[0]
if x_len >= max_len:
return x[:max_len]
# need to pad
num_repeats = int(max_len / x_len)+1
padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]
return padded_x
```
Get CM scores:
```bash
output_file = "apnet2-aishell_scores.txt"
#inference on spoof data
with open(output_file, "a") as f:
# get scores of spoof audios
for sample in spoof_data:
fname = sample["__key__"]
audio = sample["flac"]["array"]
sampling_rate = sample["flac"]["sampling_rate"]
if sampling_rate != 16000:
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
audio_padded = pad(audio,64600)
x_inp = Tensor(audio_padded).unsqueeze(0).to(device)
with torch.no_grad():
batch_out = model(x_inp)
batch_score = batch_out[:, 1].cpu().numpy().ravel()[0]
f.write(f"{fname} spoof {batch_score}\n")
#get scores of real audios
for sample in real_data:
print(real_data)
fname = sample["__key__"]
audio = sample["flac"]["array"]
sampling_rate = sample["flac"]["sampling_rate"]
if sampling_rate != 16000:
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
audio_padded = pad(audio,64600)
x_inp = Tensor(audio_padded).unsqueeze(0).to(device)
with torch.no_grad():
batch_out = model(x_inp)
batch_score = batch_out[:, 1].cpu().numpy().ravel()[0]
f.write(f"{fname} bonafide {batch_score}\n")
print(f"Scores saved in {output_file}")
```
#### Compute EER
```bash
# helper functions to calculate EER
def compute_eer(target_scores, nontarget_scores):
""" Returns equal error rate (EER) and the corresponding threshold. """
frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores)
abs_diffs = np.abs(frr - far)
min_index = np.argmin(abs_diffs)
eer = np.mean((frr[min_index], far[min_index]))
return eer, thresholds[min_index], frr, far
def compute_det_curve(target_scores, nontarget_scores):
n_scores = target_scores.size + nontarget_scores.size
all_scores = np.concatenate((target_scores, nontarget_scores))
labels = np.concatenate(
(np.ones(target_scores.size), np.zeros(nontarget_scores.size)))
# Sort labels based on scores
indices = np.argsort(all_scores, kind='mergesort')
labels = labels[indices]
# Compute false rejection and false acceptance rates
tar_trial_sums = np.cumsum(labels)
nontarget_trial_sums = nontarget_scores.size - \
(np.arange(1, n_scores + 1) - tar_trial_sums)
# false rejection rates
frr = np.concatenate(
(np.atleast_1d(0), tar_trial_sums / target_scores.size))
far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums /
nontarget_scores.size)) # false acceptance rates
# Thresholds are the sorted scores
thresholds = np.concatenate(
(np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices]))
return frr, far, thresholds
# get EER
def calculate_EER(cm_scores_file,
output_file,
printout=True):
# Load CM scores
cm_data = np.genfromtxt(cm_scores_file, dtype=str)
cm_utt_id = cm_data[:, 0]
cm_keys = cm_data[:, 1]
cm_scores = cm_data[:, 2].astype(float)
# Extract bona fide (real human) and spoof scores from the CM scores
bona_cm = cm_scores[cm_keys == 'bonafide']
spoof_cm = cm_scores[cm_keys == 'spoof']
all_scores = np.concatenate([bona_cm, spoof_cm])
all_true_labels = np.concatenate([np.ones_like(bona_cm), np.zeros_like(spoof_cm)])
auc = roc_auc_score(all_true_labels, all_scores, max_fpr=0.05)
eer_cm, eer_threshold, frr, far = compute_eer(bona_cm, spoof_cm)
if printout:
with open(output_file, "w") as f_res:
f_res.write('\nCM SYSTEM\n')
f_res.write('\tEER\t\t= {:8.9f} % '
'(Equal error rate for countermeasure)\n'.format(
eer_cm * 100))
eval_eer = calculate_EER(
cm_scores_file=output_file,output_file="apnet2_aishell_eer.txt")
```
If you find the dataset or this resource helpful for your research, please cite our work:
```bibtex
@misc{garg2025syntheticspeechdetectionwild,
title={Less is More for Synthetic Speech Detection in the Wild},
author={Ashi Garg and Zexin Cai and Henry Li Xinyuan and Leibny Paola García-Perera and Kevin Duh and Sanjeev Khudanpur and Matthew Wiesner and Nicholas Andrews},
year={2025},
eprint={2502.05674},
archivePrefix={arXiv},
primaryClass={eess.AS},
url={https://arxiv.org/abs/2502.05674},
}
```