Spaces:
Runtime error
Runtime error
import random | |
from collections import defaultdict | |
from concurrent.futures import ThreadPoolExecutor | |
from typing import Tuple, Type | |
# from lhotse import CutSet | |
# from lhotse.dataset.collation import collate_features | |
# from lhotse.dataset.input_strategies import ( | |
# ExecutorType, | |
# PrecomputedFeatures, | |
# _get_executor, | |
# ) | |
# from lhotse.utils import fastcopy | |
class PromptedFeatures: | |
def __init__(self, prompts, features): | |
self.prompts = prompts | |
self.features = features | |
def to(self, device): | |
return PromptedFeatures( | |
self.prompts.to(device), self.features.to(device) | |
) | |
def sum(self): | |
return self.features.sum() | |
def ndim(self): | |
return self.features.ndim | |
def data(self): | |
return (self.prompts, self.features) | |
# class PromptedPrecomputedFeatures(PrecomputedFeatures): | |
# """ | |
# :class:`InputStrategy` that reads pre-computed features, whose manifests | |
# are attached to cuts, from disk. | |
# | |
# It automatically pads the feature matrices with pre or post feature. | |
# | |
# .. automethod:: __call__ | |
# """ | |
# | |
# def __init__( | |
# self, | |
# dataset: str, | |
# cuts: CutSet, | |
# num_workers: int = 0, | |
# executor_type: Type[ExecutorType] = ThreadPoolExecutor, | |
# ) -> None: | |
# super(PromptedPrecomputedFeatures, self).__init__( | |
# num_workers, executor_type | |
# ) | |
# | |
# self.utt2neighbors = defaultdict(lambda: []) | |
# | |
# if dataset.lower() == "libritts": | |
# # 909_131041_000013_000002 | |
# # 909_131041_000013_000003 | |
# speaker2utts = defaultdict(lambda: []) | |
# | |
# utt2cut = {} | |
# for cut in cuts: | |
# speaker = cut.supervisions[0].speaker | |
# speaker2utts[speaker].append(cut.id) | |
# utt2cut[cut.id] = cut | |
# | |
# for spk in speaker2utts: | |
# uttids = sorted(speaker2utts[spk]) | |
# # Using the property of sorted keys to find previous utterance | |
# # The keys has structure speaker_book_x_y e.g. 1089_134691_000004_000001 | |
# if len(uttids) == 1: | |
# self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]]) | |
# continue | |
# | |
# utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1])) | |
# utt2postutt = dict(zip(uttids[:-1], uttids[1:])) | |
# | |
# for utt in utt2prevutt: | |
# self.utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]]) | |
# | |
# for utt in utt2postutt: | |
# self.utt2neighbors[utt].append(utt2cut[utt2postutt[utt]]) | |
# elif dataset.lower() == "ljspeech": | |
# utt2cut = {} | |
# uttids = [] | |
# for cut in cuts: | |
# uttids.append(cut.id) | |
# utt2cut[cut.id] = cut | |
# | |
# if len(uttids) == 1: | |
# self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]]) | |
# else: | |
# # Using the property of sorted keys to find previous utterance | |
# # The keys has structure: LJ001-0010 | |
# utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1])) | |
# utt2postutt = dict(zip(uttids[:-1], uttids[1:])) | |
# | |
# for utt in utt2postutt: | |
# postutt = utt2postutt[utt] | |
# if utt[:5] == postutt[:5]: | |
# self.utt2neighbors[utt].append(utt2cut[postutt]) | |
# | |
# for utt in utt2prevutt: | |
# prevutt = utt2prevutt[utt] | |
# if utt[:5] == prevutt[:5] or not self.utt2neighbors[utt]: | |
# self.utt2neighbors[utt].append(utt2cut[prevutt]) | |
# else: | |
# raise ValueError | |
# | |
# def __call__( | |
# self, cuts: CutSet | |
# ) -> Tuple[PromptedFeatures, PromptedFeatures]: | |
# """ | |
# Reads the pre-computed features from disk/other storage. | |
# The returned shape is``(B, T, F) => (batch_size, num_frames, num_features)``. | |
# | |
# :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding. | |
# """ | |
# features, features_lens = collate_features( | |
# cuts, | |
# executor=_get_executor( | |
# self.num_workers, executor_type=self._executor_type | |
# ), | |
# ) | |
# | |
# prompts_cuts = [] | |
# for k, cut in enumerate(cuts): | |
# prompts_cut = random.choice(self.utt2neighbors[cut.id]) | |
# prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}")) | |
# | |
# mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0]) | |
# # prompts_cuts = CutSet.from_cuts(prompts_cuts).truncate( | |
# # max_duration=mini_duration, | |
# # offset_type="random", | |
# # preserve_id=True, | |
# # ) | |
# prompts_cuts = CutSet( | |
# cuts={k: cut for k, cut in enumerate(prompts_cuts)} | |
# ).truncate( | |
# max_duration=mini_duration, | |
# offset_type="random", | |
# preserve_id=False, | |
# ) | |
# | |
# prompts, prompts_lens = collate_features( | |
# prompts_cuts, | |
# executor=_get_executor( | |
# self.num_workers, executor_type=self._executor_type | |
# ), | |
# ) | |
# | |
# return PromptedFeatures(prompts, features), PromptedFeatures( | |
# prompts_lens, features_lens | |
# ) | |