Spaces:
Running
Running
File size: 5,323 Bytes
86b351a a277e33 45fabe9 86b351a 45fabe9 86b351a a277e33 4310b90 45fabe9 4310b90 21eb680 45fabe9 21eb680 86b351a 21eb680 45fabe9 21eb680 45fabe9 21eb680 45fabe9 21eb680 45fabe9 a277e33 86b351a 45fabe9 86b351a a277e33 86b351a 21eb680 45fabe9 86b351a 45fabe9 86b351a a277e33 86b351a 45fabe9 a277e33 86b351a 45fabe9 a277e33 86b351a a277e33 86b351a 45fabe9 86b351a 45fabe9 a277e33 86b351a a277e33 45fabe9 a277e33 45fabe9 8517b46 a277e33 45fabe9 86b351a a277e33 0a18f7d 45fabe9 0a18f7d 86b351a a277e33 45fabe9 a277e33 45fabe9 a277e33 45fabe9 a277e33 45fabe9 a277e33 45fabe9 a277e33 45fabe9 a277e33 21eb680 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import asyncio
from google.genai import types
import wave
import queue
import logging
import io
import time
from config import settings
from services.google import GoogleClientFactory
logger = logging.getLogger(__name__)
async def generate_music(user_hash: str, music_tone: str, receive_audio):
if user_hash in sessions:
logger.info(
f"Music generation already started for user hash {user_hash}, skipping new generation"
)
return
async with GoogleClientFactory.audio() as client:
async with (
client.live.music.connect(model="models/lyria-realtime-exp") as session,
asyncio.TaskGroup() as tg,
):
# Set up task to receive server messages.
tg.create_task(receive_audio(session, user_hash))
# Send initial prompts and config
await asyncio.wait_for(
session.set_weighted_prompts(
prompts=[types.WeightedPrompt(text=music_tone, weight=1.0)]
),
settings.request_timeout,
)
await asyncio.wait_for(
session.set_music_generation_config(
config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
),
settings.request_timeout,
)
await asyncio.wait_for(session.play(), settings.request_timeout)
logger.info(
f"Started music generation for user hash {user_hash}, music tone: {music_tone}"
)
sessions[user_hash] = {"session": session, "queue": queue.Queue()}
async def change_music_tone(user_hash: str, new_tone):
logger.info(f"Changing music tone to {new_tone}")
session = sessions.get(user_hash, {}).get("session")
if not session:
logger.error(f"No session found for user hash {user_hash}")
return
await asyncio.wait_for(
session.set_weighted_prompts(
prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
),
settings.request_timeout,
)
SAMPLE_RATE = 48000
NUM_CHANNELS = 2 # Stereo
SAMPLE_WIDTH = 2 # 16-bit audio -> 2 bytes per sample
async def receive_audio(session, user_hash):
"""Process incoming audio from the music generation."""
while True:
try:
async for message in session.receive():
if message.server_content and message.server_content.audio_chunks:
audio_data = message.server_content.audio_chunks[0].data
queue = sessions[user_hash]["queue"]
# audio_data is already bytes (raw PCM)
await asyncio.to_thread(queue.put, audio_data)
await asyncio.sleep(10**-12)
except Exception as e:
logger.error(f"Error in receive_audio: {e}")
break
sessions = {}
async def start_music_generation(user_hash: str, music_tone: str):
"""Start the music generation in a separate thread."""
await generate_music(user_hash, music_tone, receive_audio)
async def cleanup_music_session(user_hash: str):
if user_hash in sessions:
logger.info(f"Cleaning up music session for user hash {user_hash}")
session = sessions[user_hash]["session"]
try:
await asyncio.wait_for(session.stop(), settings.request_timeout)
await asyncio.wait_for(session.close(), settings.request_timeout)
except Exception as e:
logger.error(f"Error stopping music session for user hash {user_hash}: {e}")
del sessions[user_hash]
def update_audio(user_hash):
"""Continuously stream audio from the queue as WAV bytes."""
if user_hash == "":
return
logger.info(f"Starting audio update loop for user hash: {user_hash}")
while True:
if user_hash not in sessions:
time.sleep(0.5)
continue
queue = sessions[user_hash]["queue"]
pcm_data = queue.get() # This is raw PCM audio bytes
if not isinstance(pcm_data, bytes):
logger.warning(
f"Expected bytes from audio_queue, got {type(pcm_data)}. Skipping."
)
continue
# Lyria provides stereo, 16-bit PCM at 48kHz.
# Ensure the number of bytes is consistent with stereo 16-bit audio.
# Each frame = NUM_CHANNELS * SAMPLE_WIDTH bytes.
# If len(pcm_data) is not a multiple of (NUM_CHANNELS * SAMPLE_WIDTH),
# it might indicate an incomplete chunk or an issue.
bytes_per_frame = NUM_CHANNELS * SAMPLE_WIDTH
if len(pcm_data) % bytes_per_frame != 0:
logger.warning(
f"Received PCM data with length {len(pcm_data)}, which is not a multiple of "
f"bytes_per_frame ({bytes_per_frame}). This might cause issues with WAV formatting."
)
# Depending on strictness, you might want to skip this chunk:
# continue
wav_buffer = io.BytesIO()
with wave.open(wav_buffer, "wb") as wf:
wf.setnchannels(NUM_CHANNELS)
wf.setsampwidth(SAMPLE_WIDTH) # Corresponds to 16-bit audio
wf.setframerate(SAMPLE_RATE)
wf.writeframes(pcm_data)
wav_bytes = wav_buffer.getvalue()
yield wav_bytes
|