Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import time | |
from functools import partial | |
import torch | |
import torch.nn.functional as F | |
from torch import Tensor | |
def load_what_you_can(checkpoint: dict, model: torch.nn.Module): | |
""" | |
This method takes a checkpoint and loads as many weights from it as possible: | |
If they are the same shape, there's nothing to do | |
Will load the smallest shape otherwise. | |
""" | |
import torch | |
model_state_dict = model.state_dict() | |
checkpoint_state_dict = checkpoint | |
for name, param in checkpoint_state_dict.items(): | |
if name not in model_state_dict: | |
print(f"Ignoring parameter '{name}' because it is not found in the model") | |
continue | |
model_state = model_state_dict[name] | |
mshape = model_state.shape | |
pshape = param.shape | |
if pshape == mshape: | |
model_state.copy_(param) | |
continue | |
if len(pshape) != len(mshape): | |
# Completely different shapes so probably unwise to merge | |
continue | |
min_shape = [ | |
min(param.shape[i], model_state.shape[i]) for i in range(len(param.shape)) | |
] | |
print(name, "model:", mshape, "chkpt:", pshape, "loading:", min_shape) | |
idxs = torch.meshgrid(*[torch.arange(s) for s in min_shape]) | |
model_state[tuple(idxs)].copy_(param[tuple(idxs)]) | |
return model.load_state_dict(model_state_dict) | |
def multimap( | |
items: list, func: callable, workers=4, desc=None, thread=False, chunk_size=128 | |
) -> list: | |
""" | |
Quick and dirty multiprocessing that will return the result of func if it returns None | |
""" | |
from tqdm.contrib.concurrent import process_map, thread_map | |
m = thread_map if thread else process_map | |
length = None | |
try: | |
length = len(items) | |
except Exception as e: | |
print(e, "getting length") | |
results = m( | |
func, | |
items, | |
leave=False, | |
desc=desc, | |
max_workers=workers, | |
total=length, | |
chunksize=chunk_size, | |
) | |
return list(filter(lambda x: x is not None, results)) | |
def round_up(num: float, factor: int): | |
return factor * math.ceil(num / factor) | |
def left_padding_mask(lengths, max_len, device=None, dtype=None): | |
masks = [] | |
if not max_len: | |
max_len = max(lengths) | |
for l in lengths: | |
mask = torch.empty(l, l, device=device, dtype=dtype).fill_(-torch.inf).triu_(1) | |
diff = max_len - l | |
mask = F.pad(mask, (diff, 0, diff, 0), value=-torch.inf) | |
masks.append(mask) | |
masks = torch.stack(masks) | |
return masks[:, None] | |
def seed_all(seed: int): | |
import random | |
import numpy as np | |
import torch | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
def split_bucket_path(url: str) -> tuple[str, str]: | |
url = url.replace("s3://", "") | |
url = url.replace("sj://", "") | |
url = url.replace("r2://", "") | |
bucket = url.split("/")[0] | |
path = "/".join(url.split("/")[1:]) | |
return bucket, path | |
def prob_mask_like(shape, prob: float, device): | |
import torch | |
if prob == 1: | |
return torch.ones(shape, device=device, dtype=torch.bool) | |
elif prob == 0: | |
return torch.zeros(shape, device=device, dtype=torch.bool) | |
else: | |
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob | |
def round_up_to_multiple(n: int, multiple: int) -> int: | |
if n % multiple != 0: | |
n += multiple - (n % multiple) | |
return n | |
def warmup_then_cosine_decay( | |
step: int, *, warmup_steps: int, steps: int, min_lr: float, max_lr: float | |
): | |
eps = 1e-9 | |
cooldown_steps = warmup_steps | |
if step < warmup_steps: | |
return min_lr + step * (max_lr - min_lr) / (warmup_steps) | |
elif step > steps: | |
return min_lr | |
elif step < steps - cooldown_steps: | |
decay_ratio = (step - warmup_steps) / (steps - warmup_steps - cooldown_steps) | |
# assert 0 <= decay_ratio <= 1 | |
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) | |
return min_lr + coeff * (max_lr - min_lr) | |
else: | |
# decay from min_lr to 0 | |
return min_lr * (steps - step) / cooldown_steps + eps | |
def decay_to_zero(step: int, *, decay_steps: int, steps: int, max_lr: float): | |
if step > steps: | |
return 0.0 | |
else: | |
gradient = -max_lr / decay_steps | |
return max_lr + gradient * step | |
def cross_entropy_loss(logits, mask, targets): | |
import torch | |
import torch.nn.functional as F | |
B, Q, T, _ = logits.size() | |
assert logits.shape[:-1] == targets.shape | |
assert mask.shape == targets.shape | |
loss = torch.zeros([], device=targets.device) | |
codebook_losses = [] | |
for q in range(Q): | |
logits_q = ( | |
logits[:, q, ...].contiguous().view(-1, logits.size(-1)) | |
) # [B x T, card] | |
targets_q = targets[:, q, ...].contiguous().view(-1) # [B x T] | |
mask_q = mask[:, q, ...].contiguous().view(-1) # [B x T] | |
ce_targets = targets_q[mask_q] | |
ce_logits = logits_q[mask_q] | |
q_ce = F.cross_entropy(ce_logits, ce_targets) | |
loss += q_ce | |
codebook_losses.append(q_ce.detach()) | |
# average cross entropy across codebooks | |
loss = loss / Q | |
return loss, codebook_losses | |
def build_optimizer( | |
module, *, weight_decay: float, lr: float, betas: tuple[float, float] | |
): | |
import torch | |
param_dict = {pn: p for pn, p in module.named_parameters() if p.requires_grad} | |
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. | |
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. | |
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] | |
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] | |
optim_groups = [ | |
{"params": decay_params, "weight_decay": weight_decay}, | |
{"params": nodecay_params, "weight_decay": 0.0}, | |
] | |
# num_decay_params = sum(p.numel() for p in decay_params) | |
# num_nodecay_params = sum(p.numel() for p in nodecay_params) | |
# print( | |
# f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters" | |
# ) | |
# print( | |
# f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters" | |
# ) | |
optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=betas, fused=True) | |
return optimizer | |
def pad_or_cut_right(t: Tensor, padlen: int, value=0) -> Tensor: | |
current_len = t.shape[-1] | |
if current_len == padlen: | |
return t | |
if current_len < padlen: | |
# Need to pad | |
pad_size = (0, padlen - current_len) | |
return F.pad(t, pad_size, value=value) | |
# Need to cut | |
return t[:padlen] | |
def pad_or_cut_left(t: Tensor, value: int) -> Tensor: | |
dims = t.ndim | |
current_len = t.shape[0] | |
if current_len == value: | |
return t | |
if current_len < value: | |
# Need to pad | |
pad_size = (0,) * (2 * (dims - 1)) + (value - current_len, 0) | |
return F.pad(t, pad_size) | |
# Need to cut | |
return t[-value:] | |
def dl_pt(orig: str): | |
from os.path import exists | |
import torch | |
from vui.storage import s3, split_bucket_path | |
if not orig.endswith(".pt"): | |
orig = orig + ".pt" | |
load = partial(torch.load, weights_only=True) | |
if exists(orig): | |
return load(orig) | |
url = "/data/" + orig | |
if exists(url): | |
return load(url) | |
url = "s3://fluxions/" + orig | |
bucket, key = split_bucket_path(url) | |
response = s3.get_object(Bucket=bucket, Key=key) | |
return load(response["Body"]) | |
def dl_ogg(url: str, start=0, end=-1, sr=None): | |
import re | |
from os.path import exists | |
import soundfile as sf | |
import torch | |
search_sr = re.search(r"(\d+)/", url) | |
if search_sr: | |
sr = int(search_sr.group(1)) | |
local_file = exists(url) | |
if exists("/data/audio/" + url): | |
local_file = True | |
url = "/data/audio/" + url | |
if not local_file: | |
from vui.storage import s3 | |
url = "s3://fluxions/" + url | |
b, p = split_bucket_path(url) | |
url = s3.get_object(Bucket=b, Key=p)["Body"] | |
if sr is None: | |
if local_file: | |
sr = sf.info(url).samplerate | |
else: | |
sr = sf.info(url.read()).samplerate | |
start_frame = int(start * sr) | |
num_frames = int(end * sr) - start_frame | |
wav, _ = sf.read(url, frames=num_frames, start=start_frame, always_2d=True) | |
wav = torch.from_numpy(wav).float() | |
wav = wav.T.mean(0, keepdim=True) | |
return wav, sr | |
class timer: | |
def __init__(self, name=""): | |
self.name = name | |
def __enter__(self): | |
self.t = time.perf_counter() | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
elapsed = time.perf_counter() - self.t | |
print(f"{self.name} {elapsed:.4f}") | |
def decode_audio_from_indices(model, indices, chunk_size=64): | |
""" | |
Decodes audio from indices in batches to avoid memory issues. | |
Args: | |
model: Codec | |
indices: Tensor of shape (1, n_quantizers, sequence_length) | |
chunk_size: Number of samples to process at once | |
Returns: | |
Tensor of reconstructed audio | |
""" | |
device = model.device | |
indices = indices.to(device) | |
_, _, seq_len = indices.shape | |
chunks = seq_len // chunk_size + (1 if seq_len % chunk_size != 0 else 0) | |
audio_chunks = [] | |
for i in range(chunks): | |
start_idx = i * chunk_size | |
end_idx = min(start_idx + chunk_size, seq_len) | |
chunk_indices = indices[:, :, start_idx:end_idx] | |
chunk_audio = model.from_indices(chunk_indices) | |
audio_chunks.append(chunk_audio.cpu()) | |
full_audio = torch.cat(audio_chunks, dim=-1) | |
return full_audio.flatten() | |
def normalize_loudness(waveform, sample_rate: int, lufs: float = -12.0): | |
""" | |
Normalize the loudness of an audio tensor using torchaudio.transforms.Loudness. | |
Args: | |
audio_tensor (torch.Tensor): Input audio tensor of shape (channels, samples) | |
sample_rate (int): Sampling rate of the audio | |
target_loudness (float): Target loudness in LUFS (default: -16.0 LUFS) | |
Returns: | |
torch.Tensor: Loudness-normalized audio tensor | |
""" | |
import torchaudio | |
# Ensure the input tensor is 2D (add channel dimension if it's 1D) | |
if waveform.ndim == 1: | |
waveform = waveform.unsqueeze(0) | |
# Create a Loudness transform | |
loudness_transform = torchaudio.transforms.Loudness(sample_rate) | |
# Measure the current loudness | |
current_loudness = loudness_transform(waveform) | |
# Calculate the required gain | |
gain_db = lufs - current_loudness | |
# Convert gain from dB to linear scale | |
gain_linear = torch.pow(10, gain_db / 20) | |
# Apply the gain to normalize loudness | |
normalized_audio = waveform * gain_linear | |
return normalized_audio | |
def get_basename_without_extension(file_path): | |
from pathlib import Path | |
p = Path(file_path) | |
return p.stem | |
def ollama(prompt, MODEL=None): | |
import os | |
import requests | |
OLLAMA_HOST = "http://localhost:11434" | |
API = f"{OLLAMA_HOST}/api/generate" | |
if MODEL is None: | |
MODEL = os.environ.get("OLLAMA_MODEL", "gemma:1b") | |
payload = { | |
"model": MODEL, | |
"prompt": prompt, | |
"stream": False, | |
"options": {"temperature": 0.9, "top_p": 0.9, "max_tokens": 1000}, | |
} | |
try: | |
response = requests.post(API, json=payload) | |
response.raise_for_status() # Raise exception for HTTP errors | |
result = response.json() | |
return result.get("response", "") | |
except requests.exceptions.RequestException as e: | |
print(f"Error calling Ollama API: {e}") | |
return "" | |
def decompile_state_dict(state_dict): | |
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} | |
# state_dict = convert_old_weight_norm_to_new(state_dict) | |
return {k.replace("module.", ""): v for k, v in state_dict.items()} | |