Docfile commited on
Commit
dd6e891
·
verified ·
1 Parent(s): f8e9d91

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -0
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import mediapipe as mp
3
+ import numpy as np
4
+ import base64
5
+ import io
6
+ import PIL.Image
7
+ import asyncio
8
+ import os
9
+ from google import genai
10
+ from streamlit_webrtc import webrtc_streamer
11
+ import av
12
+ import pyaudio
13
+ from mediapipe.tasks import python
14
+ from mediapipe.tasks.python import vision
15
+
16
+ # Configuration
17
+ FORMAT = pyaudio.paInt16
18
+ CHANNELS = 1
19
+ SEND_SAMPLE_RATE = 16000
20
+ RECEIVE_SAMPLE_RATE = 24000
21
+ CHUNK_SIZE = 1024
22
+
23
+ # Initialize Genai client
24
+ genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
25
+ client = genai.Client(http_options={"api_version": "v1alpha"})
26
+ MODEL = "models/gemini-2.0-flash-exp"
27
+ CONFIG = {"generation_config": {"response_modalities": ["AUDIO"]}}
28
+
29
+ class AudioProcessor:
30
+ def __init__(self):
31
+ self.audio = pyaudio.PyAudio()
32
+ self.stream = None
33
+ self.audio_queue = asyncio.Queue()
34
+
35
+ def start_stream(self):
36
+ mic_info = self.audio.get_default_input_device_info()
37
+ self.stream = self.audio.open(
38
+ format=FORMAT,
39
+ channels=CHANNELS,
40
+ rate=SEND_SAMPLE_RATE,
41
+ input=True,
42
+ input_device_index=mic_info["index"],
43
+ frames_per_buffer=CHUNK_SIZE,
44
+ )
45
+
46
+ def stop_stream(self):
47
+ if self.stream:
48
+ self.stream.stop_stream()
49
+ self.stream.close()
50
+ self.stream = None
51
+
52
+ class VideoProcessor:
53
+ def __init__(self):
54
+ self.frame_queue = asyncio.Queue(maxsize=5)
55
+ self.mp_draw = mp.solutions.drawing_utils
56
+ self.mp_face_detection = mp.solutions.face_detection
57
+ self.face_detection = self.mp_face_detection.FaceDetection(
58
+ min_detection_confidence=0.5)
59
+
60
+ def video_frame_callback(self, frame):
61
+ # Convert the frame to RGB
62
+ img = frame.to_ndarray(format="rgb24")
63
+
64
+ # Process the frame with MediaPipe
65
+ results = self.face_detection.process(img)
66
+
67
+ # Draw face detection annotations if faces are detected
68
+ if results.detections:
69
+ for detection in results.detections:
70
+ self.mp_draw.draw_detection(img, detection)
71
+
72
+ # Convert to PIL Image
73
+ pil_img = PIL.Image.fromarray(img)
74
+ pil_img.thumbnail([1024, 1024])
75
+
76
+ # Prepare frame data for Gemini
77
+ image_io = io.BytesIO()
78
+ pil_img.save(image_io, format="jpeg")
79
+ image_io.seek(0)
80
+
81
+ frame_data = {
82
+ "mime_type": "image/jpeg",
83
+ "data": base64.b64encode(image_io.read()).decode()
84
+ }
85
+
86
+ try:
87
+ self.frame_queue.put_nowait(frame_data)
88
+ except asyncio.QueueFull:
89
+ pass
90
+
91
+ return av.VideoFrame.from_ndarray(img, format="rgb24")
92
+
93
+ def __del__(self):
94
+ # Cleanup MediaPipe resources
95
+ if hasattr(self, 'face_detection'):
96
+ self.face_detection.close()
97
+
98
+ def initialize_session_state():
99
+ if 'audio_processor' not in st.session_state:
100
+ st.session_state.audio_processor = AudioProcessor()
101
+ if 'video_processor' not in st.session_state:
102
+ st.session_state.video_processor = VideoProcessor()
103
+ if 'session' not in st.session_state:
104
+ st.session_state.session = None
105
+ if 'messages' not in st.session_state:
106
+ st.session_state.messages = []
107
+
108
+ def display_chat_messages():
109
+ for message in st.session_state.messages:
110
+ with st.chat_message(message["role"]):
111
+ st.markdown(message["content"])
112
+
113
+ def main():
114
+ st.title("Gemini Interactive Assistant")
115
+
116
+ # Initialize session state
117
+ initialize_session_state()
118
+
119
+ # Sidebar configuration
120
+ st.sidebar.title("Settings")
121
+ input_mode = st.sidebar.radio(
122
+ "Input Mode",
123
+ ["Text Only", "Audio + Video", "Audio Only"]
124
+ )
125
+
126
+ # Enable face detection option
127
+ enable_face_detection = st.sidebar.checkbox("Enable Face Detection", value=True)
128
+
129
+ if enable_face_detection:
130
+ detection_confidence = st.sidebar.slider(
131
+ "Face Detection Confidence",
132
+ min_value=0.0,
133
+ max_value=1.0,
134
+ value=0.5,
135
+ step=0.1
136
+ )
137
+ st.session_state.video_processor.face_detection = (
138
+ st.session_state.video_processor.mp_face_detection.FaceDetection(
139
+ min_detection_confidence=detection_confidence
140
+ )
141
+ )
142
+
143
+ # Display chat history
144
+ display_chat_messages()
145
+
146
+ # Main interaction area
147
+ if input_mode == "Text Only":
148
+ user_input = st.chat_input("Your message")
149
+ if user_input:
150
+ # Add user message to chat
151
+ st.session_state.messages.append({"role": "user", "content": user_input})
152
+ with st.chat_message("user"):
153
+ st.markdown(user_input)
154
+
155
+ async def send_message():
156
+ async with client.aio.live.connect(model=MODEL, config=CONFIG) as session:
157
+ await session.send(user_input, end_of_turn=True)
158
+ turn = session.receive()
159
+ async for response in turn:
160
+ if text := response.text:
161
+ # Add assistant response to chat
162
+ st.session_state.messages.append(
163
+ {"role": "assistant", "content": text}
164
+ )
165
+ with st.chat_message("assistant"):
166
+ st.markdown(text)
167
+
168
+ asyncio.run(send_message())
169
+
170
+ else:
171
+ # Video stream setup
172
+ if input_mode == "Audio + Video":
173
+ ctx = webrtc_streamer(
174
+ key="gemini-stream",
175
+ video_frame_callback=st.session_state.video_processor.video_frame_callback,
176
+ rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
177
+ media_stream_constraints={"video": True, "audio": True},
178
+ )
179
+
180
+ # Audio controls
181
+ col1, col2 = st.columns(2)
182
+ with col1:
183
+ if st.button("Start Recording", type="primary"):
184
+ st.session_state.audio_processor.start_stream()
185
+ st.session_state['recording'] = True
186
+
187
+ with col2:
188
+ if st.button("Stop Recording", type="secondary"):
189
+ st.session_state.audio_processor.stop_stream()
190
+ st.session_state['recording'] = False
191
+
192
+ async def process_audio_stream():
193
+ while st.session_state.get('recording', False):
194
+ if st.session_state.audio_processor.stream:
195
+ data = st.session_state.audio_processor.stream.read(CHUNK_SIZE)
196
+ await st.session_state.audio_processor.audio_queue.put({
197
+ "data": data,
198
+ "mime_type": "audio/pcm"
199
+ })
200
+ await asyncio.sleep(0.1)
201
+
202
+ if __name__ == "__main__":
203
+ main()