mgokg commited on
Commit
3117482
·
verified ·
1 Parent(s): a52af47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -298
app.py CHANGED
@@ -1,300 +1,24 @@
1
- import os
2
- import base64
3
- import json
4
- import numpy as np
5
  import gradio as gr
6
- # import websockets.sync.client # No longer needed with FastRTC
7
- from fastrtc import (
8
- #PeerConnection,
9
- DataChannel,
10
- MediaStreamTrack,
11
- AudioFrame,
12
- VideoFrame,
13
- ) # Import FastRTC components
14
- from aiortc.contrib.media import MediaPlayer, MediaRelay
15
- import asyncio
16
-
17
- __version__ = "0.0.3"
18
-
19
- # KEY_NAME = "AIzaSyCWPviRPxj8IMLaijLGbRIsio3dO2rp3rU" # Best practice: Keep API keys out of the main code. Use environment variables.
20
-
21
-
22
- # Configuration and Utilities
23
- class GeminiConfig:
24
- """Configuration settings for Gemini API."""
25
-
26
- def __init__(self):
27
- self.api_key = os.environ.get("KEY_NAME") # Use a more descriptive name
28
- if not self.api_key:
29
- raise ValueError("GEMINI_API_KEY environment variable is not set.")
30
- self.host = "generativelanguage.googleapis.com"
31
- self.model = "models/gemini-2.0-flash-exp"
32
- # FastRTC doesn't use WebSockets directly in the same way. We'll handle the API calls differently.
33
- self.base_url = f"https://{self.host}/v1alpha/{self.model}:streamGenerateContent?key={self.api_key}"
34
-
35
-
36
- class AudioProcessor:
37
- """Handles encoding and decoding of audio data."""
38
-
39
- @staticmethod
40
- def encode_audio(data: np.ndarray, sample_rate: int) -> str:
41
- """Encodes audio data to base64."""
42
- # Ensure data is in the correct format (int16)
43
- if data.dtype != np.int16:
44
- data = data.astype(np.int16)
45
- encoded = base64.b64encode(data.tobytes()).decode("UTF-8")
46
- return encoded
47
-
48
- @staticmethod
49
- def process_audio_response(data: str) -> np.ndarray:
50
- """Decodes audio data from base64."""
51
- audio_data = base64.b64decode(data)
52
- return np.frombuffer(audio_data, dtype=np.int16)
53
-
54
-
55
- # We don't need a StreamHandler in the same way with FastRTC. We'll handle streaming directly.
56
- class GeminiHandler:
57
- """Handles interactions with the Gemini API."""
58
-
59
- def __init__(self, output_sample_rate=24000, output_frame_size=480):
60
- self.config = GeminiConfig()
61
- self.audio_processor = AudioProcessor()
62
- self.output_sample_rate = output_sample_rate
63
- self.output_frame_size = output_frame_size
64
- self.all_output_data = None
65
- self.pc = None # PeerConnection
66
- self.dc = None # DataChannel
67
- self.audio_track = None
68
- self._audio_buffer = []
69
- self.relay = MediaRelay()
70
-
71
- async def _send_audio_to_gemini(self, encoded_audio: str):
72
- """Sends audio data to the Gemini API and processes the response."""
73
- headers = {"Content-Type": "application/json"}
74
- payload = {
75
- "contents": [
76
- {
77
- "parts": [
78
- {
79
- "text": "Respond to the audio with audio."
80
- }, # Initial prompt, can be adjusted
81
- {"inline_data": {"mime_type": "audio/pcm;rate=24000", "data": encoded_audio}},
82
- ]
83
- }
84
- ]
85
- }
86
- # Use aiohttp for asynchronous HTTP requests
87
- import aiohttp
88
-
89
- async with aiohttp.ClientSession() as session:
90
- async with session.post(
91
- self.config.base_url, headers=headers, data=json.dumps(payload)
92
- ) as response:
93
- if response.status != 200:
94
- print(f"Error: Gemini API returned status {response.status}")
95
- print(await response.text())
96
- return
97
-
98
- async for line in response.content:
99
- try:
100
- line = line.strip()
101
- if not line:
102
- continue
103
- # Responses are chunked, often with multiple JSON objects per chunk. Handle that.
104
- for chunk in line.decode("utf-8").split("\n"):
105
- if not chunk.strip():
106
- continue
107
- try:
108
- data = json.loads(chunk)
109
- except json.JSONDecodeError:
110
- print(f"JSONDecodeError: {chunk}")
111
- continue
112
-
113
- if "candidates" in data:
114
- for candidate in data["candidates"]:
115
- for part in candidate.get("content", {}).get("parts", []):
116
- if "inlineData" in part:
117
- audio_data = part["inlineData"].get("data", "")
118
- if audio_data:
119
- await self._process_server_audio(audio_data)
120
-
121
- except Exception as e:
122
- print(f"Error processing response chunk: {e}")
123
-
124
- async def _process_server_audio(self, audio_data: str):
125
- """Processes and buffers audio data received from the server."""
126
- audio_array = self.audio_processor.process_audio_response(audio_data)
127
- if self.all_output_data is None:
128
- self.all_output_data = audio_array
129
- else:
130
- self.all_output_data = np.concatenate((self.all_output_data, audio_array))
131
-
132
- while self.all_output_data.shape[-1] >= self.output_frame_size:
133
- frame = AudioFrame(
134
- samples=self.output_frame_size,
135
- sample_rate=self.output_sample_rate,
136
- layout="mono", # mono channel
137
- data=self.all_output_data[: self.output_frame_size].tobytes()
138
- )
139
- self.all_output_data = self.all_output_data[self.output_frame_size:]
140
- if self.audio_track:
141
- await self.audio_track.emit(frame)
142
-
143
-
144
- async def on_track(self, track):
145
- """Handles incoming media tracks."""
146
- print(f"Track received: {track.kind}")
147
- if track.kind == "audio":
148
- self.audio_track = track # Store the audio track
149
-
150
- @track.on("frame")
151
- async def on_frame(frame):
152
- # Process received audio frames
153
- if isinstance(frame, AudioFrame):
154
- try:
155
- # Convert the frame data to a NumPy array
156
- audio_data = np.frombuffer(frame.data, dtype=np.int16)
157
- # Encode the audio and send it to Gemini
158
- encoded_audio = self.audio_processor.encode_audio(
159
- audio_data, frame.sample_rate
160
- ) # Pass sample rate
161
- await self._send_audio_to_gemini(encoded_audio)
162
- except Exception as e:
163
- print(f"Error processing audio frame: {e}")
164
-
165
- async def on_datachannel(self, channel):
166
- """Handles data channel events (not used in this example, but good practice)."""
167
- self.dc = channel
168
- print("Data channel created")
169
-
170
- @channel.on("message")
171
- async def on_message(message):
172
- print(f"Received message: {message}")
173
-
174
- async def connect(self):
175
- """Establishes the PeerConnection."""
176
- self.pc = PeerConnection()
177
- self.pc.on("track", self.on_track)
178
- self.pc.on("datachannel", self.on_datachannel)
179
-
180
- # Create a local audio track to send data
181
- self.local_audio_player = MediaPlayer("default", format="avfoundation", options={"channels": "1", "sample_rate": str(self.output_sample_rate)})
182
- self.local_audio = self.relay.subscribe(self.local_audio_player.audio)
183
- self.pc.addTrack(self.local_audio)
184
-
185
- # Add a data channel (optional, but good practice)
186
- self.dc = self.pc.createDataChannel("data")
187
-
188
- # Create an offer and set local description
189
- offer = await self.pc.createOffer()
190
- await self.pc.setLocalDescription(offer)
191
- print("PeerConnection established")
192
- return self.pc.localDescription
193
-
194
- async def set_remote_description(self, sdp, type):
195
- """Sets the remote description."""
196
- from aiortc import RTCSessionDescription
197
-
198
- await self.pc.setRemoteDescription(RTCSessionDescription(sdp=sdp, type=type))
199
- print("Remote description set")
200
-
201
- if self.pc.remoteDescription.type == "offer":
202
- answer = await self.pc.createAnswer()
203
- await self.pc.setLocalDescription(answer)
204
- return self.pc.localDescription
205
-
206
- async def add_ice_candidate(self, candidate, sdpMid, sdpMLineIndex):
207
- """Adds an ICE candidate."""
208
- from aiortc import RTCIceCandidate
209
-
210
- if candidate:
211
- try:
212
- ice_candidate = RTCIceCandidate(
213
- candidate=candidate, sdpMid=sdpMid, sdpMLineIndex=sdpMLineIndex
214
- )
215
- await self.pc.addIceCandidate(ice_candidate)
216
- print("ICE candidate added")
217
- except Exception as e:
218
- print(f"Error adding ICE candidate: {e}")
219
-
220
- def shutdown(self):
221
- """Closes the PeerConnection."""
222
- if self.pc:
223
- asyncio.create_task(self.pc.close()) # Close in the background
224
- self.pc = None
225
- print("PeerConnection closed")
226
-
227
-
228
- # Gradio Interface
229
- async def registry(
230
- name: str,
231
- token: str | None = None,
232
- **kwargs,
233
- ):
234
- """Sets up and returns the Gradio interface."""
235
- gemini_handler = GeminiHandler()
236
-
237
- async def connect_webrtc(sdp, type, candidates):
238
- """Connects to the WebRTC client and handles ICE candidates."""
239
- if gemini_handler.pc is None:
240
- local_description = await gemini_handler.connect()
241
- if local_description:
242
- yield json.dumps(
243
- {
244
- "sdp": local_description.sdp,
245
- "type": local_description.type,
246
- "candidates": [],
247
- }
248
- ) # Return initial SDP
249
- if sdp and type:
250
- answer = await gemini_handler.set_remote_description(sdp, type)
251
- if answer:
252
- yield json.dumps({"sdp": answer.sdp, "type": answer.type, "candidates": []})
253
-
254
- for candidate in candidates:
255
- if candidate and candidate.get("candidate"):
256
- await gemini_handler.add_ice_candidate(
257
- candidate["candidate"], candidate.get("sdpMid"), candidate.get("sdpMLineIndex")
258
- )
259
- yield json.dumps({"sdp": "", "type": "", "candidates": []}) # Signal completion
260
-
261
- interface = gr.Blocks()
262
- with interface:
263
- with gr.Tabs():
264
- with gr.TabItem("Voice Chat"):
265
- gr.HTML(
266
- """
267
- <div style='text-align: left'>
268
- <h1>Gemini API Voice Chat</h1>
269
- </div>
270
- """
271
- )
272
- with gr.Row():
273
- webrtc_out = gr.JSON(label="WebRTC JSON")
274
-
275
- # Use the built-in WebRTC component, but without automatic streaming.
276
- webrtc = gr.WebRTC(
277
- value={"sdp": "", "type": "", "candidates": []},
278
- interactive=True,
279
- label="Voice Chat",
280
- )
281
-
282
- connect_button = gr.Button("Connect")
283
- connect_button.click(
284
- connect_webrtc,
285
- inputs=[
286
- webrtc
287
- ], # Pass the WebRTC component's value (SDP, type, candidates)
288
- outputs=[webrtc_out], # show the webrtc connection data
289
- )
290
-
291
- return interface
292
-
293
- # Launch the Gradio interface
294
- async def main():
295
- interface = await registry(name="gemini-2.0-flash-exp")
296
- interface.queue() # Enable queuing for better concurrency
297
- await interface.launch()
298
 
299
- if __name__ == "__main__":
300
- asyncio.run(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from PyPDF2 import PdfReader
3
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ def process_pdf(file):
6
+ # Read the PDF content
7
+ pdf_reader = PdfReader(file.name)
8
+ text = ""
9
+ for page in pdf_reader.pages:
10
+ text += page.extract_text()
11
+ return text
12
+
13
+ with gr.Blocks() as demo:
14
+ gr.Markdown("### File upload", elem_classes="tab-header")
15
+ with gr.Row():
16
+ text_output = gr.Textbox(label="text")
17
+ file_input = gr.File(label="Wähle eine PDF-Datei aus", type="filepath")
18
+ upload_output = gr.Textbox(label="Upload Status")
19
+ with gr.Row():
20
+ submit_button = gr.Button("upload")
21
+ submit_button.click(process_pdf, inputs=file_input, outputs=text_output
22
+
23
+ demo.launch())
24
+