Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,9 @@ import math
|
|
11 |
|
12 |
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
|
13 |
import av
|
|
|
|
|
|
|
14 |
|
15 |
## Build and Load Model
|
16 |
def attention_block(inputs, time_steps):
|
@@ -286,87 +289,148 @@ class VideoProcessor:
|
|
286 |
|
287 |
return output_frame
|
288 |
|
289 |
-
@st.cache()
|
290 |
-
def process(self, image):
|
291 |
-
|
292 |
-
|
293 |
|
294 |
-
|
295 |
-
|
296 |
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
|
321 |
-
|
322 |
-
|
323 |
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
|
328 |
-
|
329 |
-
|
330 |
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
|
348 |
-
|
349 |
-
|
350 |
|
351 |
-
def recv(self, frame):
|
352 |
-
|
353 |
-
|
|
|
|
|
|
|
354 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
Args:
|
356 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
357 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
Returns:
|
359 |
-
av.VideoFrame: processed video
|
360 |
"""
|
361 |
-
|
362 |
-
|
363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
|
365 |
-
## Stream Webcam Video and Run Model
|
366 |
# Options
|
367 |
RTC_CONFIGURATION = RTCConfiguration(
|
368 |
{"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
|
369 |
)
|
|
|
370 |
# Streamer
|
371 |
webrtc_ctx = webrtc_streamer(
|
372 |
key="AI trainer",
|
|
|
11 |
|
12 |
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
|
13 |
import av
|
14 |
+
from io import BytesIO
|
15 |
+
import av
|
16 |
+
from PIL import Image
|
17 |
|
18 |
## Build and Load Model
|
19 |
def attention_block(inputs, time_steps):
|
|
|
289 |
|
290 |
return output_frame
|
291 |
|
292 |
+
# @st.cache()
|
293 |
+
# def process(self, image):
|
294 |
+
# """
|
295 |
+
# Function to process the video frame from the user's webcam and run the fitness trainer AI
|
296 |
|
297 |
+
# Args:
|
298 |
+
# image (numpy array): input image from the webcam
|
299 |
|
300 |
+
# Returns:
|
301 |
+
# numpy array: processed image with keypoint detection and fitness activity classification visualized
|
302 |
+
# """
|
303 |
+
# # Pose detection model
|
304 |
+
# image.flags.writeable = False
|
305 |
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
306 |
+
# results = pose.process(image)
|
307 |
|
308 |
+
# # Draw the hand annotations on the image.
|
309 |
+
# image.flags.writeable = True
|
310 |
+
# image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
311 |
+
# self.draw_landmarks(image, results)
|
312 |
|
313 |
+
# # Prediction logic
|
314 |
+
# keypoints = self.extract_keypoints(results)
|
315 |
+
# self.sequence.append(keypoints.astype('float32',casting='same_kind'))
|
316 |
+
# self.sequence = self.sequence[-self.sequence_length:]
|
317 |
|
318 |
+
# if len(self.sequence) == self.sequence_length:
|
319 |
+
# res = model.predict(np.expand_dims(self.sequence, axis=0), verbose=0)[0]
|
320 |
+
# # interpreter.set_tensor(self.input_details[0]['index'], np.expand_dims(self.sequence, axis=0))
|
321 |
+
# # interpreter.invoke()
|
322 |
+
# # res = interpreter.get_tensor(self.output_details[0]['index'])
|
323 |
|
324 |
+
# self.current_action = self.actions[np.argmax(res)]
|
325 |
+
# confidence = np.max(res)
|
326 |
|
327 |
+
# # Erase current action variable if no probability is above threshold
|
328 |
+
# if confidence < self.threshold:
|
329 |
+
# self.current_action = ''
|
330 |
|
331 |
+
# # Viz probabilities
|
332 |
+
# image = self.prob_viz(res, image)
|
333 |
|
334 |
+
# # Count reps
|
335 |
+
# try:
|
336 |
+
# landmarks = results.pose_landmarks.landmark
|
337 |
+
# self.count_reps(
|
338 |
+
# image, landmarks, mp_pose)
|
339 |
+
# except:
|
340 |
+
# pass
|
341 |
|
342 |
+
# # Display graphical information
|
343 |
+
# cv2.rectangle(image, (0,0), (640, 40), self.colors[np.argmax(res)], -1)
|
344 |
+
# cv2.putText(image, 'curl ' + str(self.curl_counter), (3,30),
|
345 |
+
# cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
|
346 |
+
# cv2.putText(image, 'press ' + str(self.press_counter), (240,30),
|
347 |
+
# cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
|
348 |
+
# cv2.putText(image, 'squat ' + str(self.squat_counter), (490,30),
|
349 |
+
# cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
|
350 |
|
351 |
+
# # return cv2.flip(image, 1)
|
352 |
+
# return image
|
353 |
|
354 |
+
# def recv(self, frame):
|
355 |
+
# """
|
356 |
+
# Receive and process video stream from webcam
|
357 |
+
|
358 |
+
# Args:
|
359 |
+
# frame: current video frame
|
360 |
|
361 |
+
# Returns:
|
362 |
+
# av.VideoFrame: processed video frame
|
363 |
+
# """
|
364 |
+
# img = frame.to_ndarray(format="bgr24")
|
365 |
+
# img = self.process(img)
|
366 |
+
# return av.VideoFrame.from_ndarray(img, format="bgr24")
|
367 |
+
def process_uploaded_file(self, file):
|
368 |
+
"""
|
369 |
+
Function to process an uploaded image or video file and run the fitness trainer AI
|
370 |
Args:
|
371 |
+
file (BytesIO): uploaded image or video file
|
372 |
+
Returns:
|
373 |
+
numpy array: processed image with keypoint detection and fitness activity classification visualized
|
374 |
+
"""
|
375 |
+
# Initialize an empty list to store processed frames
|
376 |
+
processed_frames = []
|
377 |
+
|
378 |
+
# Check if the uploaded file is a video
|
379 |
+
is_video = hasattr(file, 'name') and file.name.endswith(('.mp4', '.avi', '.mov'))
|
380 |
|
381 |
+
if is_video:
|
382 |
+
container = av.open(file)
|
383 |
+
for frame in container.decode(video=0):
|
384 |
+
# Convert the frame to OpenCV format
|
385 |
+
image = frame.to_image().convert("RGB")
|
386 |
+
image = np.array(image)
|
387 |
+
|
388 |
+
# Process the frame
|
389 |
+
processed_frame = self.process(image)
|
390 |
+
|
391 |
+
# Append the processed frame to the list
|
392 |
+
processed_frames.append(processed_frame)
|
393 |
+
|
394 |
+
# Close the video file container
|
395 |
+
container.close()
|
396 |
+
else:
|
397 |
+
# If the uploaded file is an image
|
398 |
+
# Load the image from the BytesIO object
|
399 |
+
image = Image.open(file)
|
400 |
+
image = np.array(image)
|
401 |
+
|
402 |
+
# Process the image
|
403 |
+
processed_frame = self.process(image)
|
404 |
+
|
405 |
+
# Append the processed frame to the list
|
406 |
+
processed_frames.append(processed_frame)
|
407 |
+
|
408 |
+
return processed_frames
|
409 |
+
|
410 |
+
def recv_uploaded_file(self, file):
|
411 |
+
"""
|
412 |
+
Receive and process an uploaded video file
|
413 |
+
Args:
|
414 |
+
file (BytesIO): uploaded video file
|
415 |
Returns:
|
416 |
+
List[av.VideoFrame]: list of processed video frames
|
417 |
"""
|
418 |
+
# Process the uploaded file
|
419 |
+
processed_frames = self.process_uploaded_file(file)
|
420 |
+
|
421 |
+
# Convert processed frames to av.VideoFrame objects
|
422 |
+
av_frames = []
|
423 |
+
for frame in processed_frames:
|
424 |
+
av_frame = av.VideoFrame.from_ndarray(frame, format="bgr24")
|
425 |
+
av_frames.append(av_frame)
|
426 |
+
|
427 |
+
return av_frames
|
428 |
|
|
|
429 |
# Options
|
430 |
RTC_CONFIGURATION = RTCConfiguration(
|
431 |
{"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
|
432 |
)
|
433 |
+
|
434 |
# Streamer
|
435 |
webrtc_ctx = webrtc_streamer(
|
436 |
key="AI trainer",
|