File size: 4,185 Bytes
2417517
 
 
 
 
 
 
9bd82d6
8e3c59e
 
 
2417517
 
 
 
8e3c59e
2417517
8e3c59e
 
 
 
 
 
 
 
 
 
 
 
2417517
 
 
 
8e3c59e
2417517
9bd82d6
8e3c59e
2417517
8e3c59e
2417517
8e3c59e
2417517
8e3c59e
2417517
 
 
 
 
8e3c59e
2417517
 
8e3c59e
2417517
 
8e3c59e
2417517
 
 
 
 
8e3c59e
 
2417517
 
8e3c59e
9bd82d6
8e3c59e
 
 
2417517
 
 
8e3c59e
2417517
 
8e3c59e
 
2417517
 
 
8e3c59e
9bd82d6
8e3c59e
2417517
 
8e3c59e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bd82d6
8e3c59e
9bd82d6
 
 
 
8e3c59e
9bd82d6
 
8e3c59e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import base64
import faster_whisper
import tempfile
import torch
import requests

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load the model from Hugging Face
model_name = 'ivrit-ai/faster-whisper-v2-d4'
model = faster_whisper.WhisperModel(model_name, device=device)

# Maximum data size: 200MB
MAX_PAYLOAD_SIZE = 200 * 1024 * 1024


def download_file(url, max_size_bytes, output_filename, api_key=None):
    """
    Download a file from a given URL with size limit and optional API key.

    Args:
    url (str): The URL of the file to download.
    max_size_bytes (int): Maximum allowed file size in bytes.
    output_filename (str): The name of the file to save the download as.
    api_key (str, optional): API key to be used as a bearer token.

    Returns:
    bool: True if download was successful, False otherwise.
    """
    try:
        headers = {}
        if api_key:
            headers['Authorization'] = f'Bearer {api_key}'

        response = requests.get(url, stream=True, headers=headers)
        response.raise_for_status()

        file_size = int(response.headers.get('Content-Length', 0))

        if file_size > max_size_bytes:
            print(f"File size ({file_size} bytes) exceeds the maximum allowed size ({max_size_bytes} bytes).")
            return False

        downloaded_size = 0
        with open(output_filename, 'wb') as file:
            for chunk in response.iter_content(chunk_size=8192):
                downloaded_size += len(chunk)
                if downloaded_size > max_size_bytes:
                    print(f"Download stopped: Size limit exceeded ({max_size_bytes} bytes).")
                    return False
                file.write(chunk)

        print(f"File downloaded successfully: {output_filename}")
        return True

    except requests.RequestException as e:
        print(f"Error downloading file: {e}")
        return False


def transcribe(job):
    datatype = job['input'].get('type', None)
    if not datatype:
        return {"error": "datatype field not provided. Should be 'blob' or 'url'."}

    if datatype not in ['blob', 'url']:
        return {"error": f"datatype should be 'blob' or 'url', but is {datatype} instead."}

    api_key = job['input'].get('api_key', None)

    with tempfile.TemporaryDirectory() as d:
        audio_file = f'{d}/audio.mp3'

        if datatype == 'blob':
            mp3_bytes = base64.b64decode(job['input']['data'])
            with open(audio_file, 'wb') as file:
                file.write(mp3_bytes)
        elif datatype == 'url':
            success = download_file(job['input']['url'], MAX_PAYLOAD_SIZE, audio_file, api_key)
            if not success:
                return {"error": f"Error downloading data from {job['input']['url']}"}

        result = transcribe_core(audio_file)
        return {'result': result}


def transcribe_core(audio_file):
    print('Transcribing...')

    ret = {'segments': []}

    segs, _ = model.transcribe(audio_file, language='he', word_timestamps=True)
    for s in segs:
        words = []
        for w in s.words:
            words.append({
                'start': w.start,
                'end': w.end,
                'word': w.word,
                'probability': w.probability
            })

        seg = {
            'id': s.id,
            'seek': s.seek,
            'start': s.start,
            'end': s.end,
            'text': s.text,
            'avg_logprob': s.avg_logprob,
            'compression_ratio': s.compression_ratio,
            'no_speech_prob': s.no_speech_prob,
            'words': words
        }

        print(seg)
        ret['segments'].append(seg)

    return ret


# The script can be run directly or served using Hugging Face's Gradio app or API
if __name__ == "__main__":
    # For testing purposes, you can define a sample job and call the transcribe function
    test_job = {
        "input": {
            "type": "url",
            "url": "https://github.com/metaldaniel/HebrewASR-Comparison/raw/main/HaTankistiot_n12-mp3.mp3",
            "api_key": "your_api_key_here"  # Optional, replace with actual key if needed
        }
    }
    print(transcribe(test_job))