scooter7 commited on
Commit
6c85d16
·
verified ·
1 Parent(s): 2f5d7dc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +246 -0
app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import json
4
+ import os
5
+ import pathlib
6
+ from typing import AsyncGenerator, Literal, List
7
+
8
+ import numpy as np
9
+ from dotenv import load_dotenv
10
+ from fastapi import FastAPI
11
+ from fastapi.responses import HTMLResponse
12
+ from fastrtc import AsyncStreamHandler, Stream, get_twilio_turn_credentials, wait_for_item
13
+ from pydantic import BaseModel
14
+ import uvicorn
15
+
16
+ # --- Document processing and RAG libraries ---
17
+ import PyPDF2
18
+ import docx
19
+ import faiss
20
+ from sentence_transformers import SentenceTransformer
21
+ from transformers import pipeline
22
+
23
+ # --- Speech processing libraries ---
24
+ import whisper
25
+ from gtts import gTTS
26
+ from pydub import AudioSegment
27
+ import io
28
+
29
+ # Load environment variables and define current directory
30
+ load_dotenv()
31
+ current_dir = pathlib.Path(__file__).parent
32
+
33
+ # ====================================================
34
+ # 1. Document Ingestion & RAG Pipeline Setup
35
+ # ====================================================
36
+
37
+ # Folder containing PDFs, Word docs, and text files (place this folder alongside app.py)
38
+ DOCS_FOLDER = current_dir / "docs"
39
+
40
+ def extract_text_from_pdf(file_path: pathlib.Path) -> str:
41
+ text = ""
42
+ with open(file_path, "rb") as f:
43
+ reader = PyPDF2.PdfReader(f)
44
+ for page in reader.pages:
45
+ page_text = page.extract_text()
46
+ if page_text:
47
+ text += page_text + "\n"
48
+ return text
49
+
50
+ def extract_text_from_docx(file_path: pathlib.Path) -> str:
51
+ doc = docx.Document(file_path)
52
+ return "\n".join([para.text for para in doc.paragraphs])
53
+
54
+ def extract_text_from_txt(file_path: pathlib.Path) -> str:
55
+ with open(file_path, "r", encoding="utf-8") as f:
56
+ return f.read()
57
+
58
+ def load_documents(folder: pathlib.Path) -> List[str]:
59
+ documents = []
60
+ for file_path in folder.glob("*"):
61
+ if file_path.suffix.lower() == ".pdf":
62
+ documents.append(extract_text_from_pdf(file_path))
63
+ elif file_path.suffix.lower() in [".docx", ".doc"]:
64
+ documents.append(extract_text_from_docx(file_path))
65
+ elif file_path.suffix.lower() == ".txt":
66
+ documents.append(extract_text_from_txt(file_path))
67
+ return documents
68
+
69
+ def split_text(text: str, max_length: int = 500, overlap: int = 100) -> List[str]:
70
+ chunks = []
71
+ start = 0
72
+ while start < len(text):
73
+ end = start + max_length
74
+ chunks.append(text[start:end])
75
+ start += max_length - overlap
76
+ return chunks
77
+
78
+ # Load and process documents
79
+ documents = load_documents(DOCS_FOLDER)
80
+ all_chunks = []
81
+ for doc in documents:
82
+ all_chunks.extend(split_text(doc))
83
+
84
+ # Compute embeddings and build FAISS index
85
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
86
+ chunk_embeddings = embedding_model.encode(all_chunks)
87
+ embedding_dim = chunk_embeddings.shape[1]
88
+ index = faiss.IndexFlatL2(embedding_dim)
89
+ index.add(np.array(chunk_embeddings))
90
+
91
+ # Setup a text-generation pipeline (using GPT-2 here as an example)
92
+ generator = pipeline("text-generation", model="gpt2", max_length=256)
93
+
94
+ def retrieve_context(query: str, k: int = 5) -> List[str]:
95
+ query_embedding = embedding_model.encode([query])
96
+ distances, indices = index.search(np.array(query_embedding), k)
97
+ return [all_chunks[idx] for idx in indices[0] if idx < len(all_chunks)]
98
+
99
+ def generate_answer(query: str) -> str:
100
+ context_chunks = retrieve_context(query)
101
+ context = "\n".join(context_chunks)
102
+ prompt = (
103
+ f"You are a customer support agent. Use the following context to answer the question.\n\n"
104
+ f"Context:\n{context}\n\n"
105
+ f"Question: {query}\n\n"
106
+ f"Answer:"
107
+ )
108
+ response = generator(prompt, max_length=256, do_sample=True, temperature=0.7)
109
+ return response[0]["generated_text"]
110
+
111
+ # ====================================================
112
+ # 2. Speech-to-Text and Text-to-Speech Functions
113
+ # ====================================================
114
+
115
+ # Load Whisper model for speech-to-text
116
+ stt_model = whisper.load_model("base")
117
+
118
+ def speech_to_text(audio_array: np.ndarray, sample_rate: int = 16000) -> str:
119
+ # Convert int16 PCM to float32 normalized to [-1, 1]
120
+ audio_float = audio_array.astype(np.float32) / 32768.0
121
+ result = stt_model.transcribe(audio_float, fp16=False)
122
+ return result["text"]
123
+
124
+ def text_to_speech(text: str, lang="en", target_sample_rate: int = 24000) -> np.ndarray:
125
+ tts = gTTS(text, lang=lang)
126
+ mp3_fp = io.BytesIO()
127
+ tts.write_to_fp(mp3_fp)
128
+ mp3_fp.seek(0)
129
+ audio = AudioSegment.from_file(mp3_fp, format="mp3")
130
+ audio = audio.set_frame_rate(target_sample_rate).set_channels(1)
131
+ return np.array(audio.get_array_of_samples(), dtype=np.int16)
132
+
133
+ # ====================================================
134
+ # 3. RAGVoiceHandler: Integrating Voice & RAG
135
+ # ====================================================
136
+
137
+ class RAGVoiceHandler(AsyncStreamHandler):
138
+ def __init__(
139
+ self,
140
+ expected_layout: Literal["mono"] = "mono",
141
+ output_sample_rate: int = 24000,
142
+ output_frame_size: int = 480,
143
+ ) -> None:
144
+ super().__init__(
145
+ expected_layout,
146
+ output_sample_rate,
147
+ output_frame_size,
148
+ input_sample_rate=16000,
149
+ )
150
+ self.input_queue: asyncio.Queue = asyncio.Queue()
151
+ self.output_queue: asyncio.Queue = asyncio.Queue()
152
+ self.quit: asyncio.Event = asyncio.Event()
153
+ self.input_buffer = bytearray()
154
+ self.last_input_time = asyncio.get_event_loop().time()
155
+
156
+ async def stream(self) -> AsyncGenerator[bytes, None]:
157
+ # Continuously check for new audio; if a short silence occurs (timeout), process the buffered utterance.
158
+ while not self.quit.is_set():
159
+ try:
160
+ audio_data = await asyncio.wait_for(self.input_queue.get(), timeout=0.5)
161
+ self.input_buffer.extend(audio_data)
162
+ self.last_input_time = asyncio.get_event_loop().time()
163
+ except asyncio.TimeoutError:
164
+ if self.input_buffer:
165
+ # Process the buffered utterance
166
+ audio_array = np.frombuffer(self.input_buffer, dtype=np.int16)
167
+ self.input_buffer = bytearray()
168
+ query_text = speech_to_text(audio_array, sample_rate=self.input_sample_rate)
169
+ if query_text.strip():
170
+ print("Transcribed query:", query_text)
171
+ answer_text = generate_answer(query_text)
172
+ print("Generated answer:", answer_text)
173
+ tts_audio = text_to_speech(answer_text, target_sample_rate=self.output_sample_rate)
174
+ self.output_queue.put_nowait((self.output_sample_rate, tts_audio))
175
+ await asyncio.sleep(0.1)
176
+
177
+ async def receive(self, frame: tuple[int, np.ndarray]) -> None:
178
+ # Each received frame is added as bytes to the input queue
179
+ sample_rate, audio_array = frame
180
+ audio_bytes = audio_array.tobytes()
181
+ await self.input_queue.put(audio_bytes)
182
+
183
+ async def emit(self) -> tuple[int, np.ndarray] | None:
184
+ return await wait_for_item(self.output_queue)
185
+
186
+ def shutdown(self) -> None:
187
+ self.quit.set()
188
+
189
+ # ====================================================
190
+ # 4. Twilio Voice Streaming Setup & FastAPI Endpoints
191
+ # ====================================================
192
+
193
+ # Create a Stream instance using our RAGVoiceHandler and Twilio TURN credentials
194
+ stream = Stream(
195
+ modality="audio",
196
+ mode="send-receive",
197
+ handler=RAGVoiceHandler(),
198
+ rtc_configuration=get_twilio_turn_credentials(),
199
+ concurrency_limit=5,
200
+ time_limit=90,
201
+ )
202
+
203
+ # Define a simple input hook (if needed by the client to initialize the call)
204
+ class InputData(BaseModel):
205
+ webrtc_id: str
206
+
207
+ app = FastAPI()
208
+ stream.mount(app)
209
+
210
+ @app.post("/input_hook")
211
+ async def input_hook(body: InputData):
212
+ stream.set_input(body.webrtc_id)
213
+ return {"status": "ok"}
214
+
215
+ # Endpoint to handle WebRTC offer from the client (Twilio voice calls)
216
+ @app.post("/webrtc/offer")
217
+ async def webrtc_offer(offer: dict):
218
+ # This uses fastrtc's built-in handling of the offer to set up the connection.
219
+ return await stream.handle_offer(offer)
220
+
221
+ # Serve your existing HTML file (which contains your Twilio/WebRTC voice UI)
222
+ @app.get("/")
223
+ async def index():
224
+ index_path = current_dir / "index.html"
225
+ html_content = index_path.read_text()
226
+ # If needed, replace any placeholders (for example, RTC configuration)
227
+ return HTMLResponse(content=html_content)
228
+
229
+ # ====================================================
230
+ # 5. Application Runner
231
+ # ====================================================
232
+
233
+ if __name__ == "__main__":
234
+ mode = os.getenv("MODE", "PHONE")
235
+ if mode == "UI":
236
+ # Optionally launch a text-based Gradio interface for testing the RAG backend
237
+ import gradio as gr
238
+ def gradio_chat(user_input):
239
+ return generate_answer(user_input)
240
+ iface = gr.Interface(fn=gradio_chat, inputs="text", outputs="text", title="Customer Support Chatbot")
241
+ iface.launch(server_port=7860)
242
+ elif mode == "PHONE":
243
+ # Run the FastAPI app so that callers can use the Twilio phone number to speak to the bot.
244
+ uvicorn.run(app, host="0.0.0.0", port=7860)
245
+ else:
246
+ uvicorn.run(app, host="0.0.0.0", port=7860)