Spaces:
Running
on
L40S
Running
on
L40S
from torch.utils.data import Dataset | |
from beartype.typing import Sequence, Callable, Optional, Dict, List | |
from beartype.door import is_bearable | |
import random | |
import os | |
from torchaudio.functional import resample | |
import torch | |
import typing as tp | |
from pathlib import Path | |
import torchaudio as ta | |
import torch.nn.functional as F | |
import soundfile | |
import numpy as np | |
import json | |
import yaml | |
import random | |
import librosa | |
from loguru import logger | |
import re | |
def _av_read(filepath, seek_time=0, duration=None): | |
if duration is not None: | |
sr = librosa.get_samplerate(filepath) | |
offset = seek_time | |
num_samples = int(duration * sr) | |
wav, _ = librosa.load(filepath, sr=sr, offset=offset, duration=duration) | |
else: | |
wav, sr = librosa.load(filepath, sr=None, offset=seek_time) | |
return wav, sr | |
def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., | |
duration: float = -1., pad: bool = True) -> tp.Tuple[torch.Tensor, int]: | |
"""Read audio by picking the most appropriate backend tool based on the audio format. | |
Args: | |
filepath (str or Path): Path to audio file to read. | |
seek_time (float): Time at which to start reading in the file. | |
duration (float): Duration to read from the file. If set to -1, the whole file is read. | |
pad (bool): Pad output audio if not reaching expected duration. | |
Returns: | |
tuple of torch.Tensor, int: Tuple containing audio data and sample rate. | |
""" | |
fp = Path(filepath) | |
if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg | |
# There is some bug with ffmpeg and reading flac | |
info = soundfile.info(filepath) | |
frames = -1 if duration <= 0 else int(duration * info.samplerate) | |
frame_offset = int(seek_time * info.samplerate) | |
wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32) | |
assert info.samplerate == sr, f"Mismatch of sample rates {info.samplerate} {sr}" | |
wav = torch.from_numpy(wav).t().contiguous() | |
if len(wav.shape) == 1: | |
wav = torch.unsqueeze(wav, 0) | |
elif ( | |
fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats() | |
and duration <= 0 and seek_time == 0 | |
): | |
# Torchaudio is faster if we load an entire file at once. | |
wav, sr = librosa.load(fp, sr=None, mono=True) | |
else: | |
wav, sr = _av_read(filepath, seek_time, duration) | |
if pad and duration > 0: | |
expected_frames = int(duration * sr) | |
wav = F.pad(torch.tensor(wav), (0, expected_frames - wav.shape[-1])) | |
if not isinstance(wav, torch.Tensor): | |
wav = torch.tensor(wav) | |
return wav, sr | |
def random_seek_read(filepath, duration): | |
if duration > 0: | |
total_duration = librosa.get_duration(path=filepath) | |
acceptable_start = max(0, total_duration - duration) | |
wav, sr = audio_read(filepath, random.uniform(0, acceptable_start), duration, pad=True) | |
else: | |
wav, sr = audio_read(filepath, 0, -1, pad=False) | |
return wav, sr | |
def safe_random_seek_read(filepath, duration, sample_rate): | |
try: | |
wav, sr = random_seek_read(filepath, duration) | |
if sr != sample_rate: | |
wav = resample(wav, sr, sample_rate) | |
sr = sample_rate | |
except Exception as e: | |
logger.error(f"Error reading {filepath}: {e}") | |
sr = sample_rate | |
wav = torch.zeros(sr * max(duration, 0), dtype=torch.float32) | |
return wav, sr | |
def read_jsonlike(path: os.PathLike): | |
#json or jsonl | |
if str(path).endswith(".json"): | |
with open(path, 'r', encoding='utf8') as f: | |
data = json.load(f) | |
return data | |
elif str(path).endswith(".jsonl"): | |
with open(path, 'r', encoding='utf8') as f: | |
data = [json.loads(line) for line in f.readlines()] | |
return data | |
else: | |
raise ValueError("Unknown file format") | |
dist_prob_map = { | |
1: (1.0,), | |
2: (0.5, 0.5), | |
3: (0.3, 0.4, 0.3), | |
4: (0.2, 0.3, 0.3, 0.2), | |
5: (0.2, 0.2, 0.3, 0.2, 0.1), | |
6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15), | |
7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1), | |
8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12), | |
9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08), | |
10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09) | |
} | |
dist_prob_map_low = { | |
1: (1.0,), | |
2: (0.8, 0.2), | |
3: (0.8, 0.1, 0.1), | |
4: (0.7, 0.1, 0.1, 0.1), | |
5: (0.7, 0.1, 0.1, 0.05, 0.05), | |
6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05), | |
} | |
_bpm_range_rights = ( | |
(40, '20-40'), | |
(60, '40-60'), | |
(66, '60-66'), | |
(76, '66-76'), | |
(108, '76-108'), | |
(120, '108-120'), | |
(168, '120-168'), | |
(176, '168-176'), | |
(200, '176-200') | |
) | |
_bpm_desc_map = { | |
'20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"), | |
'40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"), | |
'60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'), | |
'66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'), | |
'76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"), | |
'108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'), | |
'120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'), | |
'168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'), | |
'176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'), | |
'>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo') | |
} | |
_bpm_desc_map_zh = { | |
'20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"), | |
'40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"), | |
'60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'), | |
'66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'), | |
'76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"), | |
'108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'), | |
'120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'), | |
'168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'), | |
'176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'), | |
'>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板') | |
} | |
def get_bpm_range(bpm): | |
bpm = int(bpm) | |
for right, tag in _bpm_range_rights: | |
if bpm <= right: | |
return tag | |
return '>200' | |
def gen_bpm_descript(bpm, lang='en'): | |
bpm_range = get_bpm_range(bpm) | |
if lang == 'en': | |
return random.choice(_bpm_desc_map[bpm_range]) | |
elif lang == 'zh': | |
return random.choice(_bpm_desc_map_zh[bpm_range]) | |
else: | |
raise ValueError(f"Unknown language {lang}") | |
def read_translate(translate: Optional[Dict[str, os.PathLike]]): | |
if translate is None: | |
return None | |
return {k: read_jsonlike(path) for k, path in translate.items()} | |
def tags_to_desc(tag_list, sep=',') -> str: | |
if not isinstance(tag_list, Sequence): | |
return str(tag_list) | |
if isinstance(tag_list, str): | |
return tag_list | |
if len(tag_list) <= 0: | |
return '' | |
elif len(tag_list) <= 5: | |
probs = dist_prob_map[len(tag_list)] | |
tags_num = random.choices(range(1, len(tag_list)+1), probs)[0] | |
random.shuffle(tag_list) | |
tag_list = tag_list[:tags_num] | |
return sep.join(tag_list) | |
else: | |
probs = dist_prob_map[5] | |
tags_num = random.choices(range(1, 6), probs)[0] | |
random.shuffle(tag_list) | |
tag_list = tag_list[:tags_num] | |
return sep.join(tag_list) | |
class PromptTemplate: | |
def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'): | |
self.template_text = template_text | |
self.tag_map = tag_map | |
self.lang = lang | |
def tags(self): | |
return tuple(self.tag_map.keys()) | |
def apply(self, **kwargs): | |
for tag in list(kwargs.keys()): | |
if kwargs[tag] == '': | |
kwargs.pop(tag) | |
for tag in self.tags: | |
if tag in kwargs: | |
kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]') | |
else: | |
kwargs[tag] = '' | |
prompt = self.template_text.format(**kwargs) | |
return self.beautify(prompt) | |
def beautify(self, text): | |
if self.lang == 'en': | |
return self._beautify_en(text) | |
elif self.lang == 'zh': | |
return self._beautify_zh(text) | |
else: | |
raise ValueError(f'Unknown language {self.lang}') | |
def _beautify_en(text): | |
# no continuous commas without content between them | |
text = re.sub(r'[,\s]*,[,\s]*', r', ', text) | |
# no continuous whitespace | |
text = re.sub(r'\s+', ' ', text) | |
# the comma is NOT followed by whitespace, and should be followed by ONE whitespace | |
text = re.sub(r'\s+,', r',', text) | |
text = re.sub(r',\s+', r', ', text) | |
# no whitespace before the full stop | |
text = re.sub(r'\s+\.', r'.', text) | |
# strip whitespace, comma, and replace ',.' | |
text = text.strip(' ,') | |
text = text.replace(',.', '.') | |
return text | |
def _beautify_zh(text): | |
# no continuous commas without content between them | |
text = re.sub(r'[,、\s]*,[,、\s]*', r',', text) | |
text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text) | |
# assume there should be NO whitespace in Chinese | |
text = re.sub(r'\s+', r'', text) | |
# strip whitespace, comma, and replace ',。' | |
text = text.strip(', 、') | |
text = text.replace(',。', '。') | |
return text | |
def __repr__(self): | |
return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})' | |
__str__ = __repr__ | |
def parse_prompt_template(prompt_template_text, lang='en'): | |
span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL) | |
tag_pattern = re.compile(r'{.+?}', re.DOTALL) | |
template_text = prompt_template_text.strip() | |
span_texts = span_pattern.findall(prompt_template_text) | |
tag_map = {} | |
for span_text in span_texts: | |
tag = tag_pattern.findall(span_text)[0].strip('{}') | |
tag_map[tag] = span_text | |
template_text = template_text.replace(span_text, '{'+tag+'}') | |
return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang) | |
def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]: | |
with open(path, 'r') as f: | |
lines = f.readlines() | |
cnt = 0 | |
pts = [] | |
for line in lines: | |
pt = parse_prompt_template(line, lang=lang) | |
cnt += 1 | |
if len(pt.tags) < num: | |
logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}') | |
pts.append(pt) | |
return pts | |
class AudioStockDataset(Dataset): | |
def __init__(self, | |
num_examples:int, | |
metadata_path:str, | |
duration:float=60, | |
sr:int = 0, | |
return_path = False, | |
return_audio = True, | |
prompt_template_path: os.PathLike = None, | |
tag_types = [], | |
lang = 'en', | |
translate:Optional[Dict[str, os.PathLike]] = None | |
): | |
self.duration = duration | |
self.MAX_DURATION = 360 | |
self._load_metadata(metadata_path) | |
if num_examples > 0: | |
self.random_choose = True | |
self.dataset_len = num_examples | |
else: | |
self.random_choose = False | |
self.dataset_len = len(self.data) | |
self.sr = sr | |
self.return_path = return_path | |
self.return_audio = return_audio | |
self.use_dynamic_prompt = prompt_template_path is not None | |
if self.use_dynamic_prompt: | |
self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang) | |
self.tag_types = tag_types | |
self.lang = lang | |
self.translate = read_translate(translate) | |
def _load_metadata(self, metadata_path): | |
total_len = 0; valid_len = 0 | |
with open(metadata_path) as fp: | |
lines = fp.readlines() | |
self.data = [] | |
for line in lines: | |
item = json.loads(line) | |
total_len += 1 | |
if(item['duration']>self.duration and item['duration']<self.MAX_DURATION): | |
valid_len += 1 | |
self.data.append(item) | |
print("Filter data from {} to {}".format(total_len, valid_len)) | |
self.is_info_recorded = bool('Tags' in self.data[0]) | |
def __len__(self): | |
return self.dataset_len | |
def __getitem__(self, idx): | |
first_try = True | |
try_cnt = 0 | |
while True: | |
try: | |
if(self.random_choose or not first_try): | |
index2 = np.random.randint(0,len(self.data)) | |
else: | |
index2 = idx | |
first_try = False | |
return self.getitem_main(index2) | |
except: | |
print("Error loadding ", self.data[idx]["path"]) | |
try_cnt += 1 | |
if(try_cnt>10): | |
raise ValueError() | |
def getitem_main(self, idx): | |
path:str = self.data[idx]["path"] | |
json_path = path[:path.rfind('.')] + ".json" | |
if self.is_info_recorded: | |
item = self.data[idx] | |
else: | |
with open(json_path) as fp: | |
item:dict = json.load(fp) | |
description = self.generate_description(item) | |
if self.return_audio: | |
audio, sr = safe_random_seek_read(path, duration=self.duration, sample_rate=self.sr) | |
else: | |
audio = None | |
if self.return_path: | |
return audio, description, path | |
return audio, description | |
def generate_description(self, item): | |
if self.use_dynamic_prompt: | |
# dynamically generate prompt from given prompt template | |
prompt_template = random.choice(self.prompt_templates) | |
description = self.generate_description_dynamic(item, prompt_template) | |
else: | |
# use ordinary static prompt instead | |
description = self.generate_description_ordinary(item) | |
return description | |
def generate_description_dynamic(self, data, prompt_template: PromptTemplate): | |
exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)] | |
if len(exists_tag) > 0: | |
probs = dist_prob_map[len(exists_tag)] | |
tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0] | |
random.shuffle(exists_tag) | |
tags = exists_tag[:tags_num] | |
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags} | |
tags_args = self.handle_BPM_tag(tags_args) | |
prompt = prompt_template.apply(**tags_args) | |
else: | |
# no strong tags, use all weak tags instead | |
prompt = prompt_template.apply() | |
return prompt | |
def tags_to_desc(self, tag_list, tag_type) -> str: | |
if self.lang == 'en': | |
return tags_to_desc(tag_list) | |
elif self.lang == 'zh': | |
if tag_type == 'BPM': | |
return tags_to_desc(tag_list, sep='、') | |
translator = self.translate[tag_type] | |
translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ] | |
return tags_to_desc(translated_tag_list, sep='、') | |
def handle_BPM_tag(self, tags_args): | |
if "BPM" in tags_args and 'BPMDescript' in self.tag_types: | |
bpm = tags_args["BPM"] | |
del tags_args["BPM"] | |
tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript'))) | |
for tag_type in tag_types_used: | |
tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang) | |
return tags_args | |
def generate_description_ordinary(self, data, thresh = 0.3): | |
if self.lang != 'en': | |
raise ValueError(f'Language {self.lang} is not supported for ordinary description generation') | |
description = f'a piece of music by {data["Artist"]}' | |
# Add genre if available | |
if data["Genre"] and random.random() > thresh: | |
genres = ', '.join(data["Genre"]) | |
description += f', belonging to the {genres} genres' | |
# Add moods if available | |
if data["Tags"] and random.random() > thresh: | |
tags = ', '.join(data["Tags"]) | |
description += f'. This track contains the tags:{tags}' | |
# Add moods if available | |
if data["Mood"] and random.random() > thresh: | |
moods = ', '.join(data["Mood"]) | |
description += f'. This track conveys a {moods} mood.' | |
# Add instruments if available | |
if data["Instrument"] and random.random() > thresh: | |
instruments = ', '.join(data["Instrument"]) | |
description += f'. and primarily features the following instruments: {instruments}' | |
# Add a period to end the description | |
description += '.' | |
return description | |