Spaces:
Runtime error
Runtime error
from editings.styleclip_directions.styleclip_mapper_network import LevelsMapper | |
import torch | |
import csv | |
from options import Settings | |
import os | |
class Options(): | |
def __init__(self, no_coarse_mapper, no_medium_mapper, no_fine_mapper) -> None: | |
self.no_coarse_mapper = no_coarse_mapper | |
self.no_medium_mapper = no_medium_mapper | |
self.no_fine_mapper = no_fine_mapper | |
class StyleClip(): | |
def __init__(self) -> None: | |
self.styleclip_mapping_configs = {} | |
with open(os.path.join(Settings.styleclip_settings, 'styleclip_mapping_configs.csv'), "r") as f: | |
reader = csv.reader(f) | |
for row in reader: | |
key = row.pop(0) | |
self.styleclip_mapping_configs[key] = list(map(lambda x: True if x == "True" else False, row)) | |
def edit(self, latent, cfg): | |
with torch.no_grad(): | |
if cfg.type == 'mapper': | |
mapper = self.build_mapper(cfg.edit) | |
return latent + cfg.strength * mapper(latent) | |
if cfg.type == 'global': | |
return latent + 10 * torch.load(os.path.join(Settings.styleclip_global_directions, 'makeup.pt')) | |
# def load_global_direction(self, editname): | |
# pass | |
def build_mapper(self, editname): | |
try: # Check if loaded | |
mapper = getattr(self, f"{editname}_mapper") | |
except: | |
opts = Options(*self.styleclip_mapping_configs[editname]) | |
mapper = LevelsMapper(opts) | |
ckpt = torch.load(os.path.join(Settings.styleclip_mapper_directions, f'{editname}.pt')) | |
mapper.load_state_dict(ckpt, strict=True) | |
mapper.to(device=Settings.device) | |
for param in mapper.parameters(): | |
param.requires_grad = False | |
mapper.eval() | |
setattr(self, f"{editname}_mapper", mapper) | |
return mapper |