|
import hashlib |
|
import os |
|
from typing import Iterable |
|
import shutil |
|
import subprocess |
|
import re |
|
from collections.abc import Mapping |
|
from typing import Union |
|
import torch |
|
from torch import Tensor |
|
|
|
import server |
|
from .logger import logger |
|
import folder_paths |
|
|
|
BIGMIN = -(2**53-1) |
|
BIGMAX = (2**53-1) |
|
|
|
DIMMAX = 8192 |
|
|
|
def ffmpeg_suitability(path): |
|
try: |
|
version = subprocess.run([path, "-version"], check=True, |
|
capture_output=True).stdout.decode("utf-8") |
|
except: |
|
return 0 |
|
score = 0 |
|
|
|
simple_criterion = [("libvpx", 20),("264",10), ("265",3), |
|
("svtav1",5),("libopus", 1)] |
|
for criterion in simple_criterion: |
|
if version.find(criterion[0]) >= 0: |
|
score += criterion[1] |
|
|
|
copyright_index = version.find('2000-2') |
|
if copyright_index >= 0: |
|
copyright_year = version[copyright_index+6:copyright_index+9] |
|
if copyright_year.isnumeric(): |
|
score += int(copyright_year) |
|
return score |
|
|
|
class ImageOrLatent(str): |
|
def __ne__(self, other): |
|
return not (other == "IMAGE" or other == "LATENT" or other == "*") |
|
imageOrLatent = ImageOrLatent("IMAGE") |
|
|
|
if "VHS_FORCE_FFMPEG_PATH" in os.environ: |
|
ffmpeg_path = os.environ.get("VHS_FORCE_FFMPEG_PATH") |
|
else: |
|
ffmpeg_paths = [] |
|
try: |
|
from imageio_ffmpeg import get_ffmpeg_exe |
|
imageio_ffmpeg_path = get_ffmpeg_exe() |
|
ffmpeg_paths.append(imageio_ffmpeg_path) |
|
except: |
|
if "VHS_USE_IMAGEIO_FFMPEG" in os.environ: |
|
raise |
|
logger.warn("Failed to import imageio_ffmpeg") |
|
if "VHS_USE_IMAGEIO_FFMPEG" in os.environ: |
|
ffmpeg_path = imageio_ffmpeg_path |
|
else: |
|
system_ffmpeg = shutil.which("ffmpeg") |
|
if system_ffmpeg is not None: |
|
ffmpeg_paths.append(system_ffmpeg) |
|
if os.path.isfile("ffmpeg"): |
|
ffmpeg_paths.append(os.path.abspath("ffmpeg")) |
|
if os.path.isfile("ffmpeg.exe"): |
|
ffmpeg_paths.append(os.path.abspath("ffmpeg.exe")) |
|
if len(ffmpeg_paths) == 0: |
|
logger.error("No valid ffmpeg found.") |
|
ffmpeg_path = None |
|
elif len(ffmpeg_paths) == 1: |
|
|
|
|
|
ffmpeg_path = ffmpeg_paths[0] |
|
else: |
|
ffmpeg_path = max(ffmpeg_paths, key=ffmpeg_suitability) |
|
gifski_path = os.environ.get("VHS_GIFSKI", None) |
|
if gifski_path is None: |
|
gifski_path = os.environ.get("JOV_GIFSKI", None) |
|
if gifski_path is None: |
|
gifski_path = shutil.which("gifski") |
|
ytdl_path = os.environ.get("VHS_YTDL", None) or shutil.which('yt-dlp') \ |
|
or shutil.which('youtube-dl') |
|
download_history = {} |
|
def try_download_video(url): |
|
if ytdl_path is None: |
|
return None |
|
if url in download_history: |
|
return download_history[url] |
|
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True) |
|
|
|
|
|
|
|
|
|
try: |
|
res = subprocess.run([ytdl_path, "--print", "after_move:filepath", |
|
"-P", folder_paths.get_temp_directory(), url], |
|
capture_output=True, check=True) |
|
|
|
file = res.stdout.decode('utf-8')[:-1] |
|
except subprocess.CalledProcessError as e: |
|
raise Exception("An error occurred in the yt-dl process:\n" \ |
|
+ e.stderr.decode("utf-8")) |
|
file = None |
|
download_history[url] = file |
|
return file |
|
|
|
def is_safe_path(path): |
|
if "VHS_STRICT_PATHS" not in os.environ: |
|
return True |
|
basedir = os.path.abspath('.') |
|
try: |
|
common_path = os.path.commonpath([basedir, path]) |
|
except: |
|
|
|
return False |
|
return common_path == basedir |
|
|
|
def get_sorted_dir_files_from_directory(directory: str, skip_first_images: int=0, select_every_nth: int=1, extensions: Iterable=None): |
|
directory = strip_path(directory) |
|
dir_files = os.listdir(directory) |
|
dir_files = sorted(dir_files) |
|
dir_files = [os.path.join(directory, x) for x in dir_files] |
|
dir_files = list(filter(lambda filepath: os.path.isfile(filepath), dir_files)) |
|
|
|
if extensions is not None: |
|
extensions = list(extensions) |
|
new_dir_files = [] |
|
for filepath in dir_files: |
|
ext = "." + filepath.split(".")[-1] |
|
if ext.lower() in extensions: |
|
new_dir_files.append(filepath) |
|
dir_files = new_dir_files |
|
|
|
dir_files = dir_files[skip_first_images:] |
|
dir_files = dir_files[0::select_every_nth] |
|
return dir_files |
|
|
|
|
|
|
|
def calculate_file_hash(filename: str, hash_every_n: int = 1): |
|
|
|
|
|
h = hashlib.sha256() |
|
h.update(filename.encode()) |
|
h.update(str(os.path.getmtime(filename)).encode()) |
|
return h.hexdigest() |
|
|
|
prompt_queue = server.PromptServer.instance.prompt_queue |
|
def requeue_workflow_unchecked(): |
|
"""Requeues the current workflow without checking for multiple requeues""" |
|
currently_running = prompt_queue.currently_running |
|
(_, _, prompt, extra_data, outputs_to_execute) = next(iter(currently_running.values())) |
|
|
|
|
|
prompt = prompt.copy() |
|
for uid in prompt: |
|
if prompt[uid]['class_type'] == 'VHS_BatchManager': |
|
prompt[uid]['inputs']['requeue'] = prompt[uid]['inputs'].get('requeue',0)+1 |
|
|
|
|
|
|
|
number = -server.PromptServer.instance.number |
|
server.PromptServer.instance.number += 1 |
|
prompt_id = str(server.uuid.uuid4()) |
|
prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) |
|
|
|
requeue_guard = [None, 0, 0, {}] |
|
def requeue_workflow(requeue_required=(-1,True)): |
|
assert(len(prompt_queue.currently_running) == 1) |
|
global requeue_guard |
|
(run_number, _, prompt, _, _) = next(iter(prompt_queue.currently_running.values())) |
|
if requeue_guard[0] != run_number: |
|
|
|
managed_outputs=0 |
|
for bm_uid in prompt: |
|
if prompt[bm_uid]['class_type'] == 'VHS_BatchManager': |
|
for output_uid in prompt: |
|
if prompt[output_uid]['class_type'] in ["VHS_VideoCombine"]: |
|
for inp in prompt[output_uid]['inputs'].values(): |
|
if inp == [bm_uid, 0]: |
|
managed_outputs+=1 |
|
requeue_guard = [run_number, 0, managed_outputs, {}] |
|
requeue_guard[1] = requeue_guard[1]+1 |
|
requeue_guard[3][requeue_required[0]] = requeue_required[1] |
|
if requeue_guard[1] == requeue_guard[2] and max(requeue_guard[3].values()): |
|
requeue_workflow_unchecked() |
|
|
|
def get_audio(file, start_time=0, duration=0): |
|
args = [ffmpeg_path, "-i", file] |
|
if start_time > 0: |
|
args += ["-ss", str(start_time)] |
|
if duration > 0: |
|
args += ["-t", str(duration)] |
|
try: |
|
|
|
res = subprocess.run(args + ["-f", "f32le", "-"], |
|
capture_output=True, check=True) |
|
audio = torch.frombuffer(bytearray(res.stdout), dtype=torch.float32) |
|
match = re.search(', (\\d+) Hz, (\\w+), ',res.stderr.decode('utf-8')) |
|
except subprocess.CalledProcessError as e: |
|
raise Exception(f"VHS failed to extract audio from {file}:\n" \ |
|
+ e.stderr.decode("utf-8")) |
|
if match: |
|
ar = int(match.group(1)) |
|
|
|
|
|
ac = {"mono": 1, "stereo": 2}[match.group(2)] |
|
else: |
|
ar = 44100 |
|
ac = 2 |
|
audio = audio.reshape((-1,ac)).transpose(0,1).unsqueeze(0) |
|
return {'waveform': audio, 'sample_rate': ar} |
|
|
|
class LazyAudioMap(Mapping): |
|
def __init__(self, file, start_time, duration): |
|
self.file = file |
|
self.start_time=start_time |
|
self.duration=duration |
|
self._dict=None |
|
def __getitem__(self, key): |
|
if self._dict is None: |
|
self._dict = get_audio(self.file, self.start_time, self.duration) |
|
return self._dict[key] |
|
def __iter__(self): |
|
if self._dict is None: |
|
self._dict = get_audio(self.file, self.start_time, self.duration) |
|
return iter(self._dict) |
|
def __len__(self): |
|
if self._dict is None: |
|
self._dict = get_audio(self.file, self.start_time, self.duration) |
|
return len(self._dict) |
|
def lazy_get_audio(file, start_time=0, duration=0): |
|
return LazyAudioMap(file, start_time, duration) |
|
|
|
def is_url(url): |
|
return url.split("://")[0] in ["http", "https"] |
|
|
|
def validate_sequence(path): |
|
|
|
(path, file) = os.path.split(path) |
|
if not os.path.isdir(path): |
|
return False |
|
match = re.search('%0?\\d+d', file) |
|
if not match: |
|
return False |
|
seq = match.group() |
|
if seq == '%d': |
|
seq = '\\\\d+' |
|
else: |
|
seq = '\\\\d{%s}' % seq[1:-1] |
|
file_matcher = re.compile(re.sub('%0?\\d+d', seq, file)) |
|
for file in os.listdir(path): |
|
if file_matcher.fullmatch(file): |
|
return True |
|
return False |
|
|
|
def strip_path(path): |
|
|
|
|
|
|
|
|
|
path = path.strip() |
|
if path.startswith("\""): |
|
path = path[1:] |
|
if path.endswith("\""): |
|
path = path[:-1] |
|
return path |
|
def hash_path(path): |
|
if path is None: |
|
return "input" |
|
if is_url(path): |
|
return "url" |
|
return calculate_file_hash(strip_path(path)) |
|
|
|
|
|
def validate_path(path, allow_none=False, allow_url=True): |
|
if path is None: |
|
return allow_none |
|
if is_url(path): |
|
|
|
if not allow_url: |
|
return "URLs are unsupported for this path" |
|
return is_safe_path(path) |
|
if not os.path.isfile(strip_path(path)): |
|
return "Invalid file path: {}".format(path) |
|
return is_safe_path(path) |
|
|
|
|
|
def validate_index(index: int, length: int=0, is_range: bool=False, allow_negative=False, allow_missing=False) -> int: |
|
|
|
if is_range: |
|
return index |
|
|
|
|
|
if length > 0 and index > length-1 and not allow_missing: |
|
raise IndexError(f"Index '{index}' out of range for {length} item(s).") |
|
|
|
if index < 0: |
|
if not allow_negative: |
|
raise IndexError(f"Negative indeces not allowed, but was '{index}'.") |
|
conv_index = length+index |
|
if conv_index < 0 and not allow_missing: |
|
raise IndexError(f"Index '{index}', converted to '{conv_index}' out of range for {length} item(s).") |
|
index = conv_index |
|
return index |
|
|
|
|
|
def convert_to_index_int(raw_index: str, length: int=0, is_range: bool=False, allow_negative=False, allow_missing=False) -> int: |
|
try: |
|
return validate_index(int(raw_index), length=length, is_range=is_range, allow_negative=allow_negative, allow_missing=allow_missing) |
|
except ValueError as e: |
|
raise ValueError(f"Index '{raw_index}' must be an integer.", e) |
|
|
|
|
|
def convert_str_to_indexes(indexes_str: str, length: int=0, allow_missing=False) -> list[int]: |
|
if not indexes_str: |
|
return [] |
|
int_indexes = list(range(0, length)) |
|
allow_negative = length > 0 |
|
chosen_indexes = [] |
|
|
|
groups = indexes_str.split(",") |
|
groups = [g.strip() for g in groups] |
|
for g in groups: |
|
|
|
if ':' in g: |
|
index_range = g.split(":", 2) |
|
index_range = [r.strip() for r in index_range] |
|
|
|
start_index = index_range[0] |
|
if len(start_index) > 0: |
|
start_index = convert_to_index_int(start_index, length=length, is_range=True, allow_negative=allow_negative, allow_missing=allow_missing) |
|
else: |
|
start_index = 0 |
|
end_index = index_range[1] |
|
if len(end_index) > 0: |
|
end_index = convert_to_index_int(end_index, length=length, is_range=True, allow_negative=allow_negative, allow_missing=allow_missing) |
|
else: |
|
end_index = length |
|
|
|
step = 1 |
|
if len(index_range) > 2: |
|
step = index_range[2] |
|
if len(step) > 0: |
|
step = convert_to_index_int(step, length=length, is_range=True, allow_negative=True, allow_missing=True) |
|
else: |
|
step = 1 |
|
|
|
if len(int_indexes) > 0: |
|
chosen_indexes.extend(int_indexes[start_index:end_index][::step]) |
|
|
|
else: |
|
chosen_indexes.extend(list(range(start_index, end_index, step))) |
|
|
|
else: |
|
chosen_indexes.append(convert_to_index_int(g, length=length, allow_negative=allow_negative, allow_missing=allow_missing)) |
|
return chosen_indexes |
|
|
|
|
|
def select_indexes(input_obj: Union[Tensor, list], idxs: list): |
|
if type(input_obj) == Tensor: |
|
return input_obj[idxs] |
|
else: |
|
return [input_obj[i] for i in idxs] |
|
|
|
|
|
def select_indexes_from_str(input_obj: Union[Tensor, list], indexes: str, err_if_missing=True, err_if_empty=True): |
|
real_idxs = convert_str_to_indexes(indexes, len(input_obj), allow_missing=not err_if_missing) |
|
if err_if_empty and len(real_idxs) == 0: |
|
raise Exception(f"Nothing was selected based on indexes found in '{indexes}'.") |
|
return select_indexes(input_obj, real_idxs) |
|
|