Nechba commited on
Commit
9bb1bc6
·
verified ·
1 Parent(s): 419ab6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -5,7 +5,7 @@ import whisper
5
  import os
6
  import tempfile
7
  import io
8
-
9
 
10
  app = Flask(__name__)
11
 
@@ -24,24 +24,29 @@ ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer) # Rena
24
  @app.route('/transcribe', methods=['POST'])
25
  def transcribe():
26
  try:
27
- # Read the raw audio bytes
28
- audio_bytes = request.data # Get raw bytes from request
29
  if not audio_bytes:
30
  return jsonify({"error": "No audio data provided"}), 400
31
 
32
  # Convert bytes to a file-like object
33
  audio_file = io.BytesIO(audio_bytes)
34
 
 
 
 
 
 
 
35
  # Transcribe the audio
36
- result = whisper_model.transcribe(audio_file)
37
 
38
  return jsonify({"text": result["text"]})
39
 
40
  except Exception as e:
41
- print("Error:", str(e)) # Log the error
42
  return jsonify({"error": "Internal Server Error", "details": str(e)}), 500
43
 
44
-
45
 
46
 
47
  @app.route('/classify', methods=['POST'])
 
5
  import os
6
  import tempfile
7
  import io
8
+ import torchaudio
9
 
10
  app = Flask(__name__)
11
 
 
24
  @app.route('/transcribe', methods=['POST'])
25
  def transcribe():
26
  try:
27
+ # Read raw bytes from the request
28
+ audio_bytes = request.data
29
  if not audio_bytes:
30
  return jsonify({"error": "No audio data provided"}), 400
31
 
32
  # Convert bytes to a file-like object
33
  audio_file = io.BytesIO(audio_bytes)
34
 
35
+ # Load audio as a waveform using torchaudio
36
+ waveform, sample_rate = torchaudio.load(audio_file)
37
+
38
+ # Whisper expects a NumPy array, so we convert it
39
+ audio_numpy = waveform.squeeze().numpy()
40
+
41
  # Transcribe the audio
42
+ result = model.transcribe(audio_numpy)
43
 
44
  return jsonify({"text": result["text"]})
45
 
46
  except Exception as e:
47
+ print("Error:", str(e)) # Log error for debugging
48
  return jsonify({"error": "Internal Server Error", "details": str(e)}), 500
49
 
 
50
 
51
 
52
  @app.route('/classify', methods=['POST'])