ksang commited on
Commit
ad16c3d
·
1 Parent(s): 5ffb6b7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import gradio as gr
3
+ import torchaudio
4
+ from transformers import AutoModelForAudioClassification, AutoFeatureExtractor
5
+ import librosa
6
+ import torch
7
+
8
+ # %%
9
+ def dump_pickle(file_path: str, file, mode: str = "wb"):
10
+ import pickle
11
+
12
+ with open(file_path, mode=mode) as f:
13
+ pickle.dump(file, f)
14
+
15
+
16
+ def load_pickle(file_path: str, mode: str = "rb", encoding=""):
17
+ import pickle
18
+
19
+ with open(file_path, mode=mode) as f:
20
+ return pickle.load(f, encoding=encoding)
21
+
22
+ # %%
23
+ label2id = load_pickle('/data/audio-classification-pytorch/wav2vec2/results/best/label2id.pkl')
24
+ id2label = load_pickle('/data/audio-classification-pytorch/wav2vec2/results/best/id2label.pkl')
25
+
26
+ # %%
27
+ model = AutoModelForAudioClassification.from_pretrained(
28
+ "facebook/wav2vec2-base", num_labels=len(label2id), label2id=label2id, id2label=id2label
29
+ )
30
+
31
+ # %%
32
+ feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
33
+
34
+ # %%
35
+ checkpoint = torch.load('/data/audio-classification-pytorch/wav2vec2/results/best/pytorch_model.bin')
36
+
37
+ # %%
38
+ model.load_state_dict(checkpoint)
39
+
40
+ # %%
41
+ def predict(input):
42
+ waveform, sr = librosa.load(input)
43
+ waveform = torch.from_numpy(waveform).unsqueeze(0)
44
+ waveform = torchaudio.transforms.Resample(sr, 16_000)(waveform)
45
+ inputs = feature_extractor(waveform, sampling_rate=feature_extractor.sampling_rate,
46
+ max_length=16000, truncation=True)
47
+ tensor = torch.tensor(inputs['input_values'][0])
48
+ with torch.no_grad():
49
+ output = model(tensor)
50
+ logits = output['logits'][0]
51
+ label_id = torch.argmax(logits).item()
52
+ label_name = id2label[str(label_id)]
53
+
54
+ return label_name
55
+
56
+ # %%
57
+ demo = gr.Interface(
58
+ fn=predict,
59
+ inputs=gr.Audio(source="microphone", type="filepath", label="Speak to classify your voice!"), # record audio, save in temp file to feed to inference func
60
+ outputs="text"
61
+ )
62
+
63
+ # %%
64
+ demo.launch()
65
+
66
+ # %%
67
+
68
+
69
+