Spaces:
Running
Running
import logging | |
import os | |
from collections import defaultdict | |
from pathlib import Path | |
from typing import Union | |
import pandas as pd | |
import torch | |
from torch.utils.data.dataset import Dataset | |
log = logging.getLogger() | |
class AudioCapsData(Dataset): | |
def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]): | |
df = pd.read_csv(csv_path).to_dict(orient='records') | |
audio_files = sorted(os.listdir(audio_path)) | |
audio_files = set( | |
[Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')]) | |
self.data = [] | |
for row in df: | |
self.data.append({ | |
'name': row['name'], | |
'caption': row['caption'], | |
}) | |
self.audio_path = Path(audio_path) | |
self.csv_path = Path(csv_path) | |
log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}') | |
def __getitem__(self, idx: int) -> torch.Tensor: | |
return self.data[idx] | |
def __len__(self): | |
return len(self.data) | |