import torch |
from torchvision import transforms |
from torchvision.transforms import v2 |
import transformers |
from transformers import VivitImageProcessor, VivitConfig, VivitModel, VivitForVideoClassification |
from transformers import set_seed |
import datasets |
from torch.utils.data import Dataset, DataLoader |
from accelerate import Accelerator, notebook_launcher |
import decord |
from decord.bridge import set_bridge |
decord.bridge.set_bridge("torch") |
from decord import VideoReader |
import os |
import PIL |
import gc |
import pandas as pd |
import numpy as np |
from torch.nn import Linear, Softmax |
import gradio as gr |
import cv2 |
import io |
import tempfile |
import mediapipe as mp |
from mediapipe.tasks import python |
from mediapipe.tasks.python import vision |
from mediapipe import solutions |
from mediapipe.framework.formats import landmark_pb2 |
CLIP_SIZE = 224 |
SEED = 42 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
MODEL_TRANSFORMER = 'google/vivit-b-16x2' |
model_path_2_pytorch = 'models/vivit_ISL_pt_6_76classes_loss051.pt' |
data_path = 'signs' |
custom_css = """ |
#landmarked_video { |
max-height: 300px; |
max-width: 600px; |
object-fit: fill; |
width: 100%; |
height: 100%; |
} |
""" |
mp_drawing = mp.solutions.drawing_utils |
mp_drawing_styles = mp.solutions.drawing_styles |
mp_hands = mp.solutions.hands |
mp_face = mp.solutions.face_mesh |
mp_pose = mp.solutions.pose |
mp_holistic = mp.solutions.holistic |
hand_model_path = 'hand_landmarker.task' |
pose_model_path = 'pose_landmarker.task' |
BaseOptions = mp.tasks.BaseOptions |
HandLandmarker = mp.tasks.vision.HandLandmarker |
HandLandmarkerOptions = mp.tasks.vision.HandLandmarkerOptions |
PoseLandmarker = mp.tasks.vision.PoseLandmarker |
PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions |
VisionRunningMode = mp.tasks.vision.RunningMode |
options_hand = HandLandmarkerOptions( |
base_options=BaseOptions(model_asset_path = hand_model_path), |
running_mode=VisionRunningMode.VIDEO) |
options_pose = PoseLandmarkerOptions( |
base_options=BaseOptions(model_asset_path=pose_model_path), |
running_mode=VisionRunningMode.VIDEO) |
detector_hand = vision.HandLandmarker.create_from_options(options_hand) |
detector_pose = vision.PoseLandmarker.create_from_options(options_pose) |
holistic = mp_holistic.Holistic( |
static_image_mode=False, |
model_complexity=1, |
smooth_landmarks=True, |
enable_segmentation=False, |
refine_face_landmarks=True, |
min_detection_confidence=0.5, |
min_tracking_confidence=0.5 |
) |
class CreateDatasetProd(): |
def __init__(self |
, clip_len |
, clip_size |
, frame_step |
): |
super().__init__() |
self.clip_len = clip_len |
self.clip_size = clip_size |
self.frame_step = frame_step |
self.transform_prod = v2.Compose([ |
v2.ToImage(), |
v2.Resize((self.clip_size, self.clip_size)), |
v2.ToDtype(torch.float32, scale=True) |
]) |
def read_video(self, video_path): |
vr = VideoReader(video_path) |
total_frames = len(vr) |
if total_frames < self.clip_len: |
key_indices = list(range(total_frames)) |
for _ in range(self.clip_len - len(key_indices)): |
key_indices.append(key_indices[-1]) |
else: |
key_indices = list(range(0, total_frames, max(1, total_frames // self.clip_len)))[:self.clip_len] |
frames = vr.get_batch(key_indices) |
del vr |
gc.collect() |
return frames |
def add_landmarks(self, video): |
annotated_image = [] |
for frame in video: |
image = frame.permute(1, 2, 0).numpy() |
results = holistic.process(image) |
mp_drawing.draw_landmarks( |
image, |
results.left_hand_landmarks, |
landmark_drawing_spec = mp_drawing_styles.get_default_hand_landmarks_style(), |
connection_drawing_spec = mp_drawing_styles.get_default_hand_connections_style() |
) |
mp_drawing.draw_landmarks( |
image, |
results.right_hand_landmarks, |
landmark_drawing_spec = mp_drawing_styles.get_default_hand_landmarks_style(), |
connection_drawing_spec = mp_drawing_styles.get_default_hand_connections_style() |
) |
mp_drawing.draw_landmarks( |
image, |
results.pose_landmarks, |
mp_holistic.POSE_CONNECTIONS, |
landmark_drawing_spec = mp_drawing_styles.get_default_pose_landmarks_style(), |
) |
annotated_image.append(torch.from_numpy(image)) |
del image, results |
gc.collect() |
return torch.stack(annotated_image) |
def create_dataset(self, video_paths): |
video = self.read_video(video_paths) |
video = torch.from_numpy(video.asnumpy()) |
video = v2.functional.resize(video.permute(0, 3, 1, 2), size=(self.clip_size*2, self.clip_size*3)) |
video = self.add_landmarks(video) |
video = self.transform_prod(video.permute(0, 3, 1, 2)) |
pixel_values = video.to(device) |
del video |
gc.collect() |
return pixel_values |
dataset_prod_obj = CreateDatasetProd(CLIP_LENGTH, CLIP_SIZE, FRAME_STEPS) |
class SignClassificationModel(torch.nn.Module): |
def __init__(self, model_name, idx_to_label, label_to_idx, classes_len): |
super(SignClassificationModel, self).__init__() |
self.config = VivitConfig.from_pretrained(model_name, id2label=idx_to_label, |
label2id=label_to_idx, hidden_dropout_prob=hyperparameters['dropout_rate'], |
attention_probs_dropout_prob=hyperparameters['dropout_rate'], |
return_dict=True) |
self.backbone = VivitModel.from_pretrained(model_name, config=self.config) |
self.ff_head = Linear(self.backbone.config.hidden_size, classes_len) |
def forward(self, images): |
x = self.backbone(images).last_hidden_state |
self.backbone.gradient_checkpointing_enable() |
reduced_tensor = x.mean(dim=1) |
reduced_tensor = self.ff_head(reduced_tensor) |
return reduced_tensor |
model_pretrained_2 = torch.load(model_path_2_pytorch, map_location=device, weights_only=False) |
def prod_function(model_pretrained, prod_ds): |
accelerator = Accelerator() |
if accelerator.is_main_process: |
datasets.utils.logging.set_verbosity_warning() |
transformers.utils.logging.set_verbosity_info() |
else: |
datasets.utils.logging.set_verbosity_error() |
transformers.utils.logging.set_verbosity_error() |
set_seed(SEED) |
accelerated_model, acclerated_prod_ds = accelerator.prepare(model_pretrained, prod_ds) |
accelerated_model.eval() |
with torch.no_grad(): |
outputs = accelerated_model(acclerated_prod_ds.unsqueeze(0)) |
prod_logits = outputs.logits |
prod_softmax = torch.nn.functional.softmax(prod_logits, dim=-1) |
prod_pred = prod_softmax.argmax(-1) |
return prod_pred |
def save_video_to_mp4(video_tensor, fps=10): |
video_numpy = video_tensor.permute(0, 2, 3, 1).cpu().numpy() |
if video_numpy.max() <= 1.0: |
video_numpy = (video_numpy * 255).astype(np.uint8) |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") |
output_path = temp_file.name |
height, width, channels = video_numpy[0].shape |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
for frame in video_numpy: |
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
out.write(frame_bgr) |
out.release() |
return output_path |
def list_videos(): |
if os.path.exists(data_path): |
video_lst = [f for f in os.listdir(data_path) if f.endswith((".mp4", ".mov", ".MOV", ".webm", ".avi"))] |
return video_lst |
def play_video(selected_video): |
return os.path.join(data_path, selected_video) if selected_video else None |
def translate_sign_language(gesture): |
prod_ds = dataset_prod_obj.create_dataset(gesture) |
prod_video_path = save_video_to_mp4(prod_ds) |
predicted_prod_label = prod_function(model_pretrained_2, prod_ds) |
predicted_prod_label = predicted_prod_label.squeeze(0) |
idx_to_label = model_pretrained_2.config.id2label |
gesture_translation = idx_to_label[predicted_prod_label.cpu().numpy().item()] |
return gesture_translation , prod_video_path |
def load_about_md(): |
with open("about.md", "r") as file: |
about_content = file.read() |
return about_content |
with gr.Blocks(css=custom_css) as demo: |
gr.Markdown("# Indian Sign Language Translation App") |
with gr.Tab("About the App"): |
gr.Markdown(load_about_md()) |
with gr.Tab("Gesture recognition"): |
with gr.Row(): |
with gr.Column(scale=0.9, variant="panel"): |
with gr.Row(height=350, variant="panel"): |
video_input = gr.Video(sources=["webcam"], format="mp4", label="Gesture") |
with gr.Row(variant="panel"): |
video_button = gr.Button("Submit") |
text_output = gr.Textbox(label="Translation in English") |
with gr.Column(scale=1, variant="panel"): |
with gr.Row(): |
video_output = gr.Video(interactive=False, autoplay=True, |
streaming=False, label="Landmarked Gesture" |
) |
video_button.click(translate_sign_language, inputs=video_input, outputs=[text_output, video_output]) |
with gr.Tab("Indian Sign Language gesture reference"): |
with gr.Row(height=500, variant="panel", equal_height=False, show_progress=True): |
with gr.Column(scale=1, variant="panel"): |
video_dropdown = gr.Dropdown(choices=list_videos(), label="ISL gestures", info="More gestures comming soon!") |
search_button = gr.Button("Search Gesture") |
with gr.Column(scale=1, variant="panel"): |
search_output = gr.Video(streaming=False, label="ISL gestures Video") |
search_button.click(play_video, inputs=video_dropdown, outputs=search_output) |
if __name__ == "__main__": |
demo.launch() |