AshDavid12 commited on
Commit
40cde13
·
1 Parent(s): e909e6b

added logging

Browse files
Files changed (1) hide show
  1. infer.py +32 -7
infer.py CHANGED
@@ -3,17 +3,25 @@ import faster_whisper
3
  import tempfile
4
  import torch
5
  import requests
 
6
  from fastapi import FastAPI, HTTPException
7
  from pydantic import BaseModel
8
  from typing import Optional
9
 
 
 
 
10
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
11
 
12
  model_name = 'ivrit-ai/faster-whisper-v2-d4'
 
13
  model = faster_whisper.WhisperModel(model_name, device=device)
 
14
 
15
  # Maximum data size: 200MB
16
  MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
 
17
 
18
  app = FastAPI()
19
 
@@ -29,58 +37,76 @@ def download_file(url, max_size_bytes, output_filename, api_key=None):
29
  """
30
  Download a file from a given URL with size limit and optional API key.
31
  """
 
32
  try:
33
  headers = {}
34
  if api_key:
35
  headers['Authorization'] = f'Bearer {api_key}'
 
36
 
37
  response = requests.get(url, stream=True, headers=headers)
38
  response.raise_for_status()
39
 
40
  file_size = int(response.headers.get('Content-Length', 0))
 
41
 
42
  if file_size > max_size_bytes:
 
43
  return False
44
 
45
  downloaded_size = 0
46
  with open(output_filename, 'wb') as file:
47
  for chunk in response.iter_content(chunk_size=8192):
48
  downloaded_size += len(chunk)
 
49
  if downloaded_size > max_size_bytes:
 
50
  return False
51
  file.write(chunk)
52
 
 
53
  return True
54
 
55
  except requests.RequestException as e:
56
- print(f"Error downloading file: {e}")
57
  return False
58
 
59
 
60
  @app.post("/transcribe")
61
  async def transcribe(input_data: InputData):
 
62
  datatype = input_data.type
63
  if not datatype:
 
64
  raise HTTPException(status_code=400, detail="datatype field not provided. Should be 'blob' or 'url'.")
65
 
66
  if datatype not in ['blob', 'url']:
 
67
  raise HTTPException(status_code=400, detail=f"datatype should be 'blob' or 'url', but is {datatype} instead.")
68
 
69
  api_key = input_data.api_key
 
70
 
71
  with tempfile.TemporaryDirectory() as d:
72
  audio_file = f'{d}/audio.mp3'
 
73
 
74
  if datatype == 'blob':
75
  if not input_data.data:
 
76
  raise HTTPException(status_code=400, detail="Missing 'data' for 'blob' input.")
 
77
  mp3_bytes = base64.b64decode(input_data.data)
78
  open(audio_file, 'wb').write(mp3_bytes)
 
79
  elif datatype == 'url':
80
  if not input_data.url:
 
81
  raise HTTPException(status_code=400, detail="Missing 'url' for 'url' input.")
 
82
  success = download_file(input_data.url, MAX_PAYLOAD_SIZE, audio_file, api_key)
83
  if not success:
 
84
  raise HTTPException(status_code=400, detail=f"Error downloading data from {input_data.url}")
85
 
86
  result = transcribe_core(audio_file)
@@ -88,20 +114,19 @@ async def transcribe(input_data: InputData):
88
 
89
 
90
  def transcribe_core(audio_file):
91
- print('Transcribing...')
92
-
93
  ret = {'segments': []}
 
94
  segs, _ = model.transcribe(audio_file, language='he', word_timestamps=True)
 
 
95
  for s in segs:
96
  words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
97
  seg = {
98
  'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob,
99
  'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words
100
  }
101
- print(seg)
102
  ret['segments'].append(seg)
103
 
104
  return ret
105
-
106
-
107
-
 
3
  import tempfile
4
  import torch
5
  import requests
6
+ import logging
7
  from fastapi import FastAPI, HTTPException
8
  from pydantic import BaseModel
9
  from typing import Optional
10
 
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
13
+
14
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+ logging.info(f'Device selected: {device}')
16
 
17
  model_name = 'ivrit-ai/faster-whisper-v2-d4'
18
+ logging.info(f'Loading model: {model_name}')
19
  model = faster_whisper.WhisperModel(model_name, device=device)
20
+ logging.info('Model loaded successfully')
21
 
22
  # Maximum data size: 200MB
23
  MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
24
+ logging.info(f'Max payload size set to: {MAX_PAYLOAD_SIZE} bytes')
25
 
26
  app = FastAPI()
27
 
 
37
  """
38
  Download a file from a given URL with size limit and optional API key.
39
  """
40
+ logging.debug(f'Starting file download from URL: {url}')
41
  try:
42
  headers = {}
43
  if api_key:
44
  headers['Authorization'] = f'Bearer {api_key}'
45
+ logging.debug('API key provided, added to headers')
46
 
47
  response = requests.get(url, stream=True, headers=headers)
48
  response.raise_for_status()
49
 
50
  file_size = int(response.headers.get('Content-Length', 0))
51
+ logging.info(f'File size: {file_size} bytes')
52
 
53
  if file_size > max_size_bytes:
54
+ logging.error(f'File size exceeds limit: {file_size} > {max_size_bytes}')
55
  return False
56
 
57
  downloaded_size = 0
58
  with open(output_filename, 'wb') as file:
59
  for chunk in response.iter_content(chunk_size=8192):
60
  downloaded_size += len(chunk)
61
+ logging.debug(f'Downloaded {downloaded_size} bytes')
62
  if downloaded_size > max_size_bytes:
63
+ logging.error('Downloaded size exceeds maximum allowed payload size')
64
  return False
65
  file.write(chunk)
66
 
67
+ logging.info(f'File downloaded successfully: {output_filename}')
68
  return True
69
 
70
  except requests.RequestException as e:
71
+ logging.error(f"Error downloading file: {e}")
72
  return False
73
 
74
 
75
  @app.post("/transcribe")
76
  async def transcribe(input_data: InputData):
77
+ logging.debug(f'Received transcription request with data: {input_data}')
78
  datatype = input_data.type
79
  if not datatype:
80
+ logging.error('datatype field not provided')
81
  raise HTTPException(status_code=400, detail="datatype field not provided. Should be 'blob' or 'url'.")
82
 
83
  if datatype not in ['blob', 'url']:
84
+ logging.error(f'Invalid datatype: {datatype}')
85
  raise HTTPException(status_code=400, detail=f"datatype should be 'blob' or 'url', but is {datatype} instead.")
86
 
87
  api_key = input_data.api_key
88
+ logging.debug(f'API key: {api_key}')
89
 
90
  with tempfile.TemporaryDirectory() as d:
91
  audio_file = f'{d}/audio.mp3'
92
+ logging.debug(f'Created temporary directory: {d}')
93
 
94
  if datatype == 'blob':
95
  if not input_data.data:
96
+ logging.error("Missing 'data' for 'blob' input")
97
  raise HTTPException(status_code=400, detail="Missing 'data' for 'blob' input.")
98
+ logging.info('Decoding base64 blob data')
99
  mp3_bytes = base64.b64decode(input_data.data)
100
  open(audio_file, 'wb').write(mp3_bytes)
101
+ logging.info(f'Audio file written: {audio_file}')
102
  elif datatype == 'url':
103
  if not input_data.url:
104
+ logging.error("Missing 'url' for 'url' input")
105
  raise HTTPException(status_code=400, detail="Missing 'url' for 'url' input.")
106
+ logging.info(f'Downloading file from URL: {input_data.url}')
107
  success = download_file(input_data.url, MAX_PAYLOAD_SIZE, audio_file, api_key)
108
  if not success:
109
+ logging.error(f"Error downloading data from {input_data.url}")
110
  raise HTTPException(status_code=400, detail=f"Error downloading data from {input_data.url}")
111
 
112
  result = transcribe_core(audio_file)
 
114
 
115
 
116
  def transcribe_core(audio_file):
117
+ logging.info('Starting transcription...')
 
118
  ret = {'segments': []}
119
+
120
  segs, _ = model.transcribe(audio_file, language='he', word_timestamps=True)
121
+ logging.info('Transcription completed')
122
+
123
  for s in segs:
124
  words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
125
  seg = {
126
  'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob,
127
  'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words
128
  }
129
+ logging.debug(f'Transcription segment: {seg}')
130
  ret['segments'].append(seg)
131
 
132
  return ret