Salman11223 commited on
Commit
7d66980
·
1 Parent(s): f57ff51

Create gender_prediction.py

Browse files
Files changed (1) hide show
  1. gender_prediction.py +126 -0
gender_prediction.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tqdm
3
+ import torch
4
+ import torchaudio
5
+ import numpy as np
6
+ from torch.utils.data import DataLoader
7
+ from transformers import AutoFeatureExtractor, AutoModelForAudioClassification, Wav2Vec2Processor
8
+ from torch.nn import functional as F
9
+
10
+ class CustomDataset(torch.utils.data.Dataset):
11
+ def __init__(self, dataset, basedir=None, sampling_rate=16000, max_audio_len=5):
12
+ self.dataset = dataset
13
+ self.basedir = basedir
14
+ self.sampling_rate = sampling_rate
15
+ self.max_audio_len = max_audio_len
16
+
17
+ def __len__(self):
18
+ return len(self.dataset)
19
+
20
+ def _cutorpad(self, audio):
21
+ effective_length = self.sampling_rate * self.max_audio_len
22
+ len_audio = len(audio)
23
+
24
+ if len_audio > effective_length:
25
+ audio = audio[:effective_length]
26
+
27
+ return audio
28
+
29
+ def __getitem__(self, index):
30
+ if self.basedir is None:
31
+ filepath = self.dataset[index]
32
+ else:
33
+ filepath = os.path.join(self.basedir, self.dataset[index])
34
+
35
+ speech_array, sr = torchaudio.load(filepath)
36
+
37
+ if speech_array.shape[0] > 1:
38
+ speech_array = torch.mean(speech_array, dim=0, keepdim=True)
39
+
40
+ if sr != self.sampling_rate:
41
+ transform = torchaudio.transforms.Resample(sr, self.sampling_rate)
42
+ speech_array = transform(speech_array)
43
+ sr = self.sampling_rate
44
+
45
+ speech_array = speech_array.squeeze().numpy()
46
+ speech_array = self._cutorpad(speech_array)
47
+
48
+ return {"input_values": speech_array, "attention_mask": None}
49
+
50
+ class CollateFunc:
51
+ def __init__(self, processor, max_length=None, padding=True, pad_to_multiple_of=None, sampling_rate=16000):
52
+ self.padding = padding
53
+ self.processor = processor
54
+ self.max_length = max_length
55
+ self.sampling_rate = sampling_rate
56
+ self.pad_to_multiple_of = pad_to_multiple_of
57
+
58
+ def __call__(self, batch):
59
+ input_features = []
60
+
61
+ for audio in batch:
62
+ input_tensor = self.processor(audio["input_values"], sampling_rate=self.sampling_rate).input_values
63
+ input_tensor = np.squeeze(input_tensor)
64
+ input_features.append({"input_values": input_tensor})
65
+
66
+ batch = self.processor.pad(
67
+ input_features,
68
+ padding=self.padding,
69
+ max_length=self.max_length,
70
+ pad_to_multiple_of=self.pad_to_multiple_of,
71
+ return_tensors="pt",
72
+ )
73
+
74
+ return batch
75
+
76
+ def predict(test_dataloader, model, device):
77
+ model.to(device)
78
+ model.eval()
79
+ preds = []
80
+
81
+ with torch.no_grad():
82
+ for batch in tqdm.tqdm(test_dataloader):
83
+ input_values = batch['input_values'].to(device)
84
+
85
+ logits = model(input_values).logits
86
+ scores = F.softmax(logits, dim=-1)
87
+
88
+ pred = torch.argmax(scores, dim=1).cpu().detach().numpy()
89
+ preds.extend(pred)
90
+
91
+ return preds
92
+
93
+ def get_gender(model_name_or_path, audio_paths, device):
94
+ num_labels = 2
95
+
96
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path)
97
+ model = AutoModelForAudioClassification.from_pretrained(
98
+ pretrained_model_name_or_path=model_name_or_path,
99
+ num_labels=num_labels,
100
+ )
101
+
102
+ test_dataset = CustomDataset(audio_paths)
103
+ data_collator = CollateFunc(
104
+ processor=feature_extractor,
105
+ padding=True,
106
+ sampling_rate=16000,
107
+ )
108
+
109
+ test_dataloader = DataLoader(
110
+ dataset=test_dataset,
111
+ batch_size=16,
112
+ collate_fn=data_collator,
113
+ shuffle=False,
114
+ num_workers=10
115
+ )
116
+
117
+ preds = predict(test_dataloader=test_dataloader, model=model, device=device)
118
+
119
+ # Map class indices to labels
120
+ label_mapping = {0: "female", 1: "male"}
121
+
122
+ # Determine the most common predicted label
123
+ most_common_label = max(set(preds), key=preds.count)
124
+ predicted_label = label_mapping[most_common_label]
125
+
126
+ return predicted_label