voicechat / app.py
scooter7's picture
Update app.py
b3a570b verified
raw
history blame
9.08 kB
import asyncio
import base64
import json
import os
import pathlib
from typing import AsyncGenerator, Literal, List
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastrtc import AsyncStreamHandler, Stream, wait_for_item
from pydantic import BaseModel
import uvicorn
# --- Import get_space (optional) ---
from gradio.utils import get_space
# --- Document processing and RAG libraries ---
import PyPDF2
import docx
import faiss
from sentence_transformers import SentenceTransformer
from transformers import pipeline
# --- Speech processing libraries ---
import whisper
from gtts import gTTS
from pydub import AudioSegment
import io
# Load environment variables and define current directory
load_dotenv()
current_dir = pathlib.Path(__file__).parent
# ====================================================
# 1. Document Ingestion & RAG Pipeline Setup
# ====================================================
DOCS_FOLDER = current_dir / "docs"
def extract_text_from_pdf(file_path: pathlib.Path) -> str:
text = ""
with open(file_path, "rb") as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
page_text = page.extract_text()
if page_text:
text += page_text + "\n"
return text
def extract_text_from_docx(file_path: pathlib.Path) -> str:
doc = docx.Document(file_path)
return "\n".join([para.text for para in doc.paragraphs])
def extract_text_from_txt(file_path: pathlib.Path) -> str:
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
def load_documents(folder: pathlib.Path) -> List[str]:
documents = []
for file_path in folder.glob("*"):
if file_path.suffix.lower() == ".pdf":
documents.append(extract_text_from_pdf(file_path))
elif file_path.suffix.lower() in [".docx", ".doc"]:
documents.append(extract_text_from_docx(file_path))
elif file_path.suffix.lower() == ".txt":
documents.append(extract_text_from_txt(file_path))
return documents
def split_text(text: str, max_length: int = 500, overlap: int = 100) -> List[str]:
chunks = []
start = 0
while start < len(text):
end = start + max_length
chunks.append(text[start:end])
start += max_length - overlap
return chunks
documents = load_documents(DOCS_FOLDER)
all_chunks = []
for doc in documents:
all_chunks.extend(split_text(doc))
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
chunk_embeddings = embedding_model.encode(all_chunks)
embedding_dim = chunk_embeddings.shape[1]
faiss_index = faiss.IndexFlatL2(embedding_dim)
faiss_index.add(np.array(chunk_embeddings))
generator = pipeline("text-generation", model="gpt2", max_length=256)
def retrieve_context(query: str, k: int = 5) -> List[str]:
query_embedding = embedding_model.encode([query])
distances, indices = faiss_index.search(np.array(query_embedding), k)
return [all_chunks[idx] for idx in indices[0] if idx < len(all_chunks)]
def generate_answer(query: str) -> str:
context_chunks = retrieve_context(query)
context = "\n".join(context_chunks)
prompt = (
f"You are a customer support agent. Use the following context to answer the question.\n\n"
f"Context:\n{context}\n\n"
f"Question: {query}\n\n"
f"Answer:"
)
# Use max_new_tokens to generate additional tokens beyond the prompt.
response = generator(prompt, max_new_tokens=100, do_sample=True, temperature=0.7)
return response[0]["generated_text"]
# ====================================================
# 2. Speech-to-Text and Text-to-Speech Functions
# ====================================================
# Force Whisper to load on CPU explicitly
stt_model = whisper.load_model("base", device="cpu")
def speech_to_text(audio_array: np.ndarray, sample_rate: int = 16000) -> str:
audio_float = audio_array.astype(np.float32) / 32768.0
result = stt_model.transcribe(audio_float, fp16=False)
return result["text"]
def text_to_speech(text: str, lang="en", target_sample_rate: int = 24000) -> np.ndarray:
tts = gTTS(text, lang=lang)
mp3_fp = io.BytesIO()
tts.write_to_fp(mp3_fp)
mp3_fp.seek(0)
audio = AudioSegment.from_file(mp3_fp, format="mp3")
audio = audio.set_frame_rate(target_sample_rate).set_channels(1)
return np.array(audio.get_array_of_samples(), dtype=np.int16)
# ====================================================
# 3. RAGVoiceHandler: Integrating Voice & RAG
# ====================================================
class RAGVoiceHandler(AsyncStreamHandler):
def __init__(
self,
expected_layout: Literal["mono"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
) -> None:
super().__init__(
expected_layout,
output_sample_rate,
output_frame_size,
input_sample_rate=16000,
)
self.input_queue: asyncio.Queue = asyncio.Queue()
self.output_queue: asyncio.Queue = asyncio.Queue()
self.quit: asyncio.Event = asyncio.Event()
self.input_buffer = bytearray()
self.last_input_time = asyncio.get_event_loop().time()
def copy(self) -> "RAGVoiceHandler":
return RAGVoiceHandler(
expected_layout="mono",
output_sample_rate=self.output_sample_rate,
output_frame_size=self.output_frame_size,
)
async def stream(self) -> AsyncGenerator[bytes, None]:
while not self.quit.is_set():
try:
audio_data = await asyncio.wait_for(self.input_queue.get(), timeout=0.5)
self.input_buffer.extend(audio_data)
self.last_input_time = asyncio.get_event_loop().time()
except asyncio.TimeoutError:
if self.input_buffer:
audio_array = np.frombuffer(self.input_buffer, dtype=np.int16)
self.input_buffer = bytearray()
query_text = speech_to_text(audio_array, sample_rate=self.input_sample_rate)
if query_text.strip():
print("Transcribed query:", query_text)
answer_text = generate_answer(query_text)
print("Generated answer:", answer_text)
tts_audio = text_to_speech(answer_text, target_sample_rate=self.output_sample_rate)
self.output_queue.put_nowait((self.output_sample_rate, tts_audio))
await asyncio.sleep(0.1)
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
sample_rate, audio_array = frame
audio_bytes = audio_array.tobytes()
await self.input_queue.put(audio_bytes)
async def emit(self) -> tuple[int, np.ndarray] | None:
return await wait_for_item(self.output_queue)
def shutdown(self) -> None:
self.quit.set()
# ====================================================
# 4. Voice Streaming Setup & FastAPI Endpoints
# ====================================================
# Supply a dummy (but valid) RTC configuration with a TURN server entry
rtc_config = {
"iceServers": [
{"urls": "stun:stun.l.google.com:19302"},
{
"urls": "turn:turn.anyfirewall.com:443?transport=tcp",
"username": "webrtc",
"credential": "webrtc"
}
]
}
stream = Stream(
modality="audio",
mode="send-receive",
handler=RAGVoiceHandler(),
rtc_configuration=rtc_config,
concurrency_limit=5,
time_limit=90,
)
class InputData(BaseModel):
webrtc_id: str
app = FastAPI()
stream.mount(app)
@app.post("/input_hook")
async def input_hook(body: InputData):
stream.set_input(body.webrtc_id)
return {"status": "ok"}
@app.post("/webrtc/offer")
async def webrtc_offer(offer: dict):
return await stream.handle_offer(offer)
# Added /chat endpoint for text-based queries (fallback)
@app.post("/chat")
async def chat_endpoint(payload: dict):
question = payload.get("question", "")
if not question:
return {"error": "No question provided"}
answer = generate_answer(question)
return {"answer": answer}
@app.get("/")
async def index_endpoint():
index_path = current_dir / "index.html"
html_content = index_path.read_text()
return HTMLResponse(content=html_content)
# ====================================================
# 5. Application Runner
# ====================================================
if __name__ == "__main__":
mode = os.getenv("MODE", "PHONE")
if mode == "UI":
import gradio as gr
def gradio_chat(user_input):
return generate_answer(user_input)
iface = gr.Interface(fn=gradio_chat, inputs="text", outputs="text", title="Customer Support Chatbot")
iface.launch(server_port=7860)
elif mode == "PHONE":
uvicorn.run(app, host="0.0.0.0", port=7860)
else:
uvicorn.run(app, host="0.0.0.0", port=7860)