Spaces:
Sleeping
Sleeping
AshDavid12
commited on
Commit
·
40cde13
1
Parent(s):
e909e6b
added logging
Browse files
infer.py
CHANGED
@@ -3,17 +3,25 @@ import faster_whisper
|
|
3 |
import tempfile
|
4 |
import torch
|
5 |
import requests
|
|
|
6 |
from fastapi import FastAPI, HTTPException
|
7 |
from pydantic import BaseModel
|
8 |
from typing import Optional
|
9 |
|
|
|
|
|
|
|
10 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
11 |
|
12 |
model_name = 'ivrit-ai/faster-whisper-v2-d4'
|
|
|
13 |
model = faster_whisper.WhisperModel(model_name, device=device)
|
|
|
14 |
|
15 |
# Maximum data size: 200MB
|
16 |
MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
|
|
|
17 |
|
18 |
app = FastAPI()
|
19 |
|
@@ -29,58 +37,76 @@ def download_file(url, max_size_bytes, output_filename, api_key=None):
|
|
29 |
"""
|
30 |
Download a file from a given URL with size limit and optional API key.
|
31 |
"""
|
|
|
32 |
try:
|
33 |
headers = {}
|
34 |
if api_key:
|
35 |
headers['Authorization'] = f'Bearer {api_key}'
|
|
|
36 |
|
37 |
response = requests.get(url, stream=True, headers=headers)
|
38 |
response.raise_for_status()
|
39 |
|
40 |
file_size = int(response.headers.get('Content-Length', 0))
|
|
|
41 |
|
42 |
if file_size > max_size_bytes:
|
|
|
43 |
return False
|
44 |
|
45 |
downloaded_size = 0
|
46 |
with open(output_filename, 'wb') as file:
|
47 |
for chunk in response.iter_content(chunk_size=8192):
|
48 |
downloaded_size += len(chunk)
|
|
|
49 |
if downloaded_size > max_size_bytes:
|
|
|
50 |
return False
|
51 |
file.write(chunk)
|
52 |
|
|
|
53 |
return True
|
54 |
|
55 |
except requests.RequestException as e:
|
56 |
-
|
57 |
return False
|
58 |
|
59 |
|
60 |
@app.post("/transcribe")
|
61 |
async def transcribe(input_data: InputData):
|
|
|
62 |
datatype = input_data.type
|
63 |
if not datatype:
|
|
|
64 |
raise HTTPException(status_code=400, detail="datatype field not provided. Should be 'blob' or 'url'.")
|
65 |
|
66 |
if datatype not in ['blob', 'url']:
|
|
|
67 |
raise HTTPException(status_code=400, detail=f"datatype should be 'blob' or 'url', but is {datatype} instead.")
|
68 |
|
69 |
api_key = input_data.api_key
|
|
|
70 |
|
71 |
with tempfile.TemporaryDirectory() as d:
|
72 |
audio_file = f'{d}/audio.mp3'
|
|
|
73 |
|
74 |
if datatype == 'blob':
|
75 |
if not input_data.data:
|
|
|
76 |
raise HTTPException(status_code=400, detail="Missing 'data' for 'blob' input.")
|
|
|
77 |
mp3_bytes = base64.b64decode(input_data.data)
|
78 |
open(audio_file, 'wb').write(mp3_bytes)
|
|
|
79 |
elif datatype == 'url':
|
80 |
if not input_data.url:
|
|
|
81 |
raise HTTPException(status_code=400, detail="Missing 'url' for 'url' input.")
|
|
|
82 |
success = download_file(input_data.url, MAX_PAYLOAD_SIZE, audio_file, api_key)
|
83 |
if not success:
|
|
|
84 |
raise HTTPException(status_code=400, detail=f"Error downloading data from {input_data.url}")
|
85 |
|
86 |
result = transcribe_core(audio_file)
|
@@ -88,20 +114,19 @@ async def transcribe(input_data: InputData):
|
|
88 |
|
89 |
|
90 |
def transcribe_core(audio_file):
|
91 |
-
|
92 |
-
|
93 |
ret = {'segments': []}
|
|
|
94 |
segs, _ = model.transcribe(audio_file, language='he', word_timestamps=True)
|
|
|
|
|
95 |
for s in segs:
|
96 |
words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
|
97 |
seg = {
|
98 |
'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob,
|
99 |
'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words
|
100 |
}
|
101 |
-
|
102 |
ret['segments'].append(seg)
|
103 |
|
104 |
return ret
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
3 |
import tempfile
|
4 |
import torch
|
5 |
import requests
|
6 |
+
import logging
|
7 |
from fastapi import FastAPI, HTTPException
|
8 |
from pydantic import BaseModel
|
9 |
from typing import Optional
|
10 |
|
11 |
+
# Configure logging
|
12 |
+
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
13 |
+
|
14 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
15 |
+
logging.info(f'Device selected: {device}')
|
16 |
|
17 |
model_name = 'ivrit-ai/faster-whisper-v2-d4'
|
18 |
+
logging.info(f'Loading model: {model_name}')
|
19 |
model = faster_whisper.WhisperModel(model_name, device=device)
|
20 |
+
logging.info('Model loaded successfully')
|
21 |
|
22 |
# Maximum data size: 200MB
|
23 |
MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
|
24 |
+
logging.info(f'Max payload size set to: {MAX_PAYLOAD_SIZE} bytes')
|
25 |
|
26 |
app = FastAPI()
|
27 |
|
|
|
37 |
"""
|
38 |
Download a file from a given URL with size limit and optional API key.
|
39 |
"""
|
40 |
+
logging.debug(f'Starting file download from URL: {url}')
|
41 |
try:
|
42 |
headers = {}
|
43 |
if api_key:
|
44 |
headers['Authorization'] = f'Bearer {api_key}'
|
45 |
+
logging.debug('API key provided, added to headers')
|
46 |
|
47 |
response = requests.get(url, stream=True, headers=headers)
|
48 |
response.raise_for_status()
|
49 |
|
50 |
file_size = int(response.headers.get('Content-Length', 0))
|
51 |
+
logging.info(f'File size: {file_size} bytes')
|
52 |
|
53 |
if file_size > max_size_bytes:
|
54 |
+
logging.error(f'File size exceeds limit: {file_size} > {max_size_bytes}')
|
55 |
return False
|
56 |
|
57 |
downloaded_size = 0
|
58 |
with open(output_filename, 'wb') as file:
|
59 |
for chunk in response.iter_content(chunk_size=8192):
|
60 |
downloaded_size += len(chunk)
|
61 |
+
logging.debug(f'Downloaded {downloaded_size} bytes')
|
62 |
if downloaded_size > max_size_bytes:
|
63 |
+
logging.error('Downloaded size exceeds maximum allowed payload size')
|
64 |
return False
|
65 |
file.write(chunk)
|
66 |
|
67 |
+
logging.info(f'File downloaded successfully: {output_filename}')
|
68 |
return True
|
69 |
|
70 |
except requests.RequestException as e:
|
71 |
+
logging.error(f"Error downloading file: {e}")
|
72 |
return False
|
73 |
|
74 |
|
75 |
@app.post("/transcribe")
|
76 |
async def transcribe(input_data: InputData):
|
77 |
+
logging.debug(f'Received transcription request with data: {input_data}')
|
78 |
datatype = input_data.type
|
79 |
if not datatype:
|
80 |
+
logging.error('datatype field not provided')
|
81 |
raise HTTPException(status_code=400, detail="datatype field not provided. Should be 'blob' or 'url'.")
|
82 |
|
83 |
if datatype not in ['blob', 'url']:
|
84 |
+
logging.error(f'Invalid datatype: {datatype}')
|
85 |
raise HTTPException(status_code=400, detail=f"datatype should be 'blob' or 'url', but is {datatype} instead.")
|
86 |
|
87 |
api_key = input_data.api_key
|
88 |
+
logging.debug(f'API key: {api_key}')
|
89 |
|
90 |
with tempfile.TemporaryDirectory() as d:
|
91 |
audio_file = f'{d}/audio.mp3'
|
92 |
+
logging.debug(f'Created temporary directory: {d}')
|
93 |
|
94 |
if datatype == 'blob':
|
95 |
if not input_data.data:
|
96 |
+
logging.error("Missing 'data' for 'blob' input")
|
97 |
raise HTTPException(status_code=400, detail="Missing 'data' for 'blob' input.")
|
98 |
+
logging.info('Decoding base64 blob data')
|
99 |
mp3_bytes = base64.b64decode(input_data.data)
|
100 |
open(audio_file, 'wb').write(mp3_bytes)
|
101 |
+
logging.info(f'Audio file written: {audio_file}')
|
102 |
elif datatype == 'url':
|
103 |
if not input_data.url:
|
104 |
+
logging.error("Missing 'url' for 'url' input")
|
105 |
raise HTTPException(status_code=400, detail="Missing 'url' for 'url' input.")
|
106 |
+
logging.info(f'Downloading file from URL: {input_data.url}')
|
107 |
success = download_file(input_data.url, MAX_PAYLOAD_SIZE, audio_file, api_key)
|
108 |
if not success:
|
109 |
+
logging.error(f"Error downloading data from {input_data.url}")
|
110 |
raise HTTPException(status_code=400, detail=f"Error downloading data from {input_data.url}")
|
111 |
|
112 |
result = transcribe_core(audio_file)
|
|
|
114 |
|
115 |
|
116 |
def transcribe_core(audio_file):
|
117 |
+
logging.info('Starting transcription...')
|
|
|
118 |
ret = {'segments': []}
|
119 |
+
|
120 |
segs, _ = model.transcribe(audio_file, language='he', word_timestamps=True)
|
121 |
+
logging.info('Transcription completed')
|
122 |
+
|
123 |
for s in segs:
|
124 |
words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
|
125 |
seg = {
|
126 |
'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob,
|
127 |
'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words
|
128 |
}
|
129 |
+
logging.debug(f'Transcription segment: {seg}')
|
130 |
ret['segments'].append(seg)
|
131 |
|
132 |
return ret
|
|
|
|
|
|