Spaces:
Paused
Paused
from abc import ABC, abstractmethod | |
from typing import Union | |
import numpy as np | |
import PIL.Image | |
import torch | |
class Preprocessor(ABC): | |
""" | |
This abstract base class defines the interface for image preprocessors. | |
Subclasses should implement the abstract methods `from_pretrained` and | |
`__call__` to provide specific loading and preprocessing logic for their | |
respective models. | |
Args: | |
model (`nn.Module`): The torch model to use. | |
""" | |
def __init__(self, model): | |
self.model = model | |
def to(self, device): | |
""" | |
Moves the underlying model to the specified device | |
(e.g., CPU or GPU). | |
Args: | |
device (`torch.device`): The target device. | |
Returns: | |
`Preprocessor`: The preprocessor object itself (for method chaining). | |
""" | |
self.model = self.model.to(device) | |
return self | |
def from_pretrained(self): | |
""" | |
This abstract method defines how the preprocessor loads pre-trained | |
weights or configurations specific to the model it supports. Subclasses | |
must implement this method to handle model-specific loading logic. | |
This method might download pre-trained weights from a repository or | |
load them from a local file depending on the model's requirements. | |
""" | |
pass | |
def __call__( | |
self, | |
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], | |
resolution_scale: float = 1.0, | |
invert: bool = True, | |
return_type: str = "pil", | |
): | |
""" | |
Preprocesses an image for use with the underlying model. | |
Args: | |
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): Input image as PIL Image, | |
NumPy array, or PyTorch tensor format. | |
resolution_scale (`float`, optional, defaults to 1.0): Scale factor for image resolution during | |
resolution_scale (`float`, *optional*, defaults to 1.0): Scale factor for image resolution during | |
preprocessing and post-processing. Defaults to 1.0 for no scaling. | |
invert (`bool`, *optional*, defaults to True): Inverts the generated image if True. | |
return_type (`str`, *optional*, defaults to "pil"): The desired return type, either "pt" for PyTorch tensor, | |
"np" for NumPy array, or "pil" for PIL image. | |
Returns: | |
`Union[PIL.Image.Image, torch.Tensor]`: The preprocessed image in the | |
specified format. | |
""" | |
pass | |