|
import asyncio |
|
import base64 |
|
import json |
|
import os |
|
import pathlib |
|
from typing import AsyncGenerator, Literal, List |
|
|
|
import numpy as np |
|
from dotenv import load_dotenv |
|
from fastapi import FastAPI |
|
from fastapi.responses import HTMLResponse |
|
from fastrtc import AsyncStreamHandler, Stream, wait_for_item |
|
from pydantic import BaseModel |
|
import uvicorn |
|
|
|
|
|
from gradio.utils import get_space |
|
|
|
|
|
import PyPDF2 |
|
import docx |
|
import faiss |
|
from sentence_transformers import SentenceTransformer |
|
from transformers import pipeline |
|
|
|
|
|
import whisper |
|
from gtts import gTTS |
|
from pydub import AudioSegment |
|
import io |
|
|
|
|
|
load_dotenv() |
|
current_dir = pathlib.Path(__file__).parent |
|
|
|
|
|
|
|
|
|
|
|
DOCS_FOLDER = current_dir / "docs" |
|
|
|
def extract_text_from_pdf(file_path: pathlib.Path) -> str: |
|
text = "" |
|
with open(file_path, "rb") as f: |
|
reader = PyPDF2.PdfReader(f) |
|
for page in reader.pages: |
|
page_text = page.extract_text() |
|
if page_text: |
|
text += page_text + "\n" |
|
return text |
|
|
|
def extract_text_from_docx(file_path: pathlib.Path) -> str: |
|
doc = docx.Document(file_path) |
|
return "\n".join([para.text for para in doc.paragraphs]) |
|
|
|
def extract_text_from_txt(file_path: pathlib.Path) -> str: |
|
with open(file_path, "r", encoding="utf-8") as f: |
|
return f.read() |
|
|
|
def load_documents(folder: pathlib.Path) -> List[str]: |
|
documents = [] |
|
for file_path in folder.glob("*"): |
|
if file_path.suffix.lower() == ".pdf": |
|
documents.append(extract_text_from_pdf(file_path)) |
|
elif file_path.suffix.lower() in [".docx", ".doc"]: |
|
documents.append(extract_text_from_docx(file_path)) |
|
elif file_path.suffix.lower() == ".txt": |
|
documents.append(extract_text_from_txt(file_path)) |
|
return documents |
|
|
|
def split_text(text: str, max_length: int = 500, overlap: int = 100) -> List[str]: |
|
chunks = [] |
|
start = 0 |
|
while start < len(text): |
|
end = start + max_length |
|
chunks.append(text[start:end]) |
|
start += max_length - overlap |
|
return chunks |
|
|
|
documents = load_documents(DOCS_FOLDER) |
|
all_chunks = [] |
|
for doc in documents: |
|
all_chunks.extend(split_text(doc)) |
|
|
|
embedding_model = SentenceTransformer("all-MiniLM-L6-v2") |
|
chunk_embeddings = embedding_model.encode(all_chunks) |
|
embedding_dim = chunk_embeddings.shape[1] |
|
faiss_index = faiss.IndexFlatL2(embedding_dim) |
|
faiss_index.add(np.array(chunk_embeddings)) |
|
|
|
generator = pipeline("text-generation", model="gpt2", max_length=256) |
|
|
|
def retrieve_context(query: str, k: int = 5) -> List[str]: |
|
query_embedding = embedding_model.encode([query]) |
|
distances, indices = faiss_index.search(np.array(query_embedding), k) |
|
return [all_chunks[idx] for idx in indices[0] if idx < len(all_chunks)] |
|
|
|
def generate_answer(query: str) -> str: |
|
context_chunks = retrieve_context(query) |
|
context = "\n".join(context_chunks) |
|
prompt = ( |
|
f"You are a customer support agent. Use the following context to answer the question.\n\n" |
|
f"Context:\n{context}\n\n" |
|
f"Question: {query}\n\n" |
|
f"Answer:" |
|
) |
|
response = generator(prompt, max_new_tokens=100, do_sample=True, temperature=0.7) |
|
generated_text = response[0]["generated_text"] |
|
|
|
if "Answer:" in generated_text: |
|
answer = generated_text.split("Answer:", 1)[1].strip() |
|
else: |
|
answer = generated_text.strip() |
|
return answer |
|
|
|
|
|
|
|
|
|
|
|
stt_model = whisper.load_model("base", device="cpu") |
|
|
|
def speech_to_text(audio_array: np.ndarray, sample_rate: int = 16000) -> str: |
|
audio_float = audio_array.astype(np.float32) / 32768.0 |
|
result = stt_model.transcribe(audio_float, fp16=False) |
|
return result["text"] |
|
|
|
def text_to_speech(text: str, lang="en", target_sample_rate: int = 24000) -> np.ndarray: |
|
tts = gTTS(text, lang=lang) |
|
mp3_fp = io.BytesIO() |
|
tts.write_to_fp(mp3_fp) |
|
mp3_fp.seek(0) |
|
audio = AudioSegment.from_file(mp3_fp, format="mp3") |
|
audio = audio.set_frame_rate(target_sample_rate).set_channels(1) |
|
return np.array(audio.get_array_of_samples(), dtype=np.int16) |
|
|
|
|
|
|
|
|
|
|
|
class RAGVoiceHandler(AsyncStreamHandler): |
|
def __init__( |
|
self, |
|
expected_layout: Literal["mono"] = "mono", |
|
output_sample_rate: int = 24000, |
|
output_frame_size: int = 480, |
|
) -> None: |
|
super().__init__( |
|
expected_layout, |
|
output_sample_rate, |
|
output_frame_size, |
|
input_sample_rate=16000, |
|
) |
|
self.input_queue: asyncio.Queue = asyncio.Queue() |
|
self.output_queue: asyncio.Queue = asyncio.Queue() |
|
self.quit: asyncio.Event = asyncio.Event() |
|
self.input_buffer = bytearray() |
|
self.last_input_time = asyncio.get_event_loop().time() |
|
|
|
def copy(self) -> "RAGVoiceHandler": |
|
return RAGVoiceHandler( |
|
expected_layout="mono", |
|
output_sample_rate=self.output_sample_rate, |
|
output_frame_size=self.output_frame_size, |
|
) |
|
|
|
async def stream(self) -> AsyncGenerator[bytes, None]: |
|
while not self.quit.is_set(): |
|
try: |
|
audio_data = await asyncio.wait_for(self.input_queue.get(), timeout=0.5) |
|
self.input_buffer.extend(audio_data) |
|
self.last_input_time = asyncio.get_event_loop().time() |
|
except asyncio.TimeoutError: |
|
if self.input_buffer: |
|
audio_array = np.frombuffer(self.input_buffer, dtype=np.int16) |
|
self.input_buffer = bytearray() |
|
query_text = speech_to_text(audio_array, sample_rate=self.input_sample_rate) |
|
if query_text.strip(): |
|
print("Transcribed query:", query_text) |
|
answer_text = generate_answer(query_text) |
|
print("Generated answer:", answer_text) |
|
tts_audio = text_to_speech(answer_text, target_sample_rate=self.output_sample_rate) |
|
self.output_queue.put_nowait((self.output_sample_rate, tts_audio)) |
|
await asyncio.sleep(0.1) |
|
|
|
async def receive(self, frame: tuple[int, np.ndarray]) -> None: |
|
sample_rate, audio_array = frame |
|
audio_bytes = audio_array.tobytes() |
|
await self.input_queue.put(audio_bytes) |
|
|
|
async def emit(self) -> tuple[int, np.ndarray] | None: |
|
return await wait_for_item(self.output_queue) |
|
|
|
def shutdown(self) -> None: |
|
self.quit.set() |
|
|
|
|
|
|
|
|
|
|
|
rtc_config = { |
|
"iceServers": [ |
|
{"urls": "stun:stun.l.google.com:19302"}, |
|
{ |
|
"urls": "turn:turn.anyfirewall.com:443?transport=tcp", |
|
"username": "webrtc", |
|
"credential": "webrtc" |
|
} |
|
] |
|
} |
|
|
|
stream = Stream( |
|
modality="audio", |
|
mode="send-receive", |
|
handler=RAGVoiceHandler(), |
|
rtc_configuration=rtc_config, |
|
concurrency_limit=5, |
|
time_limit=90, |
|
) |
|
|
|
class InputData(BaseModel): |
|
webrtc_id: str |
|
|
|
app = FastAPI() |
|
stream.mount(app) |
|
|
|
@app.post("/input_hook") |
|
async def input_hook(body: InputData): |
|
stream.set_input(body.webrtc_id) |
|
return {"status": "ok"} |
|
|
|
@app.post("/webrtc/offer") |
|
async def webrtc_offer(offer: dict): |
|
return await stream.handle_offer(offer) |
|
|
|
@app.post("/chat") |
|
async def chat_endpoint(payload: dict): |
|
question = payload.get("question", "") |
|
if not question: |
|
return {"error": "No question provided"} |
|
answer = generate_answer(question) |
|
return {"answer": answer} |
|
|
|
@app.get("/") |
|
async def index_endpoint(): |
|
index_path = current_dir / "index.html" |
|
html_content = index_path.read_text() |
|
return HTMLResponse(content=html_content) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
mode = os.getenv("MODE", "PHONE") |
|
if mode == "UI": |
|
import gradio as gr |
|
def gradio_chat(user_input): |
|
return generate_answer(user_input) |
|
iface = gr.Interface(fn=gradio_chat, inputs="text", outputs="text", title="Customer Support Chatbot") |
|
iface.launch(server_port=7860) |
|
elif mode == "PHONE": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
else: |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|