from transformers import PreTrainedModel | |
from facenet_pytorch import MTCNN, InceptionResnetV1 | |
from .deepfakeconfig import DeepFakeConfig | |
class DeepFakeModel(PreTrainedModel): | |
config_class = DeepFakeConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = InceptionResnetV1( | |
pretrained="vggface2", | |
classify=True, | |
num_classes=1, | |
device=config.DEVICE | |
) | |
DeepFakeConfig.register_for_auto_class() | |
DeepFakeModel.register_for_auto_class("AutoModelForImageClassification") |