|
from transformers import PretrainedConfig |
|
from typing import List |
|
|
|
classes_example = { |
|
0: 'nsfw_gore', |
|
1: 'nsfw_suggestive', |
|
2: 'safe' |
|
} |
|
|
|
class InceptionV3Config(PretrainedConfig): |
|
model_type = "inceptionv3" |
|
def __init__(self, model_name: str = "inception_v3", input_channels: int = 3, num_classes: int = 3, input_size: List[int] = [3, 299, 299], pool_size: List[int] = [8, 8, 2048], crop_pct: float = 0.875, interpolation: str = "bicubic", mean: List[float] = [0.5, 0.5, 0.5], std: List[float] = [0.5, 0.5, 0.5], first_conv: str = "Conv2d_1a_3x3.conv", classifier: str = "fc", has_aux: bool = True, label_offset: int = 1, classes: dict = classes_example, output_channels: int = 2048, use_jit=False, **kwargs): |
|
self.model_name = model_name |
|
self.input_channels = input_channels |
|
self.num_classes = num_classes |
|
self.input_size = input_size |
|
self.pool_size = pool_size |
|
self.crop_pct = crop_pct |
|
self.interpolation = interpolation |
|
self.mean = mean |
|
self.std = std |
|
self.first_conv = first_conv |
|
self.classifier = classifier |
|
self.has_aux = has_aux |
|
self.label_offset = label_offset |
|
self.classes = classes |
|
self.output_channels = output_channels |
|
self.use_jit = use_jit |
|
super().__init__(**kwargs) |
|
|
|
""" |
|
|
|
inceptionv3_config = InceptionV3Config() |
|
inceptionv3_config.save_pretrained("inceptionv3_config") |
|
|
|
""" |