|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
import torchvision |
|
from torchvision import transforms, models |
|
from torchvision.transforms import v2 |
|
|
|
from torchvision.models import video as ptv |
|
import transformers |
|
|
|
from transformers import VivitImageProcessor |
|
from transformers import set_seed |
|
|
|
import datasets |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
from accelerate import Accelerator, notebook_launcher |
|
|
|
import decord |
|
from decord.bridge import set_bridge |
|
decord.bridge.set_bridge("torch") |
|
from decord import VideoReader |
|
|
|
import os |
|
import PIL |
|
import gc |
|
import pandas as pd |
|
import numpy as np |
|
from torch.nn import Linear, Softmax |
|
import gradio as gr |
|
import cv2 |
|
import io |
|
import tempfile |
|
|
|
import mediapipe as mp |
|
from mediapipe.tasks import python |
|
from mediapipe.tasks.python import vision |
|
from mediapipe import solutions |
|
from mediapipe.framework.formats import landmark_pb2 |
|
|
|
CLIP_LENGTH = 32 |
|
FRAME_STEPS = 4 |
|
CLIP_SIZE = 224 |
|
BATCH_SIZE = 1 |
|
SEED = 42 |
|
CLASSES = ['afternoon', 'animal', 'bad', 'beautiful', 'big', 'bird', 'blind', 'cat', 'cheap', 'clothing', 'cold', 'cow', 'curved', 'deaf', 'dog', 'dress', 'dry', 'evening', |
|
'expensive', 'famous', 'fast', 'female', 'fish', 'flat', 'friday', 'good', 'happy', 'hat', 'healthy', 'horse', 'hot', 'hour', 'light', 'long', 'loose', 'loud', |
|
'minute', 'monday','month', 'morning', 'mouse', 'narrow', 'new', 'night', 'old', 'pant', 'pocket', 'quiet', 'sad', 'saturday', 'second', 'shirt', 'shoes', 'short', |
|
'sick', 'skirt', 'slow', 'small', 'suit', 'sunday', 't_shirt', 'tall', 'thursday', 'time', 'today', 'tomorrow', 'tuesday', 'ugly', 'warm', 'wednesday', 'week', 'wet', |
|
'wide', 'year', 'yesterday', 'young'] |
|
|
|
idx_to_label = {0: 'afternoon', 1: 'animal', 2: 'bad', 3: 'beautiful', 4: 'big', 5: 'bird', 6: 'blind', 7: 'cat', 8: 'cheap', 9: 'clothing', 10: 'cold', 11: 'cow', |
|
12: 'curved', 13: 'deaf', 14: 'dog', 15: 'dress', 16: 'dry', 17: 'evening', 18: 'expensive', 19: 'famous', 20: 'fast', 21: 'female', 22: 'fish', |
|
23: 'flat', 24: 'friday', 25: 'good', 26: 'happy', 27: 'hat', 28: 'healthy', 29: 'horse', 30: 'hot', 31: 'hour', 32: 'light', 33: 'long', 34: 'loose', |
|
35: 'loud', 36: 'minute', 37: 'monday', 38: 'month', 39: 'morning', 40: 'mouse', 41: 'narrow', 42: 'new', 43: 'night', 44: 'old', 45: 'pant', |
|
46: 'pocket', 47: 'quiet', 48: 'sad', 49: 'saturday', 50: 'second', 51: 'shirt', 52: 'shoes', 53: 'short', 54: 'sick', 55: 'skirt', 56: 'slow', |
|
57: 'small', 58: 'suit', 59: 'sunday', 60: 't_shirt', 61: 'tall', 62: 'thursday', 63: 'time', 64: 'today', 65: 'tomorrow', 66: 'tuesday', 67: 'ugly', |
|
68: 'warm', 69: 'wednesday', 70: 'week', 71: 'wet', 72: 'wide', 73: 'year', 74: 'yesterday', 75: 'young'} |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
IMAGE_PROCESSOR = 'google/vivit-b-16x2' |
|
WEIGHTS = 'KINETICS400_V1' |
|
|
|
model_path = 'models/swin_tiny_ISL_pt_loss035.pt' |
|
data_path = 'signs' |
|
|
|
|
|
custom_css = """ |
|
#landmarked_video { |
|
max-height: 300px; |
|
max-width: 600px; |
|
object-fit: fill; |
|
width: 100%; |
|
height: 100%; |
|
} |
|
""" |
|
|
|
|
|
mp_drawing = mp.solutions.drawing_utils |
|
mp_drawing_styles = mp.solutions.drawing_styles |
|
mp_hands = mp.solutions.hands |
|
mp_face = mp.solutions.face_mesh |
|
mp_pose = mp.solutions.pose |
|
mp_holistic = mp.solutions.holistic |
|
hand_model_path = 'hand_landmarker.task' |
|
pose_model_path = 'pose_landmarker.task' |
|
|
|
BaseOptions = mp.tasks.BaseOptions |
|
HandLandmarker = mp.tasks.vision.HandLandmarker |
|
HandLandmarkerOptions = mp.tasks.vision.HandLandmarkerOptions |
|
PoseLandmarker = mp.tasks.vision.PoseLandmarker |
|
PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions |
|
VisionRunningMode = mp.tasks.vision.RunningMode |
|
|
|
|
|
options_hand = HandLandmarkerOptions( |
|
base_options=BaseOptions(model_asset_path = hand_model_path), |
|
running_mode=VisionRunningMode.VIDEO) |
|
|
|
|
|
options_pose = PoseLandmarkerOptions( |
|
base_options=BaseOptions(model_asset_path=pose_model_path), |
|
running_mode=VisionRunningMode.VIDEO) |
|
|
|
detector_hand = vision.HandLandmarker.create_from_options(options_hand) |
|
detector_pose = vision.PoseLandmarker.create_from_options(options_pose) |
|
|
|
holistic = mp_holistic.Holistic( |
|
static_image_mode=False, |
|
model_complexity=1, |
|
smooth_landmarks=True, |
|
enable_segmentation=False, |
|
refine_face_landmarks=True, |
|
min_detection_confidence=0.5, |
|
min_tracking_confidence=0.5 |
|
) |
|
|
|
|
|
class CreateDatasetProd(): |
|
def __init__(self |
|
, clip_len |
|
, clip_size |
|
, frame_step |
|
, image_processor |
|
): |
|
super().__init__() |
|
self.clip_len = clip_len |
|
self.clip_size = clip_size |
|
self.frame_step = frame_step |
|
self.image_processor = image_processor |
|
|
|
self.transform_prod = v2.Compose([ |
|
v2.ToImage(), |
|
v2.Resize((self.clip_size, self.clip_size)), |
|
v2.ToDtype(torch.uint8, scale=False) |
|
]) |
|
|
|
def read_video(self, video_path): |
|
|
|
vr = VideoReader(video_path) |
|
total_frames = len(vr) |
|
|
|
|
|
if total_frames < self.clip_len: |
|
key_indices = list(range(total_frames)) |
|
for _ in range(self.clip_len - len(key_indices)): |
|
key_indices.append(key_indices[-1]) |
|
else: |
|
key_indices = list(range(0, total_frames, max(1, total_frames // self.clip_len)))[:self.clip_len] |
|
|
|
|
|
frames = vr.get_batch(key_indices) |
|
del vr |
|
|
|
gc.collect() |
|
|
|
return frames |
|
|
|
def add_landmarks(self, video): |
|
annotated_image = [] |
|
for frame in video: |
|
|
|
image = frame.permute(1, 2, 0).numpy() |
|
|
|
results = holistic.process(image) |
|
|
|
mp_drawing.draw_landmarks( |
|
image, |
|
results.left_hand_landmarks, |
|
mp_hands.HAND_CONNECTIONS, |
|
landmark_drawing_spec = mp_drawing_styles.get_default_hand_landmarks_style(), |
|
connection_drawing_spec = mp_drawing_styles.get_default_hand_connections_style() |
|
) |
|
mp_drawing.draw_landmarks( |
|
image, |
|
results.right_hand_landmarks, |
|
mp_hands.HAND_CONNECTIONS, |
|
landmark_drawing_spec = mp_drawing_styles.get_default_hand_landmarks_style(), |
|
connection_drawing_spec = mp_drawing_styles.get_default_hand_connections_style() |
|
) |
|
mp_drawing.draw_landmarks( |
|
image, |
|
results.pose_landmarks, |
|
mp_holistic.POSE_CONNECTIONS, |
|
landmark_drawing_spec = mp_drawing_styles.get_default_pose_landmarks_style(), |
|
|
|
) |
|
|
|
annotated_image.append(torch.from_numpy(image)) |
|
|
|
del image, results |
|
|
|
gc.collect() |
|
|
|
return torch.stack(annotated_image) |
|
|
|
def create_dataset(self, video_paths): |
|
|
|
video = self.read_video(video_paths) |
|
video = torch.from_numpy(video.asnumpy()) |
|
|
|
video = v2.functional.resize(video.permute(0, 3, 1, 2), size=(self.clip_size*2, self.clip_size*3)) |
|
video = self.add_landmarks(video) |
|
|
|
video = self.transform_prod(video.permute(0, 3, 1, 2)) |
|
video = self.image_processor(list(video), return_tensors='pt', input_data_format='channels_first') |
|
pixel_values = video['pixel_values'].squeeze(0) |
|
|
|
|
|
|
|
del video |
|
gc.collect() |
|
|
|
return pixel_values |
|
|
|
|
|
image_processor = VivitImageProcessor.from_pretrained(IMAGE_PROCESSOR, attn_implementation="sdpa", torch_dtype=torch.float16) |
|
dataset_prod_obj = CreateDatasetProd(CLIP_LENGTH, CLIP_SIZE, FRAME_STEPS, image_processor) |
|
|
|
|
|
class SwinTClassifications(nn.Module): |
|
def __init__(self,classes, weights="DEFAULT"): |
|
super().__init__() |
|
self.classes = classes |
|
self.weights = weights |
|
|
|
self.base_model = ptv.swin3d_t(weights=self.weights) |
|
self.classification_head = nn.Sequential(torch.nn.Linear(self.base_model.head.in_features , len(self.classes)),) |
|
|
|
|
|
self.base_model.head = nn.Identity() |
|
|
|
def forward(self, x): |
|
x = self.base_model(x) |
|
x = self.classification_head(x) |
|
return x |
|
|
|
|
|
recon_model = SwinTClassifications(classes=CLASSES, weights=WEIGHTS).to(device) |
|
recon_model.load_state_dict(torch.load(model_path, weights_only=True, map_location=device)) |
|
recon_model.eval() |
|
print("Entire model loaded successfully!") |
|
|
|
|
|
def prod_function(model_pretrained, prod_ds): |
|
|
|
accelerator = Accelerator() |
|
|
|
|
|
set_seed(SEED) |
|
|
|
|
|
accelerated_model, acclerated_prod_ds = accelerator.prepare(model_pretrained, prod_ds) |
|
|
|
|
|
accelerated_model.eval() |
|
|
|
with torch.no_grad(): |
|
outputs = accelerated_model(acclerated_prod_ds.permute(1, 0, 2, 3).unsqueeze(0)) |
|
|
|
prod_softmax = torch.nn.functional.softmax(outputs, dim=-1) |
|
prod_pred = prod_softmax.argmax(-1) |
|
|
|
return prod_pred, acclerated_prod_ds |
|
|
|
|
|
def save_video_to_mp4(video_tensor, fps=10): |
|
|
|
video_numpy = video_tensor.permute(0, 2, 3, 1).cpu().numpy() |
|
|
|
if video_numpy.max() <= 1.0: |
|
video_numpy = (((video_numpy + 1) / 2) * 255).astype(np.uint8) |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") |
|
output_path = temp_file.name |
|
|
|
|
|
|
|
|
|
height, width, channels = video_numpy[0].shape |
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
|
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
|
|
|
|
for frame in video_numpy: |
|
|
|
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
out.write(frame_bgr) |
|
|
|
out.release() |
|
|
|
|
|
return output_path |
|
|
|
|
|
def list_videos(): |
|
if os.path.exists(data_path): |
|
video_lst = [f for f in os.listdir(data_path) if f.endswith((".mp4", ".mov", ".MOV", ".webm", ".avi"))] |
|
return video_lst |
|
|
|
|
|
def play_video(selected_video): |
|
return os.path.join(data_path, selected_video) if selected_video else None |
|
|
|
|
|
|
|
def translate_sign_language(gesture): |
|
|
|
prod_ds = dataset_prod_obj.create_dataset(gesture) |
|
prod_video_path = save_video_to_mp4(prod_ds) |
|
|
|
|
|
|
|
predicted_prod_label, video_tensor = prod_function(recon_model, prod_ds) |
|
|
|
|
|
predicted_prod_label = predicted_prod_label.squeeze(0) |
|
gesture_translation = idx_to_label[predicted_prod_label.cpu().numpy().item()] |
|
|
|
return gesture_translation, prod_video_path |
|
|
|
|
|
def load_about_md(): |
|
with open("about.md", "r") as file: |
|
about_content = file.read() |
|
return about_content |
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
gr.Markdown("# Indian Sign Language Translation App") |
|
|
|
|
|
with gr.Tab("About the App"): |
|
gr.Markdown(load_about_md()) |
|
|
|
|
|
with gr.Tab("Gesture recognition"): |
|
with gr.Row(): |
|
with gr.Column(scale=0.9, variant="panel"): |
|
with gr.Row(height=350, variant="panel"): |
|
|
|
video_input = gr.Video(sources=["webcam"], format="mp4", label="Gesture") |
|
with gr.Row(variant="panel"): |
|
|
|
video_button = gr.Button("Submit") |
|
|
|
text_output = gr.Textbox(label="Translation in English") |
|
|
|
|
|
with gr.Column(scale=1, variant="panel"): |
|
with gr.Row(): |
|
|
|
video_output = gr.Video(interactive=False, autoplay=True, |
|
streaming=False, label="Landmarked Gesture" |
|
|
|
) |
|
|
|
video_button.click(translate_sign_language, inputs=video_input, outputs=[text_output, video_output]) |
|
|
|
|
|
with gr.Tab("Indian Sign Language gesture reference"): |
|
with gr.Row(height=500, variant="panel", equal_height=False, show_progress=True): |
|
with gr.Column(scale=1, variant="panel"): |
|
video_dropdown = gr.Dropdown(choices=list_videos(), label="ISL gestures", info="More gestures comming soon!") |
|
search_button = gr.Button("Search Gesture") |
|
with gr.Column(scale=1, variant="panel"): |
|
search_output = gr.Video(streaming=False, label="ISL gestures Video") |
|
|
|
search_button.click(play_video, inputs=video_dropdown, outputs=search_output) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|