|
from fastapi import FastAPI, UploadFile, File, Response, Request, Form, Body |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.responses import FileResponse |
|
import ggwave |
|
import scipy.io.wavfile as wav |
|
import numpy as np |
|
import os |
|
from pydantic import BaseModel |
|
from groq import Groq |
|
import io |
|
import wave |
|
import json |
|
from typing import List, Dict, Optional |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
|
|
|
client = Groq(api_key=os.environ.get("GROQ_API_KEY")) |
|
|
|
class TextInput(BaseModel): |
|
text: str |
|
|
|
@app.get("/") |
|
async def serve_homepage(): |
|
"""Serve the chat interface HTML.""" |
|
return FileResponse("static/index.html") |
|
|
|
@app.get("/conv/") |
|
async def serve_convpage(): |
|
"""Serve the chat interface HTML.""" |
|
return FileResponse("static/conv.html") |
|
|
|
@app.post("/stt/") |
|
async def speech_to_text(file: UploadFile = File(...)): |
|
"""Convert WAV audio file to text using ggwave.""" |
|
with open("temp.wav", "wb") as audio_file: |
|
audio_file.write(await file.read()) |
|
|
|
|
|
fs, recorded_waveform = wav.read("temp.wav") |
|
os.remove("temp.wav") |
|
|
|
|
|
waveform_bytes = recorded_waveform.astype(np.uint8).tobytes() |
|
decoded_message = ggwave.decode(instance, waveform_bytes) |
|
|
|
return {"text": decoded_message} |
|
|
|
@app.post("/tts/") |
|
def text_to_speech(input_text: TextInput): |
|
"""Convert text to a WAV audio file using ggwave and return as response.""" |
|
encoded_waveform = ggwave.encode(input_text.text, protocolId=1, volume=100) |
|
|
|
|
|
waveform_float32 = np.frombuffer(encoded_waveform, dtype=np.float32) |
|
|
|
|
|
waveform_int16 = np.int16(waveform_float32 * 32767) |
|
|
|
|
|
buffer = io.BytesIO() |
|
with wave.open(buffer, "wb") as wf: |
|
wf.setnchannels(1) |
|
wf.setsampwidth(2) |
|
wf.setframerate(48000) |
|
wf.writeframes(waveform_int16.tobytes()) |
|
|
|
buffer.seek(0) |
|
return Response(content=buffer.getvalue(), media_type="audio/wav") |
|
|
|
@app.post("/chat/") |
|
async def chat_with_llm(file: UploadFile = File(...)): |
|
"""Process input WAV, send text to LLM, and return generated response as WAV.""" |
|
try: |
|
|
|
print(f"File received: {file.filename}, Content-Type: {file.content_type}") |
|
|
|
|
|
file_content = await file.read() |
|
if not file_content: |
|
return Response( |
|
content="Empty file uploaded", |
|
media_type="text/plain", |
|
status_code=400 |
|
) |
|
|
|
|
|
instance = ggwave.init() |
|
|
|
|
|
with io.BytesIO(file_content) as buffer: |
|
try: |
|
fs, recorded_waveform = wav.read(buffer) |
|
recorded_waveform = recorded_waveform.astype(np.float32) / 32767.0 |
|
waveform_bytes = recorded_waveform.tobytes() |
|
user_message = ggwave.decode(instance, waveform_bytes) |
|
|
|
if user_message is None: |
|
return Response( |
|
content="No message detected in audio", |
|
media_type="text/plain", |
|
status_code=400 |
|
) |
|
|
|
print("Decoded user message:", user_message.decode("utf-8")) |
|
|
|
|
|
chat_completion = client.chat.completions.create( |
|
messages=[ |
|
{"role": "system", "content": "you are a helpful assistant. answer always in one sentence"}, |
|
{"role": "user", "content": user_message.decode("utf-8")} |
|
], |
|
model="llama-3.3-70b-versatile", |
|
) |
|
|
|
llm_response = chat_completion.choices[0].message.content |
|
print("LLM Response:", llm_response) |
|
|
|
|
|
encoded_waveform = ggwave.encode(llm_response, protocolId=1, volume=100) |
|
|
|
|
|
waveform_float32 = np.frombuffer(encoded_waveform, dtype=np.float32) |
|
|
|
|
|
waveform_int16 = np.int16(waveform_float32 * 32767) |
|
|
|
|
|
buffer = io.BytesIO() |
|
with wave.open(buffer, "wb") as wf: |
|
wf.setnchannels(1) |
|
wf.setsampwidth(2) |
|
wf.setframerate(48000) |
|
wf.writeframes(waveform_int16.tobytes()) |
|
|
|
buffer.seek(0) |
|
ggwave.free(instance) |
|
return Response( |
|
content=buffer.getvalue(), |
|
media_type="audio/wav", |
|
headers={ |
|
"X-User-Message": user_message.decode("utf-8"), |
|
"X-LLM-Response": llm_response |
|
} |
|
) |
|
|
|
except Exception as e: |
|
print(f"Error processing audio: {str(e)}") |
|
ggwave.free(instance) |
|
return Response( |
|
content=f"Error processing audio: {str(e)}", |
|
media_type="text/plain", |
|
status_code=500 |
|
) |
|
|
|
except Exception as e: |
|
print(f"Unexpected error: {str(e)}") |
|
return Response( |
|
content=f"Unexpected error: {str(e)}", |
|
media_type="text/plain", |
|
status_code=500 |
|
) |
|
@app.post("/continuous-chat/") |
|
async def continuous_chat( |
|
file: UploadFile = File(...), |
|
chat_history: Optional[str] = Form(None) |
|
): |
|
"""Process input WAV with chat history, send text to LLM, and return response as WAV.""" |
|
|
|
instance = ggwave.init() |
|
|
|
|
|
|
|
|
|
messages = [{"role": "system", "content": "you are a helpful assistant. answer always in one sentence"}] |
|
|
|
if chat_history: |
|
try: |
|
history = json.loads(chat_history) |
|
for msg in history: |
|
if msg["role"] in ["user", "assistant"]: |
|
messages.append(msg) |
|
except Exception as e: |
|
print(f"Error parsing chat history: {str(e)}") |
|
|
|
|
|
file_content = await file.read() |
|
|
|
|
|
with io.BytesIO(file_content) as buffer: |
|
try: |
|
fs, recorded_waveform = wav.read(buffer) |
|
recorded_waveform = recorded_waveform.astype(np.float32) / 32767.0 |
|
waveform_bytes = recorded_waveform.tobytes() |
|
user_message = ggwave.decode(instance, waveform_bytes) |
|
|
|
if user_message is None: |
|
return Response( |
|
content="No message detected in audio", |
|
media_type="text/plain", |
|
status_code=400 |
|
) |
|
|
|
decoded_message = user_message.decode("utf-8") |
|
print("user_message: " + decoded_message) |
|
|
|
|
|
messages.append({"role": "user", "content": decoded_message}) |
|
|
|
|
|
chat_completion = client.chat.completions.create( |
|
messages=messages, |
|
model="llama-3.3-70b-versatile", |
|
) |
|
|
|
llm_response = chat_completion.choices[0].message.content |
|
print(llm_response) |
|
|
|
|
|
encoded_waveform = ggwave.encode(llm_response, protocolId=1, volume=100) |
|
waveform_float32 = np.frombuffer(encoded_waveform, dtype=np.float32) |
|
waveform_int16 = np.int16(waveform_float32 * 32767) |
|
|
|
|
|
buffer = io.BytesIO() |
|
with wave.open(buffer, "wb") as wf: |
|
wf.setnchannels(1) |
|
wf.setsampwidth(2) |
|
wf.setframerate(48000) |
|
wf.writeframes(waveform_int16.tobytes()) |
|
|
|
buffer.seek(0) |
|
ggwave.free(instance) |
|
|
|
return Response( |
|
content=buffer.getvalue(), |
|
media_type="audio/wav", |
|
headers={ |
|
"X-User-Message": decoded_message, |
|
"X-LLM-Response": llm_response |
|
} |
|
) |
|
|
|
except Exception as e: |
|
print(f"Error processing audio: {str(e)}") |
|
ggwave.free(instance) |
|
return Response( |
|
content=f"Error processing audio: {str(e)}", |
|
media_type="text/plain", |
|
status_code=500 |
|
) |