mgokg's picture
Update app.py
a52af47 verified
raw
history blame
11.9 kB
import os
import base64
import json
import numpy as np
import gradio as gr
# import websockets.sync.client # No longer needed with FastRTC
from fastrtc import (
#PeerConnection,
DataChannel,
MediaStreamTrack,
AudioFrame,
VideoFrame,
) # Import FastRTC components
from aiortc.contrib.media import MediaPlayer, MediaRelay
import asyncio
__version__ = "0.0.3"
# KEY_NAME = "AIzaSyCWPviRPxj8IMLaijLGbRIsio3dO2rp3rU" # Best practice: Keep API keys out of the main code. Use environment variables.
# Configuration and Utilities
class GeminiConfig:
"""Configuration settings for Gemini API."""
def __init__(self):
self.api_key = os.environ.get("KEY_NAME") # Use a more descriptive name
if not self.api_key:
raise ValueError("GEMINI_API_KEY environment variable is not set.")
self.host = "generativelanguage.googleapis.com"
self.model = "models/gemini-2.0-flash-exp"
# FastRTC doesn't use WebSockets directly in the same way. We'll handle the API calls differently.
self.base_url = f"https://{self.host}/v1alpha/{self.model}:streamGenerateContent?key={self.api_key}"
class AudioProcessor:
"""Handles encoding and decoding of audio data."""
@staticmethod
def encode_audio(data: np.ndarray, sample_rate: int) -> str:
"""Encodes audio data to base64."""
# Ensure data is in the correct format (int16)
if data.dtype != np.int16:
data = data.astype(np.int16)
encoded = base64.b64encode(data.tobytes()).decode("UTF-8")
return encoded
@staticmethod
def process_audio_response(data: str) -> np.ndarray:
"""Decodes audio data from base64."""
audio_data = base64.b64decode(data)
return np.frombuffer(audio_data, dtype=np.int16)
# We don't need a StreamHandler in the same way with FastRTC. We'll handle streaming directly.
class GeminiHandler:
"""Handles interactions with the Gemini API."""
def __init__(self, output_sample_rate=24000, output_frame_size=480):
self.config = GeminiConfig()
self.audio_processor = AudioProcessor()
self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size
self.all_output_data = None
self.pc = None # PeerConnection
self.dc = None # DataChannel
self.audio_track = None
self._audio_buffer = []
self.relay = MediaRelay()
async def _send_audio_to_gemini(self, encoded_audio: str):
"""Sends audio data to the Gemini API and processes the response."""
headers = {"Content-Type": "application/json"}
payload = {
"contents": [
{
"parts": [
{
"text": "Respond to the audio with audio."
}, # Initial prompt, can be adjusted
{"inline_data": {"mime_type": "audio/pcm;rate=24000", "data": encoded_audio}},
]
}
]
}
# Use aiohttp for asynchronous HTTP requests
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.post(
self.config.base_url, headers=headers, data=json.dumps(payload)
) as response:
if response.status != 200:
print(f"Error: Gemini API returned status {response.status}")
print(await response.text())
return
async for line in response.content:
try:
line = line.strip()
if not line:
continue
# Responses are chunked, often with multiple JSON objects per chunk. Handle that.
for chunk in line.decode("utf-8").split("\n"):
if not chunk.strip():
continue
try:
data = json.loads(chunk)
except json.JSONDecodeError:
print(f"JSONDecodeError: {chunk}")
continue
if "candidates" in data:
for candidate in data["candidates"]:
for part in candidate.get("content", {}).get("parts", []):
if "inlineData" in part:
audio_data = part["inlineData"].get("data", "")
if audio_data:
await self._process_server_audio(audio_data)
except Exception as e:
print(f"Error processing response chunk: {e}")
async def _process_server_audio(self, audio_data: str):
"""Processes and buffers audio data received from the server."""
audio_array = self.audio_processor.process_audio_response(audio_data)
if self.all_output_data is None:
self.all_output_data = audio_array
else:
self.all_output_data = np.concatenate((self.all_output_data, audio_array))
while self.all_output_data.shape[-1] >= self.output_frame_size:
frame = AudioFrame(
samples=self.output_frame_size,
sample_rate=self.output_sample_rate,
layout="mono", # mono channel
data=self.all_output_data[: self.output_frame_size].tobytes()
)
self.all_output_data = self.all_output_data[self.output_frame_size:]
if self.audio_track:
await self.audio_track.emit(frame)
async def on_track(self, track):
"""Handles incoming media tracks."""
print(f"Track received: {track.kind}")
if track.kind == "audio":
self.audio_track = track # Store the audio track
@track.on("frame")
async def on_frame(frame):
# Process received audio frames
if isinstance(frame, AudioFrame):
try:
# Convert the frame data to a NumPy array
audio_data = np.frombuffer(frame.data, dtype=np.int16)
# Encode the audio and send it to Gemini
encoded_audio = self.audio_processor.encode_audio(
audio_data, frame.sample_rate
) # Pass sample rate
await self._send_audio_to_gemini(encoded_audio)
except Exception as e:
print(f"Error processing audio frame: {e}")
async def on_datachannel(self, channel):
"""Handles data channel events (not used in this example, but good practice)."""
self.dc = channel
print("Data channel created")
@channel.on("message")
async def on_message(message):
print(f"Received message: {message}")
async def connect(self):
"""Establishes the PeerConnection."""
self.pc = PeerConnection()
self.pc.on("track", self.on_track)
self.pc.on("datachannel", self.on_datachannel)
# Create a local audio track to send data
self.local_audio_player = MediaPlayer("default", format="avfoundation", options={"channels": "1", "sample_rate": str(self.output_sample_rate)})
self.local_audio = self.relay.subscribe(self.local_audio_player.audio)
self.pc.addTrack(self.local_audio)
# Add a data channel (optional, but good practice)
self.dc = self.pc.createDataChannel("data")
# Create an offer and set local description
offer = await self.pc.createOffer()
await self.pc.setLocalDescription(offer)
print("PeerConnection established")
return self.pc.localDescription
async def set_remote_description(self, sdp, type):
"""Sets the remote description."""
from aiortc import RTCSessionDescription
await self.pc.setRemoteDescription(RTCSessionDescription(sdp=sdp, type=type))
print("Remote description set")
if self.pc.remoteDescription.type == "offer":
answer = await self.pc.createAnswer()
await self.pc.setLocalDescription(answer)
return self.pc.localDescription
async def add_ice_candidate(self, candidate, sdpMid, sdpMLineIndex):
"""Adds an ICE candidate."""
from aiortc import RTCIceCandidate
if candidate:
try:
ice_candidate = RTCIceCandidate(
candidate=candidate, sdpMid=sdpMid, sdpMLineIndex=sdpMLineIndex
)
await self.pc.addIceCandidate(ice_candidate)
print("ICE candidate added")
except Exception as e:
print(f"Error adding ICE candidate: {e}")
def shutdown(self):
"""Closes the PeerConnection."""
if self.pc:
asyncio.create_task(self.pc.close()) # Close in the background
self.pc = None
print("PeerConnection closed")
# Gradio Interface
async def registry(
name: str,
token: str | None = None,
**kwargs,
):
"""Sets up and returns the Gradio interface."""
gemini_handler = GeminiHandler()
async def connect_webrtc(sdp, type, candidates):
"""Connects to the WebRTC client and handles ICE candidates."""
if gemini_handler.pc is None:
local_description = await gemini_handler.connect()
if local_description:
yield json.dumps(
{
"sdp": local_description.sdp,
"type": local_description.type,
"candidates": [],
}
) # Return initial SDP
if sdp and type:
answer = await gemini_handler.set_remote_description(sdp, type)
if answer:
yield json.dumps({"sdp": answer.sdp, "type": answer.type, "candidates": []})
for candidate in candidates:
if candidate and candidate.get("candidate"):
await gemini_handler.add_ice_candidate(
candidate["candidate"], candidate.get("sdpMid"), candidate.get("sdpMLineIndex")
)
yield json.dumps({"sdp": "", "type": "", "candidates": []}) # Signal completion
interface = gr.Blocks()
with interface:
with gr.Tabs():
with gr.TabItem("Voice Chat"):
gr.HTML(
"""
<div style='text-align: left'>
<h1>Gemini API Voice Chat</h1>
</div>
"""
)
with gr.Row():
webrtc_out = gr.JSON(label="WebRTC JSON")
# Use the built-in WebRTC component, but without automatic streaming.
webrtc = gr.WebRTC(
value={"sdp": "", "type": "", "candidates": []},
interactive=True,
label="Voice Chat",
)
connect_button = gr.Button("Connect")
connect_button.click(
connect_webrtc,
inputs=[
webrtc
], # Pass the WebRTC component's value (SDP, type, candidates)
outputs=[webrtc_out], # show the webrtc connection data
)
return interface
# Launch the Gradio interface
async def main():
interface = await registry(name="gemini-2.0-flash-exp")
interface.queue() # Enable queuing for better concurrency
await interface.launch()
if __name__ == "__main__":
asyncio.run(main())