Spaces:
Sleeping
Sleeping
File size: 4,185 Bytes
2417517 9bd82d6 8e3c59e 2417517 8e3c59e 2417517 8e3c59e 2417517 8e3c59e 2417517 9bd82d6 8e3c59e 2417517 8e3c59e 2417517 8e3c59e 2417517 8e3c59e 2417517 8e3c59e 2417517 8e3c59e 2417517 8e3c59e 2417517 8e3c59e 2417517 8e3c59e 9bd82d6 8e3c59e 2417517 8e3c59e 2417517 8e3c59e 2417517 8e3c59e 9bd82d6 8e3c59e 2417517 8e3c59e 9bd82d6 8e3c59e 9bd82d6 8e3c59e 9bd82d6 8e3c59e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import base64
import faster_whisper
import tempfile
import torch
import requests
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the model from Hugging Face
model_name = 'ivrit-ai/faster-whisper-v2-d4'
model = faster_whisper.WhisperModel(model_name, device=device)
# Maximum data size: 200MB
MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
def download_file(url, max_size_bytes, output_filename, api_key=None):
"""
Download a file from a given URL with size limit and optional API key.
Args:
url (str): The URL of the file to download.
max_size_bytes (int): Maximum allowed file size in bytes.
output_filename (str): The name of the file to save the download as.
api_key (str, optional): API key to be used as a bearer token.
Returns:
bool: True if download was successful, False otherwise.
"""
try:
headers = {}
if api_key:
headers['Authorization'] = f'Bearer {api_key}'
response = requests.get(url, stream=True, headers=headers)
response.raise_for_status()
file_size = int(response.headers.get('Content-Length', 0))
if file_size > max_size_bytes:
print(f"File size ({file_size} bytes) exceeds the maximum allowed size ({max_size_bytes} bytes).")
return False
downloaded_size = 0
with open(output_filename, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
downloaded_size += len(chunk)
if downloaded_size > max_size_bytes:
print(f"Download stopped: Size limit exceeded ({max_size_bytes} bytes).")
return False
file.write(chunk)
print(f"File downloaded successfully: {output_filename}")
return True
except requests.RequestException as e:
print(f"Error downloading file: {e}")
return False
def transcribe(job):
datatype = job['input'].get('type', None)
if not datatype:
return {"error": "datatype field not provided. Should be 'blob' or 'url'."}
if datatype not in ['blob', 'url']:
return {"error": f"datatype should be 'blob' or 'url', but is {datatype} instead."}
api_key = job['input'].get('api_key', None)
with tempfile.TemporaryDirectory() as d:
audio_file = f'{d}/audio.mp3'
if datatype == 'blob':
mp3_bytes = base64.b64decode(job['input']['data'])
with open(audio_file, 'wb') as file:
file.write(mp3_bytes)
elif datatype == 'url':
success = download_file(job['input']['url'], MAX_PAYLOAD_SIZE, audio_file, api_key)
if not success:
return {"error": f"Error downloading data from {job['input']['url']}"}
result = transcribe_core(audio_file)
return {'result': result}
def transcribe_core(audio_file):
print('Transcribing...')
ret = {'segments': []}
segs, _ = model.transcribe(audio_file, language='he', word_timestamps=True)
for s in segs:
words = []
for w in s.words:
words.append({
'start': w.start,
'end': w.end,
'word': w.word,
'probability': w.probability
})
seg = {
'id': s.id,
'seek': s.seek,
'start': s.start,
'end': s.end,
'text': s.text,
'avg_logprob': s.avg_logprob,
'compression_ratio': s.compression_ratio,
'no_speech_prob': s.no_speech_prob,
'words': words
}
print(seg)
ret['segments'].append(seg)
return ret
# The script can be run directly or served using Hugging Face's Gradio app or API
if __name__ == "__main__":
# For testing purposes, you can define a sample job and call the transcribe function
test_job = {
"input": {
"type": "url",
"url": "https://github.com/metaldaniel/HebrewASR-Comparison/raw/main/HaTankistiot_n12-mp3.mp3",
"api_key": "your_api_key_here" # Optional, replace with actual key if needed
}
}
print(transcribe(test_job))
|