Sumsub-ffs-demo / model_loader.py
RomanShnurov's picture
add new synthetic detector
f3b2c5b
raw
history blame
2 kB
from enum import Enum
import torch
from model_classes import Model200M, Model5M, SyntheticV2
from model_transforms import transform_200M, transform_5M, transform_synthetic
class ModelType(str, Enum):
MIDJOURNEY_200M = "midjourney_200M"
DIFFUSIONS_200M = "diffusions_200M"
MIDJOURNEY_5M = "midjourney_5M"
DIFFUSIONS_5M = "diffusions_5M"
SYNTHETIC_DETECTOR_V2 = "synthetic_detector_v2"
def __str__(self):
return str(self.value)
@staticmethod
def get_list():
return [model_type.value for model_type in ModelType]
def load_model(value: ModelType):
model = type_to_class[value]
path = type_to_path[value]
ckpt = torch.load(path, map_location=torch.device('cpu'))
model.load_state_dict(ckpt)
model.eval()
return model
type_to_class = {
ModelType.MIDJOURNEY_200M : Model200M(),
ModelType.DIFFUSIONS_200M : Model200M(),
ModelType.MIDJOURNEY_5M : Model5M(),
ModelType.DIFFUSIONS_5M : Model5M(),
ModelType.SYNTHETIC_DETECTOR_V2 : SyntheticV2(),
}
type_to_path = {
ModelType.MIDJOURNEY_200M : 'models/midjourney200M.pt',
ModelType.DIFFUSIONS_200M : 'models/diffusions200M.pt',
ModelType.MIDJOURNEY_5M : 'models/midjourney5M.pt',
ModelType.DIFFUSIONS_5M : 'models/diffusions5M.pt',
ModelType.SYNTHETIC_DETECTOR_V2 : 'models/synthetic_detector_v2.pt',
}
type_to_loaded_model = {
ModelType.MIDJOURNEY_200M: load_model(ModelType.MIDJOURNEY_200M),
ModelType.DIFFUSIONS_200M: load_model(ModelType.DIFFUSIONS_200M),
ModelType.MIDJOURNEY_5M: load_model(ModelType.MIDJOURNEY_5M),
ModelType.DIFFUSIONS_5M: load_model(ModelType.DIFFUSIONS_5M),
ModelType.SYNTHETIC_DETECTOR_V2: load_model(ModelType.SYNTHETIC_DETECTOR_V2)
}
type_to_transforms = {
ModelType.MIDJOURNEY_200M: transform_200M,
ModelType.DIFFUSIONS_200M: transform_200M,
ModelType.MIDJOURNEY_5M: transform_5M,
ModelType.DIFFUSIONS_5M: transform_5M,
ModelType.SYNTHETIC_DETECTOR_V2: transform_synthetic
}