Commit
·
b41fa6a
1
Parent(s):
75d7cea
Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,7 @@ from torchvision import transforms
|
|
5 |
from torchvision.transforms import v2
|
6 |
# For ML Model
|
7 |
import transformers
|
8 |
-
from transformers import VivitImageProcessor, VivitConfig, VivitModel
|
9 |
from transformers import set_seed
|
10 |
# For Data Loaders
|
11 |
import datasets
|
@@ -47,7 +47,8 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
47 |
MODEL_TRANSFORMER = 'google/vivit-b-16x2'
|
48 |
# Set Paths
|
49 |
#model_path = 'vivit_pytorch_loss051.pt'
|
50 |
-
|
|
|
51 |
data_path = 'signs'
|
52 |
|
53 |
# Custom CSS to control output video size
|
@@ -225,7 +226,9 @@ class SignClassificationModel(torch.nn.Module):
|
|
225 |
return reduced_tensor
|
226 |
|
227 |
# Load the model
|
228 |
-
model_pretrained = torch.load(model_path, map_location=device, weights_only=False) #torch.device('cpu')
|
|
|
|
|
229 |
|
230 |
# Evaluation Function
|
231 |
def prod_function(model_pretrained, prod_ds):
|
@@ -307,7 +310,8 @@ def translate_sign_language(gesture):
|
|
307 |
#prod_video = np.random.randint(0, 255, (32, 225, 225, 3), dtype=np.uint8)
|
308 |
|
309 |
# Run ML Model
|
310 |
-
predicted_prod_label = prod_function(model_pretrained, prod_ds)
|
|
|
311 |
|
312 |
# Identify the hand gesture
|
313 |
predicted_prod_label = predicted_prod_label.squeeze(0)
|
@@ -347,7 +351,7 @@ with gr.Blocks(css=custom_css) as demo:
|
|
347 |
# text_output = gr.Textbox(label="Translation in English")
|
348 |
|
349 |
with gr.Row():
|
350 |
-
with gr.Column(scale=1, variant="panel"):
|
351 |
with gr.Row(height=350, variant="panel"):
|
352 |
# Add webcam input for sign language video capture
|
353 |
video_input = gr.Video(sources=["webcam"], format="mp4", label="Gesture")
|
|
|
5 |
from torchvision.transforms import v2
|
6 |
# For ML Model
|
7 |
import transformers
|
8 |
+
from transformers import VivitImageProcessor, VivitConfig, VivitModel, VivitForVideoClassification
|
9 |
from transformers import set_seed
|
10 |
# For Data Loaders
|
11 |
import datasets
|
|
|
47 |
MODEL_TRANSFORMER = 'google/vivit-b-16x2'
|
48 |
# Set Paths
|
49 |
#model_path = 'vivit_pytorch_loss051.pt'
|
50 |
+
model_path_2_pytorch = 'vivit_pytorch_GPU_6_acc087.pt'
|
51 |
+
model_path_2_transformer = ''
|
52 |
data_path = 'signs'
|
53 |
|
54 |
# Custom CSS to control output video size
|
|
|
226 |
return reduced_tensor
|
227 |
|
228 |
# Load the model
|
229 |
+
#model_pretrained = torch.load(model_path, map_location=device, weights_only=False) #torch.device('cpu')
|
230 |
+
#model_pretrained_2 = torch.load(model_path_2, map_location=device, weights_only=False)
|
231 |
+
model_pretrained_2 = VivitForVideoClassification.from_pretrained(model_path_2_transformer)
|
232 |
|
233 |
# Evaluation Function
|
234 |
def prod_function(model_pretrained, prod_ds):
|
|
|
310 |
#prod_video = np.random.randint(0, 255, (32, 225, 225, 3), dtype=np.uint8)
|
311 |
|
312 |
# Run ML Model
|
313 |
+
#predicted_prod_label = prod_function(model_pretrained, prod_ds)
|
314 |
+
predicted_prod_label = prod_function(model_pretrained_2, prod_ds)
|
315 |
|
316 |
# Identify the hand gesture
|
317 |
predicted_prod_label = predicted_prod_label.squeeze(0)
|
|
|
351 |
# text_output = gr.Textbox(label="Translation in English")
|
352 |
|
353 |
with gr.Row():
|
354 |
+
with gr.Column(scale=1.25, variant="panel"):
|
355 |
with gr.Row(height=350, variant="panel"):
|
356 |
# Add webcam input for sign language video capture
|
357 |
video_input = gr.Video(sources=["webcam"], format="mp4", label="Gesture")
|