File size: 3,627 Bytes
33f6d66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f177433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

from fastapi import FastAPI, Request, HTTPException
import torch
import torchaudio
from transformers import AutoProcessor, pipeline
import io
from pydub import AudioSegment
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
import numpy as np
import uvicorn
app = FastAPI()

# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Load the model and processor
model_id = "WajeehAzeemX/whisper-small-ar2_onnx"
model = ORTModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
processor = AutoProcessor.from_pretrained(model_id)


pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
)

@app.post("/transcribe/")
async def transcribe_audio(request: Request):
    try:
        # Read binary data from the request
        audio_data = await request.body()

        # Convert binary data to a file-like object
        audio_file = io.BytesIO(audio_data)
        
        # Load the audio file using pydub
        try:
            audio_segment = AudioSegment.from_file(audio_file, format="wav")
        except Exception as e:
            raise HTTPException(status_code=400, detail=f"Error loading audio file: {str(e)}")

        # Convert to mono if the audio is stereo (multi-channel)
        if audio_segment.channels > 1:
            audio_segment = audio_segment.set_channels(1)

        # Resample the audio to 16kHz
        target_sample_rate = 16000
        if audio_segment.frame_rate != target_sample_rate:
            audio_segment = audio_segment.set_frame_rate(target_sample_rate)

        # Convert audio to numpy array
        audio_array = np.array(audio_segment.get_array_of_samples())
        if audio_segment.sample_width == 2:
            audio_array = audio_array.astype(np.float32) / 32768.0
        else:
            raise HTTPException(status_code=400, detail="Unsupported sample width")

        # Convert to the format expected by the model
        inputs = processor(audio_array, sampling_rate=target_sample_rate, return_tensors="pt")
        inputs = inputs.to(device)

        # Get the transcription result
        result = pipe(audio_array)
        transcription = result["text"]

        return {"transcription": transcription}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
    















#     from fastapi import FastAPI, Request, HTTPException
# import io
# import time
# from faster_whisper import WhisperModel
# import uvicorn

# app = FastAPI()

# model = WhisperModel("WajeehAzeemX/faster-whisper-smallar2-int8", device="cpu", compute_type="int8")

# @app.post("/transcribe/")
# async def transcribe_audio(request: Request):
#     try:
#         # Read binary data from the request
#         audio_data = await request.body()

#         # Convert binary data to a file-like object
#         audio_file = io.BytesIO(audio_data)

#         # Start timing the transcription
#         start_time = time.time()

#         # Transcribe the audio
#         segments, info = model.transcribe(audio_file)
#         transcription = " ".join([segment.text for segment in segments])

#         # Calculate time taken
#         time_taken = time.time() - start_time

#         return {"transcription": transcription, "time_taken": time_taken}
#     except Exception as e:
#         raise HTTPException(status_code=500, detail=str(e))