Spaces:
Running
Running
File size: 9,222 Bytes
fbb1d85 aaef8e0 fbb1d85 aaef8e0 bdf49c6 aaef8e0 bdf49c6 aaef8e0 bdf49c6 aaef8e0 bdf49c6 aaef8e0 bdf49c6 aaef8e0 bdf49c6 aaef8e0 bdf49c6 fbb1d85 bdf49c6 fbb1d85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import os
import json
import unittest
from pathlib import Path
from zipfile import ZipFile
from typing import List, Dict, Any, Union
from tempfile import TemporaryDirectory
def validate_zip(submission_track: str, submission_zip: Union[Path, str]):
"""
Validates the submission format and contents
Args:
submission_track: the track of the submission
submission_zip: path to the submission zip file
Raises:
ValueError: if the submission zip is invalid
"""
with TemporaryDirectory() as temp_dir:
with ZipFile(submission_zip, 'r') as submission_zip_file:
submission_zip_file.extractall(temp_dir)
submission_dir = Path(temp_dir)
if submission_track in ['NOTSOFAR-SC', 'NOTSOFAR-MC']:
validate_notsofar_submission(submission_dir=submission_dir)
elif submission_track in ['DASR-Constrained-LM', 'DASR-Unconstrained-LM']:
validate_dasr_submission(submission_dir=submission_dir)
else:
raise ValueError(f'Invalid submission track: {submission_track}')
def validate_notsofar_submission(submission_dir: Path):
"""
Validates NOTSOFAR submission format and contents
Args:
submission_dir: path to the submission directory
Raises:
ValueError: if the submission zip is invalid
"""
submission_file_names = ['tcp_wer_hyp.json']
optional_file_names = ['tc_orc_wer_ref.json']
fields = ['session_id', 'words', 'speaker', 'start_time', 'end_time']
for file_name in submission_file_names + optional_file_names:
file_path = submission_dir / file_name
if not file_path.exists():
if file_name in submission_file_names:
raise ValueError(f'Missing {file_name}')
else:
continue
validate_json_file_structure(file_path, fields)
def validate_dasr_submission(submission_dir: Path):
"""
Validates DASR submission format and contents
Args:
submission_dir: path to the submission directory
Raises:
ValueError: if the submission zip is invalid
"""
submission_file_names = ['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json']
fields = ['session_id', 'words', 'speaker', 'start_time', 'end_time']
if not (submission_dir / 'dev').exists():
raise ValueError('Missing `dev` directory, expecting a directory named `dev` with the submission files in it.')
for file_name in submission_file_names:
file_path = submission_dir / 'dev' / file_name
if not file_path.exists():
raise ValueError(f'Missing {file_name}')
validate_json_file_structure(file_path, fields)
def validate_json_file_structure(file_path: Path, fields: List[str]):
"""
Validates the structure of a json file
Args:
file_path: path to the json file
fields: list of fields that are required in each entry
Raises:
ValueError: if the json file is invalid
"""
with open(file_path, 'r') as json_file:
json_data: List[Dict[str, Any]] = json.load(json_file)
if not isinstance(json_data, list):
raise ValueError(f'Invalid `{file_path.name}` format, expecting a list of entries')
for data in json_data:
if not all(field in data for field in fields):
raise ValueError(f'Invalid `{file_path.name}` format, fields: {fields} are required in each entry')
####################################################################################################
# Tests
####################################################################################################
class TestValidateZip(unittest.TestCase):
DATA_SAMPLES = 10
@classmethod
def setUpClass(cls):
cls.valid_data = [{'session_id': 'session_id', 'words': 'words', 'speaker': 'speaker',
'start_time': 0.0, 'end_time': 1.0} for _ in range(cls.DATA_SAMPLES)]
cls.invalid_data = [{'session_id': 'session_id', 'words': 'words',
'start_time': 0.0} for _ in range(cls.DATA_SAMPLES)]
def setUp(self):
self.temp_dir = TemporaryDirectory()
self.submission_zip = Path(self.temp_dir.name) / 'submission.zip'
def create_test_data(self, submission_track: str, data: List[Dict[str, Any]], json_file_names: List[str],
parent_zip_dir: str = None):
submission_dir = Path(self.temp_dir.name) / submission_track
os.makedirs(submission_dir, exist_ok=True)
with ZipFile(self.submission_zip, 'w') as submission_zip_file:
for json_file_name in json_file_names:
if parent_zip_dir:
json_file_name = str(Path(parent_zip_dir) / json_file_name)
submission_zip_file.writestr(json_file_name, json.dumps(data))
return submission_track, self.submission_zip
def tearDown(self):
self.temp_dir.cleanup()
def test_NOTSOFAR_SC_valid_data_tcp(self):
self.assertEqual(validate_zip(*self.create_test_data(
'NOTSOFAR-SC', self.valid_data, ['tcp_wer_hyp.json'])), None)
def test_NOTSOFAR_SC_valid_data_tcp_and_tcorc(self):
self.assertEqual(validate_zip(*self.create_test_data(
'NOTSOFAR-SC', self.valid_data, ['tcp_wer_hyp.json', 'tc_orc_wer_ref.json'])), None)
def test_NOTSOFAR_SC_missing_tcp_file(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data(
'NOTSOFAR-SC', self.valid_data, ['tc_orc_wer_ref.json']))
def test_NOTSOFAR_SC_invalid_data(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data(
'NOTSOFAR-SC', self.invalid_data, ['tcp_wer_hyp.json']))
def test_NOTSOFAR_MC_valid_data_tcp(self):
self.assertEqual(validate_zip(*self.create_test_data(
'NOTSOFAR-MC', self.valid_data, ['tcp_wer_hyp.json'])), None)
def test_NOTSOFAR_MC_valid_data_tcp_and_tcorc(self):
self.assertEqual(validate_zip(*self.create_test_data(
'NOTSOFAR-MC', self.valid_data, ['tcp_wer_hyp.json', 'tc_orc_wer_ref.json'])), None)
def test_NOTSOFAR_MC_missing_tcp_file(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data(
'NOTSOFAR-MC', self.valid_data, ['tc_orc_wer_ref.json']))
def test_NOTSOFAR_MC_invalid_data(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data(
'NOTSOFAR-MC', self.invalid_data, ['tcp_wer_hyp.json']))
def test_DASR_Constrained_LM_valid_data(self):
self.assertEqual(validate_zip(*self.create_test_data('DASR-Constrained-LM', self.valid_data,
['chime6.json', 'dipco.json', 'mixer6.json',
'notsofar1.json'], 'dev')), None)
def test_DASR_Constrained_LM_invalid_data(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data('DASR-Constrained-LM', self.invalid_data,
['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json'], 'dev'))
def test_DASR_Constrained_LM_missing_dev_dir(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data('DASR-Constrained-LM', self.valid_data,
['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json']))
def test_DASR_Constrained_LM_missing_json_file(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data('DASR-Constrained-LM', self.valid_data,
['chime6.json', 'dipco.json', 'mixer6.json'], 'dev'))
def test_DASR_Unconstrained_LM_valid_data(self):
self.assertEqual(validate_zip(*self.create_test_data('DASR-Unconstrained-LM', self.valid_data,
['chime6.json', 'dipco.json', 'mixer6.json',
'notsofar1.json'], 'dev')), None)
def test_DASR_Unconstrained_LM_invalid_data(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data('DASR-Unconstrained-LM', self.invalid_data,
['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json'], 'dev'))
def test_DASR_Unconstrained_LM_missing_dev_dir(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data('DASR-Unconstrained-LM', self.valid_data,
['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json']))
def test_DASR_Unconstrained_LM_missing_json_file(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data('DASR-Unconstrained-LM', self.valid_data,
['chime6.json', 'dipco.json', 'mixer6.json'], 'dev'))
if __name__ == '__main__':
unittest.main()
|