nurfarah57 commited on
Commit
e261ec2
Β·
verified Β·
1 Parent(s): 099fb86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -6,28 +6,32 @@ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
6
  import torch
7
  import torchaudio
8
 
 
 
 
9
  app = FastAPI()
10
 
 
11
  processor = Wav2Vec2Processor.from_pretrained("tacab/ASR_SOMALI")
12
  model = Wav2Vec2ForCTC.from_pretrained("tacab/ASR_SOMALI")
13
  model.to("cpu")
14
 
15
  @app.post("/transcribe")
16
  async def transcribe(file: UploadFile = File(...)):
 
17
  audio_bytes = await file.read()
18
-
19
  temp_path = "/tmp/temp.wav"
20
  with open(temp_path, "wb") as f:
21
  f.write(audio_bytes)
22
 
 
23
  speech_array, sampling_rate = torchaudio.load(temp_path)
24
- inputs = processor(speech_array.squeeze(), return_tensors="pt", sampling_rate=sampling_rate)
25
 
 
 
26
  with torch.no_grad():
27
  logits = model(**inputs).logits
28
-
29
  predicted_ids = torch.argmax(logits, dim=-1)
30
  transcription = processor.batch_decode(predicted_ids)[0]
31
 
32
  return {"text": transcription}
33
-
 
6
  import torch
7
  import torchaudio
8
 
9
+ # βœ… Explicitly set the backend so .wav files load properly
10
+ torchaudio.set_audio_backend("soundfile")
11
+
12
  app = FastAPI()
13
 
14
+ # βœ… Load model and processor
15
  processor = Wav2Vec2Processor.from_pretrained("tacab/ASR_SOMALI")
16
  model = Wav2Vec2ForCTC.from_pretrained("tacab/ASR_SOMALI")
17
  model.to("cpu")
18
 
19
  @app.post("/transcribe")
20
  async def transcribe(file: UploadFile = File(...)):
21
+ # βœ… Save uploaded file to /tmp
22
  audio_bytes = await file.read()
 
23
  temp_path = "/tmp/temp.wav"
24
  with open(temp_path, "wb") as f:
25
  f.write(audio_bytes)
26
 
27
+ # βœ… Load audio file
28
  speech_array, sampling_rate = torchaudio.load(temp_path)
 
29
 
30
+ # βœ… Run through ASR model
31
+ inputs = processor(speech_array.squeeze(), return_tensors="pt", sampling_rate=sampling_rate)
32
  with torch.no_grad():
33
  logits = model(**inputs).logits
 
34
  predicted_ids = torch.argmax(logits, dim=-1)
35
  transcription = processor.batch_decode(predicted_ids)[0]
36
 
37
  return {"text": transcription}