File size: 10,826 Bytes
bdd9100
b3935fd
47058ca
5a62402
 
bdd9100
ebaaf9b
bdd9100
40cde13
9ccc5b5
ebaaf9b
bdd9100
 
92ce07c
ebaaf9b
7380009
40cde13
b85baaf
9ccc5b5
5a62402
e37aac1
bdd9100
40cde13
7380009
bdd9100
40cde13
bdd9100
40cde13
8e3c59e
bdd9100
 
40cde13
b3935fd
bdd9100
b3935fd
 
5a62402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47058ca
9ccc5b5
91062af
 
 
 
47058ca
5a62402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebaaf9b
 
5a62402
 
 
 
 
 
b85baaf
5a62402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebaaf9b
9ccc5b5
5a62402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebaaf9b
5a62402
 
ebaaf9b
5a62402
 
 
 
 
ebaaf9b
5a62402
 
153f836
5a62402
4c42c49
ebaaf9b
 
 
 
c72d2a4
ebaaf9b
 
 
a94388a
5a62402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import base64
import faster_whisper
import tempfile

import numpy as np
import torch
import time
import requests
import logging
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
import websockets
from pydantic import BaseModel
from typing import Optional
import sys
import asyncio

# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s: %(message)s',
                    handlers=[logging.StreamHandler(sys.stdout)], force=True)
logger = logging.getLogger(__name__)
#logging.getLogger("asyncio").setLevel(logging.DEBUG)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logging.info(f'Device selected: {device}')

model_name = 'ivrit-ai/faster-whisper-v2-d4'
logging.info(f'Loading model: {model_name}')
model = faster_whisper.WhisperModel(model_name, device=device)
logging.info('Model loaded successfully')

# Maximum data size: 200MB
MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
logging.info(f'Max payload size set to: {MAX_PAYLOAD_SIZE} bytes')

app = FastAPI()


# class InputData(BaseModel):
#     type: str
#     data: Optional[str] = None  # Used for blob input
#     url: Optional[str] = None  # Used for url input
#
#
# def download_file(url, max_size_bytes, output_filename, api_key=None):
#     """
#     Download a file from a given URL with size limit and optional API key.
#     """
#     logging.debug(f'Starting file download from URL: {url}')
#     try:
#         headers = {}
#         if api_key:
#             headers['Authorization'] = f'Bearer {api_key}'
#             logging.debug('API key provided, added to headers')
#
#         response = requests.get(url, stream=True, headers=headers)
#         response.raise_for_status()
#
#         file_size = int(response.headers.get('Content-Length', 0))
#         logging.info(f'File size: {file_size} bytes')
#
#         if file_size > max_size_bytes:
#             logging.error(f'File size exceeds limit: {file_size} > {max_size_bytes}')
#             return False
#
#         downloaded_size = 0
#         with open(output_filename, 'wb') as file:
#             for chunk in response.iter_content(chunk_size=8192):
#                 downloaded_size += len(chunk)
#                 logging.debug(f'Downloaded {downloaded_size} bytes')
#                 if downloaded_size > max_size_bytes:
#                     logging.error('Downloaded size exceeds maximum allowed payload size')
#                     return False
#                 file.write(chunk)
#
#         logging.info(f'File downloaded successfully: {output_filename}')
#         return True
#
#     except requests.RequestException as e:
#         logging.error(f"Error downloading file: {e}")
#         return False


@app.get("/")
async def read_root():
    return {"message": "This is the Ivrit AI Streaming service."}


# async def transcribe_core_ws(audio_file):
#     ret = {'segments': []}
#
#     try:
#
#         logging.debug(f"Initiating model transcription for file: {audio_file}")
#
#         segs, _ = await asyncio.to_thread(model.transcribe, audio_file, language='he', word_timestamps=True)
#         logging.info('Transcription completed successfully.')
#     except Exception as e:
#         logging.error(f"Error during transcription: {e}")
#         raise e
#
#     # Track the new segments and update the last transcribed time
#     for s in segs:
#         logging.info(f"Processing segment with start time: {s.start} and end time: {s.end}")
#
#         # Only process segments that start after the last transcribed time
#         logging.info(f"New segment found starting at {s.start} seconds.")
#         words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
#
#         seg = {
#             'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text,
#             'avg_logprob': s.avg_logprob, 'compression_ratio': s.compression_ratio,
#             'no_speech_prob': s.no_speech_prob, 'words': words
#         }
#         logging.info(f'Adding new transcription segment: {seg}')
#         ret['segements'].append(seg)
#
#         # Update the last transcribed time to the end of the current segment
#
#
# #logging.info(f"Returning {len(ret['new_segments'])} new segments and updated last transcribed time.")
#     return ret


import tempfile


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    """WebSocket endpoint to handle client connections."""
    await websocket.accept()
    client_ip = websocket.client.host
    logger.info(f"Client connected: {client_ip}")
    sys.stdout.flush()
    try:
        await process_audio_stream(websocket)
    except WebSocketDisconnect:
        logger.info(f"Client disconnected: {client_ip}")
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        await websocket.close()

async def process_audio_stream(websocket: WebSocket):
    """Continuously receive audio chunks and initiate transcription tasks."""
    sampling_rate = 16000
    min_chunk_size = 1  # in seconds
    audio_buffer = np.array([], dtype=np.float32)

    transcription_task = None
    chunk_counter = 0
    total_bytes_received = 0

    while True:
        try:
            # Receive audio data from client
            data = await websocket.receive_bytes()
            if not data:
                logger.info("No data received, closing connection")
                break
            chunk_counter += 1
            chunk_size = len(data)
            total_bytes_received += chunk_size
            logger.debug(f"Received chunk {chunk_counter}: {chunk_size} bytes")

            audio_chunk = process_received_audio(data)
            logger.debug(f"Processed audio chunk {chunk_counter}: {len(audio_chunk)} samples")

            audio_buffer = np.concatenate((audio_buffer, audio_chunk))
            logger.debug(f"Audio buffer size: {len(audio_buffer)} samples")
        except Exception as e:
            logger.error(f"Error receiving data: {e}")
            break

        # Check if enough audio has been buffered
        if len(audio_buffer) >= min_chunk_size * sampling_rate:
            if transcription_task is None or transcription_task.done():
                # Start a new transcription task
                logger.info(f"Starting transcription task for {len(audio_buffer)} samples")
                transcription_task = asyncio.create_task(
                    transcribe_and_send(websocket, audio_buffer.copy())
                )
                audio_buffer = np.array([], dtype=np.float32)
                logger.debug("Audio buffer reset after starting transcription task")

async def transcribe_and_send(websocket: WebSocket, audio_data):
    """Run transcription in a separate thread and send the result to the client."""
    logger.debug(f"Transcription task started for {len(audio_data)} samples")
    transcription_result = await asyncio.to_thread(sync_transcribe_audio, audio_data)
    if transcription_result:
        try:
            # Send the result as JSON
            await websocket.send_json(transcription_result)
            logger.info("Transcription JSON sent to client")
        except Exception as e:
            logger.error(f"Error sending transcription: {e}")
    else:
        logger.warning("No transcription result to send")

def sync_transcribe_audio(audio_data):
    """Synchronously transcribe audio data using the ASR model and format the result."""
    try:

        logger.info('Starting transcription...')
        segments, info = model.transcribe(
            audio_data, language="he", beam_size=5, word_timestamps=True
        )
        logger.info('Transcription completed')

        # Build the transcription result as per your requirement
        ret = {'segments': []}

        for s in segments:
            logger.debug(f"Processing segment {s.id} with start time: {s.start} and end time: {s.end}")

            # Process words in the segment
            words = [{
                'start': float(w.start),
                'end': float(w.end),
                'word': w.word,
                'probability': float(w.probability)
            } for w in s.words]

            seg = {
                'id': int(s.id),
                'seek': int(s.seek),
                'start': float(s.start),
                'end': float(s.end),
                'text': s.text,
                'avg_logprob': float(s.avg_logprob),
                'compression_ratio': float(s.compression_ratio),
                'no_speech_prob': float(s.no_speech_prob),
                'words': words
            }
            logger.debug(f'Adding new transcription segment: {seg}')
            ret['segments'].append(seg)

            logger.debug(f"Total segments in transcription result: {len(ret['segments'])}")
            return ret
    except Exception as e:
        logger.error(f"Transcription error: {e}")
        return {}

def process_received_audio(data):
    """Convert received bytes into normalized float32 NumPy array."""
    logger.debug(f"Processing received audio data of size {len(data)} bytes")
    audio_int16 = np.frombuffer(data, dtype=np.int16)
    logger.debug(f"Converted to int16 NumPy array with {len(audio_int16)} samples")

    audio_float32 = audio_int16.astype(np.float32) / 32768.0  # Normalize to [-1, 1]
    logger.debug(f"Normalized audio data to float32 with {len(audio_float32)} samples")

    return audio_float32










# @app.websocket("/wtranscribe")
# async def websocket_transcribe(websocket: WebSocket):
#     logging.info("New WebSocket connection request received.")
#     await websocket.accept()
#     logging.info("WebSocket connection established successfully.")
#
#     try:
#         while True:
#             try:
#                 audio_chunk = await websocket.receive_bytes()
#                 if not audio_chunk:
#                     logging.warning("Received empty audio chunk, skipping processing.")
#                     continue
#                 with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file: ##new temp file for every chunk
#                     logging.info(f"Temporary audio file created at {temp_audio_file.name}")
#                     # Receive the next chunk of audio data
#
#
#
#                     partial_result = await transcribe_core_ws(temp_audio_file.name)
#                     await websocket.send_json(partial_result)
#
#             except WebSocketDisconnect:
#                 logging.info("WebSocket connection closed by the client.")
#                 break
#
#     except Exception as e:
#         logging.error(f"Unexpected error during WebSocket transcription: {e}")
#         await websocket.send_json({"error": str(e)})
#
#     finally:
#         logging.info("Cleaning up and closing WebSocket connection.")