Spaces:
Sleeping
Sleeping
File size: 5,957 Bytes
5caedb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import logging
import os
import pickle
import random
import zipfile
from typing import Any, Optional
import numpy as np
import psutil
import torch
logger = logging.getLogger(__name__)
def set_seed(seed: int = 1234) -> None:
"""
Sets the random seed for various Python libraries to ensure reproducibility
of results across different runs.
Args:
seed (int, optional): seed value. Defaults to 1234.
"""
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def check_metric(cfg):
"""
Checks if the metric is set to GPT and if the OpenAI API key is set.
If not, sets the metric to BLEU and logs a warning.
"""
if "GPT" in cfg.prediction.metric and os.getenv("OPENAI_API_KEY", "") == "":
logger.warning("No OpenAI API Key set. Setting metric to BLEU. ")
cfg.prediction.metric = "BLEU"
return cfg
def kill_child_processes(current_pid: int, exclude=None) -> bool:
"""
Killing all child processes of the current process.
Optionally, excludes one PID
Args:
current_pid: current process id
exclude: process id to exclude
Returns:
True or False in case of success or failure
"""
logger.debug(f"Killing process id: {current_pid}")
try:
current_process = psutil.Process(current_pid)
if current_process.status() == "zombie":
return False
children = current_process.children(recursive=True)
for child in children:
if child.pid == exclude:
continue
child.kill()
return True
except psutil.NoSuchProcess:
logger.warning(f"Cannot kill process id: {current_pid}. No such process.")
return False
def kill_child_processes_and_current(current_pid: Optional[int] = None) -> bool:
"""
Kill all child processes of the current process, then terminates itself.
Optionally, specify the current process id.
If not specified, uses the current process id.
Args:
current_pid: current process id (default: None)
Returns:
True or False in case of success or failure
"""
if current_pid is None:
current_pid = os.getpid()
kill_child_processes(current_pid)
try:
current_process = psutil.Process(current_pid)
current_process.kill()
return True
except psutil.NoSuchProcess:
logger.warning(f"Cannot kill process id: {current_pid}. No such process.")
return False
def kill_sibling_ddp_processes() -> None:
"""
Killing all sibling DDP processes from a single DDP process.
"""
pid = os.getpid()
parent_pid = os.getppid()
kill_child_processes(parent_pid, exclude=pid)
current_process = psutil.Process(pid)
current_process.kill()
def add_file_to_zip(zf: zipfile.ZipFile, path: str, folder=None) -> None:
"""Adds a file to the existing zip. Does nothing if file does not exist.
Args:
zf: zipfile object to add to
path: path to the file to add
folder: folder in the zip to add the file to
"""
try:
if folder is None:
zip_path = os.path.basename(path)
else:
zip_path = os.path.join(folder, os.path.basename(path))
zf.write(path, zip_path)
except Exception:
logger.warning(f"File {path} could not be added to zip.")
def save_pickle(path: str, obj: Any, protocol: int = 4) -> None:
"""Saves object as pickle file
Args:
path: path of file to save
obj: object to save
protocol: protocol to use when saving pickle
"""
with open(path, "wb") as pickle_file:
pickle.dump(obj, pickle_file, protocol=protocol)
class DisableLogger:
def __init__(self, level: int = logging.INFO):
self.level = level
def __enter__(self):
logging.disable(self.level)
def __exit__(self, exit_type, exit_value, exit_traceback):
logging.disable(logging.NOTSET)
class PatchedAttribute:
"""
Patches an attribute of an object for the duration of this context manager.
Similar to unittest.mock.patch,
but works also for properties that are not present in the original class
>>> class MyObj:
... attr = 'original'
>>> my_obj = MyObj()
>>> with PatchedAttribute(my_obj, 'attr', 'patched'):
... print(my_obj.attr)
patched
>>> print(my_obj.attr)
original
>>> with PatchedAttribute(my_obj, 'new_attr', 'new_patched'):
... print(my_obj.new_attr)
new_patched
>>> assert not hasattr(my_obj, 'new_attr')
"""
def __init__(self, obj, attribute, new_value):
self.obj = obj
self.attribute = attribute
self.new_value = new_value
self.original_exists = hasattr(obj, attribute)
if self.original_exists:
self.original_value = getattr(obj, attribute)
def __enter__(self):
setattr(self.obj, self.attribute, self.new_value)
def __exit__(self, exc_type, exc_val, exc_tb):
if self.original_exists:
setattr(self.obj, self.attribute, self.original_value)
else:
delattr(self.obj, self.attribute)
def create_symlinks_in_parent_folder(directory):
"""Creates symlinks for each item in a folder in the parent folder
Only creates symlinks for items at the root level of the directory.
"""
if not os.path.exists(directory):
raise FileNotFoundError(f"Directory {directory} does not exist.")
parent_directory = os.path.dirname(directory)
for file in os.listdir(directory):
src = os.path.join(directory, file)
dst = os.path.join(parent_directory, file)
if os.path.exists(dst):
os.remove(dst)
os.symlink(src, dst)
|