File size: 4,207 Bytes
e97054d
 
 
 
8e8cbdc
e97054d
 
 
 
8e8cbdc
 
 
e97054d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e8cbdc
e97054d
 
8e8cbdc
e97054d
 
 
 
 
8e8cbdc
e97054d
 
 
8e8cbdc
e97054d
 
8e8cbdc
 
e97054d
 
 
 
8e8cbdc
 
 
 
 
 
 
 
 
e97054d
 
8e8cbdc
 
 
 
 
 
 
 
 
e97054d
8e8cbdc
 
 
 
 
 
 
 
e97054d
8e8cbdc
 
e97054d
 
8e8cbdc
 
 
 
e97054d
8e8cbdc
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from typing import List, Optional, Union
from torchvision import transforms
from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoImageProcessor, AutoModel
import os
from huggingface_hub import hf_hub_download
import torch
import torch.nn as nn
from transformers.pipelines import PIPELINE_REGISTRY
from transformers.utils import add_end_docstrings
from transformers.pipelines.base import Pipeline, build_pipeline_init_args
class SscdImageProcessor(BaseImageProcessor):
    def __init__(
            self,
            do_resize: bool = True,
            size: int = 288,
            image_mean: Optional[Union[float, List[float]]] = None,
            image_std: Optional[Union[float, List[float]]] = None,
            do_convert_rgb: bool = True,
            **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.size = size
        self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406]
        self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225]
        self.do_convert_rgb = do_convert_rgb
        self.do_resize = do_resize

    def preprocess(
            self,
            image: Image,
            do_resize: bool = None,
            **kwargs,
    ):
        size_transforms = [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=self.image_mean, std=self.image_std,
            ),
        ]
        if do_resize is None:
            do_resize = self.do_resize
        if do_resize:
            size_transforms.append(transforms.Resize(self.size))
        preprocess = transforms.Compose([
            transforms.Resize(self.size),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=self.image_mean, std=self.image_std,
            ),
        ])
        if self.do_convert_rgb:
            image = image.convert('RGB')
        return preprocess(image).unsqueeze(0)


class SscdConfig(PretrainedConfig):
    model_type = 'sscd-copy-detection'

    def __init__(self, model_path: str = None, **kwargs):
        if model_path is None:
            model_path = 'sscd_disc_mixup.torchscript.pt'
        super().__init__(model_path=model_path, **kwargs)


class SscdModel(PreTrainedModel):
    config_class = SscdConfig

    def __init__(self, config, model_path: str = None):
        super().__init__(config)
        self.dummy_param = nn.Parameter(torch.zeros(0))
        if model_path is None:
            model_path = config.model_path
        is_local = os.path.isdir(config.name_or_path)
        if is_local:
            config.base_path = config.name_or_path
        else:
            file_path = hf_hub_download(repo_id=config.name_or_path, filename=model_path)
            config.base_path = os.path.dirname(file_path)
        model_path = config.base_path + '/' + model_path
        if model_path is not None:
            self.model = torch.jit.load(model_path)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        return cls(AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs))

    def forward(self, inputs):
        return self.model(inputs)[0, :]



@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
class SscdPipeline(Pipeline):
    def __init__(self, model, **kwargs):
        self.device_id = kwargs['device']
        super().__init__(model=model, **kwargs)

    def _sanitize_parameters(self, **kwargs):
        return {}, {}, {}

    def preprocess(self, input):
        return self.image_processor.preprocess(input)

    def _forward(self, inputs):
        return self.model(inputs)

    def postprocess(self, model_outputs):
        return model_outputs


AutoConfig.register('sscd-copy-detection', SscdConfig)
AutoModel.register(SscdConfig, SscdModel)
AutoImageProcessor.register(SscdConfig, slow_image_processor_class=SscdImageProcessor)
models = AutoModel.from_pretrained('m3/sscd-copy-detection')

PIPELINE_REGISTRY.register_pipeline(
    task='sscd-copy-detection',
    pipeline_class=SscdPipeline,
    pt_model=SscdModel
)