AshDavid12 commited on
Commit
9bd82d6
·
1 Parent(s): 1bd5368

origin infer wo runpod

Browse files
Files changed (3) hide show
  1. Dockerfile +14 -59
  2. infer.py +41 -141
  3. whisper_online.py +7 -4
Dockerfile CHANGED
@@ -1,65 +1,20 @@
1
- FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04
2
-
3
- ENV PYTHON_VERSION=3.11
4
-
5
- RUN export DEBIAN_FRONTEND=noninteractive \
6
- && apt-get -qq update \
7
- && apt-get -qq install --no-install-recommends \
8
- python${PYTHON_VERSION} \
9
- python${PYTHON_VERSION}-venv \
10
- python3-pip \
11
- libcublas11 \
12
- && rm -rf /var/lib/apt/lists/*
13
-
14
-
15
- # Set up Python environment
16
- RUN python3 -m pip install --upgrade pip
17
-
18
- # Copy the requirements file and install Python packages
19
- COPY requirements.txt .
20
- RUN pip install -r requirements.txt
21
-
22
- # Install the specific model using faster-whisper
23
- #RUN python3 -c 'import faster_whisper; m = faster_whisper.WhisperModel("ivrit-ai/faster-whisper-v2-d3-e3")'
24
- # Set the SENTENCE_TRANSFORMERS_HOME environment variable to a writable directory
25
- # Set environment variables for cache directories
26
- ENV SENTENCE_TRANSFORMERS_HOME="/tmp/.cache/sentence_transformers"
27
- ENV HF_HOME="/tmp/.cache/huggingface"
28
-
29
- # Ensure the cache directories exist
30
- RUN mkdir -p $SENTENCE_TRANSFORMERS_HOME $HF_HOME
31
-
32
-
33
-
34
- # Add your Python scripts
35
- COPY infer.py .
36
- COPY whisper_online.py .
37
 
38
- EXPOSE 7860
39
- # Run the infer.py script when the container starts
40
- CMD ["python3", "-u", "/infer.py"]
41
 
 
 
 
 
42
 
 
43
 
 
 
44
 
 
45
 
46
- # Include Python
47
- #from python:3.11.1-buster
48
- #
49
- ## Define your working directory
50
- #WORKDIR /
51
- #
52
- ## Install runpod
53
- #COPY requirements.txt .
54
- #RUN pip install -r requirements.txt
55
- #
56
- #RUN python3 -c 'import faster_whisper; m = faster_whisper.WhisperModel("ivrit-ai/faster-whisper-v2-d3-e3")'
57
- #
58
- ## Add your file
59
- #ADD infer.py .
60
- #ADD whisper_online.py .
61
- #
62
- #ENV LD_LIBRARY_PATH="/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib:/usr/local/lib/python3.11/site-packages/nvidia/cublas/lib"
63
- #
64
- ## Call your file when your container starts
65
- #CMD [ "python", "-u", "/infer.py" ]
 
1
+ # Include Python
2
+ from python:3.11.1-buster
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ # Define your working directory
5
+ WORKDIR /
 
6
 
7
+ # Install runpod
8
+ RUN pip install runpod
9
+ RUN pip install torch==2.3.1
10
+ RUN pip install faster-whisper
11
 
12
+ RUN python3 -c 'import faster_whisper; m = faster_whisper.WhisperModel("ivrit-ai/faster-whisper-v2-d4")'
13
 
14
+ # Add your file
15
+ ADD infer.py .
16
 
17
+ ENV LD_LIBRARY_PATH="/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib:/usr/local/lib/python3.11/site-packages/nvidia/cublas/lib"
18
 
19
+ # Call your file when your container starts
20
+ CMD [ "python", "-u", "/infer.py" ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer.py CHANGED
@@ -1,5 +1,3 @@
1
- import runpod
2
-
3
  import base64
4
  import faster_whisper
5
  import tempfile
@@ -11,195 +9,97 @@ import os
11
 
12
  import whisper_online
13
 
 
14
  logger = logging.getLogger(__name__)
15
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)])
16
 
17
- # Try to import the module
18
- try:
19
- logging.info("attempting to load whisper online")
20
- from whisper_online import * # Replace 'some_module' with the actual module name
21
-
22
- logging.info("Successfully imported whisper_online.")
23
- except ImportError as e:
24
- logging.error(f"Failed to import whisper_online: {e}", exc_info=True)
25
- except Exception as e:
26
- logging.error(f"Unknown from exception- error to import whisper_online: {e}", exc_info=True)
27
-
28
- if torch.cuda.is_available():
29
- logging.info(f"CUDA is available.")
30
- else:
31
- logging.info("CUDA is not available. Using CPU.")
32
-
33
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
34
-
35
  model_name = 'ivrit-ai/faster-whisper-v2-d3-e3'
36
- logging.info(f"Selected model name: {model_name}")
37
- #model = faster_whisper.WhisperModel(model_name, device=device)
38
  try:
39
  lan = 'he'
40
  logging.info(f"Attempting to initialize FasterWhisperASR with device: {device}")
41
- logging.info(f"Cache directory before: {tempfile.gettempdir()}") # Log the temp directory
42
- cache_dir = os.environ.get('XDG_CACHE_HOME', tempfile.gettempdir())
43
- logging.info(f"Cache directory after: {tempfile.gettempdir()}") # Log the temp directory
44
- model = whisper_online.FasterWhisperASR(lan=lan, modelsize=model_name, cache_dir=cache_dir, model_dir=None)
45
  logging.info("FasterWhisperASR model initialized successfully.")
46
  except Exception as e:
47
- logging.error(f"Falied to inilialize faster whisper model {e}")
48
 
49
  # Maximum data size: 200MB
50
  MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
51
 
52
-
53
  def download_file(url, max_size_bytes, output_filename, api_key=None):
54
- """
55
- Download a file from a given URL with size limit and optional API key.
56
-
57
- Args:
58
- url (str): The URL of the file to download.
59
- max_size_bytes (int): Maximum allowed file size in bytes.
60
- output_filename (str): The name of the file to save the download as.
61
- api_key (str, optional): API key to be used as a bearer token.
62
-
63
- Returns:
64
- bool: True if download was successful, False otherwise.
65
- """
66
  try:
67
- # Prepare headers
68
  headers = {}
69
  if api_key:
70
  headers['Authorization'] = f'Bearer {api_key}'
71
-
72
- # Send a GET request
73
  response = requests.get(url, stream=True, headers=headers)
74
- response.raise_for_status() # Raises an HTTPError for bad requests
75
-
76
- # Get the file size if possible
77
  file_size = int(response.headers.get('Content-Length', 0))
78
-
79
  if file_size > max_size_bytes:
80
- print(f"File size ({file_size} bytes) exceeds the maximum allowed size ({max_size_bytes} bytes).")
81
  return False
82
-
83
- # Download and write the file
84
  downloaded_size = 0
85
  with open(output_filename, 'wb') as file:
86
  for chunk in response.iter_content(chunk_size=8192):
87
  downloaded_size += len(chunk)
88
  if downloaded_size > max_size_bytes:
89
- print(f"Download stopped: Size limit exceeded ({max_size_bytes} bytes).")
90
  return False
91
  file.write(chunk)
92
-
93
  print(f"File downloaded successfully: {output_filename}")
94
  return True
95
-
96
  except requests.RequestException as e:
97
  print(f"Error downloading file: {e}")
98
  return False
99
 
100
-
101
- def transcribe(job):
102
- datatype = job['input'].get('type', None)
103
- if not datatype:
104
- return {"error": "datatype field not provided. Should be 'blob' or 'url'."}
105
-
106
- if not datatype in ['blob', 'url']:
107
- return {"error": f"datatype should be 'blob' or 'url', but is {datatype} instead."}
108
-
109
- # Get the API key from the job input
110
- api_key = job['input'].get('api_key', None)
111
-
112
- with tempfile.TemporaryDirectory() as d:
113
- audio_file = f'{d}/audio.mp3'
114
-
115
- if datatype == 'blob':
116
- mp3_bytes = base64.b64decode(job['input']['data'])
117
- open(audio_file, 'wb').write(mp3_bytes)
118
- elif datatype == 'url':
119
- success = download_file(job['input']['url'], MAX_PAYLOAD_SIZE, audio_file, api_key)
120
- if not success:
121
- return {"error": f"Error downloading data from {job['input']['url']}"}
122
-
123
- result = transcribe_core(audio_file)
124
- return {'result': result}
125
-
126
-
127
- def transcribe_core(audio_file):
128
- print('Transcribing...')
129
-
130
  ret = {'segments': []}
131
-
132
- segs, dummy = model.transcribe(audio_file, language='he', word_timestamps=True)
133
- for s in segs:
134
- words = []
135
- for w in s.words:
136
- words.append({'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability})
137
-
138
- seg = {'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob,
139
- 'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words}
140
-
141
- print(seg)
142
- ret['segments'].append(seg)
143
-
144
  return ret
145
 
146
-
147
- #runpod.serverless.start({"handler": transcribe})
148
-
149
  def transcribe_whisper(job):
150
- logging.info(f"in triscribe-whisper")
151
- datatype = job['input'].get('type', None)
 
152
  if not datatype:
153
  return {"error": "datatype field not provided. Should be 'blob' or 'url'."}
 
 
154
 
155
- if not datatype in ['blob', 'url']:
156
- return {"error": f"datatype should be 'blob' or 'url', but is {datatype} instead."}
157
-
158
- # Get the API key from the job input
159
- api_key = job['input'].get('api_key', None)
160
-
161
  with tempfile.TemporaryDirectory() as d:
162
  audio_file = f'{d}/audio.mp3'
163
-
164
  if datatype == 'blob':
165
  mp3_bytes = base64.b64decode(job['input']['data'])
166
- open(audio_file, 'wb').write(mp3_bytes)
 
167
  elif datatype == 'url':
168
  success = download_file(job['input']['url'], MAX_PAYLOAD_SIZE, audio_file, api_key)
169
  if not success:
170
- return {"error": f"Error downloading data from {job['input']['url']}"}
171
- logging.info("Starting transcription process using transcribe_core_whisper.")
172
  result = transcribe_core_whisper(audio_file)
173
- logging.info(f"DONE: in triscribe-whisper")
174
  return {'result': result}
175
 
176
- def transcribe_core_whisper(audio_file):
177
- print('Transcribing...')
178
-
179
- ret = {'segments': []}
180
-
181
- try:
182
- logging.debug(f"Transcribing audio file: {audio_file}")
183
-
184
- segs = model.transcribe(audio_file, init_prompt="")
185
- logging.info("Transcription completed successfully.")
186
- for s in segs:
187
- words = []
188
- for w in s.words:
189
- words.append({'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability})
190
-
191
- seg = {'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob,
192
- 'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words}
193
- logging.debug(f"All segments processed. Final transcription result: {ret}")
194
- print(seg)
195
- ret['segments'].append(seg)
196
-
197
- except Exception as e:
198
- # Log any exception that occurs during the transcription process
199
- logging.error(f"Error during transcribe_core_whisper: {e}", exc_info=True)
200
- return {"error": str(e)}
201
- # Return the final result
202
- logging.info("Transcription core function completed.")
203
- return ret
204
-
205
- #runpod.serverless.start({"handler": transcribe_whisper})
 
 
 
1
  import base64
2
  import faster_whisper
3
  import tempfile
 
9
 
10
  import whisper_online
11
 
12
+ # Set up logging
13
  logger = logging.getLogger(__name__)
14
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)])
15
 
16
+ # Load the FasterWhisper model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
18
  model_name = 'ivrit-ai/faster-whisper-v2-d3-e3'
19
+
 
20
  try:
21
  lan = 'he'
22
  logging.info(f"Attempting to initialize FasterWhisperASR with device: {device}")
23
+ model = whisper_online.FasterWhisperASR(lan=lan, modelsize=model_name)
 
 
 
24
  logging.info("FasterWhisperASR model initialized successfully.")
25
  except Exception as e:
26
+ logging.error(f"Failed to initialize FasterWhisperASR model: {e}")
27
 
28
  # Maximum data size: 200MB
29
  MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
30
 
 
31
  def download_file(url, max_size_bytes, output_filename, api_key=None):
32
+ """Download a file from a given URL with size limit and optional API key."""
 
 
 
 
 
 
 
 
 
 
 
33
  try:
 
34
  headers = {}
35
  if api_key:
36
  headers['Authorization'] = f'Bearer {api_key}'
 
 
37
  response = requests.get(url, stream=True, headers=headers)
38
+ response.raise_for_status()
 
 
39
  file_size = int(response.headers.get('Content-Length', 0))
 
40
  if file_size > max_size_bytes:
41
+ print(f"File size exceeds the limit: {file_size} bytes.")
42
  return False
 
 
43
  downloaded_size = 0
44
  with open(output_filename, 'wb') as file:
45
  for chunk in response.iter_content(chunk_size=8192):
46
  downloaded_size += len(chunk)
47
  if downloaded_size > max_size_bytes:
48
+ print(f"Download stopped: size limit exceeded.")
49
  return False
50
  file.write(chunk)
 
51
  print(f"File downloaded successfully: {output_filename}")
52
  return True
 
53
  except requests.RequestException as e:
54
  print(f"Error downloading file: {e}")
55
  return False
56
 
57
+ def transcribe_core_whisper(audio_file):
58
+ """Transcribe the audio file using FasterWhisper."""
59
+ logging.info(f"Transcribing audio file: {audio_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  ret = {'segments': []}
61
+ try:
62
+ segs, dummy = model.transcribe(audio_file, language='he', word_timestamps=True)
63
+ for s in segs:
64
+ words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
65
+ seg = {'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob,
66
+ 'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words}
67
+ ret['segments'].append(seg)
68
+ logging.info("Transcription completed successfully.")
69
+ except Exception as e:
70
+ logging.error(f"Error during transcription: {e}", exc_info=True)
 
 
 
71
  return ret
72
 
 
 
 
73
  def transcribe_whisper(job):
74
+ """Main transcription handler."""
75
+ logging.info(f"Processing job: {job}")
76
+ datatype = job.get('input', {}).get('type')
77
  if not datatype:
78
  return {"error": "datatype field not provided. Should be 'blob' or 'url'."}
79
+ if datatype not in ['blob', 'url']:
80
+ return {"error": f"Invalid datatype: {datatype}."}
81
 
82
+ api_key = job.get('input', {}).get('api_key')
 
 
 
 
 
83
  with tempfile.TemporaryDirectory() as d:
84
  audio_file = f'{d}/audio.mp3'
 
85
  if datatype == 'blob':
86
  mp3_bytes = base64.b64decode(job['input']['data'])
87
+ with open(audio_file, 'wb') as f:
88
+ f.write(mp3_bytes)
89
  elif datatype == 'url':
90
  success = download_file(job['input']['url'], MAX_PAYLOAD_SIZE, audio_file, api_key)
91
  if not success:
92
+ return {"error": f"Failed to download from {job['input']['url']}"}
93
+
94
  result = transcribe_core_whisper(audio_file)
 
95
  return {'result': result}
96
 
97
+ # Example job input to test locally
98
+ if __name__ == "__main__":
99
+ test_job = {
100
+ "input": {
101
+ "type": "url",
102
+ "url": "https://github.com/metaldaniel/HebrewASR-Comparison/raw/main/HaTankistiot_n12-mp3.mp3",
103
+ }
104
+ }
105
+ print(transcribe_whisper(test_job))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisper_online.py CHANGED
@@ -105,15 +105,15 @@ class FasterWhisperASR(ASRBase):
105
 
106
  sep = ""
107
 
108
- def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
109
  from faster_whisper import WhisperModel
110
  # logging.getLogger("faster_whisper").setLevel(logger.level)
111
 
112
  logging.info("Starting model loading process...")
113
- logging.debug(f"Model loading parameters - modelsize: {modelsize}, cache_dir: {cache_dir}, model_dir: {model_dir}")
114
 
115
  if model_dir is not None:
116
- logger.debug(
117
  f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.")
118
  model_size_or_path = model_dir
119
  elif modelsize is not None:
@@ -123,7 +123,10 @@ class FasterWhisperASR(ASRBase):
123
 
124
  try:
125
  logging.info(f"Loading WhisperModel on device: ")
126
- logging.info(f"Cache directory in online: {tempfile.gettempdir()}") # Log the temp directory
 
 
 
127
  model = WhisperModel(model_size_or_path, device="cuda", compute_type="float16", download_root=cache_dir)
128
  logging.info("Model loaded successfully.")
129
  except Exception as e:
 
105
 
106
  sep = ""
107
 
108
+ def load_model(self, modelsize=None, cache_dir="/tmp/.cache/huggingface", model_dir=None):
109
  from faster_whisper import WhisperModel
110
  # logging.getLogger("faster_whisper").setLevel(logger.level)
111
 
112
  logging.info("Starting model loading process...")
113
+ logging.info(f"Model loading parameters - modelsize: {modelsize}, cache_dir: {cache_dir}, model_dir: {model_dir}")
114
 
115
  if model_dir is not None:
116
+ logger.info(
117
  f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.")
118
  model_size_or_path = model_dir
119
  elif modelsize is not None:
 
123
 
124
  try:
125
  logging.info(f"Loading WhisperModel on device: ")
126
+ os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/tmp/.cache/sentence_transformers'
127
+ os.environ['HF_HOME'] = '/tmp/.cache/huggingface'
128
+ # Ensure the cache directory exists
129
+ os.makedirs(cache_dir, exist_ok=True)
130
  model = WhisperModel(model_size_or_path, device="cuda", compute_type="float16", download_root=cache_dir)
131
  logging.info("Model loaded successfully.")
132
  except Exception as e: