{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 242,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import numpy as np\n",
    "import os\n",
    "from matplotlib import pyplot as plt\n",
    "import time\n",
    "import mediapipe as mp\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 243,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pre-trained pose estimation model from Google Mediapipe\n",
    "mp_pose = mp.solutions.pose\n",
    "\n",
    "# Supported Mediapipe visualization tools\n",
    "mp_drawing = mp.solutions.drawing_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 244,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mediapipe_detection(image, model):\n",
    "    \"\"\"\n",
    "    This function detects human pose estimation keypoints from webcam footage\n",
    "    \n",
    "    \"\"\"\n",
    "    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # COLOR CONVERSION BGR 2 RGB\n",
    "    image.flags.writeable = False                  # Image is no longer writeable\n",
    "    results = model.process(image)                 # Make prediction\n",
    "    image.flags.writeable = True                   # Image is now writeable \n",
    "    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # COLOR COVERSION RGB 2 BGR\n",
    "    return image, results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 245,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_landmarks(image, results):\n",
    "    \"\"\"\n",
    "    This function draws keypoints and landmarks detected by the human pose estimation model\n",
    "    \n",
    "    \"\"\"\n",
    "    mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,\n",
    "                                mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2), \n",
    "                                mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2) \n",
    "                                 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 246,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_detection(image, results):\n",
    "\n",
    "    h, w, c = image.shape\n",
    "    cx_min = w\n",
    "    cy_min = h\n",
    "    cx_max = cy_max = 0\n",
    "    center = [w//2, h//2]\n",
    "    try:\n",
    "        for id, lm in enumerate(results.pose_landmarks.landmark):\n",
    "            cx, cy = int(lm.x * w), int(lm.y * h)\n",
    "            if cx < cx_min:\n",
    "                cx_min = cx\n",
    "            if cy < cy_min:\n",
    "                cy_min = cy\n",
    "            if cx > cx_max:\n",
    "                cx_max = cx\n",
    "            if cy > cy_max:\n",
    "                cy_max = cy\n",
    "                \n",
    "        boxW, boxH = cx_max - cx_min, cy_max - cy_min\n",
    "        \n",
    "        # center\n",
    "        cx, cy = cx_min + (boxW // 2), \\\n",
    "                 cy_min + (boxH // 2)      \n",
    "        center = [cx, cy]\n",
    "                \n",
    "        cv2.rectangle(\n",
    "            image, (cx_min, cy_min), (cx_max, cy_max), (255, 255, 0), 2\n",
    "        )\n",
    "    except:\n",
    "        pass\n",
    "    \n",
    "    return [[cx_min, cy_min], [cx_max, cy_max]], center"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 247,
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize(image, results, bounding_box, landmark_names):\n",
    "    h, w, c = image.shape\n",
    "    if results.pose_landmarks:\n",
    "        xy = {}\n",
    "        xy_norm = {}\n",
    "        i = 0\n",
    "        for res in results.pose_landmarks.landmark:\n",
    "            x = res.x * w\n",
    "            y = res.y * h\n",
    "            \n",
    "            x_norm = (x - bounding_box[0][0]) / (bounding_box[1][0] - bounding_box[0][0])\n",
    "            y_norm = (y - bounding_box[0][1]) / (bounding_box[1][1] - bounding_box[0][1])\n",
    "            \n",
    "            # xy_norm.append([x_norm, y_norm])\n",
    "            \n",
    "            xy_norm[landmark_names[i]] = [x_norm, y_norm]\n",
    "            i += 1\n",
    "    else:\n",
    "        # xy_norm = np.zeros([0,0] * 33)\n",
    "        \n",
    "        # xy = {landmark_names: [0,0]}\n",
    "        # xy_norm = {landmark_names: [0,0]}\n",
    "        \n",
    "        xy_norm = dict(zip(landmark_names, [0,0] * 33))\n",
    "    \n",
    "    return xy_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 248,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_coordinates(landmarks, mp_pose, side, joint):\n",
    "    \"\"\"\n",
    "    Retrieves x and y coordinates of a particular keypoint from the pose estimation model\n",
    "         \n",
    "     Args:\n",
    "         landmarks: processed keypoints from the pose estimation model\n",
    "         mp_pose: Mediapipe pose estimation model\n",
    "         side: 'left' or 'right'. Denotes the side of the body of the landmark of interest.\n",
    "         joint: 'shoulder', 'elbow', 'wrist', 'hip', 'knee', or 'ankle'. Denotes which body joint is associated with the landmark of interest.\n",
    "    \n",
    "    \"\"\"\n",
    "    coord = getattr(mp_pose.PoseLandmark,side.upper()+\"_\"+joint.upper())\n",
    "    x_coord_val = landmarks[coord.value].x\n",
    "    y_coord_val = landmarks[coord.value].y\n",
    "    return [x_coord_val, y_coord_val]    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 249,
   "metadata": {},
   "outputs": [],
   "source": [
    "def viz_coords(image, norm_coords, landmarks, mp_pose, side, joint):\n",
    "    \"\"\"\n",
    "    Displays the joint angle value near the joint within the image frame\n",
    "    \n",
    "    \"\"\"\n",
    "    try:\n",
    "          point = side.upper()+\"_\"+joint.upper()\n",
    "          norm_coords = norm_coords[point]\n",
    "          joint = get_coordinates(landmarks, mp_pose, side, joint)\n",
    "          \n",
    "          coords = [ '%.2f' % elem for elem in joint ]\n",
    "          coords = ' '.join(str(coords))\n",
    "          norm_coords = [ '%.2f' % elem for elem in norm_coords ]\n",
    "          norm_coords = ' '.join(str(norm_coords))\n",
    "          cv2.putText(image, coords, \n",
    "                         tuple(np.multiply(joint, [640, 480]).astype(int)), \n",
    "                         cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA\n",
    "                              )\n",
    "          cv2.putText(image, norm_coords, \n",
    "                         tuple(np.multiply(joint, [640, 480]).astype(int) + 20), \n",
    "                         cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255), 2, cv2.LINE_AA\n",
    "                              )\n",
    "    except:\n",
    "         pass\n",
    "    return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 250,
   "metadata": {},
   "outputs": [],
   "source": [
    "cap = cv2.VideoCapture(0) # camera object\n",
    "HEIGHT = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # webcam video frame height\n",
    "WIDTH = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # webcam video frame width\n",
    "FPS = int(cap.get(cv2.CAP_PROP_FPS)) # webcam video fram rate \n",
    "\n",
    "landmark_names = dir(mp_pose.PoseLandmark)[:-4]\n",
    "\n",
    "# Set and test mediapipe model using webcam\n",
    "with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5, enable_segmentation=True) as pose:\n",
    "    while cap.isOpened():\n",
    "\n",
    "        # Read feed\n",
    "        ret, frame = cap.read()\n",
    "      \n",
    "        # Make detection\n",
    "        image, results = mediapipe_detection(frame, pose)\n",
    "        \n",
    "        # Extract landmarks\n",
    "        try:\n",
    "            landmarks = results.pose_landmarks.landmark\n",
    "        except:\n",
    "            pass\n",
    "        \n",
    "        # draw bounding box\n",
    "        bounding_box, box_center = draw_detection(image, results)\n",
    "           \n",
    "        # Render detections\n",
    "        draw_landmarks(image, results)  \n",
    "        \n",
    "        # normalize coordinates\n",
    "        xy_norm = normalize(image, results, bounding_box, landmark_names) \n",
    "        viz_coords(image, xy_norm, landmarks, mp_pose, 'left', 'wrist')   \n",
    "        viz_coords(image, xy_norm, landmarks, mp_pose, 'right', 'wrist')         \n",
    "        \n",
    "        # Display frame on screen\n",
    "        cv2.imshow('OpenCV Feed', image)\n",
    "        \n",
    "        # Draw segmentation on the image.\n",
    "        # To improve segmentation around boundaries, consider applying a joint\n",
    "        # bilateral filter to \"results.segmentation_mask\" with \"image\".\n",
    "        # tightness = 0.3 # Probability threshold in [0, 1] that says how \"tight\" to make the segmentation. Greater value => tighter.\n",
    "        # condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > tightness\n",
    "        # bg_image = np.zeros(image.shape, dtype=np.uint8)\n",
    "        # bg_image[:] = (192, 192, 192) # gray\n",
    "        # image = np.where(condition, image, bg_image)\n",
    "        \n",
    "        # Exit / break out logic\n",
    "        if cv2.waitKey(10) & 0xFF == ord('q'):\n",
    "            break\n",
    "\n",
    "    cap.release()\n",
    "    cv2.destroyAllWindows()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 251,
   "metadata": {},
   "outputs": [],
   "source": [
    "cap.release()\n",
    "cv2.destroyAllWindows()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('AItrainer')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "80aa1d3f3a8cfb37a38c47373cc49a39149184c5fa770d709389b1b8782c1d85"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}