Spaces:
Running
Running
File size: 2,995 Bytes
fcc02a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import os
from collections import OrderedDict
from safetensors.torch import save_file
from jobs.process.BaseProcess import BaseProcess
from toolkit.metadata import get_meta_for_safetensors
from typing import ForwardRef
from toolkit.train_tools import get_torch_dtype
class BaseExtractProcess(BaseProcess):
def __init__(
self,
process_id: int,
job,
config: OrderedDict
):
super().__init__(process_id, job, config)
self.config: OrderedDict
self.output_folder: str
self.output_filename: str
self.output_path: str
self.process_id = process_id
self.job = job
self.config = config
self.dtype = self.get_conf('dtype', self.job.dtype)
self.torch_dtype = get_torch_dtype(self.dtype)
self.extract_unet = self.get_conf('extract_unet', self.job.extract_unet)
self.extract_text_encoder = self.get_conf('extract_text_encoder', self.job.extract_text_encoder)
def run(self):
# here instead of init because child init needs to go first
self.output_path = self.get_output_path()
# implement in child class
# be sure to call super().run() first
pass
# you can override this in the child class if you want
# call super().get_output_path(prefix="your_prefix_", suffix="_your_suffix") to extend this
def get_output_path(self, prefix=None, suffix=None):
config_output_path = self.get_conf('output_path', None)
config_filename = self.get_conf('filename', None)
# replace [name] with name
if config_output_path is not None:
config_output_path = config_output_path.replace('[name]', self.job.name)
return config_output_path
if config_output_path is None and config_filename is not None:
# build the output path from the output folder and filename
return os.path.join(self.job.output_folder, config_filename)
# build our own
if suffix is None:
# we will just add process it to the end of the filename if there is more than one process
# and no other suffix was given
suffix = f"_{self.process_id}" if len(self.config['process']) > 1 else ''
if prefix is None:
prefix = ''
output_filename = f"{prefix}{self.output_filename}{suffix}"
return os.path.join(self.job.output_folder, output_filename)
def save(self, state_dict):
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
# save
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(self.torch_dtype)
state_dict[key] = v
# having issues with meta
save_file(state_dict, self.output_path, save_meta)
print(f"Saved to {self.output_path}")
|