import asyncio |
import base64 |
import json |
import os |
import pathlib |
from typing import AsyncGenerator, Literal, List |
import numpy as np |
from dotenv import load_dotenv |
from fastapi import FastAPI |
from fastapi.responses import HTMLResponse |
from fastrtc import AsyncStreamHandler, Stream, wait_for_item |
from pydantic import BaseModel |
import uvicorn |
from gradio.utils import get_space |
import PyPDF2 |
import docx |
import faiss |
from sentence_transformers import SentenceTransformer |
from transformers import pipeline |
import whisper |
from gtts import gTTS |
from pydub import AudioSegment |
import io |
load_dotenv() |
current_dir = pathlib.Path(__file__).parent |
DOCS_FOLDER = current_dir / "docs" |
def extract_text_from_pdf(file_path: pathlib.Path) -> str: |
text = "" |
with open(file_path, "rb") as f: |
reader = PyPDF2.PdfReader(f) |
for page in reader.pages: |
page_text = page.extract_text() |
if page_text: |
text += page_text + "\n" |
return text |
def extract_text_from_docx(file_path: pathlib.Path) -> str: |
doc = docx.Document(file_path) |
return "\n".join([para.text for para in doc.paragraphs]) |
def extract_text_from_txt(file_path: pathlib.Path) -> str: |
with open(file_path, "r", encoding="utf-8") as f: |
return f.read() |
def load_documents(folder: pathlib.Path) -> List[str]: |
documents = [] |
for file_path in folder.glob("*"): |
if file_path.suffix.lower() == ".pdf": |
documents.append(extract_text_from_pdf(file_path)) |
elif file_path.suffix.lower() in [".docx", ".doc"]: |
documents.append(extract_text_from_docx(file_path)) |
elif file_path.suffix.lower() == ".txt": |
documents.append(extract_text_from_txt(file_path)) |
return documents |
def split_text(text: str, max_length: int = 500, overlap: int = 100) -> List[str]: |
chunks = [] |
start = 0 |
while start < len(text): |
end = start + max_length |
chunks.append(text[start:end]) |
start += max_length - overlap |
return chunks |
documents = load_documents(DOCS_FOLDER) |
all_chunks = [] |
for doc in documents: |
all_chunks.extend(split_text(doc)) |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2") |
chunk_embeddings = embedding_model.encode(all_chunks) |
embedding_dim = chunk_embeddings.shape[1] |
faiss_index = faiss.IndexFlatL2(embedding_dim) |
faiss_index.add(np.array(chunk_embeddings)) |
generator = pipeline("text-generation", model="gpt2", max_length=256) |
def retrieve_context(query: str, k: int = 5) -> List[str]: |
query_embedding = embedding_model.encode([query]) |
distances, indices = faiss_index.search(np.array(query_embedding), k) |
return [all_chunks[idx] for idx in indices[0] if idx < len(all_chunks)] |
def generate_answer(query: str) -> str: |
context_chunks = retrieve_context(query) |
context = "\n".join(context_chunks) |
prompt = ( |
f"You are a customer support agent. Use the following context to answer the question.\n\n" |
f"Context:\n{context}\n\n" |
f"Question: {query}\n\n" |
f"Answer:" |
) |
response = generator(prompt, max_new_tokens=100, do_sample=True, temperature=0.7) |
return response[0]["generated_text"] |
stt_model = whisper.load_model("base", device="cpu") |
def speech_to_text(audio_array: np.ndarray, sample_rate: int = 16000) -> str: |
audio_float = audio_array.astype(np.float32) / 32768.0 |
result = stt_model.transcribe(audio_float, fp16=False) |
return result["text"] |
def text_to_speech(text: str, lang="en", target_sample_rate: int = 24000) -> np.ndarray: |
tts = gTTS(text, lang=lang) |
mp3_fp = io.BytesIO() |
tts.write_to_fp(mp3_fp) |
mp3_fp.seek(0) |
audio = AudioSegment.from_file(mp3_fp, format="mp3") |
audio = audio.set_frame_rate(target_sample_rate).set_channels(1) |
return np.array(audio.get_array_of_samples(), dtype=np.int16) |
class RAGVoiceHandler(AsyncStreamHandler): |
def __init__( |
self, |
expected_layout: Literal["mono"] = "mono", |
output_sample_rate: int = 24000, |
output_frame_size: int = 480, |
) -> None: |
super().__init__( |
expected_layout, |
output_sample_rate, |
output_frame_size, |
input_sample_rate=16000, |
) |
self.input_queue: asyncio.Queue = asyncio.Queue() |
self.output_queue: asyncio.Queue = asyncio.Queue() |
self.quit: asyncio.Event = asyncio.Event() |
self.input_buffer = bytearray() |
self.last_input_time = asyncio.get_event_loop().time() |
def copy(self) -> "RAGVoiceHandler": |
return RAGVoiceHandler( |
expected_layout="mono", |
output_sample_rate=self.output_sample_rate, |
output_frame_size=self.output_frame_size, |
) |
async def stream(self) -> AsyncGenerator[bytes, None]: |
while not self.quit.is_set(): |
try: |
audio_data = await asyncio.wait_for(self.input_queue.get(), timeout=0.5) |
self.input_buffer.extend(audio_data) |
self.last_input_time = asyncio.get_event_loop().time() |
except asyncio.TimeoutError: |
if self.input_buffer: |
audio_array = np.frombuffer(self.input_buffer, dtype=np.int16) |
self.input_buffer = bytearray() |
query_text = speech_to_text(audio_array, sample_rate=self.input_sample_rate) |
if query_text.strip(): |
print("Transcribed query:", query_text) |
answer_text = generate_answer(query_text) |
print("Generated answer:", answer_text) |
tts_audio = text_to_speech(answer_text, target_sample_rate=self.output_sample_rate) |
self.output_queue.put_nowait((self.output_sample_rate, tts_audio)) |
await asyncio.sleep(0.1) |
async def receive(self, frame: tuple[int, np.ndarray]) -> None: |
sample_rate, audio_array = frame |
audio_bytes = audio_array.tobytes() |
await self.input_queue.put(audio_bytes) |
async def emit(self) -> tuple[int, np.ndarray] | None: |
return await wait_for_item(self.output_queue) |
def shutdown(self) -> None: |
self.quit.set() |
rtc_config = { |
"iceServers": [ |
{"urls": "stun:stun.l.google.com:19302"}, |
{ |
"urls": "turn:turn.anyfirewall.com:443?transport=tcp", |
"username": "webrtc", |
"credential": "webrtc" |
} |
] |
} |
stream = Stream( |
modality="audio", |
mode="send-receive", |
handler=RAGVoiceHandler(), |
rtc_configuration=rtc_config, |
concurrency_limit=5, |
time_limit=90, |
) |
class InputData(BaseModel): |
webrtc_id: str |
app = FastAPI() |
stream.mount(app) |
@app.post("/input_hook") |
async def input_hook(body: InputData): |
stream.set_input(body.webrtc_id) |
return {"status": "ok"} |
@app.post("/webrtc/offer") |
async def webrtc_offer(offer: dict): |
return await stream.handle_offer(offer) |
@app.post("/chat") |
async def chat_endpoint(payload: dict): |
question = payload.get("question", "") |
if not question: |
return {"error": "No question provided"} |
answer = generate_answer(question) |
return {"answer": answer} |
@app.get("/") |
async def index_endpoint(): |
index_path = current_dir / "index.html" |
html_content = index_path.read_text() |
return HTMLResponse(content=html_content) |
if __name__ == "__main__": |
mode = os.getenv("MODE", "PHONE") |
if mode == "UI": |
import gradio as gr |
def gradio_chat(user_input): |
return generate_answer(user_input) |
iface = gr.Interface(fn=gradio_chat, inputs="text", outputs="text", title="Customer Support Chatbot") |
iface.launch(server_port=7860) |
elif mode == "PHONE": |
uvicorn.run(app, host="", port=7860) |
else: |
uvicorn.run(app, host="", port=7860) |