Spaces:
Running
Running
#!/usr/bin/env python | |
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import logging | |
import os | |
import os.path as osp | |
import platform | |
import subprocess | |
from copy import copy | |
from datetime import datetime, timezone | |
from pathlib import Path | |
import numpy as np | |
import torch | |
def none_or_int(value): | |
if value == "None": | |
return None | |
return int(value) | |
def inside_slurm(): | |
"""Check whether the python process was launched through slurm""" | |
# TODO(rcadene): return False for interactive mode `--pty bash` | |
return "SLURM_JOB_ID" in os.environ | |
def auto_select_torch_device() -> torch.device: | |
"""Tries to select automatically a torch device.""" | |
if torch.cuda.is_available(): | |
logging.info("Cuda backend detected, using cuda.") | |
return torch.device("cuda") | |
elif torch.backends.mps.is_available(): | |
logging.info("Metal backend detected, using cuda.") | |
return torch.device("mps") | |
else: | |
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.") | |
return torch.device("cpu") | |
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level | |
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: | |
"""Given a string, return a torch.device with checks on whether the device is available.""" | |
try_device = str(try_device) | |
match try_device: | |
case "cuda": | |
assert torch.cuda.is_available() | |
device = torch.device("cuda") | |
case "mps": | |
assert torch.backends.mps.is_available() | |
device = torch.device("mps") | |
case "cpu": | |
device = torch.device("cpu") | |
if log: | |
logging.warning("Using CPU, this will be slow.") | |
case _: | |
device = torch.device(try_device) | |
if log: | |
logging.warning(f"Using custom {try_device} device.") | |
return device | |
def get_safe_dtype(dtype: torch.dtype, device: str | torch.device): | |
""" | |
mps is currently not compatible with float64 | |
""" | |
if isinstance(device, torch.device): | |
device = device.type | |
if device == "mps" and dtype == torch.float64: | |
return torch.float32 | |
else: | |
return dtype | |
def is_torch_device_available(try_device: str) -> bool: | |
try_device = str(try_device) # Ensure try_device is a string | |
if try_device == "cuda": | |
return torch.cuda.is_available() | |
elif try_device == "mps": | |
return torch.backends.mps.is_available() | |
elif try_device == "cpu": | |
return True | |
else: | |
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.") | |
def is_amp_available(device: str): | |
if device in ["cuda", "cpu"]: | |
return True | |
elif device == "mps": | |
return False | |
else: | |
raise ValueError(f"Unknown device '{device}.") | |
def init_logging(): | |
def custom_format(record): | |
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
fnameline = f"{record.pathname}:{record.lineno}" | |
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}" | |
return message | |
logging.basicConfig(level=logging.INFO) | |
for handler in logging.root.handlers[:]: | |
logging.root.removeHandler(handler) | |
formatter = logging.Formatter() | |
formatter.format = custom_format | |
console_handler = logging.StreamHandler() | |
console_handler.setFormatter(formatter) | |
logging.getLogger().addHandler(console_handler) | |
def format_big_number(num, precision=0): | |
suffixes = ["", "K", "M", "B", "T", "Q"] | |
divisor = 1000.0 | |
for suffix in suffixes: | |
if abs(num) < divisor: | |
return f"{num:.{precision}f}{suffix}" | |
num /= divisor | |
return num | |
def _relative_path_between(path1: Path, path2: Path) -> Path: | |
"""Returns path1 relative to path2.""" | |
path1 = path1.absolute() | |
path2 = path2.absolute() | |
try: | |
return path1.relative_to(path2) | |
except ValueError: # most likely because path1 is not a subpath of path2 | |
common_parts = Path(osp.commonpath([path1, path2])).parts | |
return Path( | |
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) | |
) | |
def print_cuda_memory_usage(): | |
"""Use this function to locate and debug memory leak.""" | |
import gc | |
gc.collect() | |
# Also clear the cache if you want to fully release the memory | |
torch.cuda.empty_cache() | |
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2)) | |
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2)) | |
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2)) | |
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2)) | |
def capture_timestamp_utc(): | |
return datetime.now(timezone.utc) | |
def say(text, blocking=False): | |
system = platform.system() | |
if system == "Darwin": | |
cmd = ["say", text] | |
elif system == "Linux": | |
cmd = ["spd-say", text] | |
if blocking: | |
cmd.append("--wait") | |
elif system == "Windows": | |
cmd = [ | |
"PowerShell", | |
"-Command", | |
"Add-Type -AssemblyName System.Speech; " | |
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')", | |
] | |
else: | |
raise RuntimeError("Unsupported operating system for text-to-speech.") | |
if blocking: | |
subprocess.run(cmd, check=True) | |
else: | |
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0) | |
def log_say(text, play_sounds, blocking=False): | |
logging.info(text) | |
if play_sounds: | |
say(text, blocking) | |
def get_channel_first_image_shape(image_shape: tuple) -> tuple: | |
shape = copy(image_shape) | |
if shape[2] < shape[0] and shape[2] < shape[1]: # (h, w, c) -> (c, h, w) | |
shape = (shape[2], shape[0], shape[1]) | |
elif not (shape[0] < shape[1] and shape[0] < shape[2]): | |
raise ValueError(image_shape) | |
return shape | |
def has_method(cls: object, method_name: str) -> bool: | |
return hasattr(cls, method_name) and callable(getattr(cls, method_name)) | |
def is_valid_numpy_dtype_string(dtype_str: str) -> bool: | |
""" | |
Return True if a given string can be converted to a numpy dtype. | |
""" | |
try: | |
# Attempt to convert the string to a numpy dtype | |
np.dtype(dtype_str) | |
return True | |
except TypeError: | |
# If a TypeError is raised, the string is not a valid dtype | |
return False | |