Spaces:
Sleeping
Sleeping
AshDavid12
commited on
Commit
·
460f073
1
Parent(s):
cc6b80d
prev version worked but not accurate-changed mask
Browse files
infer.py
CHANGED
@@ -25,12 +25,18 @@ audio_input, _ = sf.read(audio_data)
|
|
25 |
|
26 |
# Preprocess the audio for Whisper
|
27 |
inputs = processor(audio_input, return_tensors="pt", sampling_rate=16000)
|
|
|
|
|
|
|
28 |
inputs = {key: value.to(device) for key, value in inputs.items()}
|
|
|
29 |
|
30 |
-
# Generate the transcription
|
31 |
with torch.no_grad():
|
32 |
-
predicted_ids = model.generate(
|
33 |
-
|
|
|
|
|
34 |
# Decode the transcription
|
35 |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
36 |
|
|
|
25 |
|
26 |
# Preprocess the audio for Whisper
|
27 |
inputs = processor(audio_input, return_tensors="pt", sampling_rate=16000)
|
28 |
+
attention_mask = inputs['input_features'].ne(processor.tokenizer.pad_token_id).long()
|
29 |
+
|
30 |
+
# Move inputs and attention mask to the correct device
|
31 |
inputs = {key: value.to(device) for key, value in inputs.items()}
|
32 |
+
attention_mask = attention_mask.to(device)
|
33 |
|
34 |
+
# Generate the transcription with attention mask
|
35 |
with torch.no_grad():
|
36 |
+
predicted_ids = model.generate(
|
37 |
+
inputs["input_features"],
|
38 |
+
attention_mask=attention_mask # Pass attention mask explicitly
|
39 |
+
)
|
40 |
# Decode the transcription
|
41 |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
42 |
|