WajeehAzeemX commited on
Commit
8f81f58
·
verified ·
1 Parent(s): 7f3077c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -14,6 +14,7 @@ model = WhisperForConditionalGeneration.from_pretrained(
14
  import torch
15
 
16
  processor = WhisperProcessor.from_pretrained('WajeehAzeemX/whisper-smal-ar-testing-kale-5000')
 
17
 
18
  from transformers import GenerationConfig, WhisperForConditionalGeneration
19
  generation_config = GenerationConfig.from_pretrained("openai/whisper-small") # if you are using a multilingual model
@@ -43,7 +44,7 @@ async def transcribe_audio(request: Request):
43
  input_features = processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt").input_features
44
 
45
  # Generate token ids
46
- predicted_ids = model.generate(input_features, return_timestamps=True)
47
 
48
  # Decode token ids to text
49
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
 
14
  import torch
15
 
16
  processor = WhisperProcessor.from_pretrained('WajeehAzeemX/whisper-smal-ar-testing-kale-5000')
17
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language="arabic", task="transcribe")
18
 
19
  from transformers import GenerationConfig, WhisperForConditionalGeneration
20
  generation_config = GenerationConfig.from_pretrained("openai/whisper-small") # if you are using a multilingual model
 
44
  input_features = processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt").input_features
45
 
46
  # Generate token ids
47
+ predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
48
 
49
  # Decode token ids to text
50
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)