|
from enum import Enum |
|
import torch |
|
|
|
from model_classes import Model200M, Model5M, SyntheticV2 |
|
from model_transforms import transform_200M, transform_5M, transform_synthetic |
|
|
|
class ModelType(str, Enum): |
|
MIDJOURNEY_200M = "midjourney_200M" |
|
DIFFUSIONS_200M = "diffusions_200M" |
|
MIDJOURNEY_5M = "midjourney_5M" |
|
DIFFUSIONS_5M = "diffusions_5M" |
|
SYNTHETIC_DETECTOR_V2 = "synthetic_detector_v2" |
|
|
|
def __str__(self): |
|
return str(self.value) |
|
|
|
@staticmethod |
|
def get_list(): |
|
return [model_type.value for model_type in ModelType] |
|
|
|
def load_model(value: ModelType): |
|
model = type_to_class[value] |
|
path = type_to_path[value] |
|
ckpt = torch.load(path, map_location=torch.device('cpu')) |
|
model.load_state_dict(ckpt) |
|
model.eval() |
|
return model |
|
|
|
type_to_class = { |
|
ModelType.MIDJOURNEY_200M : Model200M(), |
|
ModelType.DIFFUSIONS_200M : Model200M(), |
|
ModelType.MIDJOURNEY_5M : Model5M(), |
|
ModelType.DIFFUSIONS_5M : Model5M(), |
|
ModelType.SYNTHETIC_DETECTOR_V2 : SyntheticV2(), |
|
} |
|
|
|
type_to_path = { |
|
ModelType.MIDJOURNEY_200M : 'models/midjourney200M.pt', |
|
ModelType.DIFFUSIONS_200M : 'models/diffusions200M.pt', |
|
ModelType.MIDJOURNEY_5M : 'models/midjourney5M.pt', |
|
ModelType.DIFFUSIONS_5M : 'models/diffusions5M.pt', |
|
ModelType.SYNTHETIC_DETECTOR_V2 : 'models/synthetic_detector_v2.pt', |
|
} |
|
|
|
type_to_loaded_model = { |
|
ModelType.MIDJOURNEY_200M: load_model(ModelType.MIDJOURNEY_200M), |
|
ModelType.DIFFUSIONS_200M: load_model(ModelType.DIFFUSIONS_200M), |
|
ModelType.MIDJOURNEY_5M: load_model(ModelType.MIDJOURNEY_5M), |
|
ModelType.DIFFUSIONS_5M: load_model(ModelType.DIFFUSIONS_5M), |
|
ModelType.SYNTHETIC_DETECTOR_V2: load_model(ModelType.SYNTHETIC_DETECTOR_V2) |
|
} |
|
|
|
type_to_transforms = { |
|
ModelType.MIDJOURNEY_200M: transform_200M, |
|
ModelType.DIFFUSIONS_200M: transform_200M, |
|
ModelType.MIDJOURNEY_5M: transform_5M, |
|
ModelType.DIFFUSIONS_5M: transform_5M, |
|
ModelType.SYNTHETIC_DETECTOR_V2: transform_synthetic |
|
} |