Spaces:
Sleeping
Sleeping
File size: 11,931 Bytes
fc5e5f2 eb44df9 a52af47 eb44df9 fc5e5f2 eb44df9 fc5e5f2 eb44df9 fc5e5f2 a52af47 eb44df9 fc5e5f2 eb44df9 fc5e5f2 eb44df9 fc5e5f2 eb44df9 fc5e5f2 eb44df9 fc5e5f2 eb44df9 fc5e5f2 eb44df9 fc5e5f2 eb44df9 fc5e5f2 eb44df9 fc5e5f2 eb44df9 fc5e5f2 eb44df9 fc5e5f2 eb44df9 fc5e5f2 eb44df9 fc5e5f2 eb44df9 fc5e5f2 eb44df9 fc5e5f2 eb44df9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 |
import os
import base64
import json
import numpy as np
import gradio as gr
# import websockets.sync.client # No longer needed with FastRTC
from fastrtc import (
#PeerConnection,
DataChannel,
MediaStreamTrack,
AudioFrame,
VideoFrame,
) # Import FastRTC components
from aiortc.contrib.media import MediaPlayer, MediaRelay
import asyncio
__version__ = "0.0.3"
# KEY_NAME = "AIzaSyCWPviRPxj8IMLaijLGbRIsio3dO2rp3rU" # Best practice: Keep API keys out of the main code. Use environment variables.
# Configuration and Utilities
class GeminiConfig:
"""Configuration settings for Gemini API."""
def __init__(self):
self.api_key = os.environ.get("KEY_NAME") # Use a more descriptive name
if not self.api_key:
raise ValueError("GEMINI_API_KEY environment variable is not set.")
self.host = "generativelanguage.googleapis.com"
self.model = "models/gemini-2.0-flash-exp"
# FastRTC doesn't use WebSockets directly in the same way. We'll handle the API calls differently.
self.base_url = f"https://{self.host}/v1alpha/{self.model}:streamGenerateContent?key={self.api_key}"
class AudioProcessor:
"""Handles encoding and decoding of audio data."""
@staticmethod
def encode_audio(data: np.ndarray, sample_rate: int) -> str:
"""Encodes audio data to base64."""
# Ensure data is in the correct format (int16)
if data.dtype != np.int16:
data = data.astype(np.int16)
encoded = base64.b64encode(data.tobytes()).decode("UTF-8")
return encoded
@staticmethod
def process_audio_response(data: str) -> np.ndarray:
"""Decodes audio data from base64."""
audio_data = base64.b64decode(data)
return np.frombuffer(audio_data, dtype=np.int16)
# We don't need a StreamHandler in the same way with FastRTC. We'll handle streaming directly.
class GeminiHandler:
"""Handles interactions with the Gemini API."""
def __init__(self, output_sample_rate=24000, output_frame_size=480):
self.config = GeminiConfig()
self.audio_processor = AudioProcessor()
self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size
self.all_output_data = None
self.pc = None # PeerConnection
self.dc = None # DataChannel
self.audio_track = None
self._audio_buffer = []
self.relay = MediaRelay()
async def _send_audio_to_gemini(self, encoded_audio: str):
"""Sends audio data to the Gemini API and processes the response."""
headers = {"Content-Type": "application/json"}
payload = {
"contents": [
{
"parts": [
{
"text": "Respond to the audio with audio."
}, # Initial prompt, can be adjusted
{"inline_data": {"mime_type": "audio/pcm;rate=24000", "data": encoded_audio}},
]
}
]
}
# Use aiohttp for asynchronous HTTP requests
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.post(
self.config.base_url, headers=headers, data=json.dumps(payload)
) as response:
if response.status != 200:
print(f"Error: Gemini API returned status {response.status}")
print(await response.text())
return
async for line in response.content:
try:
line = line.strip()
if not line:
continue
# Responses are chunked, often with multiple JSON objects per chunk. Handle that.
for chunk in line.decode("utf-8").split("\n"):
if not chunk.strip():
continue
try:
data = json.loads(chunk)
except json.JSONDecodeError:
print(f"JSONDecodeError: {chunk}")
continue
if "candidates" in data:
for candidate in data["candidates"]:
for part in candidate.get("content", {}).get("parts", []):
if "inlineData" in part:
audio_data = part["inlineData"].get("data", "")
if audio_data:
await self._process_server_audio(audio_data)
except Exception as e:
print(f"Error processing response chunk: {e}")
async def _process_server_audio(self, audio_data: str):
"""Processes and buffers audio data received from the server."""
audio_array = self.audio_processor.process_audio_response(audio_data)
if self.all_output_data is None:
self.all_output_data = audio_array
else:
self.all_output_data = np.concatenate((self.all_output_data, audio_array))
while self.all_output_data.shape[-1] >= self.output_frame_size:
frame = AudioFrame(
samples=self.output_frame_size,
sample_rate=self.output_sample_rate,
layout="mono", # mono channel
data=self.all_output_data[: self.output_frame_size].tobytes()
)
self.all_output_data = self.all_output_data[self.output_frame_size:]
if self.audio_track:
await self.audio_track.emit(frame)
async def on_track(self, track):
"""Handles incoming media tracks."""
print(f"Track received: {track.kind}")
if track.kind == "audio":
self.audio_track = track # Store the audio track
@track.on("frame")
async def on_frame(frame):
# Process received audio frames
if isinstance(frame, AudioFrame):
try:
# Convert the frame data to a NumPy array
audio_data = np.frombuffer(frame.data, dtype=np.int16)
# Encode the audio and send it to Gemini
encoded_audio = self.audio_processor.encode_audio(
audio_data, frame.sample_rate
) # Pass sample rate
await self._send_audio_to_gemini(encoded_audio)
except Exception as e:
print(f"Error processing audio frame: {e}")
async def on_datachannel(self, channel):
"""Handles data channel events (not used in this example, but good practice)."""
self.dc = channel
print("Data channel created")
@channel.on("message")
async def on_message(message):
print(f"Received message: {message}")
async def connect(self):
"""Establishes the PeerConnection."""
self.pc = PeerConnection()
self.pc.on("track", self.on_track)
self.pc.on("datachannel", self.on_datachannel)
# Create a local audio track to send data
self.local_audio_player = MediaPlayer("default", format="avfoundation", options={"channels": "1", "sample_rate": str(self.output_sample_rate)})
self.local_audio = self.relay.subscribe(self.local_audio_player.audio)
self.pc.addTrack(self.local_audio)
# Add a data channel (optional, but good practice)
self.dc = self.pc.createDataChannel("data")
# Create an offer and set local description
offer = await self.pc.createOffer()
await self.pc.setLocalDescription(offer)
print("PeerConnection established")
return self.pc.localDescription
async def set_remote_description(self, sdp, type):
"""Sets the remote description."""
from aiortc import RTCSessionDescription
await self.pc.setRemoteDescription(RTCSessionDescription(sdp=sdp, type=type))
print("Remote description set")
if self.pc.remoteDescription.type == "offer":
answer = await self.pc.createAnswer()
await self.pc.setLocalDescription(answer)
return self.pc.localDescription
async def add_ice_candidate(self, candidate, sdpMid, sdpMLineIndex):
"""Adds an ICE candidate."""
from aiortc import RTCIceCandidate
if candidate:
try:
ice_candidate = RTCIceCandidate(
candidate=candidate, sdpMid=sdpMid, sdpMLineIndex=sdpMLineIndex
)
await self.pc.addIceCandidate(ice_candidate)
print("ICE candidate added")
except Exception as e:
print(f"Error adding ICE candidate: {e}")
def shutdown(self):
"""Closes the PeerConnection."""
if self.pc:
asyncio.create_task(self.pc.close()) # Close in the background
self.pc = None
print("PeerConnection closed")
# Gradio Interface
async def registry(
name: str,
token: str | None = None,
**kwargs,
):
"""Sets up and returns the Gradio interface."""
gemini_handler = GeminiHandler()
async def connect_webrtc(sdp, type, candidates):
"""Connects to the WebRTC client and handles ICE candidates."""
if gemini_handler.pc is None:
local_description = await gemini_handler.connect()
if local_description:
yield json.dumps(
{
"sdp": local_description.sdp,
"type": local_description.type,
"candidates": [],
}
) # Return initial SDP
if sdp and type:
answer = await gemini_handler.set_remote_description(sdp, type)
if answer:
yield json.dumps({"sdp": answer.sdp, "type": answer.type, "candidates": []})
for candidate in candidates:
if candidate and candidate.get("candidate"):
await gemini_handler.add_ice_candidate(
candidate["candidate"], candidate.get("sdpMid"), candidate.get("sdpMLineIndex")
)
yield json.dumps({"sdp": "", "type": "", "candidates": []}) # Signal completion
interface = gr.Blocks()
with interface:
with gr.Tabs():
with gr.TabItem("Voice Chat"):
gr.HTML(
"""
<div style='text-align: left'>
<h1>Gemini API Voice Chat</h1>
</div>
"""
)
with gr.Row():
webrtc_out = gr.JSON(label="WebRTC JSON")
# Use the built-in WebRTC component, but without automatic streaming.
webrtc = gr.WebRTC(
value={"sdp": "", "type": "", "candidates": []},
interactive=True,
label="Voice Chat",
)
connect_button = gr.Button("Connect")
connect_button.click(
connect_webrtc,
inputs=[
webrtc
], # Pass the WebRTC component's value (SDP, type, candidates)
outputs=[webrtc_out], # show the webrtc connection data
)
return interface
# Launch the Gradio interface
async def main():
interface = await registry(name="gemini-2.0-flash-exp")
interface.queue() # Enable queuing for better concurrency
await interface.launch()
if __name__ == "__main__":
asyncio.run(main()) |