AshDavid12 commited on
Commit
460f073
·
1 Parent(s): cc6b80d

prev version worked but not accurate-changed mask

Browse files
Files changed (1) hide show
  1. infer.py +9 -3
infer.py CHANGED
@@ -25,12 +25,18 @@ audio_input, _ = sf.read(audio_data)
25
 
26
  # Preprocess the audio for Whisper
27
  inputs = processor(audio_input, return_tensors="pt", sampling_rate=16000)
 
 
 
28
  inputs = {key: value.to(device) for key, value in inputs.items()}
 
29
 
30
- # Generate the transcription
31
  with torch.no_grad():
32
- predicted_ids = model.generate(inputs["input_features"])
33
-
 
 
34
  # Decode the transcription
35
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
36
 
 
25
 
26
  # Preprocess the audio for Whisper
27
  inputs = processor(audio_input, return_tensors="pt", sampling_rate=16000)
28
+ attention_mask = inputs['input_features'].ne(processor.tokenizer.pad_token_id).long()
29
+
30
+ # Move inputs and attention mask to the correct device
31
  inputs = {key: value.to(device) for key, value in inputs.items()}
32
+ attention_mask = attention_mask.to(device)
33
 
34
+ # Generate the transcription with attention mask
35
  with torch.no_grad():
36
+ predicted_ids = model.generate(
37
+ inputs["input_features"],
38
+ attention_mask=attention_mask # Pass attention mask explicitly
39
+ )
40
  # Decode the transcription
41
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
42