LLMGameHub / src /audio /audio_generator.py
gsavin's picture
fix: add retry/trycatch logic
8517b46
import asyncio
from google.genai import types
import wave
import queue
import logging
import io
import time
from config import settings
from services.google import GoogleClientFactory
logger = logging.getLogger(__name__)
async def generate_music(user_hash: str, music_tone: str, receive_audio):
if user_hash in sessions:
logger.info(
f"Music generation already started for user hash {user_hash}, skipping new generation"
)
return
async with GoogleClientFactory.audio() as client:
async with (
client.live.music.connect(model="models/lyria-realtime-exp") as session,
asyncio.TaskGroup() as tg,
):
# Set up task to receive server messages.
tg.create_task(receive_audio(session, user_hash))
# Send initial prompts and config
await asyncio.wait_for(
session.set_weighted_prompts(
prompts=[types.WeightedPrompt(text=music_tone, weight=1.0)]
),
settings.request_timeout,
)
await asyncio.wait_for(
session.set_music_generation_config(
config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
),
settings.request_timeout,
)
await asyncio.wait_for(session.play(), settings.request_timeout)
logger.info(
f"Started music generation for user hash {user_hash}, music tone: {music_tone}"
)
sessions[user_hash] = {"session": session, "queue": queue.Queue()}
async def change_music_tone(user_hash: str, new_tone):
logger.info(f"Changing music tone to {new_tone}")
session = sessions.get(user_hash, {}).get("session")
if not session:
logger.error(f"No session found for user hash {user_hash}")
return
await asyncio.wait_for(
session.set_weighted_prompts(
prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
),
settings.request_timeout,
)
SAMPLE_RATE = 48000
NUM_CHANNELS = 2 # Stereo
SAMPLE_WIDTH = 2 # 16-bit audio -> 2 bytes per sample
async def receive_audio(session, user_hash):
"""Process incoming audio from the music generation."""
while True:
try:
async for message in session.receive():
if message.server_content and message.server_content.audio_chunks:
audio_data = message.server_content.audio_chunks[0].data
queue = sessions[user_hash]["queue"]
# audio_data is already bytes (raw PCM)
await asyncio.to_thread(queue.put, audio_data)
await asyncio.sleep(10**-12)
except Exception as e:
logger.error(f"Error in receive_audio: {e}")
break
sessions = {}
async def start_music_generation(user_hash: str, music_tone: str):
"""Start the music generation in a separate thread."""
await generate_music(user_hash, music_tone, receive_audio)
async def cleanup_music_session(user_hash: str):
if user_hash in sessions:
logger.info(f"Cleaning up music session for user hash {user_hash}")
session = sessions[user_hash]["session"]
try:
await asyncio.wait_for(session.stop(), settings.request_timeout)
await asyncio.wait_for(session.close(), settings.request_timeout)
except Exception as e:
logger.error(f"Error stopping music session for user hash {user_hash}: {e}")
del sessions[user_hash]
def update_audio(user_hash):
"""Continuously stream audio from the queue as WAV bytes."""
if user_hash == "":
return
logger.info(f"Starting audio update loop for user hash: {user_hash}")
while True:
if user_hash not in sessions:
time.sleep(0.5)
continue
queue = sessions[user_hash]["queue"]
pcm_data = queue.get() # This is raw PCM audio bytes
if not isinstance(pcm_data, bytes):
logger.warning(
f"Expected bytes from audio_queue, got {type(pcm_data)}. Skipping."
)
continue
# Lyria provides stereo, 16-bit PCM at 48kHz.
# Ensure the number of bytes is consistent with stereo 16-bit audio.
# Each frame = NUM_CHANNELS * SAMPLE_WIDTH bytes.
# If len(pcm_data) is not a multiple of (NUM_CHANNELS * SAMPLE_WIDTH),
# it might indicate an incomplete chunk or an issue.
bytes_per_frame = NUM_CHANNELS * SAMPLE_WIDTH
if len(pcm_data) % bytes_per_frame != 0:
logger.warning(
f"Received PCM data with length {len(pcm_data)}, which is not a multiple of "
f"bytes_per_frame ({bytes_per_frame}). This might cause issues with WAV formatting."
)
# Depending on strictness, you might want to skip this chunk:
# continue
wav_buffer = io.BytesIO()
with wave.open(wav_buffer, "wb") as wf:
wf.setnchannels(NUM_CHANNELS)
wf.setsampwidth(SAMPLE_WIDTH) # Corresponds to 16-bit audio
wf.setframerate(SAMPLE_RATE)
wf.writeframes(pcm_data)
wav_bytes = wav_buffer.getvalue()
yield wav_bytes