wolo-wolo
V1.0
4d10ed1
# -*- coding: utf-8 -*-
# Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
# --------------------------------------------------------
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
# You can find the license in the LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
import numpy as np
import torch
import torch.nn.functional as F
import util.misc as misc
from util.metrics import *
@torch.no_grad()
def test_two_class(data_loader, model, device):
criterion = torch.nn.CrossEntropyLoss()
# switch to evaluation mode
model.eval()
frame_labels = np.array([]) # int label
frame_preds = np.array([]) # pred logit
frame_y_preds = np.array([]) # pred int
video_names_list = list()
for batch in data_loader:
images = batch[0] # torch.Size([BS, C, H, W])
target = batch[1] # torch.Size([BS])
video_name = batch[-1] # list[BS]
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(images).to(device, non_blocking=True) # modified
loss = criterion(output, target)
frame_pred = (F.softmax(output, dim=1)[:, 1].detach().cpu().numpy())
frame_preds = np.append(frame_preds, frame_pred)
frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1)
frame_y_preds = np.append(frame_y_preds, frame_y_pred)
frame_label = (target.detach().cpu().numpy())
frame_labels = np.append(frame_labels, frame_label)
video_names_list.extend(list(video_name))
# video-level metrics:
frame_labels_list = frame_labels.tolist()
frame_preds_list = frame_preds.tolist()
video_label_list, video_pred_list, video_y_pred_list = get_video_level_label_pred(frame_labels_list, video_names_list, frame_preds_list)
return frame_preds_list, video_pred_list
@torch.no_grad()
def test_multi_class(data_loader, model, device):
criterion = torch.nn.CrossEntropyLoss()
# switch to evaluation mode
model.eval()
frame_labels = np.array([]) # int label
frame_preds = np.empty((0, 4)) # pred logit, initialize as 2D array with 4 columns for 4 classes
frame_y_preds = np.array([]) # pred int
video_names_list = list()
for batch in data_loader:
images = batch[0] # torch.Size([BS, C, H, W])
target = batch[1] # torch.Size([BS])
video_name = batch[-1] # list[BS]
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(images).to(device, non_blocking=True)
loss = criterion(output, target)
frame_pred = F.softmax(output, dim=1).detach().cpu().numpy()
frame_preds = np.append(frame_preds, frame_pred, axis=0)
frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1)
frame_y_preds = np.append(frame_y_preds, frame_y_pred)
frame_label = target.detach().cpu().numpy()
frame_labels = np.append(frame_labels, frame_label)
video_names_list.extend(list(video_name))
# video-level metrics:
frame_labels_list = frame_labels.tolist()
frame_preds_list = frame_preds.tolist()
video_label_list, video_pred_list, video_y_pred_list = get_video_level_label_pred_multi_class(frame_labels_list, video_names_list, frame_preds_list)
return frame_preds_list, video_pred_list
# @torch.no_grad()
# def test_multi_class(data_loader, model, device):
# criterion = torch.nn.CrossEntropyLoss()
#
# # switch to evaluation mode
# model.eval()
#
# frame_labels = np.array([]) # int label
# frame_preds = np.array([]) # pred logit
# frame_y_preds = np.array([]) # pred int
# video_names_list = list()
#
# for batch in data_loader:
# images = batch[0] # torch.Size([BS, C, H, W])
# target = batch[1] # torch.Size([BS])
# video_name = batch[-1] # list[BS]
# images = images.to(device, non_blocking=True)
# target = target.to(device, non_blocking=True)
#
# output = model(images).to(device, non_blocking=True)
# loss = criterion(output, target)
#
# frame_pred = F.softmax(output, dim=1).detach().cpu().numpy()
# frame_preds = np.append(frame_preds, frame_pred, axis=0)
# frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1)
# frame_y_preds = np.append(frame_y_preds, frame_y_pred)
#
# frame_label = target.detach().cpu().numpy()
# frame_labels = np.append(frame_labels, frame_label)
# video_names_list.extend(list(video_name))
#
# # video-level metrics:
# frame_labels_list = frame_labels.tolist()
# frame_preds_list = frame_preds.tolist()
# video_label_list, video_pred_list, video_y_pred_list = get_video_level_label_pred_multi_class(frame_labels_list, video_names_list, frame_preds_list)
#
# return frame_preds_list, video_pred_list