|
import streamlit as st |
|
import requests |
|
import asyncio |
|
import aiohttp |
|
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription |
|
from aiortc.contrib.media import MediaBlackhole, MediaRecorder |
|
import av |
|
import base64 |
|
import json |
|
|
|
API_URL = "https://api-inference.huggingface.co/models/openai/whisper-large-v3-turbo" |
|
headers = {"Authorization": f"Bearer {st.secrets['hf_token']}"} |
|
|
|
class AudioTranscriber: |
|
def __init__(self): |
|
self.buffer = [] |
|
self.text = "" |
|
|
|
async def transcribe(self, audio_data): |
|
async with aiohttp.ClientSession() as session: |
|
async with session.post(API_URL, headers=headers, data=audio_data) as response: |
|
result = await response.json() |
|
if 'text' in result: |
|
self.text += result['text'] + " " |
|
st.text_area("Transcription", self.text, height=200) |
|
|
|
class AudioTrack(MediaStreamTrack): |
|
kind = "audio" |
|
|
|
def __init__(self, track, transcriber): |
|
super().__init__() |
|
self.track = track |
|
self.transcriber = transcriber |
|
|
|
async def recv(self): |
|
frame = await self.track.recv() |
|
if len(self.transcriber.buffer) < 5: |
|
self.transcriber.buffer.append(frame.to_ndarray()) |
|
else: |
|
audio_data = b''.join([av.AudioFrame.from_ndarray(buf).to_bytes() for buf in self.transcriber.buffer]) |
|
asyncio.create_task(self.transcriber.transcribe(audio_data)) |
|
self.transcriber.buffer = [] |
|
return frame |
|
|
|
async def process_offer(offer, transcriber): |
|
pc = RTCPeerConnection() |
|
pc.addTransceiver("audio", direction="recvonly") |
|
|
|
@pc.on("track") |
|
def on_track(track): |
|
if track.kind == "audio": |
|
pc.addTrack(AudioTrack(track, transcriber)) |
|
|
|
await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"])) |
|
answer = await pc.createAnswer() |
|
await pc.setLocalDescription(answer) |
|
|
|
return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} |
|
|
|
st.title("Real-time Speech Recognition with Whisper") |
|
|
|
webrtc_ctx = st.config.get_option("server.enableXsrfProtection") |
|
if webrtc_ctx: |
|
st.warning("To use this app, you need to disable XSRF protection. Set server.enableXsrfProtection=false in your Streamlit config.") |
|
else: |
|
offer = st.text_input("Paste the offer SDP here") |
|
if offer: |
|
transcriber = AudioTranscriber() |
|
answer = asyncio.run(process_offer(json.loads(offer), transcriber)) |
|
st.text_area("Answer SDP", json.dumps(answer)) |
|
st.write("Speak into your microphone. The transcription will appear below.") |
|
|
|
st.markdown("---") |
|
st.write("Note: This app uses the Whisper API from Hugging Face for real-time speech recognition.") |