File size: 4,358 Bytes
bbef1a2
 
 
 
 
a2ed330
 
 
bbef1a2
a2ed330
 
 
bbef1a2
 
 
 
a2ed330
bbef1a2
 
4a37dab
4e9b286
01a49c3
4e9b286
 
41d06ba
01a49c3
a2ed330
4e9b286
 
 
 
 
bc98115
4e9b286
c4d6bf6
01a49c3
 
 
d531709
4e9b286
 
 
4a37dab
 
 
c4d6bf6
 
4a37dab
a2ed330
 
5e3f570
4a37dab
a2ed330
4a37dab
01a49c3
 
 
 
 
4e9b286
01a49c3
 
 
 
 
 
 
 
4e9b286
891e37e
 
 
c4d6bf6
4a37dab
c4d6bf6
bc98115
d531709
 
4e9b286
d531709
 
 
 
c4d6bf6
4e9b286
c4d6bf6
4e9b286
c4d6bf6
 
 
 
4e9b286
 
 
 
 
 
 
 
 
 
 
 
d531709
 
 
 
c4d6bf6
bc98115
 
 
01a49c3
eb02780
 
bc98115
 
 
 
 
eb02780
01a49c3
4e9b286
01a49c3
4e9b286
 
d531709
4e9b286
 
 
 
 
 
5f58cac
4e9b286
 
 
 
d5ade87
4e9b286
 
 
fcb6d65
 
 
4e9b286
eb02780
31fe9de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import logging

# Configure the root logger to WARNING to suppress debug messages from other libraries
logging.basicConfig(level=logging.WARNING)

# Create a file handler instead of console handler
file_handler = logging.FileHandler("gradio_webrtc.log")
file_handler.setLevel(logging.DEBUG)

# Create a formatter (you might want to add timestamp to file logs)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatter)

# Configure the logger for your specific library
logger = logging.getLogger("gradio_webrtc")
logger.setLevel(logging.DEBUG)
logger.addHandler(file_handler)


import base64
import io
import os
import tempfile
import time
import traceback
from dataclasses import dataclass
from threading import Event, Thread

import gradio as gr
import librosa
import numpy as np
import requests
from gradio_webrtc import ReplyOnPause, WebRTC
from huggingface_hub import snapshot_download
from pydub import AudioSegment
from twilio.rest import Client

from server import serve

# from server import serve
from utils.vad import VadOptions, collect_chunks, get_speech_timestamps

repo_id = "gpt-omni/mini-omni"
snapshot_download(repo_id, local_dir="./checkpoint", revision="main")

IP = "0.0.0.0"
PORT = 60808

thread = Thread(target=serve, daemon=True)
thread.start()


API_URL = "http://0.0.0.0:60808/chat"

account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")

if account_sid and auth_token:
    client = Client(account_sid, auth_token)

    token = client.tokens.create()

    rtc_configuration = {
        "iceServers": token.ice_servers,
        "iceTransportPolicy": "relay",
    }
else:
    rtc_configuration = None

OUT_CHANNELS = 1
OUT_RATE = 24000
OUT_SAMPLE_WIDTH = 2
OUT_CHUNK = 20 * 4096


def speaking(audio_bytes: bytes):
    base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8")
    files = {"audio": base64_encoded}
    byte_buffer = b""
    with requests.post(API_URL, json=files, stream=True) as response:
        try:
            for chunk in response.iter_content(chunk_size=OUT_CHUNK):
                if chunk:
                    # Create an audio segment from the numpy array
                    byte_buffer += chunk
                    audio_segment = AudioSegment(
                        chunk + b"\x00" if len(chunk) % 2 != 0 else chunk,
                        frame_rate=OUT_RATE,
                        sample_width=OUT_SAMPLE_WIDTH,
                        channels=OUT_CHANNELS,
                    )
                    # Export the audio segment to a numpy array
                    audio_np = np.array(audio_segment.get_array_of_samples())
                    yield audio_np.reshape(1, -1)
            all_output_audio = AudioSegment(
                byte_buffer,
                frame_rate=OUT_RATE,
                sample_width=OUT_SAMPLE_WIDTH,
                channels=1,
            )
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
                all_output_audio.export(f.name, format="wav")
                print("output file written", f.name)
        except Exception as e:
            raise gr.Error(f"Error during audio streaming: {e}")



def response(audio: tuple[int, np.ndarray]):
    sampling_rate, audio_np = audio
    audio_np = audio_np.squeeze()

    audio_buffer = io.BytesIO()
    segment = AudioSegment(
        audio_np.tobytes(),
        frame_rate=sampling_rate,
        sample_width=audio_np.dtype.itemsize,
        channels=1)

    segment.export(audio_buffer, format="wav")

    for numpy_array in speaking(audio_buffer.getvalue()):
        yield (OUT_RATE, numpy_array, "mono")


with gr.Blocks() as demo:
    gr.HTML(
        """
    <h1 style='text-align: center'>
    Omni Chat (Powered by WebRTC ⚡️)
    </h1>
    """
    )
    with gr.Column():
        with gr.Group():
            audio = WebRTC(
                label="Stream",
                rtc_configuration=rtc_configuration,
                mode="send-receive",
                modality="audio",
            )
        audio.stream(fn=ReplyOnPause(response,
                                     output_sample_rate=OUT_RATE,
                                     output_frame_size=480), inputs=[audio], outputs=[audio], time_limit=60)


demo.launch(ssr_mode=False)