Spaces:
Sleeping
Sleeping
import logging | |
import os | |
import pickle | |
import random | |
import zipfile | |
from typing import Any, Optional | |
import numpy as np | |
import psutil | |
import torch | |
logger = logging.getLogger(__name__) | |
def set_seed(seed: int = 1234) -> None: | |
""" | |
Sets the random seed for various Python libraries to ensure reproducibility | |
of results across different runs. | |
Args: | |
seed (int, optional): seed value. Defaults to 1234. | |
""" | |
random.seed(seed) | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = False | |
torch.backends.cudnn.benchmark = True | |
def check_metric(cfg): | |
""" | |
Checks if the metric is set to GPT and if the OpenAI API key is set. | |
If not, sets the metric to BLEU and logs a warning. | |
""" | |
if "GPT" in cfg.prediction.metric and os.getenv("OPENAI_API_KEY", "") == "": | |
logger.warning("No OpenAI API Key set. Setting metric to BLEU. ") | |
cfg.prediction.metric = "BLEU" | |
return cfg | |
def kill_child_processes(current_pid: int, exclude=None) -> bool: | |
""" | |
Killing all child processes of the current process. | |
Optionally, excludes one PID | |
Args: | |
current_pid: current process id | |
exclude: process id to exclude | |
Returns: | |
True or False in case of success or failure | |
""" | |
logger.debug(f"Killing process id: {current_pid}") | |
try: | |
current_process = psutil.Process(current_pid) | |
if current_process.status() == "zombie": | |
return False | |
children = current_process.children(recursive=True) | |
for child in children: | |
if child.pid == exclude: | |
continue | |
child.kill() | |
return True | |
except psutil.NoSuchProcess: | |
logger.warning(f"Cannot kill process id: {current_pid}. No such process.") | |
return False | |
def kill_child_processes_and_current(current_pid: Optional[int] = None) -> bool: | |
""" | |
Kill all child processes of the current process, then terminates itself. | |
Optionally, specify the current process id. | |
If not specified, uses the current process id. | |
Args: | |
current_pid: current process id (default: None) | |
Returns: | |
True or False in case of success or failure | |
""" | |
if current_pid is None: | |
current_pid = os.getpid() | |
kill_child_processes(current_pid) | |
try: | |
current_process = psutil.Process(current_pid) | |
current_process.kill() | |
return True | |
except psutil.NoSuchProcess: | |
logger.warning(f"Cannot kill process id: {current_pid}. No such process.") | |
return False | |
def kill_sibling_ddp_processes() -> None: | |
""" | |
Killing all sibling DDP processes from a single DDP process. | |
""" | |
pid = os.getpid() | |
parent_pid = os.getppid() | |
kill_child_processes(parent_pid, exclude=pid) | |
current_process = psutil.Process(pid) | |
current_process.kill() | |
def add_file_to_zip(zf: zipfile.ZipFile, path: str, folder=None) -> None: | |
"""Adds a file to the existing zip. Does nothing if file does not exist. | |
Args: | |
zf: zipfile object to add to | |
path: path to the file to add | |
folder: folder in the zip to add the file to | |
""" | |
try: | |
if folder is None: | |
zip_path = os.path.basename(path) | |
else: | |
zip_path = os.path.join(folder, os.path.basename(path)) | |
zf.write(path, zip_path) | |
except Exception: | |
logger.warning(f"File {path} could not be added to zip.") | |
def save_pickle(path: str, obj: Any, protocol: int = 4) -> None: | |
"""Saves object as pickle file | |
Args: | |
path: path of file to save | |
obj: object to save | |
protocol: protocol to use when saving pickle | |
""" | |
with open(path, "wb") as pickle_file: | |
pickle.dump(obj, pickle_file, protocol=protocol) | |
class DisableLogger: | |
def __init__(self, level: int = logging.INFO): | |
self.level = level | |
def __enter__(self): | |
logging.disable(self.level) | |
def __exit__(self, exit_type, exit_value, exit_traceback): | |
logging.disable(logging.NOTSET) | |
class PatchedAttribute: | |
""" | |
Patches an attribute of an object for the duration of this context manager. | |
Similar to unittest.mock.patch, | |
but works also for properties that are not present in the original class | |
>>> class MyObj: | |
... attr = 'original' | |
>>> my_obj = MyObj() | |
>>> with PatchedAttribute(my_obj, 'attr', 'patched'): | |
... print(my_obj.attr) | |
patched | |
>>> print(my_obj.attr) | |
original | |
>>> with PatchedAttribute(my_obj, 'new_attr', 'new_patched'): | |
... print(my_obj.new_attr) | |
new_patched | |
>>> assert not hasattr(my_obj, 'new_attr') | |
""" | |
def __init__(self, obj, attribute, new_value): | |
self.obj = obj | |
self.attribute = attribute | |
self.new_value = new_value | |
self.original_exists = hasattr(obj, attribute) | |
if self.original_exists: | |
self.original_value = getattr(obj, attribute) | |
def __enter__(self): | |
setattr(self.obj, self.attribute, self.new_value) | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
if self.original_exists: | |
setattr(self.obj, self.attribute, self.original_value) | |
else: | |
delattr(self.obj, self.attribute) | |
def create_symlinks_in_parent_folder(directory): | |
"""Creates symlinks for each item in a folder in the parent folder | |
Only creates symlinks for items at the root level of the directory. | |
""" | |
if not os.path.exists(directory): | |
raise FileNotFoundError(f"Directory {directory} does not exist.") | |
parent_directory = os.path.dirname(directory) | |
for file in os.listdir(directory): | |
src = os.path.join(directory, file) | |
dst = os.path.join(parent_directory, file) | |
if os.path.exists(dst): | |
os.remove(dst) | |
os.symlink(src, dst) | |