# -*- coding: utf-8 -*- # Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR # -------------------------------------------------------- # This source code is licensed under the Attribution-NonCommercial 4.0 International License. # You can find the license in the LICENSE file in the root directory of this source tree. # -------------------------------------------------------- import numpy as np import torch import torch.nn.functional as F import util.misc as misc from util.metrics import * @torch.no_grad() def test_two_class(data_loader, model, device): criterion = torch.nn.CrossEntropyLoss() # switch to evaluation mode model.eval() frame_labels = np.array([]) # int label frame_preds = np.array([]) # pred logit frame_y_preds = np.array([]) # pred int video_names_list = list() for batch in data_loader: images = batch[0] # torch.Size([BS, C, H, W]) target = batch[1] # torch.Size([BS]) video_name = batch[-1] # list[BS] images = images.to(device, non_blocking=True) target = target.to(device, non_blocking=True) output = model(images).to(device, non_blocking=True) # modified loss = criterion(output, target) frame_pred = (F.softmax(output, dim=1)[:, 1].detach().cpu().numpy()) frame_preds = np.append(frame_preds, frame_pred) frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1) frame_y_preds = np.append(frame_y_preds, frame_y_pred) frame_label = (target.detach().cpu().numpy()) frame_labels = np.append(frame_labels, frame_label) video_names_list.extend(list(video_name)) # video-level metrics: frame_labels_list = frame_labels.tolist() frame_preds_list = frame_preds.tolist() video_label_list, video_pred_list, video_y_pred_list = get_video_level_label_pred(frame_labels_list, video_names_list, frame_preds_list) return frame_preds_list, video_pred_list @torch.no_grad() def test_multi_class(data_loader, model, device): criterion = torch.nn.CrossEntropyLoss() # switch to evaluation mode model.eval() frame_labels = np.array([]) # int label frame_preds = np.empty((0, 4)) # pred logit, initialize as 2D array with 4 columns for 4 classes frame_y_preds = np.array([]) # pred int video_names_list = list() for batch in data_loader: images = batch[0] # torch.Size([BS, C, H, W]) target = batch[1] # torch.Size([BS]) video_name = batch[-1] # list[BS] images = images.to(device, non_blocking=True) target = target.to(device, non_blocking=True) output = model(images).to(device, non_blocking=True) loss = criterion(output, target) frame_pred = F.softmax(output, dim=1).detach().cpu().numpy() frame_preds = np.append(frame_preds, frame_pred, axis=0) frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1) frame_y_preds = np.append(frame_y_preds, frame_y_pred) frame_label = target.detach().cpu().numpy() frame_labels = np.append(frame_labels, frame_label) video_names_list.extend(list(video_name)) # video-level metrics: frame_labels_list = frame_labels.tolist() frame_preds_list = frame_preds.tolist() video_label_list, video_pred_list, video_y_pred_list = get_video_level_label_pred_multi_class(frame_labels_list, video_names_list, frame_preds_list) return frame_preds_list, video_pred_list # @torch.no_grad() # def test_multi_class(data_loader, model, device): # criterion = torch.nn.CrossEntropyLoss() # # # switch to evaluation mode # model.eval() # # frame_labels = np.array([]) # int label # frame_preds = np.array([]) # pred logit # frame_y_preds = np.array([]) # pred int # video_names_list = list() # # for batch in data_loader: # images = batch[0] # torch.Size([BS, C, H, W]) # target = batch[1] # torch.Size([BS]) # video_name = batch[-1] # list[BS] # images = images.to(device, non_blocking=True) # target = target.to(device, non_blocking=True) # # output = model(images).to(device, non_blocking=True) # loss = criterion(output, target) # # frame_pred = F.softmax(output, dim=1).detach().cpu().numpy() # frame_preds = np.append(frame_preds, frame_pred, axis=0) # frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1) # frame_y_preds = np.append(frame_y_preds, frame_y_pred) # # frame_label = target.detach().cpu().numpy() # frame_labels = np.append(frame_labels, frame_label) # video_names_list.extend(list(video_name)) # # # video-level metrics: # frame_labels_list = frame_labels.tolist() # frame_preds_list = frame_preds.tolist() # video_label_list, video_pred_list, video_y_pred_list = get_video_level_label_pred_multi_class(frame_labels_list, video_names_list, frame_preds_list) # # return frame_preds_list, video_pred_list