Spaces:
Sleeping
Sleeping
AshDavid12
commited on
Commit
·
9bd82d6
1
Parent(s):
1bd5368
origin infer wo runpod
Browse files- Dockerfile +14 -59
- infer.py +41 -141
- whisper_online.py +7 -4
Dockerfile
CHANGED
@@ -1,65 +1,20 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
ENV PYTHON_VERSION=3.11
|
4 |
-
|
5 |
-
RUN export DEBIAN_FRONTEND=noninteractive \
|
6 |
-
&& apt-get -qq update \
|
7 |
-
&& apt-get -qq install --no-install-recommends \
|
8 |
-
python${PYTHON_VERSION} \
|
9 |
-
python${PYTHON_VERSION}-venv \
|
10 |
-
python3-pip \
|
11 |
-
libcublas11 \
|
12 |
-
&& rm -rf /var/lib/apt/lists/*
|
13 |
-
|
14 |
-
|
15 |
-
# Set up Python environment
|
16 |
-
RUN python3 -m pip install --upgrade pip
|
17 |
-
|
18 |
-
# Copy the requirements file and install Python packages
|
19 |
-
COPY requirements.txt .
|
20 |
-
RUN pip install -r requirements.txt
|
21 |
-
|
22 |
-
# Install the specific model using faster-whisper
|
23 |
-
#RUN python3 -c 'import faster_whisper; m = faster_whisper.WhisperModel("ivrit-ai/faster-whisper-v2-d3-e3")'
|
24 |
-
# Set the SENTENCE_TRANSFORMERS_HOME environment variable to a writable directory
|
25 |
-
# Set environment variables for cache directories
|
26 |
-
ENV SENTENCE_TRANSFORMERS_HOME="/tmp/.cache/sentence_transformers"
|
27 |
-
ENV HF_HOME="/tmp/.cache/huggingface"
|
28 |
-
|
29 |
-
# Ensure the cache directories exist
|
30 |
-
RUN mkdir -p $SENTENCE_TRANSFORMERS_HOME $HF_HOME
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
# Add your Python scripts
|
35 |
-
COPY infer.py .
|
36 |
-
COPY whisper_online.py .
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
CMD ["python3", "-u", "/infer.py"]
|
41 |
|
|
|
|
|
|
|
|
|
42 |
|
|
|
43 |
|
|
|
|
|
44 |
|
|
|
45 |
|
46 |
-
#
|
47 |
-
|
48 |
-
#
|
49 |
-
## Define your working directory
|
50 |
-
#WORKDIR /
|
51 |
-
#
|
52 |
-
## Install runpod
|
53 |
-
#COPY requirements.txt .
|
54 |
-
#RUN pip install -r requirements.txt
|
55 |
-
#
|
56 |
-
#RUN python3 -c 'import faster_whisper; m = faster_whisper.WhisperModel("ivrit-ai/faster-whisper-v2-d3-e3")'
|
57 |
-
#
|
58 |
-
## Add your file
|
59 |
-
#ADD infer.py .
|
60 |
-
#ADD whisper_online.py .
|
61 |
-
#
|
62 |
-
#ENV LD_LIBRARY_PATH="/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib:/usr/local/lib/python3.11/site-packages/nvidia/cublas/lib"
|
63 |
-
#
|
64 |
-
## Call your file when your container starts
|
65 |
-
#CMD [ "python", "-u", "/infer.py" ]
|
|
|
1 |
+
# Include Python
|
2 |
+
from python:3.11.1-buster
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
+
# Define your working directory
|
5 |
+
WORKDIR /
|
|
|
6 |
|
7 |
+
# Install runpod
|
8 |
+
RUN pip install runpod
|
9 |
+
RUN pip install torch==2.3.1
|
10 |
+
RUN pip install faster-whisper
|
11 |
|
12 |
+
RUN python3 -c 'import faster_whisper; m = faster_whisper.WhisperModel("ivrit-ai/faster-whisper-v2-d4")'
|
13 |
|
14 |
+
# Add your file
|
15 |
+
ADD infer.py .
|
16 |
|
17 |
+
ENV LD_LIBRARY_PATH="/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib:/usr/local/lib/python3.11/site-packages/nvidia/cublas/lib"
|
18 |
|
19 |
+
# Call your file when your container starts
|
20 |
+
CMD [ "python", "-u", "/infer.py" ]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
infer.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
import runpod
|
2 |
-
|
3 |
import base64
|
4 |
import faster_whisper
|
5 |
import tempfile
|
@@ -11,195 +9,97 @@ import os
|
|
11 |
|
12 |
import whisper_online
|
13 |
|
|
|
14 |
logger = logging.getLogger(__name__)
|
15 |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)])
|
16 |
|
17 |
-
#
|
18 |
-
try:
|
19 |
-
logging.info("attempting to load whisper online")
|
20 |
-
from whisper_online import * # Replace 'some_module' with the actual module name
|
21 |
-
|
22 |
-
logging.info("Successfully imported whisper_online.")
|
23 |
-
except ImportError as e:
|
24 |
-
logging.error(f"Failed to import whisper_online: {e}", exc_info=True)
|
25 |
-
except Exception as e:
|
26 |
-
logging.error(f"Unknown from exception- error to import whisper_online: {e}", exc_info=True)
|
27 |
-
|
28 |
-
if torch.cuda.is_available():
|
29 |
-
logging.info(f"CUDA is available.")
|
30 |
-
else:
|
31 |
-
logging.info("CUDA is not available. Using CPU.")
|
32 |
-
|
33 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
34 |
-
|
35 |
model_name = 'ivrit-ai/faster-whisper-v2-d3-e3'
|
36 |
-
|
37 |
-
#model = faster_whisper.WhisperModel(model_name, device=device)
|
38 |
try:
|
39 |
lan = 'he'
|
40 |
logging.info(f"Attempting to initialize FasterWhisperASR with device: {device}")
|
41 |
-
|
42 |
-
cache_dir = os.environ.get('XDG_CACHE_HOME', tempfile.gettempdir())
|
43 |
-
logging.info(f"Cache directory after: {tempfile.gettempdir()}") # Log the temp directory
|
44 |
-
model = whisper_online.FasterWhisperASR(lan=lan, modelsize=model_name, cache_dir=cache_dir, model_dir=None)
|
45 |
logging.info("FasterWhisperASR model initialized successfully.")
|
46 |
except Exception as e:
|
47 |
-
logging.error(f"
|
48 |
|
49 |
# Maximum data size: 200MB
|
50 |
MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
|
51 |
|
52 |
-
|
53 |
def download_file(url, max_size_bytes, output_filename, api_key=None):
|
54 |
-
"""
|
55 |
-
Download a file from a given URL with size limit and optional API key.
|
56 |
-
|
57 |
-
Args:
|
58 |
-
url (str): The URL of the file to download.
|
59 |
-
max_size_bytes (int): Maximum allowed file size in bytes.
|
60 |
-
output_filename (str): The name of the file to save the download as.
|
61 |
-
api_key (str, optional): API key to be used as a bearer token.
|
62 |
-
|
63 |
-
Returns:
|
64 |
-
bool: True if download was successful, False otherwise.
|
65 |
-
"""
|
66 |
try:
|
67 |
-
# Prepare headers
|
68 |
headers = {}
|
69 |
if api_key:
|
70 |
headers['Authorization'] = f'Bearer {api_key}'
|
71 |
-
|
72 |
-
# Send a GET request
|
73 |
response = requests.get(url, stream=True, headers=headers)
|
74 |
-
response.raise_for_status()
|
75 |
-
|
76 |
-
# Get the file size if possible
|
77 |
file_size = int(response.headers.get('Content-Length', 0))
|
78 |
-
|
79 |
if file_size > max_size_bytes:
|
80 |
-
print(f"File size
|
81 |
return False
|
82 |
-
|
83 |
-
# Download and write the file
|
84 |
downloaded_size = 0
|
85 |
with open(output_filename, 'wb') as file:
|
86 |
for chunk in response.iter_content(chunk_size=8192):
|
87 |
downloaded_size += len(chunk)
|
88 |
if downloaded_size > max_size_bytes:
|
89 |
-
print(f"Download stopped:
|
90 |
return False
|
91 |
file.write(chunk)
|
92 |
-
|
93 |
print(f"File downloaded successfully: {output_filename}")
|
94 |
return True
|
95 |
-
|
96 |
except requests.RequestException as e:
|
97 |
print(f"Error downloading file: {e}")
|
98 |
return False
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
if not datatype:
|
104 |
-
return {"error": "datatype field not provided. Should be 'blob' or 'url'."}
|
105 |
-
|
106 |
-
if not datatype in ['blob', 'url']:
|
107 |
-
return {"error": f"datatype should be 'blob' or 'url', but is {datatype} instead."}
|
108 |
-
|
109 |
-
# Get the API key from the job input
|
110 |
-
api_key = job['input'].get('api_key', None)
|
111 |
-
|
112 |
-
with tempfile.TemporaryDirectory() as d:
|
113 |
-
audio_file = f'{d}/audio.mp3'
|
114 |
-
|
115 |
-
if datatype == 'blob':
|
116 |
-
mp3_bytes = base64.b64decode(job['input']['data'])
|
117 |
-
open(audio_file, 'wb').write(mp3_bytes)
|
118 |
-
elif datatype == 'url':
|
119 |
-
success = download_file(job['input']['url'], MAX_PAYLOAD_SIZE, audio_file, api_key)
|
120 |
-
if not success:
|
121 |
-
return {"error": f"Error downloading data from {job['input']['url']}"}
|
122 |
-
|
123 |
-
result = transcribe_core(audio_file)
|
124 |
-
return {'result': result}
|
125 |
-
|
126 |
-
|
127 |
-
def transcribe_core(audio_file):
|
128 |
-
print('Transcribing...')
|
129 |
-
|
130 |
ret = {'segments': []}
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
print(seg)
|
142 |
-
ret['segments'].append(seg)
|
143 |
-
|
144 |
return ret
|
145 |
|
146 |
-
|
147 |
-
#runpod.serverless.start({"handler": transcribe})
|
148 |
-
|
149 |
def transcribe_whisper(job):
|
150 |
-
|
151 |
-
|
|
|
152 |
if not datatype:
|
153 |
return {"error": "datatype field not provided. Should be 'blob' or 'url'."}
|
|
|
|
|
154 |
|
155 |
-
|
156 |
-
return {"error": f"datatype should be 'blob' or 'url', but is {datatype} instead."}
|
157 |
-
|
158 |
-
# Get the API key from the job input
|
159 |
-
api_key = job['input'].get('api_key', None)
|
160 |
-
|
161 |
with tempfile.TemporaryDirectory() as d:
|
162 |
audio_file = f'{d}/audio.mp3'
|
163 |
-
|
164 |
if datatype == 'blob':
|
165 |
mp3_bytes = base64.b64decode(job['input']['data'])
|
166 |
-
open(audio_file, 'wb')
|
|
|
167 |
elif datatype == 'url':
|
168 |
success = download_file(job['input']['url'], MAX_PAYLOAD_SIZE, audio_file, api_key)
|
169 |
if not success:
|
170 |
-
return {"error": f"
|
171 |
-
|
172 |
result = transcribe_core_whisper(audio_file)
|
173 |
-
logging.info(f"DONE: in triscribe-whisper")
|
174 |
return {'result': result}
|
175 |
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
logging.info("Transcription completed successfully.")
|
186 |
-
for s in segs:
|
187 |
-
words = []
|
188 |
-
for w in s.words:
|
189 |
-
words.append({'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability})
|
190 |
-
|
191 |
-
seg = {'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob,
|
192 |
-
'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words}
|
193 |
-
logging.debug(f"All segments processed. Final transcription result: {ret}")
|
194 |
-
print(seg)
|
195 |
-
ret['segments'].append(seg)
|
196 |
-
|
197 |
-
except Exception as e:
|
198 |
-
# Log any exception that occurs during the transcription process
|
199 |
-
logging.error(f"Error during transcribe_core_whisper: {e}", exc_info=True)
|
200 |
-
return {"error": str(e)}
|
201 |
-
# Return the final result
|
202 |
-
logging.info("Transcription core function completed.")
|
203 |
-
return ret
|
204 |
-
|
205 |
-
#runpod.serverless.start({"handler": transcribe_whisper})
|
|
|
|
|
|
|
1 |
import base64
|
2 |
import faster_whisper
|
3 |
import tempfile
|
|
|
9 |
|
10 |
import whisper_online
|
11 |
|
12 |
+
# Set up logging
|
13 |
logger = logging.getLogger(__name__)
|
14 |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)])
|
15 |
|
16 |
+
# Load the FasterWhisper model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
18 |
model_name = 'ivrit-ai/faster-whisper-v2-d3-e3'
|
19 |
+
|
|
|
20 |
try:
|
21 |
lan = 'he'
|
22 |
logging.info(f"Attempting to initialize FasterWhisperASR with device: {device}")
|
23 |
+
model = whisper_online.FasterWhisperASR(lan=lan, modelsize=model_name)
|
|
|
|
|
|
|
24 |
logging.info("FasterWhisperASR model initialized successfully.")
|
25 |
except Exception as e:
|
26 |
+
logging.error(f"Failed to initialize FasterWhisperASR model: {e}")
|
27 |
|
28 |
# Maximum data size: 200MB
|
29 |
MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
|
30 |
|
|
|
31 |
def download_file(url, max_size_bytes, output_filename, api_key=None):
|
32 |
+
"""Download a file from a given URL with size limit and optional API key."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
try:
|
|
|
34 |
headers = {}
|
35 |
if api_key:
|
36 |
headers['Authorization'] = f'Bearer {api_key}'
|
|
|
|
|
37 |
response = requests.get(url, stream=True, headers=headers)
|
38 |
+
response.raise_for_status()
|
|
|
|
|
39 |
file_size = int(response.headers.get('Content-Length', 0))
|
|
|
40 |
if file_size > max_size_bytes:
|
41 |
+
print(f"File size exceeds the limit: {file_size} bytes.")
|
42 |
return False
|
|
|
|
|
43 |
downloaded_size = 0
|
44 |
with open(output_filename, 'wb') as file:
|
45 |
for chunk in response.iter_content(chunk_size=8192):
|
46 |
downloaded_size += len(chunk)
|
47 |
if downloaded_size > max_size_bytes:
|
48 |
+
print(f"Download stopped: size limit exceeded.")
|
49 |
return False
|
50 |
file.write(chunk)
|
|
|
51 |
print(f"File downloaded successfully: {output_filename}")
|
52 |
return True
|
|
|
53 |
except requests.RequestException as e:
|
54 |
print(f"Error downloading file: {e}")
|
55 |
return False
|
56 |
|
57 |
+
def transcribe_core_whisper(audio_file):
|
58 |
+
"""Transcribe the audio file using FasterWhisper."""
|
59 |
+
logging.info(f"Transcribing audio file: {audio_file}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
ret = {'segments': []}
|
61 |
+
try:
|
62 |
+
segs, dummy = model.transcribe(audio_file, language='he', word_timestamps=True)
|
63 |
+
for s in segs:
|
64 |
+
words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
|
65 |
+
seg = {'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob,
|
66 |
+
'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words}
|
67 |
+
ret['segments'].append(seg)
|
68 |
+
logging.info("Transcription completed successfully.")
|
69 |
+
except Exception as e:
|
70 |
+
logging.error(f"Error during transcription: {e}", exc_info=True)
|
|
|
|
|
|
|
71 |
return ret
|
72 |
|
|
|
|
|
|
|
73 |
def transcribe_whisper(job):
|
74 |
+
"""Main transcription handler."""
|
75 |
+
logging.info(f"Processing job: {job}")
|
76 |
+
datatype = job.get('input', {}).get('type')
|
77 |
if not datatype:
|
78 |
return {"error": "datatype field not provided. Should be 'blob' or 'url'."}
|
79 |
+
if datatype not in ['blob', 'url']:
|
80 |
+
return {"error": f"Invalid datatype: {datatype}."}
|
81 |
|
82 |
+
api_key = job.get('input', {}).get('api_key')
|
|
|
|
|
|
|
|
|
|
|
83 |
with tempfile.TemporaryDirectory() as d:
|
84 |
audio_file = f'{d}/audio.mp3'
|
|
|
85 |
if datatype == 'blob':
|
86 |
mp3_bytes = base64.b64decode(job['input']['data'])
|
87 |
+
with open(audio_file, 'wb') as f:
|
88 |
+
f.write(mp3_bytes)
|
89 |
elif datatype == 'url':
|
90 |
success = download_file(job['input']['url'], MAX_PAYLOAD_SIZE, audio_file, api_key)
|
91 |
if not success:
|
92 |
+
return {"error": f"Failed to download from {job['input']['url']}"}
|
93 |
+
|
94 |
result = transcribe_core_whisper(audio_file)
|
|
|
95 |
return {'result': result}
|
96 |
|
97 |
+
# Example job input to test locally
|
98 |
+
if __name__ == "__main__":
|
99 |
+
test_job = {
|
100 |
+
"input": {
|
101 |
+
"type": "url",
|
102 |
+
"url": "https://github.com/metaldaniel/HebrewASR-Comparison/raw/main/HaTankistiot_n12-mp3.mp3",
|
103 |
+
}
|
104 |
+
}
|
105 |
+
print(transcribe_whisper(test_job))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
whisper_online.py
CHANGED
@@ -105,15 +105,15 @@ class FasterWhisperASR(ASRBase):
|
|
105 |
|
106 |
sep = ""
|
107 |
|
108 |
-
def load_model(self, modelsize=None, cache_dir=
|
109 |
from faster_whisper import WhisperModel
|
110 |
# logging.getLogger("faster_whisper").setLevel(logger.level)
|
111 |
|
112 |
logging.info("Starting model loading process...")
|
113 |
-
logging.
|
114 |
|
115 |
if model_dir is not None:
|
116 |
-
logger.
|
117 |
f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.")
|
118 |
model_size_or_path = model_dir
|
119 |
elif modelsize is not None:
|
@@ -123,7 +123,10 @@ class FasterWhisperASR(ASRBase):
|
|
123 |
|
124 |
try:
|
125 |
logging.info(f"Loading WhisperModel on device: ")
|
126 |
-
|
|
|
|
|
|
|
127 |
model = WhisperModel(model_size_or_path, device="cuda", compute_type="float16", download_root=cache_dir)
|
128 |
logging.info("Model loaded successfully.")
|
129 |
except Exception as e:
|
|
|
105 |
|
106 |
sep = ""
|
107 |
|
108 |
+
def load_model(self, modelsize=None, cache_dir="/tmp/.cache/huggingface", model_dir=None):
|
109 |
from faster_whisper import WhisperModel
|
110 |
# logging.getLogger("faster_whisper").setLevel(logger.level)
|
111 |
|
112 |
logging.info("Starting model loading process...")
|
113 |
+
logging.info(f"Model loading parameters - modelsize: {modelsize}, cache_dir: {cache_dir}, model_dir: {model_dir}")
|
114 |
|
115 |
if model_dir is not None:
|
116 |
+
logger.info(
|
117 |
f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.")
|
118 |
model_size_or_path = model_dir
|
119 |
elif modelsize is not None:
|
|
|
123 |
|
124 |
try:
|
125 |
logging.info(f"Loading WhisperModel on device: ")
|
126 |
+
os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/tmp/.cache/sentence_transformers'
|
127 |
+
os.environ['HF_HOME'] = '/tmp/.cache/huggingface'
|
128 |
+
# Ensure the cache directory exists
|
129 |
+
os.makedirs(cache_dir, exist_ok=True)
|
130 |
model = WhisperModel(model_size_or_path, device="cuda", compute_type="float16", download_root=cache_dir)
|
131 |
logging.info("Model loaded successfully.")
|
132 |
except Exception as e:
|