ivrit-ai-streaming / infer.py
AshDavid12
changed temp file into loop
9ccc5b5
raw
history blame
8.04 kB
import base64
import faster_whisper
import tempfile
import torch
import time
import requests
import logging
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
import websockets
from pydantic import BaseModel
from typing import Optional
import sys
import asyncio
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s',
handlers=[logging.StreamHandler(sys.stdout)], force=True)
#logging.getLogger("asyncio").setLevel(logging.DEBUG)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logging.info(f'Device selected: {device}')
model_name = 'ivrit-ai/faster-whisper-v2-d4'
logging.info(f'Loading model: {model_name}')
model = faster_whisper.WhisperModel(model_name, device=device)
logging.info('Model loaded successfully')
# Maximum data size: 200MB
MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
logging.info(f'Max payload size set to: {MAX_PAYLOAD_SIZE} bytes')
app = FastAPI()
class InputData(BaseModel):
type: str
data: Optional[str] = None # Used for blob input
url: Optional[str] = None # Used for url input
def download_file(url, max_size_bytes, output_filename, api_key=None):
"""
Download a file from a given URL with size limit and optional API key.
"""
logging.debug(f'Starting file download from URL: {url}')
try:
headers = {}
if api_key:
headers['Authorization'] = f'Bearer {api_key}'
logging.debug('API key provided, added to headers')
response = requests.get(url, stream=True, headers=headers)
response.raise_for_status()
file_size = int(response.headers.get('Content-Length', 0))
logging.info(f'File size: {file_size} bytes')
if file_size > max_size_bytes:
logging.error(f'File size exceeds limit: {file_size} > {max_size_bytes}')
return False
downloaded_size = 0
with open(output_filename, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
downloaded_size += len(chunk)
logging.debug(f'Downloaded {downloaded_size} bytes')
if downloaded_size > max_size_bytes:
logging.error('Downloaded size exceeds maximum allowed payload size')
return False
file.write(chunk)
logging.info(f'File downloaded successfully: {output_filename}')
return True
except requests.RequestException as e:
logging.error(f"Error downloading file: {e}")
return False
@app.get("/")
async def read_root():
return {"message": "This is the Ivrit AI Streaming service."}
async def transcribe_core_ws(audio_file, last_transcribed_time):
"""
Transcribe the audio file and return only the segments that have not been processed yet.
:param audio_file: Path to the growing audio file.
:param last_transcribed_time: The last time (in seconds) that was transcribed.
:return: Newly transcribed segments and the updated last transcribed time.
"""
logging.info(f"Starting transcription for file: {audio_file} from {last_transcribed_time} seconds.")
ret = {'new_segments': []}
new_last_transcribed_time = last_transcribed_time
try:
# Transcribe the entire audio file
logging.debug(f"Initiating model transcription for file: {audio_file}")
segs, _ = await asyncio.to_thread(model.transcribe, audio_file, language='he', word_timestamps=True)
logging.info('Transcription completed successfully.')
except Exception as e:
logging.error(f"Error during transcription: {e}")
raise e
# Track the new segments and update the last transcribed time
for s in segs:
logging.info(f"Processing segment with start time: {s.start} and end time: {s.end}")
# Only process segments that start after the last transcribed time
if s.start >= last_transcribed_time:
logging.info(f"New segment found starting at {s.start} seconds.")
words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
seg = {
'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text,
'avg_logprob': s.avg_logprob, 'compression_ratio': s.compression_ratio,
'no_speech_prob': s.no_speech_prob, 'words': words
}
logging.info(f'Adding new transcription segment: {seg}')
ret['new_segments'].append(seg)
# Update the last transcribed time to the end of the current segment
new_last_transcribed_time = s.end
logging.debug(f"Updated last transcribed time to: {new_last_transcribed_time} seconds")
#logging.info(f"Returning {len(ret['new_segments'])} new segments and updated last transcribed time.")
return ret, new_last_transcribed_time
import tempfile
@app.websocket("/wtranscribe")
async def websocket_transcribe(websocket: WebSocket):
logging.info("New WebSocket connection request received.")
await websocket.accept()
logging.info("WebSocket connection established successfully.")
try:
processed_segments = [] # Keeps track of the segments already transcribed
accumulated_audio_size = 0 # Track how much audio data has been buffered
accumulated_audio_time = 0 # Track the total audio duration accumulated
last_transcribed_time = 0.0
#min_transcription_time = 5.0 # Minimum duration of audio in seconds before transcription starts
# A temporary file to store the growing audio data
while True:
try:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file:
logging.info(f"Temporary audio file created at {temp_audio_file.name}")
# Receive the next chunk of audio data
audio_chunk = await websocket.receive_bytes()
if not audio_chunk:
logging.warning("Received empty audio chunk, skipping processing.")
continue
# Write audio chunk to file and accumulate size and time
temp_audio_file.write(audio_chunk)
temp_audio_file.flush()
accumulated_audio_size += len(audio_chunk)
# Estimate the duration of the chunk based on its size (e.g., 16kHz audio)
chunk_duration = len(audio_chunk) / (16000 * 2) # Assuming 16kHz mono WAV (2 bytes per sample)
accumulated_audio_time += chunk_duration
partial_result, last_transcribed_time = transcribe_core_ws(temp_audio_file.name,
last_transcribed_time)
accumulated_audio_time = 0 # Reset the accumulated audio time
processed_segments.extend(partial_result['new_segments'])
# Reset the accumulated audio size after transcription
accumulated_audio_size = 0
# Send the transcription result back to the client with both new and all processed segments
response = {
"new_segments": partial_result['new_segments'],
"processed_segments": processed_segments
}
logging.info(f"Sending {len(partial_result['new_segments'])} new segments to the client.")
await websocket.send_json(response)
except WebSocketDisconnect:
logging.info("WebSocket connection closed by the client.")
break
except Exception as e:
logging.error(f"Unexpected error during WebSocket transcription: {e}")
await websocket.send_json({"error": str(e)})
finally:
logging.info("Cleaning up and closing WebSocket connection.")