WajeehAzeemX commited on
Commit
e942829
·
verified ·
1 Parent(s): 273575f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -20
app.py CHANGED
@@ -3,49 +3,56 @@ from transformers import pipeline
3
  import io
4
  import librosa
5
  from transformers import WhisperForConditionalGeneration, WhisperProcessor
6
- from faster_whisper import WhisperModel
7
- import multiprocessing
8
 
9
  app = FastAPI()
10
  # Device configuration
11
  # Load the model and processor
12
-
13
- import torch
14
-
15
  model_id = "WajeehAzeemX/whisper-smal-ar-testing-kale-5000"
16
  model = WhisperForConditionalGeneration.from_pretrained(
17
  model_id
18
  )
19
- processor = WhisperProcessor.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
20
  pipe = pipeline(
21
  "automatic-speech-recognition",
22
  model=model,
23
  tokenizer=processor.tokenizer,
24
- feature_extractor=processor.feature_extractor
 
25
  )
26
 
27
-
28
  @app.post("/transcribe/")
29
  async def transcribe_audio(request: Request):
30
  try:
31
  # Read binary data from the request
32
  audio_data = await request.body()
 
33
  # Convert binary data to a file-like object
34
  audio_file = io.BytesIO(audio_data)
35
- # # Load the audio file using pydub
 
36
  audio_array, sampling_rate = librosa.load(audio_file, sr=16000)
37
- # # Process the audio array
38
- # input_features = processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt").input_features
39
- # # Generate token ids
40
- # predicted_ids = model.generate(input_features)
41
- # # Decode token ids to text
42
- # transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
43
- transcription = pipe(audio_array)
 
 
 
44
  # Print the transcription
45
- print(transcription)
46
  print(transcription[0]) # Display the transcriptiontry:
 
47
  return {"transcription": transcription[0]}
48
  except Exception as e:
49
- raise HTTPException(status_code=500, detail=str(e))
50
-
51
-
 
3
  import io
4
  import librosa
5
  from transformers import WhisperForConditionalGeneration, WhisperProcessor
 
 
6
 
7
  app = FastAPI()
8
  # Device configuration
9
  # Load the model and processor
 
 
 
10
  model_id = "WajeehAzeemX/whisper-smal-ar-testing-kale-5000"
11
  model = WhisperForConditionalGeneration.from_pretrained(
12
  model_id
13
  )
14
+ import torch
15
+
16
+ processor = WhisperProcessor.from_pretrained('WajeehAzeemX/whisper-smal-ar-testing-kale-5000')
17
+ model.config.forced_decoder_ids = None
18
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language="Arabic", task="transcribe")
19
+ model.generation_config.cache_implementation = "static"
20
+ from transformers import GenerationConfig, WhisperForConditionalGeneration
21
+ generation_config = GenerationConfig.from_pretrained("openai/whisper-small") # if you are using a multilingual model
22
+ model.generation_config = generation_config
23
+
24
  pipe = pipeline(
25
  "automatic-speech-recognition",
26
  model=model,
27
  tokenizer=processor.tokenizer,
28
+ feature_extractor=processor.feature_extractor,
29
+
30
  )
31
 
 
32
  @app.post("/transcribe/")
33
  async def transcribe_audio(request: Request):
34
  try:
35
  # Read binary data from the request
36
  audio_data = await request.body()
37
+
38
  # Convert binary data to a file-like object
39
  audio_file = io.BytesIO(audio_data)
40
+
41
+ # Load the audio file using pydub
42
  audio_array, sampling_rate = librosa.load(audio_file, sr=16000)
43
+
44
+ # Process the audio array
45
+ input_features = processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt").input_features
46
+
47
+ # Generate token ids
48
+ predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids, return_timestamps=True)
49
+
50
+ # Decode token ids to text
51
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
52
+
53
  # Print the transcription
 
54
  print(transcription[0]) # Display the transcriptiontry:
55
+
56
  return {"transcription": transcription[0]}
57
  except Exception as e:
58
+ raise HTTPException(status_code=500, detail=str(e))