Update vit_model_test.py
Browse files- vit_model_test.py +8 -6
vit_model_test.py
CHANGED
@@ -6,9 +6,11 @@ from transformers import ViTForImageClassification
|
|
6 |
import os
|
7 |
import pandas as pd
|
8 |
from sklearn.model_selection import train_test_split
|
9 |
-
from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score
|
10 |
import matplotlib.pyplot as plt
|
11 |
import seaborn as sns
|
|
|
|
|
12 |
|
13 |
# 驻讜谞拽爪讬讛 诇讛讞讝专转 HTML 砖诇 住专讟讜谉
|
14 |
def display_video(video_url):
|
@@ -37,7 +39,7 @@ if __name__ == "__main__":
|
|
37 |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
|
38 |
|
39 |
model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
|
40 |
-
|
41 |
# Define the image preprocessing pipeline
|
42 |
preprocess = transforms.Compose([
|
43 |
transforms.Resize((224, 224)),
|
@@ -56,10 +58,10 @@ if __name__ == "__main__":
|
|
56 |
model.load_state_dict(torch.load('trained_model.pth'))
|
57 |
|
58 |
# 拽讬砖讜专 诇住专讟讜谉
|
59 |
-
video_url = '
|
60 |
video_html = display_video(video_url)
|
61 |
|
62 |
-
#
|
63 |
print(video_html) # 讝讛 讗诪讜专 诇讛爪讬讙 讗转 讛-HTML 讘讚砖讘讜专讚 砖诇讱
|
64 |
|
65 |
# Evaluate the model
|
@@ -67,8 +69,8 @@ if __name__ == "__main__":
|
|
67 |
true_labels = []
|
68 |
predicted_labels = []
|
69 |
|
70 |
-
#
|
71 |
-
|
72 |
|
73 |
with torch.no_grad():
|
74 |
for images, labels in test_loader:
|
|
|
6 |
import os
|
7 |
import pandas as pd
|
8 |
from sklearn.model_selection import train_test_split
|
9 |
+
from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score
|
10 |
import matplotlib.pyplot as plt
|
11 |
import seaborn as sns
|
12 |
+
from sklearn.metrics import recall_score
|
13 |
+
from vit_model_traning import labeling, CustomDataset
|
14 |
|
15 |
# 驻讜谞拽爪讬讛 诇讛讞讝专转 HTML 砖诇 住专讟讜谉
|
16 |
def display_video(video_url):
|
|
|
39 |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
|
40 |
|
41 |
model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
|
42 |
+
|
43 |
# Define the image preprocessing pipeline
|
44 |
preprocess = transforms.Compose([
|
45 |
transforms.Resize((224, 224)),
|
|
|
58 |
model.load_state_dict(torch.load('trained_model.pth'))
|
59 |
|
60 |
# 拽讬砖讜专 诇住专讟讜谉
|
61 |
+
video_url = '"C:\Users\litav\Downloads\0001-0120.mp4"' # 讛讞诇讬驻讬 讻讗谉 注诐 讛-URL 砖诇 讛住专讟讜谉 砖诇讱
|
62 |
video_html = display_video(video_url)
|
63 |
|
64 |
+
# 讛专讗讛 讗转 讛住专讟讜谉 讻讗砖专 讛讻驻转讜专 谞诇讞抓
|
65 |
print(video_html) # 讝讛 讗诪讜专 诇讛爪讬讙 讗转 讛-HTML 讘讚砖讘讜专讚 砖诇讱
|
66 |
|
67 |
# Evaluate the model
|
|
|
69 |
true_labels = []
|
70 |
predicted_labels = []
|
71 |
|
72 |
+
# 讻讗谉 转讜住讬祝 拽讜讚 JavaScript 诇讛驻注讬诇 讗转 讛住专讟讜谉 讘注转 诇讞讬爪讛 注诇 讻驻转讜专 讛-SUBMIT
|
73 |
+
# 讚讜讙诪讛: <button onclick="playVideo()">Submit</button>
|
74 |
|
75 |
with torch.no_grad():
|
76 |
for images, labels in test_loader:
|