Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request, HTTPException | |
import torch | |
import torchaudio | |
from transformers import AutoProcessor, pipeline | |
import io | |
from pydub import AudioSegment | |
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq | |
import numpy as np | |
import uvicorn | |
app = FastAPI() | |
# Device configuration | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(device) | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
# Load the model and processor | |
model_id = "WajeehAzeemX/whisper-small-ar2_onnx" | |
model = ORTModelForSpeechSeq2Seq.from_pretrained( | |
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
) | |
processor = AutoProcessor.from_pretrained(model_id) | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
torch_dtype=torch_dtype, | |
) | |
async def transcribe_audio(request: Request): | |
try: | |
# Read binary data from the request | |
audio_data = await request.body() | |
# Convert binary data to a file-like object | |
audio_file = io.BytesIO(audio_data) | |
# Load the audio file using pydub | |
try: | |
audio_segment = AudioSegment.from_file(audio_file, format="wav") | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Error loading audio file: {str(e)}") | |
# Convert to mono if the audio is stereo (multi-channel) | |
if audio_segment.channels > 1: | |
audio_segment = audio_segment.set_channels(1) | |
# Resample the audio to 16kHz | |
target_sample_rate = 16000 | |
if audio_segment.frame_rate != target_sample_rate: | |
audio_segment = audio_segment.set_frame_rate(target_sample_rate) | |
# Convert audio to numpy array | |
audio_array = np.array(audio_segment.get_array_of_samples()) | |
if audio_segment.sample_width == 2: | |
audio_array = audio_array.astype(np.float32) / 32768.0 | |
else: | |
raise HTTPException(status_code=400, detail="Unsupported sample width") | |
# Convert to the format expected by the model | |
inputs = processor(audio_array, sampling_rate=target_sample_rate, return_tensors="pt") | |
inputs = inputs.to(device) | |
# Get the transcription result | |
result = pipe(audio_array) | |
transcription = result["text"] | |
return {"transcription": transcription} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# from fastapi import FastAPI, Request, HTTPException | |
# import io | |
# import time | |
# from faster_whisper import WhisperModel | |
# import uvicorn | |
# app = FastAPI() | |
# model = WhisperModel("WajeehAzeemX/faster-whisper-smallar2-int8", device="cpu", compute_type="int8") | |
# @app.post("/transcribe/") | |
# async def transcribe_audio(request: Request): | |
# try: | |
# # Read binary data from the request | |
# audio_data = await request.body() | |
# # Convert binary data to a file-like object | |
# audio_file = io.BytesIO(audio_data) | |
# # Start timing the transcription | |
# start_time = time.time() | |
# # Transcribe the audio | |
# segments, info = model.transcribe(audio_file) | |
# transcription = " ".join([segment.text for segment in segments]) | |
# # Calculate time taken | |
# time_taken = time.time() - start_time | |
# return {"transcription": transcription, "time_taken": time_taken} | |
# except Exception as e: | |
# raise HTTPException(status_code=500, detail=str(e)) | |