WajeehAzeemX commited on
Commit
eae5b83
·
1 Parent(s): 327ec66

tiny model

Browse files
Files changed (2) hide show
  1. app.py +19 -41
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,26 +1,21 @@
1
 
2
  from fastapi import FastAPI, Request, HTTPException
3
- import torch
4
- import torchaudio
5
- from transformers import AutoProcessor, pipeline
6
  import io
7
- from pydub import AudioSegment
8
- # from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
9
- import numpy as np
10
- import uvicorn
11
- import time
12
- app = FastAPI()
13
  from transformers import WhisperForConditionalGeneration, WhisperProcessor
14
 
 
15
  # Device configuration
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
17
- print(device)
18
  # Load the model and processor
19
- model_id = "whitefox123/whisper-small-ar2"
20
  model = WhisperForConditionalGeneration.from_pretrained(
21
  model_id
22
  )
23
- processor = WhisperProcessor.from_pretrained(model_id)
 
 
 
24
 
25
 
26
  pipe = pipeline(
@@ -40,38 +35,21 @@ async def transcribe_audio(request: Request):
40
  audio_file = io.BytesIO(audio_data)
41
 
42
  # Load the audio file using pydub
43
- try:
44
- audio_segment = AudioSegment.from_file(audio_file, format="wav")
45
- except Exception as e:
46
- raise HTTPException(status_code=400, detail=f"Error loading audio file: {str(e)}")
47
-
48
- # Convert to mono if the audio is stereo (multi-channel)
49
- if audio_segment.channels > 1:
50
- audio_segment = audio_segment.set_channels(1)
51
 
52
- # Resample the audio to 16kHz
53
- target_sample_rate = 16000
54
- if audio_segment.frame_rate != target_sample_rate:
55
- audio_segment = audio_segment.set_frame_rate(target_sample_rate)
56
 
57
- # Convert audio to numpy array
58
- audio_array = np.array(audio_segment.get_array_of_samples())
59
- if audio_segment.sample_width == 2:
60
- audio_array = audio_array.astype(np.float32) / 32768.0
61
- else:
62
- raise HTTPException(status_code=400, detail="Unsupported sample width")
63
- start_time = time.time()
64
- # Convert to the format expected by the model
65
- inputs = processor(audio_array, sampling_rate=target_sample_rate, return_tensors="pt")
66
- inputs = inputs.to(device)
67
 
68
- # Get the transcription result
69
- result = pipe(audio_array)
70
- # Calculate time taken
71
- time_taken = time.time() - start_time
72
- transcription = result["text"]
73
 
74
- return {"transcription": transcription, "time_taken": time_taken}
 
 
 
75
  except Exception as e:
76
  raise HTTPException(status_code=500, detail=str(e))
77
 
 
1
 
2
  from fastapi import FastAPI, Request, HTTPException
3
+ from transformers import pipeline
 
 
4
  import io
5
+ import librosa
 
 
 
 
 
6
  from transformers import WhisperForConditionalGeneration, WhisperProcessor
7
 
8
+ app = FastAPI()
9
  # Device configuration
 
 
10
  # Load the model and processor
11
+ model_id = "WajeehAzeemX/whisper-tiny-ar-tashkeel"
12
  model = WhisperForConditionalGeneration.from_pretrained(
13
  model_id
14
  )
15
+
16
+ processor = WhisperProcessor.from_pretrained('openai/whisper-tiny')
17
+ model.config.forced_decoder_ids = None
18
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language="arabic", task="transcribe")
19
 
20
 
21
  pipe = pipeline(
 
35
  audio_file = io.BytesIO(audio_data)
36
 
37
  # Load the audio file using pydub
38
+ audio_array, sampling_rate = librosa.load(audio_file, sr=16000)
 
 
 
 
 
 
 
39
 
40
+ # Process the audio array
41
+ input_features = processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt").input_features
 
 
42
 
43
+ # Generate token ids
44
+ predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
 
 
 
 
 
 
 
 
45
 
46
+ # Decode token ids to text
47
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
 
 
 
48
 
49
+ # Print the transcription
50
+ print(transcription[0]) # Display the transcriptiontry:
51
+
52
+ return {"transcription": transcription[0]}
53
  except Exception as e:
54
  raise HTTPException(status_code=500, detail=str(e))
55
 
requirements.txt CHANGED
@@ -10,4 +10,5 @@ numpy
10
  onnx
11
  optimum
12
  onnxruntime
13
- faster_whisper
 
 
10
  onnx
11
  optimum
12
  onnxruntime
13
+ faster_whisper
14
+ librosa