File size: 3,551 Bytes
b77b218
 
cea0ce1
b77b218
 
cea0ce1
adf8836
b77b218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adf8836
b77b218
 
 
 
 
 
cea0ce1
 
 
 
 
 
b77b218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab94a07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adf8836
b77b218
 
cea0ce1
 
 
bb2dcc7
b77b218
 
cea0ce1
bb2dcc7
cea0ce1
bb2dcc7
b77b218
adf8836
 
b77b218
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from io import BytesIO
import os
import chainlit as cl
import httpx
from dotenv import load_dotenv
from langchain.schema.runnable.config import RunnableConfig
from smartquery.sql_agent import SQLAgent
from openai import AsyncOpenAI
from chainlit.element import Audio

# Load the .env file
load_dotenv()

# Set up the transcription API (e.g., Eleven Labs)
ELEVENLABS_API_KEY = os.environ.get("ELEVENLABS_API_KEY")
ELEVENLABS_VOICE_ID = os.environ.get("ELEVENLABS_VOICE_ID")

if not ELEVENLABS_API_KEY or not ELEVENLABS_VOICE_ID:
    raise ValueError("ELEVENLABS_API_KEY and ELEVENLABS_VOICE_ID must be set")

client = AsyncOpenAI()

@cl.step(type="tool")
async def speech_to_text(audio_file):
    response = await client.audio.transcriptions.create(
        model="whisper-1", file=audio_file
    )
    return response.text

@cl.step(type="tool")
async def generate_text_answer(transcription, images):
    model = "gpt-4o"
    messages = [{"role": "user", "content": transcription}]
    response = await client.chat.completions.create(
        messages=messages, model=model, temperature=0.3
    )
    return response.choices[0].message.content

@cl.on_chat_start
async def on_chat_start():
    cl.user_session.set("agent", SQLAgent)

@cl.on_message
async def on_message(message: cl.Message):
    await process_message(message.content)

@cl.on_audio_chunk
async def on_audio_chunk(chunk: cl.AudioChunk):
    if chunk.isStart:
        buffer = BytesIO()
        # This is required for whisper to recognize the file type
        buffer.name = f"input_audio.{chunk.mimeType.split('/')[1]}"
        # Initialize the session for a new audio stream
        cl.user_session.set("audio_buffer", buffer)
        cl.user_session.set("audio_mime_type", chunk.mimeType)

    cl.user_session.get("audio_buffer").write(chunk.data)

@cl.on_audio_end
async def on_audio_end(elements: list[Audio]):
    try:
        audio_buffer: BytesIO = cl.user_session.get("audio_buffer")
        audio_buffer.seek(0)
        audio_file = audio_buffer.read()
        audio_mime_type: str = cl.user_session.get("audio_mime_type")

        input_audio_el = Audio(
            mime=audio_mime_type, content=audio_file, name=audio_buffer.name
        )
        await cl.Message(
            author="You",
            type="user_message",
            content="",
            elements=[input_audio_el, *elements]
        ).send()

        whisper_input = (audio_buffer.name, audio_file, audio_mime_type)
        transcription = await speech_to_text(whisper_input)

        await process_message(transcription)
    except Exception as e:
        print(f"Error processing audio: {e}")
        await cl.Message(content="Error processing audio. Please try again.").send()
    finally:
        # Reset audio buffer and mime type
        cl.user_session.set("audio_buffer", None)
        cl.user_session.set("audio_mime_type", None)
        print("Audio buffer reset")

async def process_message(content: str, answer_message=None, mime_type=None):
    agent = cl.user_session.get("agent")
    cb = cl.AsyncLangchainCallbackHandler(stream_final_answer=True)
    config = RunnableConfig(callbacks=[cb])

    async with cl.Step(name="SmartQuery Agent", root=True) as step:
        step.input = content
        result = await agent.ainvoke(content, config=config)

        final_answer = result.get('output', 'No answer returned')

        await step.stream_token(final_answer)

        if answer_message:
            answer_message.content = final_answer
            await answer_message.update()