from torch import nn from transformers import PreTrainedModel from transformers import PretrainedConfig from torchvision import transforms from transformers.models.auto.configuration_auto import CONFIG_MAPPING from transformers.models.auto.modeling_auto import MODEL_MAPPING class ImageToImageConfig(PretrainedConfig): model_type = "upscaleing" def __init__(self, in_channels=3, out_channels=3,in_resolution=256,out_resolution=768, **kwargs): super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.in_resolution = in_resolution self.out_resolution = out_resolution class ImageToImageModel(PreTrainedModel): config_class = ImageToImageConfig def __init__(self, config): super().__init__(config) self.model = nn.Sequential( nn.Conv2d(in_channels=config.in_channels, out_channels=64, kernel_size=5, padding=2), nn.ReLU(), nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(in_channels=32, out_channels=3 * 3 * 3, kernel_size=3, padding=1), nn.PixelShuffle(3) ) self.transform1 = transforms.Compose( transforms=( transforms.Resize((config.in_resolution,config.in_resolution)), transforms.ToTensor() ) ) self.transform2 = transforms.Compose( transforms=( transforms.ToPILImage(), transforms.Resize((config.out_resolution,config.out_resolution)) ) ) def forward(self, image): x = self.transform1(image) x = self.model(x) x = self.transform2(x) return x CONFIG_MAPPING.register("upscaleing", ImageToImageConfig) MODEL_MAPPING.register(ImageToImageConfig, ImageToImageModel) config = ImageToImageConfig() model = ImageToImageModel(config) model.save_pretrained("AIupscaleing")