whisper-timestamped / tests /test_transcribe.py
acharyasagar's picture
Upload folder using huggingface_hub
89c4248 verified
__author__ = "Jérôme Louradour"
__credits__ = ["Jérôme Louradour"]
__license__ = "GPLv3"
import unittest
import sys
import os
import subprocess
import shutil
import tempfile
import json
import torch
import jsonschema
FAIL_IF_REFERENCE_NOT_FOUND = True
GENERATE_NEW_ONLY = False
GENERATE_ALL = False
GENERATE_DEVICE_DEPENDENT = False
SKIP_LONG_TEST_IF_CPU = True
CMD_OPTIONS = []
class TestHelper(unittest.TestCase):
def skipLongTests(self):
return SKIP_LONG_TEST_IF_CPU and not torch.cuda.is_available()
def setUp(self):
self.maxDiff = None
self.createdReferences = []
def tearDown(self):
if GENERATE_ALL or GENERATE_NEW_ONLY or not FAIL_IF_REFERENCE_NOT_FOUND or GENERATE_DEVICE_DEPENDENT:
if len(self.createdReferences) > 0:
print("WARNING: Created references: " +
", ".join(self.createdReferences).replace(self.get_data_path()+"/", ""))
else:
self.assertEqual(self.createdReferences, [], "Created references: " +
", ".join(self.createdReferences).replace(self.get_data_path()+"/", ""))
def get_main_path(self, fn=None, check=False):
return self._get_path("whisper_timestamped", fn, check=check)
def get_output_path(self, fn=None):
if fn == None:
return tempfile.gettempdir()
return os.path.join(tempfile.gettempdir(), fn + self._extra_cmd_options())
def get_expected_path(self, fn=None, check=False):
return self._get_path("tests/expected" + self._extra_cmd_options(), fn, check=check)
def _extra_cmd_options(self):
s = "".join([f.replace("-","").strip() for f in CMD_OPTIONS])
if s:
return "." + s
return ""
def get_data_files(self, files=None, excluded_by_default=["apollo11.mp3", "music.mp4", "arabic.mp3", "japanese.mp3", "empty.wav", "words.wav"]):
if files == None:
files = os.listdir(self.get_data_path())
files = [f for f in files if f not in excluded_by_default and not f.endswith("json")]
files = sorted(files)
return [self.get_data_path(fn) for fn in files]
def get_generated_files(self, input_filename, output_path, extensions):
for ext in extensions:
yield os.path.join(output_path, os.path.basename(input_filename) + "." + ext.lstrip("."))
def main_script(self, pyscript = "transcribe.py", exename = "whisper_timestamped"):
main_script = self.get_main_path(pyscript, check=False)
if not os.path.exists(main_script):
main_script = exename
return main_script
def assertRun(self, cmd):
if isinstance(cmd, str):
return self.assertRun(cmd.split())
curdir = os.getcwd()
os.chdir(tempfile.gettempdir())
if cmd[0].endswith(".py"):
cmd = [sys.executable] + cmd
print("Running:", " ".join(cmd))
p = subprocess.Popen(cmd,
# Otherwise ".local" path might be missing
env=dict(
os.environ, PYTHONPATH=os.pathsep.join(sys.path)),
stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
os.chdir(curdir)
(stdout, stderr) = p.communicate()
self.assertEqual(p.returncode, 0, msg=stderr.decode("utf-8"))
return (stdout.decode("utf-8"), stderr.decode("utf-8"))
def assertNonRegression(self, content, reference, string_is_file=True):
"""
Check that a file/folder is the same as a reference file/folder.
"""
if isinstance(content, dict):
# Make a temporary file
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", encoding="utf8", delete=False) as f:
json.dump(content, f, indent=2, ensure_ascii=False)
content = f.name
res = self.assertNonRegression(f.name, reference)
os.remove(f.name)
return res
elif not isinstance(content, str):
raise ValueError(f"Invalid content type: {type(content)}")
if not string_is_file:
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", encoding="utf8", delete=False) as f:
f.write(content)
content = f.name
res = self.assertNonRegression(f.name, reference)
os.remove(f.name)
return res
self.assertTrue(os.path.exists(content), f"Missing file: {content}")
is_file = os.path.isfile(reference) if os.path.exists(reference) else os.path.isfile(content)
reference = self.get_expected_path(
reference, check=FAIL_IF_REFERENCE_NOT_FOUND)
if not os.path.exists(reference) or ((GENERATE_ALL or GENERATE_DEVICE_DEPENDENT) and reference not in self.createdReferences):
dirname = os.path.dirname(reference)
if not os.path.isdir(dirname):
os.makedirs(dirname)
if is_file:
shutil.copyfile(content, reference)
else:
shutil.copytree(content, reference)
self.createdReferences.append(reference)
if is_file:
self.assertTrue(os.path.isfile(content))
self._check_file_non_regression(content, reference)
else:
self.assertTrue(os.path.isdir(content))
for root, dirs, files in os.walk(content):
for f in files:
f_ref = os.path.join(reference, f)
self.assertTrue(os.path.isfile(f_ref),
f"Additional file: {f}")
self._check_file_non_regression(
os.path.join(root, f), f_ref)
for root, dirs, files in os.walk(reference):
for f in files:
f = os.path.join(content, f)
self.assertTrue(os.path.isfile(f), f"Missing file: {f}")
def get_data_path(self, fn=None, check=True):
return self._get_path("tests/data", fn, check)
def _get_path(self, prefix, fn=None, check=True):
path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
prefix
)
if fn:
path = os.path.join(path, fn)
if check:
self.assertTrue(os.path.exists(path), f"Cannot find {path}")
return path
def _check_file_non_regression(self, file, reference):
if file.endswith(".json"):
with open(file) as f:
content = json.load(f)
with open(reference) as f:
reference_content = json.load(f)
if "language" in content and "language" in reference_content:
content["language"] = self.norm_language(content["language"])
reference_content["language"] = self.norm_language(reference_content["language"])
self.assertClose(content, reference_content,
msg=f"File {file} does not match reference {reference}")
return
with open(file) as f:
content = f.readlines()
with open(reference) as f:
reference_content = f.readlines()
self.assertEqual(content, reference_content,
msg=f"File {file} does not match reference {reference}")
def assertClose(self, obj1, obj2, msg=None):
return self.assertEqual(self.loose(obj1), self.loose(obj2), msg=msg)
def loose(self, obj):
# Return an approximative value of an object
if isinstance(obj, list):
return [self.loose(a) for a in obj]
if isinstance(obj, float):
f = round(obj, 1)
return 0.0 if f == -0.0 else f
if isinstance(obj, dict):
return {k: self.loose(v) for k, v in obj.items()}
if isinstance(obj, tuple):
return tuple(self.loose(list(obj)))
if isinstance(obj, set):
return self.loose(list(obj), "set")
return obj
def get_audio_duration(self, audio_file):
# Get the duration in sec *without introducing additional dependencies*
import whisper
return len(whisper.load_audio(audio_file)) / whisper.audio.SAMPLE_RATE
def get_device_str(self):
import torch
return "cpu" if not torch.cuda.is_available() else "cuda"
def norm_language(self, language):
# Cheap custom stuff to avoid importing everything
return {
"japanese": "ja",
}.get(language.lower(), language)
class TestHelperCli(TestHelper):
json_schema = None
def _test_cli_(self, opts, name, files=None, extensions=["words.json"], prefix=None, one_per_call=True, device_specific=None):
"""
Test command line
opts: list of options
name: name of the test
files: list of files to process
extensions: list of extensions to check, or None to test the stdout
prefix: prefix to add to the reference files
one_per_call: if True, each file is processed separately, otherwise all files are processed by a single process
"""
opts = opts + CMD_OPTIONS
output_dir = self.get_output_path(name)
input_filenames = self.get_data_files(files)
for i, input_filename in enumerate(input_filenames):
# Butterfly effect: Results are different depending on the device for long files
duration = self.get_audio_duration(input_filename)
if device_specific is None:
device_dependent = duration > 60 or (duration > 30 and "tiny_fr" in name) or ("empty" in input_filename and "medium_auto" in name)
else:
device_dependent = device_specific
name_ = name
if device_dependent and self.get_device_str() != "cuda":
name_ += f".{self.get_device_str()}"
def ref_name(output_filename):
return name_ + "/" + (f"{prefix}_" if prefix else "") + os.path.basename(output_filename)
generic_name = ref_name(input_filename + ".*")
if GENERATE_DEVICE_DEPENDENT and not device_dependent:
print("Skipping non-regression test", generic_name)
continue
if GENERATE_NEW_ONLY and min([os.path.exists(self.get_expected_path(ref_name(output_filename)))
for output_filename in self.get_generated_files(input_filename, output_dir, extensions=extensions)]
):
print("Skipping non-regression test", generic_name)
continue
print("Running non-regression test", generic_name)
if one_per_call or i == 0:
if one_per_call:
(stdout, stderr) = self.assertRun([self.main_script(), input_filename, "--output_dir", output_dir, *opts])
else:
(stdout, stderr) = self.assertRun([self.main_script(), *input_filenames, "--output_dir", output_dir, *opts])
print(stdout)
print(stderr)
output_json = self.get_generated_files(input_filename, output_dir, extensions=["words.json"]).__next__()
if os.path.isfile(output_json):
self.check_json(output_json)
if extensions is None:
output_filename = list(self.get_generated_files(input_filename, output_dir, extensions=["stdout"]))[0]
self.assertNonRegression(stdout, ref_name(output_filename), string_is_file=False)
else:
for output_filename in self.get_generated_files(input_filename, output_dir, extensions=extensions):
self.assertNonRegression(output_filename, ref_name(output_filename))
shutil.rmtree(output_dir, ignore_errors=True)
def check_json(self, json_file):
with open(json_file) as f:
content = json.load(f)
if self.json_schema is None:
schema_file = os.path.join(os.path.dirname(__file__), "json_schema.json")
self.assertTrue(os.path.isfile(schema_file), msg=f"Schema file {schema_file} not found")
self.json_schema = json.load(open(schema_file))
jsonschema.validate(instance=content, schema=self.json_schema)
class TestTranscribeTiny(TestHelperCli):
def test_cli_tiny_auto(self):
self._test_cli_(
["--model", "tiny"],
"tiny_auto",
)
def test_cli_tiny_fr(self):
self._test_cli_(
["--model", "tiny", "--language", "fr"],
"tiny_fr",
)
class TestTranscribeMedium(TestHelperCli):
def test_cli_medium_auto(self):
self._test_cli_(
["--model", "medium"],
"medium_auto",
)
def test_cli_medium_fr(self):
self._test_cli_(
["--model", "medium", "--language", "fr"],
"medium_fr",
)
class TestTranscribeNaive(TestHelperCli):
def test_naive(self):
self._test_cli_(
["--model", "small", "--language", "en", "--efficient", "--naive"],
"naive",
files=["apollo11.mp3"],
prefix="naive",
)
self._test_cli_(
["--model", "small", "--language", "en", "--accurate"],
"naive",
files=["apollo11.mp3"],
prefix="accurate",
)
def test_stucked_segments(self):
self._test_cli_(
["--model", "tiny"],
"corner_cases",
files=["apollo11.mp3"],
prefix="accurate.tiny",
)
class TestTranscribeCornerCases(TestHelperCli):
def test_stucked_lm(self):
if self.skipLongTests():
return
self._test_cli_(
["--model", "small", "--language", "en", "--efficient"],
"corner_cases",
files=["apollo11.mp3"],
prefix="stucked_lm",
)
def test_punctuation_only(self):
# When there is only a punctuation detected in a segment, it could cause issue #24
self._test_cli_(
["--model", "medium.en", "--efficient", "--punctuations", "False"],
"corner_cases",
files=["empty.wav"],
prefix="issue24",
)
def test_temperature(self):
self._test_cli_(
["--model", "small", "--language", "English",
"--condition", "False", "--temperature", "0.1", "--efficient"],
"corner_cases",
files=["apollo11.mp3"],
prefix="random.nocond",
)
if self.skipLongTests():
return
self._test_cli_(
["--model", "small", "--language", "en", "--temperature", "0.2", "--efficient"],
"corner_cases",
files=["apollo11.mp3"],
prefix="random",
)
def test_not_conditioned(self):
if not os.path.exists(self.get_data_path("music.mp4", check=False)):
return
if self.skipLongTests():
return
self._test_cli_(
["--model", "medium", "--language", "en", "--condition", "False", "--efficient"],
"corner_cases",
files=["music.mp4"],
prefix="nocond",
)
self._test_cli_(
["--model", "medium", "--language", "en",
"--condition", "False", "--temperature", "0.4", "--efficient"],
"corner_cases",
files=["music.mp4"],
prefix="nocond.random",
)
def test_large(self):
if self.skipLongTests():
return
self._test_cli_(
["--model", "large-v2", "--language", "en",
"--condition", "False", "--temperature", "0.4", "--efficient"],
"corner_cases",
files=["apollo11.mp3"],
prefix="large",
)
if os.path.exists(self.get_data_path("arabic.mp3", check=False)):
self._test_cli_(
["--model", "large-v2", "--language", "Arabic", "--efficient"],
"corner_cases",
files=["arabic.mp3"]
)
def test_gloria(self):
for model in ["medium", "large-v2"]:
for dec in ["efficient", "accurate"]:
self._test_cli_(
["--model", model, "--language", "en", "--" + dec],
"corner_cases",
files=["gloria.mp3"],
prefix=model + "." + dec,
)
class TestTranscribeMonolingual(TestHelperCli):
def test_monolingual_tiny(self):
files = ["bonjour_vous_allez_bien.mp3"]
self._test_cli_(
["--model", "tiny.en", "--efficient"],
"tiny.en",
files=files,
prefix="efficient",
)
self._test_cli_(
["--model", "tiny.en", "--accurate"],
"tiny.en",
files=files,
prefix="accurate",
)
self._test_cli_(
["--model", "tiny.en", "--condition", "False", "--efficient"],
"tiny.en",
files=files,
prefix="nocond",
)
def test_monolingual_small(self):
if os.path.exists(self.get_data_path("arabic.mp3", check=False)):
self._test_cli_(
["--model", "small.en", "--condition", "True", "--efficient"],
"small.en",
files=["arabic.mp3"],
device_specific=True,
)
class TestTranscribeWithVad(TestHelperCli):
def test_vad_default(self):
self._test_cli_(
["--model", "tiny", "--accurate", "--language", "en", "--vad", "True", "--verbose", "True"],
"verbose",
files=["words.wav"],
prefix="vad",
extensions=None,
)
def test_vad_custom_silero(self):
self._test_cli_(
["--model", "tiny", "--accurate", "--language", "en", "--vad", "silero:v3.1", "--verbose", "True"],
"verbose",
files=["words.wav"],
prefix="vad_silero3.1",
extensions=None,
)
self._test_cli_(
["--model", "tiny", "--accurate", "--language", "en", "--vad", "silero:v3.0", "--verbose", "True"],
"verbose",
files=["words.wav"],
prefix="vad_silero3.0",
extensions=None,
)
def test_vad_custom_auditok(self):
self._test_cli_(
["--model", "tiny", "--language", "en", "--vad", "auditok", "--verbose", "True"],
"verbose",
files=["words.wav"],
prefix="vad_auditok",
extensions=None,
)
class TestTranscribeUnspacedLanguage(TestHelperCli):
def test_japanese(self):
self._test_cli_(
["--model", "tiny", "--efficient"],
"tiny_auto",
files=["japanese.mp3"],
device_specific=True,
)
self._test_cli_(
["--model", "tiny", "--language", "Japanese", "--efficient"],
"tiny_auto",
files=["japanese.mp3"],
prefix="jp",
device_specific=True,
)
self._test_cli_(
["--model", "tiny", "--accurate"],
"tiny_auto",
files=["japanese.mp3"],
prefix="accurate",
device_specific=True,
)
self._test_cli_(
["--model", "tiny", "--language", "Japanese", "--accurate"],
"tiny_auto",
files=["japanese.mp3"],
prefix="accurate_jp",
device_specific=True,
)
class TestTranscribeFormats(TestHelperCli):
def test_cli_outputs(self):
files = ["punctuations.mp3", "bonjour.wav"]
extensions = ["txt", "srt", "vtt", "words.srt", "words.vtt",
"words.json", "csv", "words.csv", "tsv", "words.tsv"]
opts = ["--model", "medium", "--language", "fr"]
# An audio / model combination that produces coma
self._test_cli_(
opts,
"punctuations_yes",
files=files,
extensions=extensions,
one_per_call=False,
)
self._test_cli_(
opts + ["--punctuations", "False"],
"punctuations_no",
files=files,
extensions=extensions,
one_per_call=False,
)
def test_verbose(self):
files = ["bonjour_vous_allez_bien.mp3"]
opts = ["--model", "tiny", "--verbose", "True"]
self._test_cli_(
["--efficient", *opts],
"verbose", files=files, extensions=None,
prefix="efficient.auto",
device_specific=True,
)
self._test_cli_(
["--language", "fr", "--efficient", *opts],
"verbose", files=files, extensions=None,
prefix="efficient.fr",
device_specific=True,
)
self._test_cli_(
opts,
"verbose", files=files, extensions=None,
prefix="accurate.auto",
device_specific=True,
)
self._test_cli_(
["--language", "fr", *opts],
"verbose", files=files, extensions=None,
prefix="accurate.fr",
device_specific=True,
)
class TestMakeSubtitles(TestHelper):
def test_make_subtitles(self):
main_script = self.main_script("make_subtitles.py", "whisper_timestamped_make_subtitles")
inputs = [
self.get_data_path("smartphone.mp3.words.json"),
self.get_data_path("no_punctuations.mp3.words.json", check=True),
self.get_data_path("yes_punctuations.mp3.words.json", check=True),
]
for i, input in enumerate(inputs):
filename = os.path.basename(input).replace(".words.json", "")
for len in 6, 20, 50:
output_dir = self.get_output_path()
self.assertRun([main_script,
input if i > 0 else self.get_data_path(), output_dir,
"--max_length", str(len),
])
for format in "vtt", "srt",:
output_file = os.path.join(output_dir, f"{filename}.{format}")
self.assertTrue(os.path.isfile(output_file), msg=f"File {output_file} not found")
expected_file = f"split_subtitles/{filename.split('_')[-1]}_{len}.{format}"
self.assertNonRegression(output_file, expected_file)
os.remove(output_file)
self.assertRun([main_script,
input, output_file,
"--max_length", str(len),
])
self.assertTrue(os.path.isfile(output_file), msg=f"File {output_file} not found")
self.assertNonRegression(output_file, expected_file)
class TestHuggingFaceModel(TestHelperCli):
def test_hugging_face_model(self):
self._test_cli_(
["--model", "qanastek/whisper-tiny-french-cased", "--verbose", "True"],
"verbose", files=["bonjour.wav"], extensions=None,
prefix="hf",
device_specific=True,
)
import tempfile
from transformers import WhisperForConditionalGeneration, WhisperProcessor, GenerationConfig
tempfolder = os.path.join(tempfile.gettempdir(), "tmp_whisper-tiny-french-cased")
for safe_serialization in False, True,:
for max_shard_size in "100MB", "10GB", :
shutil.rmtree(tempfolder, ignore_errors=True)
model = WhisperForConditionalGeneration.from_pretrained("qanastek/whisper-tiny-french-cased")
processor = WhisperProcessor.from_pretrained("qanastek/whisper-tiny-french-cased")
try:
model.save_pretrained(tempfolder, safe_serialization=safe_serialization, max_shard_size=max_shard_size)
processor.save_pretrained(tempfolder)
self._test_cli_(
["--model", tempfolder, "--verbose", "True"],
"verbose", files=["bonjour.wav"], extensions=None,
prefix="hf",
device_specific=True,
)
finally:
shutil.rmtree(tempfolder)
# "ZZZ" to run this test at last (because it will fill the CUDA with some memory)
class TestZZZPythonImport(TestHelper):
def test_python_import(self):
try:
import whisper_timestamped
except ModuleNotFoundError:
sys.path.append(os.path.realpath(
os.path.dirname(os.path.dirname(__file__))))
import whisper_timestamped
# Test version
version = whisper_timestamped.__version__
self.assertTrue(isinstance(version, str))
(stdout, sterr) = self.assertRun([self.main_script(), "-v"])
self.assertEqual(stdout.strip(), version)
model = whisper_timestamped.load_model("tiny")
# Check processing of different files
for filename in "bonjour.wav", "laugh1.mp3", "laugh2.mp3":
res = whisper_timestamped.transcribe(
model, self.get_data_path(filename))
if self._can_generate_reference():
self.assertNonRegression(res, f"tiny_auto/{filename}.words.json")
for filename in "bonjour.wav", "laugh1.mp3", "laugh2.mp3":
res = whisper_timestamped.transcribe(
model, self.get_data_path(filename), language="fr")
if self._can_generate_reference():
self.assertNonRegression(res, f"tiny_fr/{filename}.words.json")
def _can_generate_reference(self):
return not GENERATE_DEVICE_DEPENDENT or self.get_device_str() != "cpu"
def test_split_tokens(self):
import whisper
whisperversion = whisper.__version__
import whisper_timestamped as whisper
from whisper_timestamped.transcribe import split_tokens_on_spaces
tokenizer = whisper.tokenizer.get_tokenizer(True, language=None)
# 220 means space
tokens = [50364, 220, 6455, 11, 2232, 11, 286, 2041, 11, 2232, 11, 8660,
291, 808, 493, 220, 365, 11, 220, 445, 718, 505, 458, 13, 220, 50714]
self.assertEqual(
split_tokens_on_spaces(tokens, tokenizer),
(['<|0.00|>', 'So,', 'uh,', 'I', 'guess,', 'uh,', 'wherever', 'you', 'come', 'up', 'with,', 'just', 'let', 'us', 'know.', '<|7.00|>'],
[['<|0.00|>'],
[' ', 'So', ','],
[' uh', ','],
[' I'],
[' guess', ','],
[' uh', ','],
[' wherever'],
[' you'],
[' come'],
[' up'],
[' ', ' with', ','],
[' ', ' just'],
[' let'],
[' us'],
[' know', '.', ' '],
['<|7.00|>']],
[[50364],
[220, 6455, 11],
[2232, 11],
[286],
[2041, 11],
[2232, 11],
[8660],
[291],
[808],
[493],
[220, 365, 11],
[220, 445],
[718],
[505],
[458, 13, 220],
[50714]
])
)
tokens = [50366, 314, 6, 11771, 17134, 11, 4666, 11, 1022, 220, 875, 2557, 68, 11, 6992, 631, 269, 6, 377, 220, 409, 7282, 1956, 871, 566, 2707, 394, 1956, 256, 622, 8208, 631, 8208, 871, 517, 7282, 1956, 5977, 7418, 371, 1004, 306, 580, 11, 5977, 12, 9498, 9505, 84, 6, 50416]
self.assertEqual(
split_tokens_on_spaces(tokens, tokenizer),
(
['<|0.04|>', "T'façon,", 'nous,', 'sur', 'la', 'touche,', 'parce', 'que', "c'est", 'un', 'sport', 'qui', 'est', 'important', 'qui', 'tue', 'deux', 'que', 'deux', 'est', 'un', 'sport', 'qui', 'peut', 'être', 'violent,', 'peut-être', "qu'", '<|1.04|>'],
[['<|0.04|>'],
[' T', "'", 'fa', 'çon', ','],
[' nous', ','],
[' sur'],
[' ', 'la'],
[' touch', 'e', ','],
[' parce'],
[' que'],
[' c', "'", 'est'],
[' ', 'un'],
[' sport'],
[' qui'],
[' est'],
[' im', 'port', 'ant'],
[' qui'],
[' t', 'ue'],
[' deux'],
[' que'],
[' deux'],
[' est'],
[' un'],
[' sport'],
[' qui'],
[' peut'],
[' être'],
[' v', 'io', 'le', 'nt', ','],
[' peut', '-', 'être'],
[' q', 'u', "'"],
['<|1.04|>']],
[[50366],
[314, 6, 11771, 17134, 11],
[4666, 11],
[1022],
[220, 875],
[2557, 68, 11],
[6992],
[631],
[269, 6, 377],
[220, 409],
[7282],
[1956],
[871],
[566, 2707, 394],
[1956],
[256, 622],
[8208],
[631],
[8208],
[871],
[517],
[7282],
[1956],
[5977],
[7418],
[371, 1004, 306, 580, 11],
[5977, 12, 9498],
[9505, 84, 6],
[50416]]
)
)
tokens = [50364, 220, 220, 6455, 11, 220, 220, 2232, 220, 220, 11, 220, 50714]
self.assertEqual(
split_tokens_on_spaces(tokens, tokenizer),
(['<|0.00|>', 'So,', 'uh', ',', '<|7.00|>'],
[['<|0.00|>'],
[' ', ' ', 'So', ','],
[' ', ' ', ' uh'],
[' ', ' ', ',', ' '],
['<|7.00|>']],
[[50364], [220, 220, 6455, 11], [220, 220, 2232], [220, 220, 11, 220], [50714]]
)
)
# Careful with the double spaces at the end...
tokens = [50364, 220, 220, 6455, 11, 220, 220, 2232, 220, 220, 11, 220, 220, 50714]
self.assertEqual(
split_tokens_on_spaces(tokens, tokenizer),
(['<|0.00|>', 'So,', 'uh', ',', '', '<|7.00|>'],
[['<|0.00|>'],
[' ', ' ', 'So', ','],
[' ', ' ', ' uh'],
[' ', ' ', ','],
[' ', ' '],
['<|7.00|>']],
[[50364], [220, 220, 6455, 11], [220, 220, 2232], [220, 220, 11], [220, 220], [50714]]
)
)
# Tokens that could be removed
tokens = [50364, 6024, 95, 8848, 7649, 8717, 38251, 11703, 3224, 51864]
self.assertEqual(
split_tokens_on_spaces(tokens, tokenizer),
(['<|0.00|>', 'الآذان', 'نسمّه', '<|30.00|>'],
[['<|0.00|>'], ['', ' الآ', 'ذ', 'ان'], [' ن', 'سم', 'ّ', 'ه'], ['<|30.00|>']],
[[50364], [6024, 95, 8848, 7649], [8717, 38251, 11703, 3224], [51864]]
)
)
# issue #61
# Special tokens that are not timestamps
tokens = [50414, 805, 12, 17, 50299, 11, 568, 12, 18, 12, 21, 11, 502, 12, 17, 12, 51464]
# 50299 is "<|te|>" and appears as ""
te = ""
self.assertEqual(
split_tokens_on_spaces(tokens, tokenizer),
(['<|1.00|>', f'3-2{te},', '2-3-6,', '1-2-', '<|22.00|>'],
[['<|1.00|>'], [' 3', '-', '2', f'{te}', ','], [' 2', '-', '3', '-','6', ','], [' 1', '-', '2', '-'], ['<|22.00|>']],
[[50414], [805, 12, 17, 50299, 11], [568, 12, 18, 12, 21, 11], [502, 12, 17, 12], [51464]])
)
tokenizer = whisper.tokenizer.get_tokenizer(False, language="en")
# Just a punctuation character
tokens = [50363, 764, 51813]
_dot = "." if whisperversion < "20230314" else " ."
self.assertEqual(
split_tokens_on_spaces(tokens, tokenizer),
(['<|0.00|>', ".", '<|29.00|>'],
[['<|0.00|>'], [_dot], ['<|29.00|>']],
[[50363], [764], [51813]]
)
)