File size: 2,831 Bytes
ccabf63
6b694b5
3a508f7
 
 
 
 
0fef32f
3a508f7
5b29361
6b694b5
3a508f7
1620753
3a508f7
 
 
 
1620753
3a508f7
 
 
 
 
 
 
ccabf63
3a508f7
 
ccabf63
3a508f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fef32f
3a508f7
 
 
 
0fef32f
3a508f7
 
 
0fef32f
3a508f7
dfb92f3
3a508f7
 
 
 
 
 
 
 
 
 
 
 
163138e
 
3a508f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import streamlit as st
import requests
import asyncio
import aiohttp
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaBlackhole, MediaRecorder
import av
import base64
import json

API_URL = "https://api-inference.huggingface.co/models/openai/whisper-large-v3-turbo"
headers = {"Authorization": f"Bearer {st.secrets['hf_token']}"}

class AudioTranscriber:
    def __init__(self):
        self.buffer = []
        self.text = ""

    async def transcribe(self, audio_data):
        async with aiohttp.ClientSession() as session:
            async with session.post(API_URL, headers=headers, data=audio_data) as response:
                result = await response.json()
                if 'text' in result:
                    self.text += result['text'] + " "
                    st.text_area("Transcription", self.text, height=200)

class AudioTrack(MediaStreamTrack):
    kind = "audio"

    def __init__(self, track, transcriber):
        super().__init__()
        self.track = track
        self.transcriber = transcriber

    async def recv(self):
        frame = await self.track.recv()
        if len(self.transcriber.buffer) < 5:  # Collect 5 seconds of audio before transcribing
            self.transcriber.buffer.append(frame.to_ndarray())
        else:
            audio_data = b''.join([av.AudioFrame.from_ndarray(buf).to_bytes() for buf in self.transcriber.buffer])
            asyncio.create_task(self.transcriber.transcribe(audio_data))
            self.transcriber.buffer = []
        return frame

async def process_offer(offer, transcriber):
    pc = RTCPeerConnection()
    pc.addTransceiver("audio", direction="recvonly")
    
    @pc.on("track")
    def on_track(track):
        if track.kind == "audio":
            pc.addTrack(AudioTrack(track, transcriber))
    
    await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"]))
    answer = await pc.createAnswer()
    await pc.setLocalDescription(answer)
    
    return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}

st.title("Real-time Speech Recognition with Whisper")

webrtc_ctx = st.config.get_option("server.enableXsrfProtection")
if webrtc_ctx:
    st.warning("To use this app, you need to disable XSRF protection. Set server.enableXsrfProtection=false in your Streamlit config.")
else:
    offer = st.text_input("Paste the offer SDP here")
    if offer:
        transcriber = AudioTranscriber()
        answer = asyncio.run(process_offer(json.loads(offer), transcriber))
        st.text_area("Answer SDP", json.dumps(answer))
        st.write("Speak into your microphone. The transcription will appear below.")

st.markdown("---")
st.write("Note: This app uses the Whisper API from Hugging Face for real-time speech recognition.")