File size: 4,425 Bytes
7a9eea4
 
09b7c54
8acf810
7a9eea4
 
 
 
159733c
 
 
 
 
 
 
 
 
 
 
 
 
 
7a9eea4
159733c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8acf810
 
 
 
 
 
 
159733c
 
 
 
8acf810
 
159733c
 
 
 
 
 
 
 
 
 
7a9eea4
159733c
 
 
 
 
 
 
 
8acf810
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import torch
import torchaudio
from torchaudio.transforms import Resample
from transformers import AutoFeatureExtractor, AutoModelForAudioXVector

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

STYLE = """
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha256-YvdLHPgkqJ8DVUxjjnGVlMMJtNimJ6dYkowFFvp4kKs=" crossorigin="anonymous">
"""
OUTPUT_OK = (
    STYLE
    + """
    <div class="container">
        <div class="row"><h1 style="text-align: center">The speakers are</h1></div>
        <div class="row"><h1 class="display-1 text-success" style="text-align: center">{:.1f}%</h1></div>
        <div class="row"><h1 style="text-align: center">similar</h1></div>
        <div class="row"><h1 class="text-success" style="text-align: center">Welcome, human!</h1></div>
        <div class="row"><small style="text-align: center">(You must get at least 80% to be considered the same person)</small><div class="row">
    </div>
"""
)
OUTPUT_FAIL = (
    STYLE
    + """
    <div class="container">
        <div class="row"><h1 style="text-align: center">The speakers are</h1></div>
        <div class="row"><h1 class="display-1 text-danger" style="text-align: center">{:.1f}%</h1></div>
        <div class="row"><h1 style="text-align: center">similar</h1></div>
        <div class="row"><h1 class="text-danger" style="text-align: center">You shall not pass!</h1></div>
        <div class="row"><small style="text-align: center">(You must get at least 80% to be considered the same person)</small><div class="row">
    </div>
"""
)

THRESHOLD = 0.80

model_name = "microsoft/wavlm-base-plus-sv"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForAudioXVector.from_pretrained(model_name).to(device)
cosine_sim = torch.nn.CosineSimilarity(dim=-1)


def preprocess_audio(file_path, target_sr=16000):
    wav, sr = torchaudio.load(file_path)
    if sr != target_sr:
        wav = Resample(orig_freq=sr, new_freq=target_sr)(wav)
    return wav


def similarity_fn(path1, path2):
    if not (path1 and path2):
        return '<b style="color:red">ERROR: Please record audio for *both* speakers!</b>'

    wav1 = preprocess_audio(path1)
    wav2 = preprocess_audio(path2)

    input1 = feature_extractor(wav1.squeeze(0), return_tensors="pt", sampling_rate=16000).input_values.to(device)
    input2 = feature_extractor(wav2.squeeze(0), return_tensors="pt", sampling_rate=16000).input_values.to(device)

    with torch.no_grad():
        emb1 = model(input1).embeddings
        emb2 = model(input2).embeddings
    emb1 = torch.nn.functional.normalize(emb1, dim=-1).cpu()
    emb2 = torch.nn.functional.normalize(emb2, dim=-1).cpu()
    similarity = cosine_sim(emb1, emb2).numpy()[0]

    if similarity >= THRESHOLD:
        output = OUTPUT_OK.format(similarity * 100)
    else:
        output = OUTPUT_FAIL.format(similarity * 100)

    return output


with gr.Blocks() as demo:
    gr.Markdown("# Voice Authentication with WavLM + X-Vectors")
    gr.Markdown(
        "This demo compares two speech samples to determine if they are from the same speaker. "
        "Try it with your own voice!"
    )

    with gr.Row():
        input1 = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Speaker #1")
        input2 = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Speaker #2")

    output = gr.HTML(label="Result")

    btn = gr.Button("Compare Speakers")
    btn.click(similarity_fn, inputs=[input1, input2], outputs=output)

    gr.Examples(
        examples=[
            ["samples/denzel_washington.mp3", "samples/denzel_washington.mp3"],
            ["samples/heath_ledger_2.mp3", "samples/heath_ledger_3.mp3"],
            ["samples/heath_ledger_3.mp3", "samples/denzel_washington.mp3"],
            ["samples/denzel_washington.mp3", "samples/heath_ledger_2.mp3"],
        ],
        inputs=[input1, input2],
    )

    gr.Markdown(
        "<p style='text-align: center'>"
        "<a href='https://huggingface.co/microsoft/wavlm-base-plus-sv' target='_blank'>πŸŽ™οΈ Learn more about WavLM</a> | "
        "<a href='https://arxiv.org/abs/2110.13900' target='_blank'>πŸ“š WavLM paper</a> | "
        "<a href='https://www.danielpovey.com/files/2018_icassp_xvectors.pdf' target='_blank'>πŸ“š X-Vector paper</a>"
        "</p>"
    )

demo.launch()