voicechat / app.py
scooter7's picture
Update app.py
59c59d4 verified
raw
history blame
8.88 kB
import asyncio
import base64
import json
import os
import pathlib
from typing import AsyncGenerator, Literal, List
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastrtc import AsyncStreamHandler, Stream, wait_for_item
from pydantic import BaseModel
import uvicorn
# --- Import get_space to detect Hugging Face Spaces (optional) ---
from gradio.utils import get_space
# --- Document processing and RAG libraries ---
import PyPDF2
import docx
import faiss
from sentence_transformers import SentenceTransformer
from transformers import pipeline
# --- Speech processing libraries ---
import whisper
from gtts import gTTS
from pydub import AudioSegment
import io
# Load environment variables and define current directory
load_dotenv()
current_dir = pathlib.Path(__file__).parent
# ====================================================
# 1. Document Ingestion & RAG Pipeline Setup
# ====================================================
# Folder containing PDFs, Word docs, and text files (place this folder alongside app.py)
DOCS_FOLDER = current_dir / "docs"
def extract_text_from_pdf(file_path: pathlib.Path) -> str:
text = ""
with open(file_path, "rb") as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
page_text = page.extract_text()
if page_text:
text += page_text + "\n"
return text
def extract_text_from_docx(file_path: pathlib.Path) -> str:
doc = docx.Document(file_path)
return "\n".join([para.text for para in doc.paragraphs])
def extract_text_from_txt(file_path: pathlib.Path) -> str:
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
def load_documents(folder: pathlib.Path) -> List[str]:
documents = []
for file_path in folder.glob("*"):
if file_path.suffix.lower() == ".pdf":
documents.append(extract_text_from_pdf(file_path))
elif file_path.suffix.lower() in [".docx", ".doc"]:
documents.append(extract_text_from_docx(file_path))
elif file_path.suffix.lower() == ".txt":
documents.append(extract_text_from_txt(file_path))
return documents
def split_text(text: str, max_length: int = 500, overlap: int = 100) -> List[str]:
chunks = []
start = 0
while start < len(text):
end = start + max_length
chunks.append(text[start:end])
start += max_length - overlap
return chunks
# Load and process documents
documents = load_documents(DOCS_FOLDER)
all_chunks = []
for doc in documents:
all_chunks.extend(split_text(doc))
# Compute embeddings and build FAISS index
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
chunk_embeddings = embedding_model.encode(all_chunks)
embedding_dim = chunk_embeddings.shape[1]
index = faiss.IndexFlatL2(embedding_dim)
index.add(np.array(chunk_embeddings))
# Setup a text-generation pipeline (using GPT-2 here as an example)
generator = pipeline("text-generation", model="gpt2", max_length=256)
def retrieve_context(query: str, k: int = 5) -> List[str]:
query_embedding = embedding_model.encode([query])
distances, indices = index.search(np.array(query_embedding), k)
return [all_chunks[idx] for idx in indices[0] if idx < len(all_chunks)]
def generate_answer(query: str) -> str:
context_chunks = retrieve_context(query)
context = "\n".join(context_chunks)
prompt = (
f"You are a customer support agent. Use the following context to answer the question.\n\n"
f"Context:\n{context}\n\n"
f"Question: {query}\n\n"
f"Answer:"
)
response = generator(prompt, max_length=256, do_sample=True, temperature=0.7)
return response[0]["generated_text"]
# ====================================================
# 2. Speech-to-Text and Text-to-Speech Functions
# ====================================================
# Force Whisper to load on CPU explicitly
stt_model = whisper.load_model("base", device="cpu")
def speech_to_text(audio_array: np.ndarray, sample_rate: int = 16000) -> str:
# Convert int16 PCM to float32 normalized to [-1, 1]
audio_float = audio_array.astype(np.float32) / 32768.0
result = stt_model.transcribe(audio_float, fp16=False)
return result["text"]
def text_to_speech(text: str, lang="en", target_sample_rate: int = 24000) -> np.ndarray:
tts = gTTS(text, lang=lang)
mp3_fp = io.BytesIO()
tts.write_to_fp(mp3_fp)
mp3_fp.seek(0)
audio = AudioSegment.from_file(mp3_fp, format="mp3")
audio = audio.set_frame_rate(target_sample_rate).set_channels(1)
return np.array(audio.get_array_of_samples(), dtype=np.int16)
# ====================================================
# 3. RAGVoiceHandler: Integrating Voice & RAG
# ====================================================
class RAGVoiceHandler(AsyncStreamHandler):
def __init__(
self,
expected_layout: Literal["mono"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
) -> None:
super().__init__(
expected_layout,
output_sample_rate,
output_frame_size,
input_sample_rate=16000,
)
self.input_queue: asyncio.Queue = asyncio.Queue()
self.output_queue: asyncio.Queue = asyncio.Queue()
self.quit: asyncio.Event = asyncio.Event()
self.input_buffer = bytearray()
self.last_input_time = asyncio.get_event_loop().time()
def copy(self) -> "RAGVoiceHandler":
return RAGVoiceHandler(
expected_layout="mono",
output_sample_rate=self.output_sample_rate,
output_frame_size=self.output_frame_size,
)
async def stream(self) -> AsyncGenerator[bytes, None]:
while not self.quit.is_set():
try:
audio_data = await asyncio.wait_for(self.input_queue.get(), timeout=0.5)
self.input_buffer.extend(audio_data)
self.last_input_time = asyncio.get_event_loop().time()
except asyncio.TimeoutError:
if self.input_buffer:
audio_array = np.frombuffer(self.input_buffer, dtype=np.int16)
self.input_buffer = bytearray()
query_text = speech_to_text(audio_array, sample_rate=self.input_sample_rate)
if query_text.strip():
print("Transcribed query:", query_text)
answer_text = generate_answer(query_text)
print("Generated answer:", answer_text)
tts_audio = text_to_speech(answer_text, target_sample_rate=self.output_sample_rate)
self.output_queue.put_nowait((self.output_sample_rate, tts_audio))
await asyncio.sleep(0.1)
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
sample_rate, audio_array = frame
audio_bytes = audio_array.tobytes()
await self.input_queue.put(audio_bytes)
async def emit(self) -> tuple[int, np.ndarray] | None:
return await wait_for_item(self.output_queue)
def shutdown(self) -> None:
self.quit.set()
# ====================================================
# 4. Voice Streaming Setup & FastAPI Endpoints
# ====================================================
# For ZeroGPU spaces, supply a dummy RTC configuration.
# (This avoids calling get_twilio_turn_credentials() which depends on NVML.)
rtc_config = {"iceServers": [{"urls": "stun:stun.l.google.com:19302"}]}
stream = Stream(
modality="audio",
mode="send-receive",
handler=RAGVoiceHandler(),
rtc_configuration=rtc_config,
concurrency_limit=5,
time_limit=90,
)
class InputData(BaseModel):
webrtc_id: str
app = FastAPI()
stream.mount(app)
@app.post("/input_hook")
async def input_hook(body: InputData):
stream.set_input(body.webrtc_id)
return {"status": "ok"}
@app.post("/webrtc/offer")
async def webrtc_offer(offer: dict):
return await stream.handle_offer(offer)
@app.get("/")
async def index():
index_path = current_dir / "index.html"
html_content = index_path.read_text()
return HTMLResponse(content=html_content)
# ====================================================
# 5. Application Runner
# ====================================================
if __name__ == "__main__":
mode = os.getenv("MODE", "PHONE")
if mode == "UI":
import gradio as gr
def gradio_chat(user_input):
return generate_answer(user_input)
iface = gr.Interface(fn=gradio_chat, inputs="text", outputs="text", title="Customer Support Chatbot")
iface.launch(server_port=7860)
elif mode == "PHONE":
uvicorn.run(app, host="0.0.0.0", port=7860)
else:
uvicorn.run(app, host="0.0.0.0", port=7860)