Spaces:
Sleeping
Sleeping
import os | |
import base64 | |
import json | |
import numpy as np | |
import gradio as gr | |
# import websockets.sync.client # No longer needed with FastRTC | |
from fastrtc import ( | |
#PeerConnection, | |
DataChannel, | |
MediaStreamTrack, | |
AudioFrame, | |
VideoFrame, | |
) # Import FastRTC components | |
from aiortc.contrib.media import MediaPlayer, MediaRelay | |
import asyncio | |
__version__ = "0.0.3" | |
# KEY_NAME = "AIzaSyCWPviRPxj8IMLaijLGbRIsio3dO2rp3rU" # Best practice: Keep API keys out of the main code. Use environment variables. | |
# Configuration and Utilities | |
class GeminiConfig: | |
"""Configuration settings for Gemini API.""" | |
def __init__(self): | |
self.api_key = os.environ.get("KEY_NAME") # Use a more descriptive name | |
if not self.api_key: | |
raise ValueError("GEMINI_API_KEY environment variable is not set.") | |
self.host = "generativelanguage.googleapis.com" | |
self.model = "models/gemini-2.0-flash-exp" | |
# FastRTC doesn't use WebSockets directly in the same way. We'll handle the API calls differently. | |
self.base_url = f"https://{self.host}/v1alpha/{self.model}:streamGenerateContent?key={self.api_key}" | |
class AudioProcessor: | |
"""Handles encoding and decoding of audio data.""" | |
def encode_audio(data: np.ndarray, sample_rate: int) -> str: | |
"""Encodes audio data to base64.""" | |
# Ensure data is in the correct format (int16) | |
if data.dtype != np.int16: | |
data = data.astype(np.int16) | |
encoded = base64.b64encode(data.tobytes()).decode("UTF-8") | |
return encoded | |
def process_audio_response(data: str) -> np.ndarray: | |
"""Decodes audio data from base64.""" | |
audio_data = base64.b64decode(data) | |
return np.frombuffer(audio_data, dtype=np.int16) | |
# We don't need a StreamHandler in the same way with FastRTC. We'll handle streaming directly. | |
class GeminiHandler: | |
"""Handles interactions with the Gemini API.""" | |
def __init__(self, output_sample_rate=24000, output_frame_size=480): | |
self.config = GeminiConfig() | |
self.audio_processor = AudioProcessor() | |
self.output_sample_rate = output_sample_rate | |
self.output_frame_size = output_frame_size | |
self.all_output_data = None | |
self.pc = None # PeerConnection | |
self.dc = None # DataChannel | |
self.audio_track = None | |
self._audio_buffer = [] | |
self.relay = MediaRelay() | |
async def _send_audio_to_gemini(self, encoded_audio: str): | |
"""Sends audio data to the Gemini API and processes the response.""" | |
headers = {"Content-Type": "application/json"} | |
payload = { | |
"contents": [ | |
{ | |
"parts": [ | |
{ | |
"text": "Respond to the audio with audio." | |
}, # Initial prompt, can be adjusted | |
{"inline_data": {"mime_type": "audio/pcm;rate=24000", "data": encoded_audio}}, | |
] | |
} | |
] | |
} | |
# Use aiohttp for asynchronous HTTP requests | |
import aiohttp | |
async with aiohttp.ClientSession() as session: | |
async with session.post( | |
self.config.base_url, headers=headers, data=json.dumps(payload) | |
) as response: | |
if response.status != 200: | |
print(f"Error: Gemini API returned status {response.status}") | |
print(await response.text()) | |
return | |
async for line in response.content: | |
try: | |
line = line.strip() | |
if not line: | |
continue | |
# Responses are chunked, often with multiple JSON objects per chunk. Handle that. | |
for chunk in line.decode("utf-8").split("\n"): | |
if not chunk.strip(): | |
continue | |
try: | |
data = json.loads(chunk) | |
except json.JSONDecodeError: | |
print(f"JSONDecodeError: {chunk}") | |
continue | |
if "candidates" in data: | |
for candidate in data["candidates"]: | |
for part in candidate.get("content", {}).get("parts", []): | |
if "inlineData" in part: | |
audio_data = part["inlineData"].get("data", "") | |
if audio_data: | |
await self._process_server_audio(audio_data) | |
except Exception as e: | |
print(f"Error processing response chunk: {e}") | |
async def _process_server_audio(self, audio_data: str): | |
"""Processes and buffers audio data received from the server.""" | |
audio_array = self.audio_processor.process_audio_response(audio_data) | |
if self.all_output_data is None: | |
self.all_output_data = audio_array | |
else: | |
self.all_output_data = np.concatenate((self.all_output_data, audio_array)) | |
while self.all_output_data.shape[-1] >= self.output_frame_size: | |
frame = AudioFrame( | |
samples=self.output_frame_size, | |
sample_rate=self.output_sample_rate, | |
layout="mono", # mono channel | |
data=self.all_output_data[: self.output_frame_size].tobytes() | |
) | |
self.all_output_data = self.all_output_data[self.output_frame_size:] | |
if self.audio_track: | |
await self.audio_track.emit(frame) | |
async def on_track(self, track): | |
"""Handles incoming media tracks.""" | |
print(f"Track received: {track.kind}") | |
if track.kind == "audio": | |
self.audio_track = track # Store the audio track | |
async def on_frame(frame): | |
# Process received audio frames | |
if isinstance(frame, AudioFrame): | |
try: | |
# Convert the frame data to a NumPy array | |
audio_data = np.frombuffer(frame.data, dtype=np.int16) | |
# Encode the audio and send it to Gemini | |
encoded_audio = self.audio_processor.encode_audio( | |
audio_data, frame.sample_rate | |
) # Pass sample rate | |
await self._send_audio_to_gemini(encoded_audio) | |
except Exception as e: | |
print(f"Error processing audio frame: {e}") | |
async def on_datachannel(self, channel): | |
"""Handles data channel events (not used in this example, but good practice).""" | |
self.dc = channel | |
print("Data channel created") | |
async def on_message(message): | |
print(f"Received message: {message}") | |
async def connect(self): | |
"""Establishes the PeerConnection.""" | |
self.pc = PeerConnection() | |
self.pc.on("track", self.on_track) | |
self.pc.on("datachannel", self.on_datachannel) | |
# Create a local audio track to send data | |
self.local_audio_player = MediaPlayer("default", format="avfoundation", options={"channels": "1", "sample_rate": str(self.output_sample_rate)}) | |
self.local_audio = self.relay.subscribe(self.local_audio_player.audio) | |
self.pc.addTrack(self.local_audio) | |
# Add a data channel (optional, but good practice) | |
self.dc = self.pc.createDataChannel("data") | |
# Create an offer and set local description | |
offer = await self.pc.createOffer() | |
await self.pc.setLocalDescription(offer) | |
print("PeerConnection established") | |
return self.pc.localDescription | |
async def set_remote_description(self, sdp, type): | |
"""Sets the remote description.""" | |
from aiortc import RTCSessionDescription | |
await self.pc.setRemoteDescription(RTCSessionDescription(sdp=sdp, type=type)) | |
print("Remote description set") | |
if self.pc.remoteDescription.type == "offer": | |
answer = await self.pc.createAnswer() | |
await self.pc.setLocalDescription(answer) | |
return self.pc.localDescription | |
async def add_ice_candidate(self, candidate, sdpMid, sdpMLineIndex): | |
"""Adds an ICE candidate.""" | |
from aiortc import RTCIceCandidate | |
if candidate: | |
try: | |
ice_candidate = RTCIceCandidate( | |
candidate=candidate, sdpMid=sdpMid, sdpMLineIndex=sdpMLineIndex | |
) | |
await self.pc.addIceCandidate(ice_candidate) | |
print("ICE candidate added") | |
except Exception as e: | |
print(f"Error adding ICE candidate: {e}") | |
def shutdown(self): | |
"""Closes the PeerConnection.""" | |
if self.pc: | |
asyncio.create_task(self.pc.close()) # Close in the background | |
self.pc = None | |
print("PeerConnection closed") | |
# Gradio Interface | |
async def registry( | |
name: str, | |
token: str | None = None, | |
**kwargs, | |
): | |
"""Sets up and returns the Gradio interface.""" | |
gemini_handler = GeminiHandler() | |
async def connect_webrtc(sdp, type, candidates): | |
"""Connects to the WebRTC client and handles ICE candidates.""" | |
if gemini_handler.pc is None: | |
local_description = await gemini_handler.connect() | |
if local_description: | |
yield json.dumps( | |
{ | |
"sdp": local_description.sdp, | |
"type": local_description.type, | |
"candidates": [], | |
} | |
) # Return initial SDP | |
if sdp and type: | |
answer = await gemini_handler.set_remote_description(sdp, type) | |
if answer: | |
yield json.dumps({"sdp": answer.sdp, "type": answer.type, "candidates": []}) | |
for candidate in candidates: | |
if candidate and candidate.get("candidate"): | |
await gemini_handler.add_ice_candidate( | |
candidate["candidate"], candidate.get("sdpMid"), candidate.get("sdpMLineIndex") | |
) | |
yield json.dumps({"sdp": "", "type": "", "candidates": []}) # Signal completion | |
interface = gr.Blocks() | |
with interface: | |
with gr.Tabs(): | |
with gr.TabItem("Voice Chat"): | |
gr.HTML( | |
""" | |
<div style='text-align: left'> | |
<h1>Gemini API Voice Chat</h1> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
webrtc_out = gr.JSON(label="WebRTC JSON") | |
# Use the built-in WebRTC component, but without automatic streaming. | |
webrtc = gr.WebRTC( | |
value={"sdp": "", "type": "", "candidates": []}, | |
interactive=True, | |
label="Voice Chat", | |
) | |
connect_button = gr.Button("Connect") | |
connect_button.click( | |
connect_webrtc, | |
inputs=[ | |
webrtc | |
], # Pass the WebRTC component's value (SDP, type, candidates) | |
outputs=[webrtc_out], # show the webrtc connection data | |
) | |
return interface | |
# Launch the Gradio interface | |
async def main(): | |
interface = await registry(name="gemini-2.0-flash-exp") | |
interface.queue() # Enable queuing for better concurrency | |
await interface.launch() | |
if __name__ == "__main__": | |
asyncio.run(main()) |