Spaces:
Running
on
T4
Running
on
T4
File size: 17,601 Bytes
b76a81b 14af4d8 5d66b58 0ffc43b 89b516d 1028cad 9634e77 ba4dacb 8fcd249 14af4d8 1a6de5e aef7fad 14af4d8 74894bc 14af4d8 74894bc aef7fad 14af4d8 aef7fad 2cfed1b aef7fad 1dda6b6 aef7fad 14af4d8 aef7fad 1dda6b6 aef7fad 2cfed1b 14af4d8 9634e77 14af4d8 1dda6b6 14af4d8 a8a94b6 14af4d8 1dda6b6 a8a94b6 505e571 14af4d8 b76a81b e3f64dd efabdc6 b76a81b 8fcd249 b76a81b 8817130 b76a81b 8817130 8fcd249 8817130 14af4d8 aef7fad 14af4d8 e3f64dd 14af4d8 8fcd249 14af4d8 1a6de5e ee1911a 1028cad 14af4d8 1028cad 14af4d8 8fcd249 aef7fad a8a94b6 8fcd249 1a6de5e b76a81b a8a94b6 b76a81b a8a94b6 b76a81b a8a94b6 b76a81b a8a94b6 1a6de5e 14af4d8 1dda6b6 14af4d8 1dda6b6 e7edd0b b76a81b 14af4d8 a8a94b6 14af4d8 e83dc6d b76a81b e3f64dd 14af4d8 5d66b58 ba4dacb 5d66b58 ba4dacb 5d66b58 0ffc43b 1dda6b6 89b516d 0ffc43b 1dda6b6 50d48cc 1a6de5e 50d48cc 1dda6b6 0ffc43b 1dda6b6 e7edd0b 50d48cc 89b516d 0ffc43b 47b87f0 5d66b58 6542012 5d66b58 47b87f0 89b516d 5d66b58 47b87f0 5d66b58 aef7fad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 |
from tabnanny import verbose
import torch
import math
from audiocraft.models import MusicGen
import numpy as np
from PIL import Image, ImageDraw, ImageFont, ImageColor
import string
import tempfile
import os
import textwrap
import requests
from io import BytesIO
from huggingface_hub import hf_hub_download
import librosa
import gradio as gr
import re
from tqdm import tqdm
INTERRUPTING = False
def separate_audio_segments(audio, segment_duration=30, overlap=1):
sr, audio_data = audio[0], audio[1]
segment_samples = sr * segment_duration
total_samples = max(min((len(audio_data) // segment_samples), 25), 0)
overlap_samples = sr * overlap
segments = []
start_sample = 0
# handle the case where the audio is shorter than the segment duration
if total_samples == 0:
total_samples = 1
segment_samples = len(audio_data)
overlap_samples = 0
while total_samples >= segment_samples:
# Collect the segment
# the end sample is the start sample plus the segment samples,
# the start sample, after 0, is minus the overlap samples to account for the overlap
end_sample = start_sample + segment_samples
segment = audio_data[start_sample:end_sample]
segments.append((sr, segment))
start_sample += segment_samples - overlap_samples
total_samples -= segment_samples
# Collect the final segment
if total_samples > 0:
segment = audio_data[-segment_samples:]
segments.append((sr, segment))
print(f"separate_audio_segments: {len(segments)} segments of length {segment_samples // sr} seconds")
return segments
def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:int=1, segment_duration:int=30, prompt_index:int=0, harmony_only:bool= False, progress= gr.Progress(track_tqdm=True)):
# generate audio segments
melody_segments = separate_audio_segments(melody, segment_duration, 0)
# Create lists to store the melody tensors for each segment
melodys = []
output_segments = []
last_chunk = []
text += ", seed=" + str(seed)
prompt_segment = None
# prevent hacking
duration = min(duration, 720)
overlap = min(overlap, 15)
# Calculate the total number of segments
total_segments = max(math.ceil(duration / segment_duration),1)
#calculate duration loss from segment overlap
duration_loss = max(total_segments - 1,0) * math.ceil(overlap / 2)
#calc excess duration
excess_duration = segment_duration - (total_segments * segment_duration - duration)
print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds. Excess {excess_duration} Overlap Loss {duration_loss}")
duration += duration_loss
pbar = tqdm(total=total_segments*2, desc="Generating segments", leave=False)
while excess_duration + duration_loss > segment_duration:
total_segments += 1
#calculate duration loss from segment overlap
duration_loss += math.ceil(overlap / 2)
#calc excess duration
excess_duration = segment_duration - (total_segments * segment_duration - duration)
print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds. Excess {excess_duration} Overlap Loss {duration_loss}")
if excess_duration + duration_loss > segment_duration:
duration += duration_loss
duration_loss = 0
pbar.update(1)
total_segments = min(total_segments, (720 // segment_duration))
# If melody_segments is shorter than total_segments, repeat the segments until the total_segments is reached
if len(melody_segments) < total_segments:
#fix melody_segments
for i in range(total_segments - len(melody_segments)):
segment = melody_segments[i]
melody_segments.append(segment)
pbar.update(1)
print(f"melody_segments: {len(melody_segments)} fixed")
# Iterate over the segments to create list of Meldoy tensors
for segment_idx in range(total_segments):
if INTERRUPTING:
return [], duration
print(f"segment {segment_idx + 1} of {total_segments} \r")
if harmony_only:
# REMOVE PERCUSION FROM MELODY
# Apply HPSS using librosa
verse_harmonic, verse_percussive = librosa.effects.hpss(melody_segments[segment_idx][1])
# Convert the separated components back to torch.Tensor
#harmonic_tensor = torch.from_numpy(verse_harmonic)
#percussive_tensor = torch.from_numpy(verse_percussive)
sr, verse = melody_segments[segment_idx][0], torch.from_numpy(verse_harmonic).to(MODEL.device).float().t().unsqueeze(0)
else:
sr, verse = melody_segments[segment_idx][0], torch.from_numpy(melody_segments[segment_idx][1]).to(MODEL.device).float().t().unsqueeze(0)
print(f"shape:{verse.shape} dim:{verse.dim()}")
if verse.dim() == 2:
verse = verse[None]
verse = verse[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
# Append the segment to the melodys list
melodys.append(verse)
pbar.update(1)
pbar.close()
torch.manual_seed(seed)
# If user selects a prompt segment, generate a new prompt segment to use on all segments
#default to the first segment for prompt conditioning
prompt_verse = melodys[0]
if prompt_index > 0:
# Get a prompt segment from the selected verse, normally the first verse
prompt_verse = melodys[prompt_index if prompt_index <= (total_segments - 1) else (total_segments -1)]
# set the prompt segment MODEL generation params
MODEL.set_generation_params(
use_sampling=True,
top_k=MODEL.generation_params["top_k"],
top_p=MODEL.generation_params["top_p"],
temperature=MODEL.generation_params["temp"],
cfg_coef=MODEL.generation_params["cfg_coef"],
duration=segment_duration,
two_step_cfg=False,
rep_penalty=0.5
)
# Generate a new prompt segment. This will be applied to all segments for consistency
print(f"Generating New Prompt Segment: {text} from verse {prompt_index}\r")
prompt_segment = MODEL.generate_with_all(
descriptions=[text],
melody_wavs=prompt_verse,
sample_rate=sr,
progress=False,
prompt=None,
)
for idx, verse in tqdm(enumerate(melodys), total=len(melodys), desc="Generating melody segments"):
if INTERRUPTING:
return output_segments, duration
print(f'Segment duration: {segment_duration}, duration: {duration}, overlap: {overlap} Overlap Loss: {duration_loss}')
# Compensate for the length of final segment
if ((idx + 1) == len(melodys)) or (duration < segment_duration):
mod_duration = max(min(duration, segment_duration),1)
print(f'Modify verse length, duration: {duration}, overlap: {overlap} Overlap Loss: {duration_loss} to mod duration: {mod_duration}')
MODEL.set_generation_params(
use_sampling=True,
top_k=MODEL.generation_params["top_k"],
top_p=MODEL.generation_params["top_p"],
temperature=MODEL.generation_params["temp"],
cfg_coef=MODEL.generation_params["cfg_coef"],
duration=mod_duration,
two_step_cfg=False,
rep_penalty=0.5
)
try:
# get last chunk
verse = verse[:, :, -mod_duration*MODEL.sample_rate:]
prompt_segment = prompt_segment[:, :, -mod_duration*MODEL.sample_rate:]
except:
# get first chunk
verse = verse[:, :, :mod_duration*MODEL.sample_rate]
prompt_segment = prompt_segment[:, :, :mod_duration*MODEL.sample_rate]
print(f"Generating New Melody Segment {idx + 1}: {text}\r")
output = MODEL.generate_with_all(
descriptions=[text],
melody_wavs=verse,
sample_rate=sr,
progress=True,
prompt=prompt_segment,
)
# If user selects a prompt segment, use the prompt segment for all segments
# Otherwise, use the previous segment as the prompt
if prompt_index < 0:
prompt_segment = output
# Append the generated output to the list of segments
#output_segments.append(output[:, :segment_duration])
output_segments.append(output)
print(f"output_segments: {len(output_segments)}: shape: {output.shape} dim {output.dim()}")
#track duration
if duration > segment_duration:
duration -= segment_duration
return output_segments, excess_duration
def save_image(image):
"""
Saves a PIL image to a temporary file and returns the file path.
Parameters:
- image: PIL.Image
The PIL image object to be saved.
Returns:
- str or None: The file path where the image was saved,
or None if there was an error saving the image.
"""
temp_dir = tempfile.gettempdir()
temp_file = tempfile.NamedTemporaryFile(suffix=".png", dir=temp_dir, delete=False)
temp_file.close()
file_path = temp_file.name
try:
image.save(file_path)
except Exception as e:
print("Unable to save image:", str(e))
return None
finally:
return file_path
def detect_color_format(color):
"""
Detects if the color is in RGB, RGBA, or hex format,
and converts it to an RGBA tuple with integer components.
Args:
color (str or tuple): The color to detect.
Returns:
tuple: The color in RGBA format as a tuple of 4 integers.
Raises:
ValueError: If the input color is not in a recognized format.
"""
# Handle color as a tuple of floats or integers
if isinstance(color, tuple):
if len(color) == 3 or len(color) == 4:
# Ensure all components are numbers
if all(isinstance(c, (int, float)) for c in color):
r, g, b = color[:3]
a = color[3] if len(color) == 4 else 255
return (
max(0, min(255, int(round(r)))),
max(0, min(255, int(round(g)))),
max(0, min(255, int(round(b)))),
max(0, min(255, int(round(a * 255)) if a <= 1 else round(a))),
)
else:
raise ValueError(f"Invalid color tuple length: {len(color)}")
# Handle hex color codes
if isinstance(color, str):
color = color.strip()
# Try to use PIL's ImageColor
try:
rgba = ImageColor.getcolor(color, "RGBA")
return rgba
except ValueError:
pass
# Handle 'rgba(r, g, b, a)' string format
rgba_match = re.match(r'rgba\(\s*([0-9.]+),\s*([0-9.]+),\s*([0-9.]+),\s*([0-9.]+)\s*\)', color)
if rgba_match:
r, g, b, a = map(float, rgba_match.groups())
return (
max(0, min(255, int(round(r)))),
max(0, min(255, int(round(g)))),
max(0, min(255, int(round(b)))),
max(0, min(255, int(round(a * 255)) if a <= 1 else round(a))),
)
# Handle 'rgb(r, g, b)' string format
rgb_match = re.match(r'rgb\(\s*([0-9.]+),\s*([0-9.]+),\s*([0-9.]+)\s*\)', color)
if rgb_match:
r, g, b = map(float, rgb_match.groups())
return (
max(0, min(255, int(round(r)))),
max(0, min(255, int(round(g)))),
max(0, min(255, int(round(b)))),
255,
)
# If none of the above conversions work, raise an error
raise ValueError(f"Invalid color format: {color}")
def hex_to_rgba(hex_color):
try:
if hex_color.startswith("#"):
clean_hex = hex_color.replace('#','')
# Use a generator expression to convert pairs of hexadecimal digits to integers and create a tuple
rgba = tuple(int(clean_hex[i:i+2], 16) for i in range(0, len(clean_hex),2))
else:
rgba = tuple(map(int,detect_color_format(hex_color)))
except ValueError:
# If the hex color is invalid, default to yellow
rgba = (255,255,0,255)
return rgba
def load_font(font_name, font_size=16):
"""
Load a font using the provided font name and font size.
Parameters:
font_name (str): The name of the font to load. Can be a font name recognized by the system, a URL to download the font file,
a local file path, or a Hugging Face model hub identifier.
font_size (int, optional): The size of the font. Default is 16.
Returns:
ImageFont.FreeTypeFont: The loaded font object.
Notes:
This function attempts to load the font using various methods until a suitable font is found. If the provided font_name
cannot be loaded, it falls back to a default font.
The font_name can be one of the following:
- A font name recognized by the system, which can be loaded using ImageFont.truetype.
- A URL pointing to the font file, which is downloaded using requests and then loaded using ImageFont.truetype.
- A local file path to the font file, which is loaded using ImageFont.truetype.
- A Hugging Face model hub identifier, which downloads the font file from the Hugging Face model hub using hf_hub_download
and then loads it using ImageFont.truetype.
Example:
font = load_font("Arial.ttf", font_size=20)
"""
font = None
if not "http" in font_name:
try:
font = ImageFont.truetype(font_name, font_size)
except (FileNotFoundError, OSError):
print("Font not found. Using Hugging Face download..\n")
if font is None:
try:
font_path = ImageFont.truetype(hf_hub_download(repo_id=os.environ.get('SPACE_ID', ''), filename="assets/" + font_name, repo_type="space"), encoding="UTF-8")
font = ImageFont.truetype(font_path, font_size)
except (FileNotFoundError, OSError):
print("Font not found. Trying to download from local assets folder...\n")
if font is None:
try:
font = ImageFont.truetype("assets/" + font_name, font_size)
except (FileNotFoundError, OSError):
print("Font not found. Trying to download from URL...\n")
if font is None:
try:
req = requests.get(font_name)
font = ImageFont.truetype(BytesIO(req.content), font_size)
except (FileNotFoundError, OSError):
print(f"Font not found: {font_name} Using default font\n")
if font:
print(f"Font loaded {font.getname()}")
else:
font = ImageFont.load_default()
return font
def add_settings_to_image(title: str = "title", description: str = "", width: int = 768, height: int = 512, background_path: str = "", font: str = "arial.ttf", font_color: str = "#ffffff", font_size: int = 28, progress=gr.Progress(track_tqdm=True)):
# Create a new RGBA image with the specified dimensions
image = Image.new("RGBA", (width, height), (255, 255, 255, 0))
# If a background image is specified, open it and paste it onto the image
if background_path == "":
background = Image.new("RGBA", (width, height), (255, 255, 255, 255))
else:
background = Image.open(background_path).convert("RGBA")
#Convert font color to RGBA tuple
font_color = hex_to_rgba(font_color)
print(f"Font Color: {font_color}\n")
# Calculate the center coordinates for placing the text
text_x = width // 2
text_y = height // 2
# Draw the title text at the center top
title_font = load_font(font, font_size) # Replace with your desired font and size
title_text = '\n'.join(textwrap.wrap(title, width // 12))
title_x, title_y, title_text_width, title_text_height = title_font.getbbox(title_text)
title_x = max(text_x - (title_text_width // 2), title_x, 0)
title_y = text_y - (height // 2) + 10 # 10 pixels padding from the top
title_draw = ImageDraw.Draw(image)
title_draw.multiline_text((title_x, title_y), title, fill=font_color, font=title_font, align="center")
# Draw the description text two lines below the title
description_font = load_font(font, int(font_size * 2 / 3)) # Replace with your desired font and size
description_text = '\n'.join(textwrap.wrap(description, width // 12))
description_x, description_y, description_text_width, description_text_height = description_font.getbbox(description_text)
description_x = max(text_x - (description_text_width // 2), description_x, 0)
description_y = title_y + title_text_height + 20 # 20 pixels spacing between title and description
description_draw = ImageDraw.Draw(image)
description_draw.multiline_text((description_x, description_y), description_text, fill=font_color, font=description_font, align="center")
# Calculate the offset to center the image on the background
bg_w, bg_h = background.size
offset = ((bg_w - width) // 2, (bg_h - height) // 2)
# Paste the image onto the background
background.paste(image, offset, mask=image)
# Save the image and return the file path
return save_image(background) |