File size: 5,323 Bytes
86b351a
 
 
 
 
a277e33
 
45fabe9
 
86b351a
 
 
45fabe9
 
86b351a
a277e33
4310b90
45fabe9
 
 
4310b90
21eb680
 
45fabe9
21eb680
 
 
 
86b351a
21eb680
 
 
 
 
45fabe9
21eb680
 
 
 
45fabe9
 
21eb680
45fabe9
21eb680
 
 
45fabe9
 
 
a277e33
86b351a
45fabe9
86b351a
a277e33
86b351a
21eb680
 
 
 
45fabe9
86b351a
45fabe9
86b351a
 
a277e33
 
86b351a
45fabe9
a277e33
86b351a
 
 
 
 
 
45fabe9
a277e33
 
86b351a
 
 
a277e33
86b351a
45fabe9
86b351a
 
45fabe9
a277e33
86b351a
a277e33
45fabe9
 
a277e33
 
 
45fabe9
8517b46
 
 
 
 
a277e33
45fabe9
86b351a
a277e33
 
0a18f7d
 
45fabe9
0a18f7d
86b351a
a277e33
 
 
45fabe9
 
 
a277e33
45fabe9
 
 
a277e33
 
 
 
 
45fabe9
a277e33
 
 
 
 
 
 
 
45fabe9
a277e33
 
45fabe9
a277e33
45fabe9
a277e33
 
 
21eb680
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import asyncio
from google.genai import types
import wave
import queue
import logging
import io
import time
from config import settings
from services.google import GoogleClientFactory

logger = logging.getLogger(__name__)




async def generate_music(user_hash: str, music_tone: str, receive_audio):
    if user_hash in sessions:
        logger.info(
            f"Music generation already started for user hash {user_hash}, skipping new generation"
        )
        return
    async with GoogleClientFactory.audio() as client:
        async with (
            client.live.music.connect(model="models/lyria-realtime-exp") as session,
            asyncio.TaskGroup() as tg,
        ):
            # Set up task to receive server messages.
            tg.create_task(receive_audio(session, user_hash))

            # Send initial prompts and config
            await asyncio.wait_for(
                session.set_weighted_prompts(
                    prompts=[types.WeightedPrompt(text=music_tone, weight=1.0)]
                ),
                settings.request_timeout,
            )
            await asyncio.wait_for(
                session.set_music_generation_config(
                    config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
                ),
                settings.request_timeout,
            )
            await asyncio.wait_for(session.play(), settings.request_timeout)
            logger.info(
                f"Started music generation for user hash {user_hash}, music tone: {music_tone}"
            )
            sessions[user_hash] = {"session": session, "queue": queue.Queue()}


async def change_music_tone(user_hash: str, new_tone):
    logger.info(f"Changing music tone to {new_tone}")
    session = sessions.get(user_hash, {}).get("session")
    if not session:
        logger.error(f"No session found for user hash {user_hash}")
        return
    await asyncio.wait_for(
        session.set_weighted_prompts(
            prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
        ),
        settings.request_timeout,
    )


SAMPLE_RATE = 48000
NUM_CHANNELS = 2  # Stereo
SAMPLE_WIDTH = 2  # 16-bit audio -> 2 bytes per sample


async def receive_audio(session, user_hash):
    """Process incoming audio from the music generation."""
    while True:
        try:
            async for message in session.receive():
                if message.server_content and message.server_content.audio_chunks:
                    audio_data = message.server_content.audio_chunks[0].data
                    queue = sessions[user_hash]["queue"]
                    # audio_data is already bytes (raw PCM)
                    await asyncio.to_thread(queue.put, audio_data)
                await asyncio.sleep(10**-12)
        except Exception as e:
            logger.error(f"Error in receive_audio: {e}")
            break


sessions = {}


async def start_music_generation(user_hash: str, music_tone: str):
    """Start the music generation in a separate thread."""
    await generate_music(user_hash, music_tone, receive_audio)


async def cleanup_music_session(user_hash: str):
    if user_hash in sessions:
        logger.info(f"Cleaning up music session for user hash {user_hash}")
        session = sessions[user_hash]["session"]
        try:
            await asyncio.wait_for(session.stop(), settings.request_timeout)
            await asyncio.wait_for(session.close(), settings.request_timeout)
        except Exception as e:
            logger.error(f"Error stopping music session for user hash {user_hash}: {e}")
        del sessions[user_hash]


def update_audio(user_hash):
    """Continuously stream audio from the queue as WAV bytes."""
    if user_hash == "":
        return

    logger.info(f"Starting audio update loop for user hash: {user_hash}")
    while True:
        if user_hash not in sessions:
            time.sleep(0.5)
            continue
        queue = sessions[user_hash]["queue"]
        pcm_data = queue.get()  # This is raw PCM audio bytes

        if not isinstance(pcm_data, bytes):
            logger.warning(
                f"Expected bytes from audio_queue, got {type(pcm_data)}. Skipping."
            )
            continue

        # Lyria provides stereo, 16-bit PCM at 48kHz.
        # Ensure the number of bytes is consistent with stereo 16-bit audio.
        # Each frame = NUM_CHANNELS * SAMPLE_WIDTH bytes.
        # If len(pcm_data) is not a multiple of (NUM_CHANNELS * SAMPLE_WIDTH),
        # it might indicate an incomplete chunk or an issue.
        bytes_per_frame = NUM_CHANNELS * SAMPLE_WIDTH
        if len(pcm_data) % bytes_per_frame != 0:
            logger.warning(
                f"Received PCM data with length {len(pcm_data)}, which is not a multiple of "
                f"bytes_per_frame ({bytes_per_frame}). This might cause issues with WAV formatting."
            )
            # Depending on strictness, you might want to skip this chunk:
            # continue

        wav_buffer = io.BytesIO()
        with wave.open(wav_buffer, "wb") as wf:
            wf.setnchannels(NUM_CHANNELS)
            wf.setsampwidth(SAMPLE_WIDTH)  # Corresponds to 16-bit audio
            wf.setframerate(SAMPLE_RATE)
            wf.writeframes(pcm_data)
        wav_bytes = wav_buffer.getvalue()
        yield wav_bytes