Spaces:
Edmond98
/
Running on TPU v5e

Edmond7 commited on
Commit
9cf4194
1 Parent(s): cd03801

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -32
app.py CHANGED
@@ -2,12 +2,11 @@ import logging
2
  import math
3
  import time
4
  import base64
5
- import io
6
  import os
7
  from typing import Dict, Any
8
  from functools import wraps
9
 
10
- from fastapi import FastAPI, Depends, HTTPException, File, UploadFile
11
  from fastapi.encoders import jsonable_encoder
12
  from pydantic import BaseModel
13
  import jax.numpy as jnp
@@ -38,7 +37,7 @@ chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
38
  stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.sampling_rate)
39
  step = chunk_len - stride_left - stride_right
40
 
41
- # do a pre-compile step so that the first user to use the demo isn't hit with a long transcription time
42
  logger.debug("Compiling forward call...")
43
  start = time.time()
44
  random_inputs = {
@@ -51,7 +50,7 @@ compile_time = time.time() - start
51
  logger.debug(f"Compiled in {compile_time}s")
52
 
53
  class TranscribeAudioRequest(BaseModel):
54
- audio_base64: str = None
55
  task: str = "transcribe"
56
  return_timestamps: bool = False
57
 
@@ -69,40 +68,55 @@ def timeit(func):
69
  return result
70
  return wrapper
71
 
72
- def check_api_key():
73
  api_key = os.environ.get("WHISPER_API_KEY")
74
- if not api_key:
75
- raise HTTPException(status_code=401, detail="API key not found in environment variables")
76
- return api_key
77
 
78
- @app.post("/transcribe_audio")
79
  @timeit
80
- async def transcribe_chunked_audio(
81
- request: TranscribeAudioRequest = None,
82
- file: UploadFile = File(None),
 
83
  api_key: str = Depends(check_api_key)
84
  ) -> Dict[str, Any]:
85
- logger.debug("Starting transcribe_chunked_audio function")
86
- logger.debug(f"Received parameters - task: {request.task if request else 'transcribe'}, return_timestamps: {request.return_timestamps if request else False}")
87
 
88
  try:
89
- if file:
90
- logger.debug("Processing uploaded file")
91
- audio_data = await file.read()
92
- file_size = len(audio_data)
93
- elif request and request.audio_base64:
94
- logger.debug("Processing base64 encoded audio")
95
- audio_data = base64.b64decode(request.audio_base64)
96
- file_size = len(audio_data)
97
- else:
98
- raise HTTPException(status_code=400, detail="No audio data provided")
99
-
 
 
 
 
 
 
 
 
 
 
 
100
  file_size_mb = file_size / (1024 * 1024)
101
- logger.debug(f"Audio data size: {file_size} bytes ({file_size_mb:.2f}MB)")
102
  except Exception as e:
103
- logger.error(f"Error processing audio data: {str(e)}", exc_info=True)
104
- raise HTTPException(status_code=400, detail=f"Error processing audio data: {str(e)}")
 
 
105
 
 
106
  if file_size_mb > FILE_LIMIT_MB:
107
  logger.warning(f"Max file size exceeded: {file_size_mb:.2f}MB > {FILE_LIMIT_MB}MB")
108
  raise HTTPException(status_code=400, detail=f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB.")
@@ -118,15 +132,13 @@ async def transcribe_chunked_audio(
118
 
119
  logger.debug("Calling tqdm_generate to transcribe audio")
120
  try:
121
- task = request.task if request else "transcribe"
122
- return_timestamps = request.return_timestamps if request else False
123
  text, runtime, timing_info = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
124
  logger.debug(f"Transcription completed. Runtime: {runtime:.2f}s")
125
  except Exception as e:
126
  logger.error(f"Error in tqdm_generate: {str(e)}", exc_info=True)
127
  raise HTTPException(status_code=500, detail=f"Error transcribing audio: {str(e)}")
128
 
129
- logger.debug("Transcribe_chunked_audio function completed successfully")
130
  return jsonable_encoder({
131
  "text": text,
132
  "runtime": runtime,
@@ -211,4 +223,5 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
211
  return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
212
  else:
213
  # we have a malformed timestamp so just return it as is
214
- return seconds
 
 
2
  import math
3
  import time
4
  import base64
 
5
  import os
6
  from typing import Dict, Any
7
  from functools import wraps
8
 
9
+ from fastapi import FastAPI, Depends, HTTPException, File, UploadFile, Form, Header
10
  from fastapi.encoders import jsonable_encoder
11
  from pydantic import BaseModel
12
  import jax.numpy as jnp
 
37
  stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.sampling_rate)
38
  step = chunk_len - stride_left - stride_right
39
 
40
+ # Pre-compile step
41
  logger.debug("Compiling forward call...")
42
  start = time.time()
43
  random_inputs = {
 
50
  logger.debug(f"Compiled in {compile_time}s")
51
 
52
  class TranscribeAudioRequest(BaseModel):
53
+ audio_base64: str
54
  task: str = "transcribe"
55
  return_timestamps: bool = False
56
 
 
68
  return result
69
  return wrapper
70
 
71
+ def check_api_key(x_api_key: str = Header(...)):
72
  api_key = os.environ.get("WHISPER_API_KEY")
73
+ if not api_key or x_api_key != api_key:
74
+ raise HTTPException(status_code=401, detail="Invalid or missing API key")
75
+ return x_api_key
76
 
77
+ @app.post("/transcribe_audio_file")
78
  @timeit
79
+ async def transcribe_audio_file(
80
+ file: UploadFile = File(...),
81
+ task: str = Form("transcribe"),
82
+ return_timestamps: bool = Form(False),
83
  api_key: str = Depends(check_api_key)
84
  ) -> Dict[str, Any]:
85
+ logger.debug("Starting transcribe_audio_file function")
86
+ logger.debug(f"Received parameters - task: {task}, return_timestamps: {return_timestamps}")
87
 
88
  try:
89
+ audio_data = await file.read()
90
+ file_size = len(audio_data)
91
+ file_size_mb = file_size / (1024 * 1024)
92
+ logger.debug(f"Audio file size: {file_size} bytes ({file_size_mb:.2f}MB)")
93
+ except Exception as e:
94
+ logger.error(f"Error reading audio file: {str(e)}", exc_info=True)
95
+ raise HTTPException(status_code=400, detail=f"Error reading audio file: {str(e)}")
96
+
97
+ return await process_audio(audio_data, file_size_mb, task, return_timestamps)
98
+
99
+ @app.post("/transcribe_audio_base64")
100
+ @timeit
101
+ async def transcribe_audio_base64(
102
+ request: TranscribeAudioRequest,
103
+ api_key: str = Depends(check_api_key)
104
+ ) -> Dict[str, Any]:
105
+ logger.debug("Starting transcribe_audio_base64 function")
106
+ logger.debug(f"Received parameters - task: {request.task}, return_timestamps: {request.return_timestamps}")
107
+
108
+ try:
109
+ audio_data = base64.b64decode(request.audio_base64)
110
+ file_size = len(audio_data)
111
  file_size_mb = file_size / (1024 * 1024)
112
+ logger.debug(f"Decoded audio data size: {file_size} bytes ({file_size_mb:.2f}MB)")
113
  except Exception as e:
114
+ logger.error(f"Error decoding base64 audio data: {str(e)}", exc_info=True)
115
+ raise HTTPException(status_code=400, detail=f"Error decoding base64 audio data: {str(e)}")
116
+
117
+ return await process_audio(audio_data, file_size_mb, request.task, request.return_timestamps)
118
 
119
+ async def process_audio(audio_data: bytes, file_size_mb: float, task: str, return_timestamps: bool) -> Dict[str, Any]:
120
  if file_size_mb > FILE_LIMIT_MB:
121
  logger.warning(f"Max file size exceeded: {file_size_mb:.2f}MB > {FILE_LIMIT_MB}MB")
122
  raise HTTPException(status_code=400, detail=f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB.")
 
132
 
133
  logger.debug("Calling tqdm_generate to transcribe audio")
134
  try:
 
 
135
  text, runtime, timing_info = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
136
  logger.debug(f"Transcription completed. Runtime: {runtime:.2f}s")
137
  except Exception as e:
138
  logger.error(f"Error in tqdm_generate: {str(e)}", exc_info=True)
139
  raise HTTPException(status_code=500, detail=f"Error transcribing audio: {str(e)}")
140
 
141
+ logger.debug("Audio processing completed successfully")
142
  return jsonable_encoder({
143
  "text": text,
144
  "runtime": runtime,
 
223
  return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
224
  else:
225
  # we have a malformed timestamp so just return it as is
226
+ return seconds
227
+