|
from transformers import PretrainedConfig |
|
from typing import List |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
class InceptionV3Config(PretrainedConfig): |
|
model_type = "inceptionv3" |
|
def __init__(self, model_name: str = "inception_v3", num_classes: int = 3, input_size: List[int] = [3, 299, 299], interpolation: str = "bicubic", mean: List[float] = [0.5, 0.5, 0.5], std: List[float] = [0.5, 0.5, 0.5], classifier: str = "fc", has_aux: bool = True, label_offset: int = 1, classes: dict = { '0': 'nsfw_gore', '1': 'nsfw_suggestive', '2': 'safe' }, output_channels: int = 2048, use_jit=False, **kwargs): |
|
self.model_name = model_name |
|
self.num_classes = num_classes |
|
self.input_size = input_size |
|
self.interpolation = interpolation |
|
self.mean = mean |
|
self.std = std |
|
self.classifier = classifier |
|
self.has_aux = has_aux |
|
self.label_offset = label_offset |
|
self.classes = classes |
|
self.output_channels = output_channels |
|
self.use_jit = use_jit |
|
super().__init__(**kwargs) |
|
|
|
""" |
|
|
|
inceptionv3_config = InceptionV3Config() |
|
inceptionv3_config.save_pretrained("inceptionv3_config") |
|
|
|
""" |