File size: 2,584 Bytes
a2919a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from abc import ABC, abstractmethod
from typing import Union

import numpy as np
import PIL.Image
import torch


class Preprocessor(ABC):
    """
    This abstract base class defines the interface for image preprocessors.

    Subclasses should implement the abstract methods `from_pretrained` and
    `__call__` to provide specific loading and preprocessing logic for their
    respective models.

    Args:
        model (`nn.Module`): The torch model to use.
    """

    def __init__(self, model):
        self.model = model

    def to(self, device):
        """
        Moves the underlying model to the specified device
        (e.g., CPU or GPU).

        Args:
            device (`torch.device`): The target device.

        Returns:
            `Preprocessor`: The preprocessor object itself (for method chaining).
        """
        self.model = self.model.to(device)
        return self

    @abstractmethod
    def from_pretrained(self):
        """
        This abstract method defines how the preprocessor loads pre-trained
        weights or configurations specific to the model it supports. Subclasses
        must implement this method to handle model-specific loading logic.

        This method might download pre-trained weights from a repository or
        load them from a local file depending on the model's requirements.
        """
        pass

    @abstractmethod
    def __call__(
        self,
        image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
        resolution_scale: float = 1.0,
        invert: bool = True,
        return_type: str = "pil",
    ):
        """
        Preprocesses an image for use with the underlying model.

        Args:
            image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): Input image as PIL Image,
                NumPy array, or PyTorch tensor format.
            resolution_scale (`float`, optional, defaults to 1.0): Scale factor for image resolution during
            resolution_scale (`float`, *optional*, defaults to 1.0): Scale factor for image resolution during
                preprocessing and post-processing. Defaults to 1.0 for no scaling.
            invert (`bool`, *optional*, defaults to True): Inverts the generated image if True.
            return_type (`str`, *optional*, defaults to "pil"): The desired return type, either "pt" for PyTorch tensor,
                "np" for NumPy array, or "pil" for PIL image.

        Returns:
            `Union[PIL.Image.Image, torch.Tensor]`: The preprocessed image in the
                specified format.
        """
        pass