{ "cells": [ { "cell_type": "code", "execution_count": 242, "metadata": {}, "outputs": [], "source": [ "import cv2\n", "import numpy as np\n", "import os\n", "from matplotlib import pyplot as plt\n", "import time\n", "import mediapipe as mp\n" ] }, { "cell_type": "code", "execution_count": 243, "metadata": {}, "outputs": [], "source": [ "# Pre-trained pose estimation model from Google Mediapipe\n", "mp_pose = mp.solutions.pose\n", "\n", "# Supported Mediapipe visualization tools\n", "mp_drawing = mp.solutions.drawing_utils" ] }, { "cell_type": "code", "execution_count": 244, "metadata": {}, "outputs": [], "source": [ "def mediapipe_detection(image, model):\n", " \"\"\"\n", " This function detects human pose estimation keypoints from webcam footage\n", " \n", " \"\"\"\n", " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # COLOR CONVERSION BGR 2 RGB\n", " image.flags.writeable = False # Image is no longer writeable\n", " results = model.process(image) # Make prediction\n", " image.flags.writeable = True # Image is now writeable \n", " image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # COLOR COVERSION RGB 2 BGR\n", " return image, results" ] }, { "cell_type": "code", "execution_count": 245, "metadata": {}, "outputs": [], "source": [ "def draw_landmarks(image, results):\n", " \"\"\"\n", " This function draws keypoints and landmarks detected by the human pose estimation model\n", " \n", " \"\"\"\n", " mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,\n", " mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2), \n", " mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2) \n", " )" ] }, { "cell_type": "code", "execution_count": 246, "metadata": {}, "outputs": [], "source": [ "def draw_detection(image, results):\n", "\n", " h, w, c = image.shape\n", " cx_min = w\n", " cy_min = h\n", " cx_max = cy_max = 0\n", " center = [w//2, h//2]\n", " try:\n", " for id, lm in enumerate(results.pose_landmarks.landmark):\n", " cx, cy = int(lm.x * w), int(lm.y * h)\n", " if cx < cx_min:\n", " cx_min = cx\n", " if cy < cy_min:\n", " cy_min = cy\n", " if cx > cx_max:\n", " cx_max = cx\n", " if cy > cy_max:\n", " cy_max = cy\n", " \n", " boxW, boxH = cx_max - cx_min, cy_max - cy_min\n", " \n", " # center\n", " cx, cy = cx_min + (boxW // 2), \\\n", " cy_min + (boxH // 2) \n", " center = [cx, cy]\n", " \n", " cv2.rectangle(\n", " image, (cx_min, cy_min), (cx_max, cy_max), (255, 255, 0), 2\n", " )\n", " except:\n", " pass\n", " \n", " return [[cx_min, cy_min], [cx_max, cy_max]], center" ] }, { "cell_type": "code", "execution_count": 247, "metadata": {}, "outputs": [], "source": [ "def normalize(image, results, bounding_box, landmark_names):\n", " h, w, c = image.shape\n", " if results.pose_landmarks:\n", " xy = {}\n", " xy_norm = {}\n", " i = 0\n", " for res in results.pose_landmarks.landmark:\n", " x = res.x * w\n", " y = res.y * h\n", " \n", " x_norm = (x - bounding_box[0][0]) / (bounding_box[1][0] - bounding_box[0][0])\n", " y_norm = (y - bounding_box[0][1]) / (bounding_box[1][1] - bounding_box[0][1])\n", " \n", " # xy_norm.append([x_norm, y_norm])\n", " \n", " xy_norm[landmark_names[i]] = [x_norm, y_norm]\n", " i += 1\n", " else:\n", " # xy_norm = np.zeros([0,0] * 33)\n", " \n", " # xy = {landmark_names: [0,0]}\n", " # xy_norm = {landmark_names: [0,0]}\n", " \n", " xy_norm = dict(zip(landmark_names, [0,0] * 33))\n", " \n", " return xy_norm" ] }, { "cell_type": "code", "execution_count": 248, "metadata": {}, "outputs": [], "source": [ "def get_coordinates(landmarks, mp_pose, side, joint):\n", " \"\"\"\n", " Retrieves x and y coordinates of a particular keypoint from the pose estimation model\n", " \n", " Args:\n", " landmarks: processed keypoints from the pose estimation model\n", " mp_pose: Mediapipe pose estimation model\n", " side: 'left' or 'right'. Denotes the side of the body of the landmark of interest.\n", " joint: 'shoulder', 'elbow', 'wrist', 'hip', 'knee', or 'ankle'. Denotes which body joint is associated with the landmark of interest.\n", " \n", " \"\"\"\n", " coord = getattr(mp_pose.PoseLandmark,side.upper()+\"_\"+joint.upper())\n", " x_coord_val = landmarks[coord.value].x\n", " y_coord_val = landmarks[coord.value].y\n", " return [x_coord_val, y_coord_val] " ] }, { "cell_type": "code", "execution_count": 249, "metadata": {}, "outputs": [], "source": [ "def viz_coords(image, norm_coords, landmarks, mp_pose, side, joint):\n", " \"\"\"\n", " Displays the joint angle value near the joint within the image frame\n", " \n", " \"\"\"\n", " try:\n", " point = side.upper()+\"_\"+joint.upper()\n", " norm_coords = norm_coords[point]\n", " joint = get_coordinates(landmarks, mp_pose, side, joint)\n", " \n", " coords = [ '%.2f' % elem for elem in joint ]\n", " coords = ' '.join(str(coords))\n", " norm_coords = [ '%.2f' % elem for elem in norm_coords ]\n", " norm_coords = ' '.join(str(norm_coords))\n", " cv2.putText(image, coords, \n", " tuple(np.multiply(joint, [640, 480]).astype(int)), \n", " cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA\n", " )\n", " cv2.putText(image, norm_coords, \n", " tuple(np.multiply(joint, [640, 480]).astype(int) + 20), \n", " cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255), 2, cv2.LINE_AA\n", " )\n", " except:\n", " pass\n", " return" ] }, { "cell_type": "code", "execution_count": 250, "metadata": {}, "outputs": [], "source": [ "cap = cv2.VideoCapture(0) # camera object\n", "HEIGHT = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # webcam video frame height\n", "WIDTH = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # webcam video frame width\n", "FPS = int(cap.get(cv2.CAP_PROP_FPS)) # webcam video fram rate \n", "\n", "landmark_names = dir(mp_pose.PoseLandmark)[:-4]\n", "\n", "# Set and test mediapipe model using webcam\n", "with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5, enable_segmentation=True) as pose:\n", " while cap.isOpened():\n", "\n", " # Read feed\n", " ret, frame = cap.read()\n", " \n", " # Make detection\n", " image, results = mediapipe_detection(frame, pose)\n", " \n", " # Extract landmarks\n", " try:\n", " landmarks = results.pose_landmarks.landmark\n", " except:\n", " pass\n", " \n", " # draw bounding box\n", " bounding_box, box_center = draw_detection(image, results)\n", " \n", " # Render detections\n", " draw_landmarks(image, results) \n", " \n", " # normalize coordinates\n", " xy_norm = normalize(image, results, bounding_box, landmark_names) \n", " viz_coords(image, xy_norm, landmarks, mp_pose, 'left', 'wrist') \n", " viz_coords(image, xy_norm, landmarks, mp_pose, 'right', 'wrist') \n", " \n", " # Display frame on screen\n", " cv2.imshow('OpenCV Feed', image)\n", " \n", " # Draw segmentation on the image.\n", " # To improve segmentation around boundaries, consider applying a joint\n", " # bilateral filter to \"results.segmentation_mask\" with \"image\".\n", " # tightness = 0.3 # Probability threshold in [0, 1] that says how \"tight\" to make the segmentation. Greater value => tighter.\n", " # condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > tightness\n", " # bg_image = np.zeros(image.shape, dtype=np.uint8)\n", " # bg_image[:] = (192, 192, 192) # gray\n", " # image = np.where(condition, image, bg_image)\n", " \n", " # Exit / break out logic\n", " if cv2.waitKey(10) & 0xFF == ord('q'):\n", " break\n", "\n", " cap.release()\n", " cv2.destroyAllWindows()" ] }, { "cell_type": "code", "execution_count": 251, "metadata": {}, "outputs": [], "source": [ "cap.release()\n", "cv2.destroyAllWindows()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.13 ('AItrainer')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "80aa1d3f3a8cfb37a38c47373cc49a39149184c5fa770d709389b1b8782c1d85" } } }, "nbformat": 4, "nbformat_minor": 2 }