gemini-webrtc / app.py
TakakuwaLui's picture
Update app.py
63115d0 verified
raw
history blame
6.31 kB
import asyncio
import base64
import json
import os
import pathlib
from typing import AsyncGenerator, Literal
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastrtc import (
AsyncStreamHandler,
Stream,
get_cloudflare_turn_credentials_async,
wait_for_item,
)
from google import genai
from google.genai.types import (
LiveConnectConfig,
PrebuiltVoiceConfig,
SpeechConfig,
VoiceConfig,
Tool,
GoogleSearch,
Content,
Part,
)
from gradio.utils import get_space
from pydantic import BaseModel
from langdetect import detect, DetectorFactory
# Make langdetect results consistent
DetectorFactory.seed = 0
current_dir = pathlib.Path(__file__).parent
load_dotenv()
def encode_audio(data: np.ndarray) -> str:
"""Encode Audio data to send to the server"""
return base64.b64encode(data.tobytes()).decode("UTF-8")
def detect_language_code(text: str) -> str:
"""Detect if the text is in English or Spanish and return Google TTS code"""
try:
lang = detect(text)
if lang.startswith("es"):
return "es-ES" # Spanish
elif lang.startswith("en"):
return "en-US" # English
else:
return "en-US" # default
except Exception:
return "en-US"
class GeminiHandler(AsyncStreamHandler):
"""Handler for the Gemini API"""
def __init__(
self,
expected_layout: Literal["mono"] = "mono",
output_sample_rate: int = 24000,
) -> None:
super().__init__(
expected_layout,
output_sample_rate,
input_sample_rate=16000,
)
self.input_queue: asyncio.Queue = asyncio.Queue()
self.output_queue: asyncio.Queue = asyncio.Queue()
self.quit: asyncio.Event = asyncio.Event()
def copy(self) -> "GeminiHandler":
return GeminiHandler(
expected_layout="mono",
output_sample_rate=self.output_sample_rate,
)
async def start_up(self):
if not self.phone_mode:
await self.wait_for_args()
api_key, voice_name, system_message = self.latest_args[1:]
else:
api_key, voice_name, system_message = None, "Kore", "You are a helpful assistant."
# Auto-detect language from system_message
lang_code = detect_language_code(system_message)
client = genai.Client(
api_key=api_key or os.getenv("GEMINI_API_KEY"),
http_options={"api_version": "v1alpha"},
)
tools = [Tool(google_search=GoogleSearch())]
system_instruction = Content(
parts=[Part.from_text(text=f"{system_message}")],
role="user"
)
config = LiveConnectConfig(
response_modalities=["AUDIO"],
speech_config=SpeechConfig(
language_code=lang_code,
voice_config=VoiceConfig(
prebuilt_voice_config=PrebuiltVoiceConfig(
voice_name=voice_name,
)
)
),
tools=tools,
system_instruction=system_instruction,
)
async with client.aio.live.connect(
model="gemini-2.0-flash-exp", config=config
) as session:
async for audio in session.start_stream(
stream=self.stream(), mime_type="audio/pcm"
):
if audio.data:
array = np.frombuffer(audio.data, dtype=np.int16)
self.output_queue.put_nowait((self.output_sample_rate, array))
async def stream(self) -> AsyncGenerator[bytes, None]:
while not self.quit.is_set():
try:
audio = await asyncio.wait_for(self.input_queue.get(), 0.1)
yield audio
except (asyncio.TimeoutError, TimeoutError):
pass
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
_, array = frame
array = array.squeeze()
audio_message = encode_audio(array)
self.input_queue.put_nowait(audio_message)
async def emit(self) -> tuple[int, np.ndarray] | None:
return await wait_for_item(self.output_queue)
def shutdown(self) -> None:
self.quit.set()
stream = Stream(
modality="audio",
mode="send-receive",
handler=GeminiHandler(),
rtc_configuration=get_cloudflare_turn_credentials_async if get_space() else None,
concurrency_limit=5 if get_space() else None,
time_limit=900 if get_space() else None,
additional_inputs=[
gr.Textbox(
label="API Key",
type="password",
value=os.getenv("GEMINI_API_KEY") if not get_space() else "",
),
gr.Dropdown(
label="Voice",
choices=[
"Puck",
"Charon",
"Kore",
"Fenrir",
"Aoede",
],
value="Kore",
),
gr.Textbox(
label="System Message",
placeholder="Enter system instructions for the AI...",
value="You are a helpful assistant who answers questions and helps with tasks.",
lines=3,
),
],
)
class InputData(BaseModel):
webrtc_id: str
voice_name: str
api_key: str
system_message: str
app = FastAPI()
stream.mount(app)
@app.post("/input_hook")
async def _(body: InputData):
stream.set_input(
body.webrtc_id,
body.api_key,
body.voice_name,
body.system_message,
)
return {"status": "ok"}
@app.get("/")
async def index():
rtc_config = await get_cloudflare_turn_credentials_async() if get_space() else None
html_content = (current_dir / "index.html").read_text()
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
return HTMLResponse(content=html_content)
if __name__ == "__main__":
import os
if (mode := os.getenv("MODE")) == "UI":
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
stream.fastphone(host="0.0.0.0", port=7860)
else:
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)