tortoise5c / tortoise /utils /wav2vec_alignment.py
djkesu's picture
added model
3bbf2c7
raw
history blame
6.4 kB
import torch
import torchaudio
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2ForCTC
def max_alignment(s1, s2, skip_character="~", record=None):
"""
A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is
used to replace that character.
Finally got to use my DP skills!
"""
if record is None:
record = {}
assert (
skip_character not in s1
), f"Found the skip character {skip_character} in the provided string, {s1}"
if len(s1) == 0:
return ""
if len(s2) == 0:
return skip_character * len(s1)
if s1 == s2:
return s1
if s1[0] == s2[0]:
return s1[0] + max_alignment(s1[1:], s2[1:], skip_character, record)
take_s1_key = (len(s1), len(s2) - 1)
if take_s1_key in record:
take_s1, take_s1_score = record[take_s1_key]
else:
take_s1 = max_alignment(s1, s2[1:], skip_character, record)
take_s1_score = len(take_s1.replace(skip_character, ""))
record[take_s1_key] = (take_s1, take_s1_score)
take_s2_key = (len(s1) - 1, len(s2))
if take_s2_key in record:
take_s2, take_s2_score = record[take_s2_key]
else:
take_s2 = max_alignment(s1[1:], s2, skip_character, record)
take_s2_score = len(take_s2.replace(skip_character, ""))
record[take_s2_key] = (take_s2, take_s2_score)
return take_s1 if take_s1_score > take_s2_score else skip_character + take_s2
class Wav2VecAlignment:
"""
Uses wav2vec2 to perform audio<->text alignment.
"""
def __init__(self, device="cuda"):
self.model = Wav2Vec2ForCTC.from_pretrained(
"jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli"
).cpu()
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"facebook/wav2vec2-large-960h"
)
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
"jbetker/tacotron-symbols"
)
self.device = device
def align(self, audio, expected_text, audio_sample_rate=24000):
orig_len = audio.shape[-1]
with torch.no_grad():
self.model = self.model.to(self.device)
audio = audio.to(self.device)
audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
logits = self.model(clip_norm).logits
self.model = self.model.cpu()
logits = logits[0]
pred_string = self.tokenizer.decode(logits.argmax(-1).tolist())
fixed_expectation = max_alignment(expected_text.lower(), pred_string)
w2v_compression = orig_len // logits.shape[0]
expected_tokens = self.tokenizer.encode(fixed_expectation)
expected_chars = list(fixed_expectation)
if len(expected_tokens) == 1:
return [0] # The alignment is simple; there is only one token.
expected_tokens.pop(0) # The first token is a given.
expected_chars.pop(0)
alignments = [0]
def pop_till_you_win():
if len(expected_tokens) == 0:
return None
popped = expected_tokens.pop(0)
popped_char = expected_chars.pop(0)
while popped_char == "~":
alignments.append(-1)
if len(expected_tokens) == 0:
return None
popped = expected_tokens.pop(0)
popped_char = expected_chars.pop(0)
return popped
next_expected_token = pop_till_you_win()
for i, logit in enumerate(logits):
top = logit.argmax()
if next_expected_token == top:
alignments.append(i * w2v_compression)
if len(expected_tokens) > 0:
next_expected_token = pop_till_you_win()
else:
break
pop_till_you_win()
if not (len(expected_tokens) == 0 and len(alignments) == len(expected_text)):
torch.save([audio, expected_text], "alignment_debug.pth")
assert False, (
"Something went wrong with the alignment algorithm. I've dumped a file, 'alignment_debug.pth' to"
"your current working directory. Please report this along with the file so it can get fixed."
)
# Now fix up alignments. Anything with -1 should be interpolated.
alignments.append(
orig_len
) # This'll get removed but makes the algorithm below more readable.
for i in range(len(alignments)):
if alignments[i] == -1:
for j in range(i + 1, len(alignments)):
if alignments[j] != -1:
next_found_token = j
break
for j in range(i, next_found_token):
gap = alignments[next_found_token] - alignments[i - 1]
alignments[j] = (j - i + 1) * gap // (
next_found_token - i + 1
) + alignments[i - 1]
return alignments[:-1]
def redact(self, audio, expected_text, audio_sample_rate=24000):
if "[" not in expected_text:
return audio
splitted = expected_text.split("[")
fully_split = [splitted[0]]
for spl in splitted[1:]:
assert (
"]" in spl
), 'Every "[" character must be paired with a "]" with no nesting.'
fully_split.extend(spl.split("]"))
# At this point, fully_split is a list of strings, with every other string being something that should be redacted.
non_redacted_intervals = []
last_point = 0
for i in range(len(fully_split)):
if i % 2 == 0:
end_interval = max(0, last_point + len(fully_split[i]) - 1)
non_redacted_intervals.append((last_point, end_interval))
last_point += len(fully_split[i])
bare_text = "".join(fully_split)
alignments = self.align(audio, bare_text, audio_sample_rate)
output_audio = []
for nri in non_redacted_intervals:
start, stop = nri
output_audio.append(audio[:, alignments[start] : alignments[stop]])
return torch.cat(output_audio, dim=-1)