Image Classification
Transformers
English
art
litav commited on
Commit
0c1c962
verified
1 Parent(s): 10b8909

Update vit_model_test.py

Browse files
Files changed (1) hide show
  1. vit_model_test.py +8 -6
vit_model_test.py CHANGED
@@ -6,9 +6,11 @@ from transformers import ViTForImageClassification
6
  import os
7
  import pandas as pd
8
  from sklearn.model_selection import train_test_split
9
- from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score, recall_score
10
  import matplotlib.pyplot as plt
11
  import seaborn as sns
 
 
12
 
13
  # 驻讜谞拽爪讬讛 诇讛讞讝专转 HTML 砖诇 住专讟讜谉
14
  def display_video(video_url):
@@ -37,7 +39,7 @@ if __name__ == "__main__":
37
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
38
 
39
  model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
40
-
41
  # Define the image preprocessing pipeline
42
  preprocess = transforms.Compose([
43
  transforms.Resize((224, 224)),
@@ -56,10 +58,10 @@ if __name__ == "__main__":
56
  model.load_state_dict(torch.load('trained_model.pth'))
57
 
58
  # 拽讬砖讜专 诇住专讟讜谉
59
- video_url = 'https://rr5---sn-33uxaxjvh-aixe.googlevideo.com/videoplayback?expire=1727025979&ei=2_7vZrzMAuGdp-oPuaTo-QI&ip=39.62.1.120&id=o-AJ04-wA4jR6nhlg7B-yNUOXEwR7yoNlJetni5NaAoWRl&itag=134&aitags=133%2C134%2C135%2C136%2C137%2C160%2C242%2C243%2C244%2C247%2C248%2C278&source=youtube&requiressl=yes&xpc=EgVo2aDSNQ%3D%3D&mh=9z&mm=31%2C29&mn=sn-33uxaxjvh-aixe%2Csn-hju7enll&ms=au%2Crdu&mv=m&mvi=5&pl=24&initcwndbps=306250&bui=AXLXGFS9xlNb5y-figGb1FTTN1Ma8zVRiN7RtpZjebiJgICl7QFK5ab9UDZVXvwn2GwOYj4m4rXuQlYc&spc=54MbxY0qT7L8eXI7eMdKq6id860EyvqxATj5F0MLSzmNFdC1mD-XNkZUkcL1EWQ&vprv=1&svpuc=1&mime=video%2Fmp4&ns=qN73Wubd4RAEtRCu3S2dItYQ&rqh=1&gir=yes&clen=242900&dur=5.000&lmt=1727003020660351&mt=1727004294&fvip=2&keepalive=yes&fexp=51299152&c=WEB&sefc=1&txp=630A224&n=Y84SAecGmAZzwg&sparams=expire%2Cei%2Cip%2Cid%2Caitags%2Csource%2Crequiressl%2Cxpc%2Cbui%2Cspc%2Cvprv%2Csvpuc%2Cmime%2Cns%2Crqh%2Cgir%2Cclen%2Cdur%2Clmt&sig=AJfQdSswRAIgGjjE8lnq2bVWML91M2fA0A3qtumgsH-bASH-qjraIRwCIBh9oYh7GnjGwTNescuIZ1qgv4PBj0WOzJJbveuTUOb8&lsparams=mh%2Cmm%2Cmn%2Cms%2Cmv%2Cmvi%2Cpl%2Cinitcwndbps&lsig=ABPmVW0wRQIgCwVg3G31n-JXtH0t66MDGpnLR8s-mRwiTjMQP9TeTawCIQC2zaC1iwicMoTjn6ha46-W1UZrW6Rv9D8HP5I96C1hfg%3D%3D&extt=mp4' # 讛讞诇讬驻讬 讻讗谉 注诐 讛-URL 砖诇 讛住专讟讜谉 砖诇讱
60
  video_html = display_video(video_url)
61
 
62
- # 讛专讗讬 讗转 讛住专讟讜谉 诇驻谞讬 讛讞讬讝讜讬
63
  print(video_html) # 讝讛 讗诪讜专 诇讛爪讬讙 讗转 讛-HTML 讘讚砖讘讜专讚 砖诇讱
64
 
65
  # Evaluate the model
@@ -67,8 +69,8 @@ if __name__ == "__main__":
67
  true_labels = []
68
  predicted_labels = []
69
 
70
- # 讛专讗讛 讗转 讛住专讟讜谉 讘注转 讞讬讝讜讬
71
- print(video_html) # 讛爪讙 讗转 讛-HTML 砖诇 讛住专讟讜谉
72
 
73
  with torch.no_grad():
74
  for images, labels in test_loader:
 
6
  import os
7
  import pandas as pd
8
  from sklearn.model_selection import train_test_split
9
+ from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score
10
  import matplotlib.pyplot as plt
11
  import seaborn as sns
12
+ from sklearn.metrics import recall_score
13
+ from vit_model_traning import labeling, CustomDataset
14
 
15
  # 驻讜谞拽爪讬讛 诇讛讞讝专转 HTML 砖诇 住专讟讜谉
16
  def display_video(video_url):
 
39
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
40
 
41
  model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
42
+
43
  # Define the image preprocessing pipeline
44
  preprocess = transforms.Compose([
45
  transforms.Resize((224, 224)),
 
58
  model.load_state_dict(torch.load('trained_model.pth'))
59
 
60
  # 拽讬砖讜专 诇住专讟讜谉
61
+ video_url = '"C:\Users\litav\Downloads\0001-0120.mp4"' # 讛讞诇讬驻讬 讻讗谉 注诐 讛-URL 砖诇 讛住专讟讜谉 砖诇讱
62
  video_html = display_video(video_url)
63
 
64
+ # 讛专讗讛 讗转 讛住专讟讜谉 讻讗砖专 讛讻驻转讜专 谞诇讞抓
65
  print(video_html) # 讝讛 讗诪讜专 诇讛爪讬讙 讗转 讛-HTML 讘讚砖讘讜专讚 砖诇讱
66
 
67
  # Evaluate the model
 
69
  true_labels = []
70
  predicted_labels = []
71
 
72
+ # 讻讗谉 转讜住讬祝 拽讜讚 JavaScript 诇讛驻注讬诇 讗转 讛住专讟讜谉 讘注转 诇讞讬爪讛 注诇 讻驻转讜专 讛-SUBMIT
73
+ # 讚讜讙诪讛: <button onclick="playVideo()">Submit</button>
74
 
75
  with torch.no_grad():
76
  for images, labels in test_loader: