MJobe commited on
Commit
d99db10
1 Parent(s): e6db199

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +29 -27
main.py CHANGED
@@ -1,5 +1,5 @@
1
  import fitz
2
- from fastapi import FastAPI, File, UploadFile, Form
3
  from fastapi.responses import JSONResponse
4
  from transformers import pipeline
5
  from PIL import Image
@@ -12,6 +12,8 @@ import numpy as np
12
  import json
13
  import torchaudio
14
  import torch
 
 
15
 
16
  app = FastAPI()
17
 
@@ -164,34 +166,34 @@ async def transcribe_and_answer(
164
  file: UploadFile = File(...),
165
  questions: str = Form(...)
166
  ):
 
 
 
 
 
 
 
 
167
  try:
168
- # Step 1: Read and convert the audio file
169
- contents = await file.read()
170
- audio = AudioSegment.from_file(BytesIO(contents))
171
-
172
- # Step 2: Ensure the audio is mono and resample if needed
173
- audio = audio.set_channels(1) # Convert to mono if it's not already
174
- audio = audio.set_frame_rate(16000) # Resample to 16000 Hz, commonly required by ASR models
175
-
176
- # Step 3: Export to WAV format and load with torchaudio
177
- wav_buffer = BytesIO()
178
- audio.export(wav_buffer, format="wav")
179
- wav_buffer.seek(0)
180
-
181
- # Load audio using torchaudio
182
- waveform, sample_rate = torchaudio.load(wav_buffer)
183
 
184
- # Convert waveform to float32 and ensure it's a numpy array
185
- waveform_np = waveform.numpy().astype(np.float32)
186
-
187
- # Step 4: Transcribe the audio
188
- transcription_result = nlp_speech_to_text(waveform_np)
189
- transcription_text = transcription_result['text']
190
-
191
- # Step 5: Parse the JSON-formatted questions
192
  questions_dict = json.loads(questions)
193
 
194
- # Step 6: Answer each question using the transcribed text
195
  answers_dict = {}
196
  for key, question in questions_dict.items():
197
  QA_input = {
@@ -202,14 +204,14 @@ async def transcribe_and_answer(
202
  result = nlp_qa_v3(QA_input)
203
  answers_dict[key] = result['answer']
204
 
205
- # Step 7: Return transcription + answers
206
  return {
207
  "transcription": transcription_text,
208
  "answers": answers_dict
209
  }
210
 
211
  except Exception as e:
212
- return JSONResponse(content={"error": f"Error processing audio or answering questions: {str(e)}"}, status_code=500)
213
 
214
  # Set up CORS middleware
215
  origins = ["*"] # or specify your list of allowed origins
 
1
  import fitz
2
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
3
  from fastapi.responses import JSONResponse
4
  from transformers import pipeline
5
  from PIL import Image
 
12
  import json
13
  import torchaudio
14
  import torch
15
+ from pydub import AudioSegment
16
+ import speech_recognition as sr
17
 
18
  app = FastAPI()
19
 
 
166
  file: UploadFile = File(...),
167
  questions: str = Form(...)
168
  ):
169
+ # Check the file format and read it
170
+ if file.content_type not in ["audio/wav", "audio/mpeg", "audio/mp3"]:
171
+ raise HTTPException(status_code=400, detail="Unsupported audio format. Please upload a WAV or MP3 file.")
172
+
173
+ # Convert uploaded file to WAV if necessary (for SpeechRecognition compatibility)
174
+ audio_data = await file.read()
175
+ audio_file = io.BytesIO(audio_data)
176
+
177
  try:
178
+ # Convert MP3 to WAV if needed
179
+ if file.content_type == "audio/mpeg" or file.content_type == "audio/mp3":
180
+ audio = AudioSegment.from_mp3(audio_file)
181
+ audio_wav = io.BytesIO()
182
+ audio.export(audio_wav, format="wav")
183
+ audio_wav.seek(0)
184
+ else:
185
+ audio_wav = audio_file
 
 
 
 
 
 
 
186
 
187
+ # Load audio into SpeechRecognition and transcribe
188
+ recognizer = sr.Recognizer()
189
+ with sr.AudioFile(audio_wav) as source:
190
+ audio = recognizer.record(source)
191
+ transcription_text = recognizer.recognize_google(audio)
192
+
193
+ # Parse the JSON-formatted questions
 
194
  questions_dict = json.loads(questions)
195
 
196
+ # Answer each question based on the transcription text
197
  answers_dict = {}
198
  for key, question in questions_dict.items():
199
  QA_input = {
 
204
  result = nlp_qa_v3(QA_input)
205
  answers_dict[key] = result['answer']
206
 
207
+ # Return transcription + answers
208
  return {
209
  "transcription": transcription_text,
210
  "answers": answers_dict
211
  }
212
 
213
  except Exception as e:
214
+ raise HTTPException(status_code=500, detail=f"Error during transcription or question answering: {str(e)}")
215
 
216
  # Set up CORS middleware
217
  origins = ["*"] # or specify your list of allowed origins