Superxixixi commited on
Commit
a25806a
·
1 Parent(s): b89c008

Delete test_multicard.py

Browse files
Files changed (1) hide show
  1. test_multicard.py +0 -99
test_multicard.py DELETED
@@ -1,99 +0,0 @@
1
- import time, os, torch, argparse, warnings, glob, pandas, json
2
-
3
- from utils.tools import *
4
- from dlhammer import bootstrap
5
-
6
- from dataLoader_multiperson import val_loader
7
- from loconet import loconet
8
-
9
-
10
- class DataPrep():
11
-
12
- def __init__(self, cfg):
13
- self.cfg = cfg
14
-
15
- def val_dataloader(self):
16
- cfg = self.cfg
17
- loader = val_loader(cfg, trialFileName = cfg.evalTrialAVA, \
18
- audioPath = os.path.join(cfg.audioPathAVA , cfg.evalDataType), \
19
- visualPath = os.path.join(cfg.visualPathAVA, cfg.evalDataType), \
20
- num_speakers=cfg.MODEL.NUM_SPEAKERS,
21
- )
22
- valLoader = torch.utils.data.DataLoader(loader,
23
- batch_size=cfg.VAL.BATCH_SIZE,
24
- shuffle=False,
25
- num_workers=16)
26
- return valLoader
27
-
28
-
29
- def prepare_context_files(cfg):
30
- path = os.path.join(cfg.DATA.dataPathAVA, "csv")
31
- for phase in ["val", "test"]:
32
- csv_f = f"{phase}_loader.csv"
33
- csv_orig = f"{phase}_orig.csv"
34
- entity_f = os.path.join(path, phase + "_entity.json")
35
- ts_f = os.path.join(path, phase + "_ts.json")
36
- if os.path.exists(entity_f) and os.path.exists(ts_f):
37
- continue
38
- orig_df = pandas.read_csv(os.path.join(path, csv_orig))
39
- entity_data = {}
40
- ts_to_entity = {}
41
-
42
- for index, row in orig_df.iterrows():
43
-
44
- entity_id = row['entity_id']
45
- video_id = row['video_id']
46
- if row['label'] == "SPEAKING_AUDIBLE":
47
- label = 1
48
- else:
49
- label = 0
50
- ts = float(row['frame_timestamp'])
51
- if video_id not in entity_data.keys():
52
- entity_data[video_id] = {}
53
- if entity_id not in entity_data[video_id].keys():
54
- entity_data[video_id][entity_id] = {}
55
- if ts not in entity_data[video_id][entity_id].keys():
56
- entity_data[video_id][entity_id][ts] = []
57
-
58
- entity_data[video_id][entity_id][ts] = label
59
-
60
- if video_id not in ts_to_entity.keys():
61
- ts_to_entity[video_id] = {}
62
- if ts not in ts_to_entity[video_id].keys():
63
- ts_to_entity[video_id][ts] = []
64
- ts_to_entity[video_id][ts].append(entity_id)
65
-
66
- with open(entity_f) as f:
67
- json.dump(entity_data, f)
68
-
69
- with open(ts_f) as f:
70
- json.dump(ts_to_entity, f)
71
-
72
-
73
- def main():
74
- cfg = bootstrap(print_cfg=False)
75
- print(cfg)
76
- epoch = cfg.RESUME_EPOCH
77
-
78
- warnings.filterwarnings("ignore")
79
-
80
- cfg = init_args(cfg)
81
-
82
- data = DataPrep(cfg)
83
-
84
- prepare_context_files(cfg)
85
-
86
- if cfg.downloadAVA == True:
87
- preprocess_AVA(cfg)
88
- quit()
89
-
90
- s = loconet(cfg)
91
-
92
- s.loadParameters(cfg.RESUME_PATH)
93
- mAP = s.evaluate_network(epoch=epoch, loader=data.val_dataloader())
94
- print(f"evaluate ckpt: {cfg.RESUME_PATH}")
95
- print(mAP)
96
-
97
-
98
- if __name__ == '__main__':
99
- main()