amos1088 commited on
Commit
a2919a7
·
1 Parent(s): 78a6ecd
Files changed (35) hide show
  1. .idea/vcs.xml +1 -1
  2. app.py +5 -2
  3. image_gen_aux/__init__.py +81 -0
  4. image_gen_aux/image_processor.py +233 -0
  5. image_gen_aux/modeling_utils.py +90 -0
  6. image_gen_aux/preprocessors/README.md +54 -0
  7. image_gen_aux/preprocessors/__init__.py +66 -0
  8. image_gen_aux/preprocessors/depth/README.md +45 -0
  9. image_gen_aux/preprocessors/depth/__init__.py +42 -0
  10. image_gen_aux/preprocessors/depth/depth_preprocessor.py +70 -0
  11. image_gen_aux/preprocessors/lineart/LICENSE.txt +21 -0
  12. image_gen_aux/preprocessors/lineart/README.md +27 -0
  13. image_gen_aux/preprocessors/lineart/__init__.py +36 -0
  14. image_gen_aux/preprocessors/lineart/lineart_preprocessor.py +101 -0
  15. image_gen_aux/preprocessors/lineart/model.py +87 -0
  16. image_gen_aux/preprocessors/lineart_standard/README.md +23 -0
  17. image_gen_aux/preprocessors/lineart_standard/__init__.py +36 -0
  18. image_gen_aux/preprocessors/lineart_standard/lineart_standard_preprocessor.py +82 -0
  19. image_gen_aux/preprocessors/preprocessor.py +75 -0
  20. image_gen_aux/preprocessors/teed/LICENSE.txt +21 -0
  21. image_gen_aux/preprocessors/teed/README.md +26 -0
  22. image_gen_aux/preprocessors/teed/__init__.py +36 -0
  23. image_gen_aux/preprocessors/teed/teed.py +323 -0
  24. image_gen_aux/preprocessors/teed/teed_preprocessor.py +121 -0
  25. image_gen_aux/upscalers/README.md +58 -0
  26. image_gen_aux/upscalers/__init__.py +36 -0
  27. image_gen_aux/upscalers/upscale_with_model.py +118 -0
  28. image_gen_aux/utils/__init__.py +25 -0
  29. image_gen_aux/utils/constants.py +1 -0
  30. image_gen_aux/utils/import_utils.py +128 -0
  31. image_gen_aux/utils/loading_utils.py +63 -0
  32. image_gen_aux/utils/logging.py +341 -0
  33. image_gen_aux/utils/model_utils.py +37 -0
  34. image_gen_aux/utils/tiling_utils.py +86 -0
  35. requirements.txt +2 -2
.idea/vcs.xml CHANGED
@@ -2,7 +2,7 @@
2
  <project version="4">
3
  <component name="VcsDirectoryMappings">
4
  <mapping directory="" vcs="Git" />
5
- <mapping directory="$PROJECT_DIR$/depth-fm" vcs="Git" />
6
  <mapping directory="$PROJECT_DIR$/test_gradio" vcs="Git" />
7
  </component>
8
  </project>
 
2
  <project version="4">
3
  <component name="VcsDirectoryMappings">
4
  <mapping directory="" vcs="Git" />
5
+ <mapping directory="$PROJECT_DIR$/_image_gen_aux" vcs="Git" />
6
  <mapping directory="$PROJECT_DIR$/test_gradio" vcs="Git" />
7
  </component>
8
  </project>
app.py CHANGED
@@ -11,6 +11,7 @@ from huggingface_hub import login
11
  import torch
12
  from diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel
13
  from diffusers.utils import load_image
 
14
 
15
  # ----------------------------
16
  # Step 1: Download IP Adapter if not exists
@@ -70,9 +71,11 @@ pipe.init_ipadapter(
70
  @spaces.GPU
71
  def gui_generation(prompt,negative_prompt, ref_img, guidance_scale, ipadapter_scale):
72
  ref_img = load_image(ref_img.name).convert('RGB')
 
 
 
 
73
 
74
- control_image = load_image(
75
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_lcm_depth.png")
76
  generator = torch.Generator(device="cpu").manual_seed(0)
77
  image = pipe(
78
  width=1024,
 
11
  import torch
12
  from diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel
13
  from diffusers.utils import load_image
14
+ from image_gen_aux import DepthPreprocessor
15
 
16
  # ----------------------------
17
  # Step 1: Download IP Adapter if not exists
 
71
  @spaces.GPU
72
  def gui_generation(prompt,negative_prompt, ref_img, guidance_scale, ipadapter_scale):
73
  ref_img = load_image(ref_img.name).convert('RGB')
74
+ image = load_image(ref_img.name)
75
+
76
+ depth_preprocessor = DepthPreprocessor.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf").to("cuda")
77
+ control_image = depth_preprocessor(image, invert=True)[0].convert("RGB")
78
 
 
 
79
  generator = torch.Generator(device="cpu").manual_seed(0)
80
  image = pipe(
81
  width=1024,
image_gen_aux/__init__.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from .utils import (
18
+ IMAGE_AUX_SLOW_IMPORT,
19
+ OptionalDependencyNotAvailable,
20
+ _LazyModule,
21
+ is_torch_available,
22
+ is_transformers_available,
23
+ )
24
+
25
+
26
+ # Lazy Import based on
27
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py
28
+
29
+ # When adding a new object to this init, please add it to `_import_structure`. The `_import_structure` is a dictionary submodule to list of object names,
30
+ # and is used to defer the actual importing for when the objects are requested.
31
+ # This way `import image_gen_aux` provides the names in the namespace without actually importing anything (and especially none of the backends).
32
+
33
+ _import_structure = {
34
+ "upscalers": [],
35
+ "preprocessors": ["LineArtStandardPreprocessor"],
36
+ "utils": [
37
+ "OptionalDependencyNotAvailable",
38
+ "is_torch_available",
39
+ "logging",
40
+ ],
41
+ }
42
+
43
+
44
+ try:
45
+ if not is_torch_available():
46
+ raise OptionalDependencyNotAvailable()
47
+ except OptionalDependencyNotAvailable:
48
+ ...
49
+ else:
50
+ _import_structure["upscalers"].extend(["UpscaleWithModel"])
51
+
52
+ _import_structure["preprocessors"].extend(["DepthPreprocessor", "LineArtPreprocessor", "TeedPreprocessor"])
53
+
54
+ if TYPE_CHECKING or IMAGE_AUX_SLOW_IMPORT:
55
+ from .preprocessors import LineArtStandardPreprocessor
56
+
57
+ try:
58
+ if not is_torch_available():
59
+ raise OptionalDependencyNotAvailable()
60
+ except OptionalDependencyNotAvailable:
61
+ ...
62
+ else:
63
+ from .preprocessors import LineArtPreprocessor, TeedPreprocessor
64
+ from .upscalers import UpscaleWithModel
65
+
66
+ try:
67
+ if not (is_torch_available() and is_transformers_available()):
68
+ raise OptionalDependencyNotAvailable()
69
+ except OptionalDependencyNotAvailable:
70
+ ...
71
+ else:
72
+ from .preprocessors import DepthPreprocessor
73
+ else:
74
+ import sys
75
+
76
+ sys.modules[__name__] = _LazyModule(
77
+ __name__,
78
+ globals()["__file__"],
79
+ _import_structure,
80
+ module_spec=__spec__,
81
+ )
image_gen_aux/image_processor.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Union
16
+
17
+ import cv2
18
+ import numpy as np
19
+ import PIL.Image
20
+ import torch
21
+ from PIL import Image
22
+
23
+
24
+ class ImageMixin:
25
+ """
26
+ A mixin class for converting images between different formats: PIL, NumPy, and PyTorch tensors.
27
+
28
+ This class provides methods to:
29
+ - Convert a PIL image or a NumPy array to a PyTorch tensor.
30
+ - Post-process a PyTorch tensor image and convert it to the specified return type.
31
+ - Convert a PIL image or a list of PIL images to NumPy arrays.
32
+ - Convert a NumPy image to a PyTorch tensor.
33
+ - Convert a PyTorch tensor to a NumPy image.
34
+ - Convert a NumPy image or a batch of images to a PIL image.
35
+ """
36
+
37
+ def convert_image_to_tensor(
38
+ self, image: Union[PIL.Image.Image, np.ndarray, List[PIL.Image.Image]], normalize: bool = True
39
+ ) -> torch.Tensor:
40
+ """
41
+ Convert a PIL image or a NumPy array to a PyTorch tensor.
42
+
43
+ Args:
44
+ image (Union[PIL.Image.Image, np.ndarray]): The input image, either as a PIL image, a NumPy array or a list of
45
+ PIL images.
46
+
47
+ Returns:
48
+ torch.Tensor: The converted image as a PyTorch tensor.
49
+ """
50
+ if isinstance(image, (PIL.Image.Image, list)):
51
+ # We expect that if it is a list, it only should contain pillow images
52
+ if isinstance(image, list):
53
+ for single_image in image:
54
+ if not isinstance(single_image, PIL.Image.Image):
55
+ raise ValueError("All images in the list must be Pillow images.")
56
+
57
+ image = self.pil_to_numpy(image, normalize)
58
+
59
+ return self.numpy_to_pt(image)
60
+
61
+ def post_process_image(self, image: torch.Tensor, return_type: str):
62
+ """
63
+ Post-process a PyTorch tensor image and convert it to the specified return type.
64
+
65
+ Args:
66
+ image (torch.Tensor): The input image as a PyTorch tensor.
67
+ return_type (str): The desired return type, either "pt" for PyTorch tensor, "np" for NumPy array, or "pil" for PIL image.
68
+
69
+ Returns:
70
+ Union[torch.Tensor, np.ndarray, List[PIL.Image.Image]]: The post-processed image in the specified return type.
71
+ """
72
+ if return_type == "pt":
73
+ return image
74
+
75
+ image = self.pt_to_numpy(image)
76
+ if return_type == "np":
77
+ return image
78
+
79
+ image = self.numpy_to_pil(image)
80
+ return image
81
+
82
+ @staticmethod
83
+ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image], normalize: bool = True) -> np.ndarray:
84
+ """
85
+ Convert a PIL image or a list of PIL images to NumPy arrays.
86
+
87
+ Args:
88
+ images (Union[List[PIL.Image.Image], PIL.Image.Image]): The input image(s) as PIL image(s).
89
+
90
+ Returns:
91
+ np.ndarray: The converted image(s) as a NumPy array.
92
+ """
93
+ if not isinstance(images, list):
94
+ images = [images]
95
+
96
+ if normalize:
97
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
98
+ else:
99
+ images = [np.array(image).astype(np.float32) for image in images]
100
+
101
+ images = np.stack(images, axis=0)
102
+
103
+ return images
104
+
105
+ @staticmethod
106
+ def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
107
+ """
108
+ Convert a NumPy image to a PyTorch tensor.
109
+
110
+ Args:
111
+ images (np.ndarray): The input image(s) as a NumPy array.
112
+
113
+ Returns:
114
+ torch.Tensor: The converted image(s) as a PyTorch tensor.
115
+ """
116
+ if images.ndim == 3:
117
+ images = images[..., None]
118
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2)).float()
119
+
120
+ return images
121
+
122
+ @staticmethod
123
+ def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
124
+ """
125
+ Convert a PyTorch tensor to a NumPy image.
126
+
127
+ Args:
128
+ images (torch.Tensor): The input image(s) as a PyTorch tensor.
129
+
130
+ Returns:
131
+ np.ndarray: The converted image(s) as a NumPy array.
132
+ """
133
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
134
+ return images
135
+
136
+ @staticmethod
137
+ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
138
+ """
139
+ Convert a NumPy image or a batch of images to a PIL image.
140
+
141
+ Args:
142
+ images (np.ndarray): The input image(s) as a NumPy array.
143
+
144
+ Returns:
145
+ List[PIL.Image.Image]: The converted image(s) as PIL images.
146
+ """
147
+ if images.ndim == 3:
148
+ images = images[None, ...]
149
+ images = (images * 255).round().astype("uint8")
150
+ if images.shape[-1] == 1:
151
+ # special case for grayscale (single channel) images
152
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
153
+ else:
154
+ pil_images = [Image.fromarray(image) for image in images]
155
+
156
+ return pil_images
157
+
158
+ @staticmethod
159
+ def scale_image(image: torch.Tensor, scale: float, mutiple_factor: int = 8) -> torch.Tensor:
160
+ """
161
+ Scales an image while maintaining aspect ratio and ensuring dimensions are multiples of `multiple_factor`.
162
+
163
+ Args:
164
+ image (`torch.Tensor`): The input image tensor of shape (batch, channels, height, width).
165
+ scale (`float`): The scaling factor applied to the image dimensions.
166
+ multiple_factor (`int`, *optional*, defaults to 8): The factor by which the new dimensions should be divisible.
167
+
168
+ Returns:
169
+ `torch.Tensor`: The scaled image tensor.
170
+ """
171
+
172
+ if scale == 1.0:
173
+ return image, scale
174
+
175
+ _batch, _channels, height, width = image.shape
176
+
177
+ # Calculate new dimensions while maintaining aspect ratio
178
+ new_height = int(height * scale)
179
+ new_width = int(width * scale)
180
+
181
+ # Ensure new dimensions are multiples of mutiple_factor
182
+ new_height = (new_height // mutiple_factor) * mutiple_factor
183
+ new_width = (new_width // mutiple_factor) * mutiple_factor
184
+
185
+ # if the final height and widht changed because of the multiple_factor, we need to set the scale too
186
+ scale = new_height / height
187
+
188
+ # Resize the image using the calculated dimensions
189
+ resized_image = torch.nn.functional.interpolate(
190
+ image, size=(new_height, new_width), mode="bilinear", align_corners=False
191
+ )
192
+
193
+ return resized_image, scale
194
+
195
+ @staticmethod
196
+ def resize_numpy_image(image: np.ndarray, scale: float, multiple_factor: int = 8) -> np.ndarray:
197
+ """
198
+ Resizes a NumPy image while maintaining aspect ratio and ensuring dimensions are multiples of `multiple_factor`.
199
+
200
+ Args:
201
+ image (`np.ndarray`): The input image array of shape (height, width, channels) or (height, width) for grayscale.
202
+ scale (`float`): The scaling factor applied to the image dimensions.
203
+ multiple_factor (`int`, *optional*, defaults to 8): The factor by which the new dimensions should be divisible.
204
+
205
+ Returns:
206
+ `np.ndarray`: The resized image array.
207
+ """
208
+ if len(image.shape) == 3: # Single image without batch dimension
209
+ image = np.expand_dims(image, axis=0)
210
+
211
+ batch_size, height, width, channels = image.shape
212
+
213
+ # Calculate new dimensions while maintaining aspect ratio
214
+ new_height = int(height * scale)
215
+ new_width = int(width * scale)
216
+
217
+ # Ensure new dimensions are multiples of multiple_factor
218
+ new_height = (new_height // multiple_factor) * multiple_factor
219
+ new_width = (new_width // multiple_factor) * multiple_factor
220
+
221
+ # if the final height and widht changed because of the multiple_factor, we need to set the scale too
222
+ scale = new_height / height
223
+
224
+ # Resize each image in the batch
225
+ resized_images = []
226
+ for i in range(batch_size):
227
+ resized_image = cv2.resize(image[i], (new_width, new_height), interpolation=cv2.INTER_LINEAR)
228
+ resized_images.append(resized_image)
229
+
230
+ # Stack resized images back into a single array
231
+ resized_images = np.stack(resized_images, axis=0)
232
+
233
+ return resized_images, scale
image_gen_aux/modeling_utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from typing import List, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+
8
+ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
9
+ """
10
+ Gets the device of a PyTorch module's parameters or buffers.
11
+
12
+ Args:
13
+ parameter (`torch.nn.Module`): The PyTorch module from which to get the device.
14
+
15
+ Returns:
16
+ `torch.device`: The device of the module's parameters or buffers.
17
+ """
18
+ try:
19
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
20
+ return next(parameters_and_buffers).device
21
+ except StopIteration:
22
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
23
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
24
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
25
+ return tuples
26
+
27
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
28
+ first_tuple = next(gen)
29
+ return first_tuple[1].device
30
+
31
+
32
+ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
33
+ """
34
+ Gets the data type of a PyTorch module's parameters or buffers.
35
+
36
+ Args:
37
+ parameter (`torch.nn.Module`): The PyTorch module from which to get the data type.
38
+
39
+ Returns:
40
+ `torch.dtype`: The data type of the module's parameters or buffers.
41
+ """
42
+ try:
43
+ params = tuple(parameter.parameters())
44
+ if len(params) > 0:
45
+ return params[0].dtype
46
+
47
+ buffers = tuple(parameter.buffers())
48
+ if len(buffers) > 0:
49
+ return buffers[0].dtype
50
+
51
+ except StopIteration:
52
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
53
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
54
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
55
+ return tuples
56
+
57
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
58
+ first_tuple = next(gen)
59
+ return first_tuple[1].dtype
60
+
61
+
62
+ class ModelMixin(torch.nn.Module):
63
+ """
64
+ Provides convenient properties to access the device and data type
65
+ of a PyTorch module.
66
+
67
+ By inheriting from this class, your custom PyTorch modules can access these properties
68
+ without manual retrieval of device and data type information.
69
+
70
+ These properties assume that all module parameters and buffers reside
71
+ on the same device and have the same data type, respectively.
72
+ """
73
+
74
+ def __init__(self):
75
+ super().__init__()
76
+
77
+ @property
78
+ def device(self) -> torch.device:
79
+ """
80
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
81
+ device).
82
+ """
83
+ return get_parameter_device(self)
84
+
85
+ @property
86
+ def dtype(self) -> torch.dtype:
87
+ """
88
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
89
+ """
90
+ return get_parameter_dtype(self)
image_gen_aux/preprocessors/README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Preprocessors
2
+
3
+ Preprocessors in this context refer to the application of machine learning models to images or groups of images as a preliminary step for other tasks. A common use case is to feed preprocessed images into a controlnet to guide the generation of a diffusion model. Examples of preprocessors include depth maps, normals, and edge detection.
4
+
5
+ ## Supported preprocessors
6
+
7
+ This is a list of the currently supported preprocessors.
8
+
9
+ * [DepthPreprocessor](https://github.com/asomoza/image_gen_aux/blob/main/src/image_gen_aux/preprocessors/depth/README.md)
10
+ * [LineArtPreprocessor](https://github.com/asomoza/image_gen_aux/blob/main/src/image_gen_aux/preprocessors/lineart/README.md)
11
+ * [LineArtStandardPreprocessor](https://github.com/asomoza/image_gen_aux/blob/main/src/image_gen_aux/preprocessors/lineart_standard/README.md)
12
+ * [TeedPreprocessor](https://github.com/asomoza/image_gen_aux/blob/main/src/image_gen_aux/preprocessors/teed/README.md)
13
+
14
+ ## General preprocessor usage
15
+
16
+ ```python
17
+ from image_gen_aux import LineArtPreprocessor
18
+ from image_gen_aux.utils import load_image
19
+
20
+ input_image = load_image(
21
+ "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/simple_upscale/hippowaffle.png"
22
+ )
23
+
24
+ lineart_preprocessor = LineArtPreprocessor.from_pretrained("OzzyGT/lineart").to("cuda")
25
+ image = lineart_preprocessor(input_image)
26
+ image.save("lineart.png")
27
+ ```
28
+
29
+ Use the Hugging Face Hub model id, a URL, or a local path to load a model. You can also load pickle checkpoints, but it's not recommended due to the [vulnerabilities](https://docs.python.org/3/library/pickle.html) of pickled files. Loading [safetensor](https://hf.co/docs/safetensors/index) files is a more secure option.
30
+
31
+ If the checkpoint has custom imports (high risk!), you'll need to set the `weights_only` argument to `False`.
32
+
33
+ You can also specify a `filename` and a `subfolder` if needed.
34
+
35
+ ```python
36
+ lineart_preprocessor = LineArtPreprocessor.from_pretrained("lllyasviel/Annotators", filename="sk_model.pth", weights_only=False).to("cuda")
37
+ ```
38
+
39
+ ### List of safetensors checkpoints
40
+
41
+ This is the current list of safetensor checkpoints available on the Hub.
42
+
43
+ |Preprocessor|Repository|Author|
44
+ |---|---|---|
45
+ |Depth Anything V2 Small|depth-anything/Depth-Anything-V2-Small-hf|<https://depth-anything-v2.github.io/>|
46
+ |Depth Anything V2 Base|depth-anything/Depth-Anything-V2-Base-hf|<https://depth-anything-v2.github.io/>|
47
+ |Depth Anything V2 Large|depth-anything/Depth-Anything-V2-Large-hf|<https://depth-anything-v2.github.io/>|
48
+ |LineArt|OzzyGT/lineart|[Caroline Chan](https://github.com/carolineec)|
49
+ |Teed|OzzyGT/teed|<https://github.com/xavysp/TEED>|
50
+ |ZoeDepth NYU|Intel/zoedepth-nyu|<https://github.com/isl-org/ZoeDepth>|
51
+ |ZoeDepth KITTI|Intel/zoedepth-kitti|<https://github.com/isl-org/ZoeDepth>|
52
+ |ZoeDepth NYU and KITTI|Intel/zoedepth-nyu-kitti|<https://github.com/isl-org/ZoeDepth>|
53
+
54
+ If you own the model and want us to change the repository to your name/organization please open an issue.
image_gen_aux/preprocessors/__init__.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ..utils import (
4
+ IMAGE_AUX_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ is_torch_available,
8
+ is_transformers_available,
9
+ )
10
+
11
+
12
+ _import_structure = {
13
+ "lineart": [],
14
+ "depth": [],
15
+ "teed": [],
16
+ "lineart_standard": ["LineArtStandardPreprocessor"],
17
+ }
18
+
19
+
20
+ try:
21
+ if not (is_torch_available()):
22
+ raise OptionalDependencyNotAvailable()
23
+ except OptionalDependencyNotAvailable:
24
+ ...
25
+ else:
26
+ _import_structure["lineart"] = ["LineArtPreprocessor"]
27
+ _import_structure["teed"] = ["TeedPreprocessor"]
28
+
29
+ try:
30
+ if not (is_torch_available() and is_transformers_available()):
31
+ raise OptionalDependencyNotAvailable()
32
+ except OptionalDependencyNotAvailable:
33
+ ...
34
+ else:
35
+ _import_structure["depth"] = [
36
+ "DepthPreprocessor",
37
+ ]
38
+
39
+ if TYPE_CHECKING or IMAGE_AUX_SLOW_IMPORT:
40
+ from .lineart_standard import LineArtStandardPreprocessor
41
+
42
+ try:
43
+ if not is_torch_available():
44
+ raise OptionalDependencyNotAvailable()
45
+ except OptionalDependencyNotAvailable:
46
+ ...
47
+ else:
48
+ from .lineart import LineArtPreprocessor
49
+ from .teed import TeedPreprocessor
50
+
51
+ try:
52
+ if not (is_torch_available() and is_transformers_available()):
53
+ raise OptionalDependencyNotAvailable()
54
+ except OptionalDependencyNotAvailable:
55
+ ...
56
+ else:
57
+ from .depth import DepthPreprocessor
58
+ else:
59
+ import sys
60
+
61
+ sys.modules[__name__] = _LazyModule(
62
+ __name__,
63
+ globals()["__file__"],
64
+ _import_structure,
65
+ module_spec=__spec__,
66
+ )
image_gen_aux/preprocessors/depth/README.md ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DEPTH
2
+
3
+ Monocular depth estimation is a computer vision task that involves predicting the depth information of a scene from a single image. In other words, it is the process of estimating the distance of objects in a scene from a single camera viewpoint.
4
+
5
+ Monocular depth estimation has various applications, including 3D reconstruction, augmented reality, autonomous driving, and robotics. It is a challenging task as it requires the model to understand the complex relationships between objects in the scene and the corresponding depth information, which can be affected by factors such as lighting conditions, occlusion, and texture.
6
+
7
+ ## Usage
8
+
9
+ ```python
10
+ from image_gen_aux import DepthPreprocessor
11
+ from image_gen_aux.utils import load_image
12
+
13
+ input_image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/depth/coffee_ship.png")
14
+
15
+ depth_preprocessor = DepthPreprocessor.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf").to("cuda")
16
+ image = depth_preprocessor(input_image)[0]
17
+ image.save("depth.png")
18
+ ```
19
+
20
+ ## Models
21
+
22
+ The Depth Preprocessor supports any depth estimation model that the `transformers` library supports that doesn't have a fixed image size restriction, but we mainly recommend and ensure the correct functionality for these models:
23
+
24
+ |Model|License|Project Page|
25
+ |---|---|---|
26
+ |Depth Anything V2|CC-BY-NC-4.0|<https://depth-anything-v2.github.io/>|
27
+ |ZoeDepth|MIT|<https://github.com/isl-org/ZoeDepth>|
28
+
29
+ Each model has different variations:
30
+
31
+ ### Depth Anything V2
32
+
33
+ |Variation|Repo ID|
34
+ |---|---|
35
+ |Small|depth-anything/Depth-Anything-V2-Small-hf|
36
+ |Base|depth-anything/Depth-Anything-V2-Base-hf|
37
+ |Large|depth-anything/Depth-Anything-V2-Large-hf|
38
+
39
+ ### ZoeDepth
40
+
41
+ |Variation|Repo ID|
42
+ |---|---|
43
+ |NYU|Intel/zoedepth-nyu|
44
+ |KITTI|Intel/zoedepth-kitti|
45
+ |NYU and KITTI|Intel/zoedepth-nyu-kitti|
image_gen_aux/preprocessors/depth/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ IMAGE_AUX_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ is_torch_available,
8
+ is_transformers_available,
9
+ )
10
+
11
+
12
+ _import_structure = {}
13
+
14
+ try:
15
+ if not (is_torch_available() and is_transformers_available()):
16
+ raise OptionalDependencyNotAvailable()
17
+ except OptionalDependencyNotAvailable:
18
+ ...
19
+ else:
20
+ _import_structure["depth_preprocessor"] = [
21
+ "DepthPreprocessor",
22
+ ]
23
+
24
+ if TYPE_CHECKING or IMAGE_AUX_SLOW_IMPORT:
25
+ try:
26
+ if not (is_torch_available() and is_transformers_available()):
27
+ raise OptionalDependencyNotAvailable()
28
+ except OptionalDependencyNotAvailable:
29
+ ...
30
+ else:
31
+ from .depth_preprocessor import (
32
+ DepthPreprocessor,
33
+ )
34
+ else:
35
+ import sys
36
+
37
+ sys.modules[__name__] = _LazyModule(
38
+ __name__,
39
+ globals()["__file__"],
40
+ _import_structure,
41
+ module_spec=__spec__,
42
+ )
image_gen_aux/preprocessors/depth/depth_preprocessor.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from transformers import AutoModelForDepthEstimation
9
+
10
+ from ...image_processor import ImageMixin
11
+ from ..preprocessor import Preprocessor
12
+
13
+
14
+ class DepthPreprocessor(Preprocessor, ImageMixin):
15
+ """Preprocessor specifically designed for monocular depth estimation.
16
+
17
+ This class inherits from both `Preprocessor` and `ImageMixin`. Please refer to each
18
+ one to get more information.
19
+ """
20
+
21
+ @classmethod
22
+ def from_pretrained(cls, pretrained_model_or_path: Union[str, os.PathLike]):
23
+ model = AutoModelForDepthEstimation.from_pretrained(pretrained_model_or_path)
24
+
25
+ return cls(model)
26
+
27
+ @torch.inference_mode
28
+ def __call__(
29
+ self,
30
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image]],
31
+ resolution_scale: float = 1.0,
32
+ invert: bool = False,
33
+ batch_size: int = 1,
34
+ return_type: str = "pil",
35
+ ):
36
+ if not isinstance(image, torch.Tensor):
37
+ image = self.convert_image_to_tensor(image).to(self.model.device)
38
+
39
+ image, resolution_scale = self.scale_image(image, resolution_scale)
40
+
41
+ processed_images = []
42
+
43
+ for i in range(0, len(image), batch_size):
44
+ batch = image[i : i + batch_size].to(self.model.device)
45
+
46
+ predicted_depth = self.model(batch).predicted_depth
47
+
48
+ # depth models returns only batch, height and width, so we add the channel
49
+ predicted_depth = predicted_depth.unsqueeze(1)
50
+
51
+ # models like depth anything can return a different size image
52
+ if batch.shape[2] != predicted_depth.shape[2] or batch.shape[3] != predicted_depth.shape[3]:
53
+ predicted_depth = F.interpolate(
54
+ predicted_depth, size=(batch.shape[2], batch.shape[3]), mode="bilinear", align_corners=False
55
+ )
56
+
57
+ if invert:
58
+ predicted_depth = 255 - predicted_depth
59
+ processed_images.append(predicted_depth.cpu())
60
+ predicted_depth = torch.cat(processed_images, dim=0)
61
+
62
+ if resolution_scale != 1.0:
63
+ predicted_depth, _ = self.scale_image(predicted_depth, 1 / resolution_scale)
64
+
65
+ predicted_depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
66
+ predicted_depth = predicted_depth.clamp(0, 1)
67
+
68
+ image = self.post_process_image(predicted_depth, return_type)
69
+
70
+ return image
image_gen_aux/preprocessors/lineart/LICENSE.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Caroline Chan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
image_gen_aux/preprocessors/lineart/README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Line Art
2
+
3
+ Line Art is based on the original repository, [Informative Drawings: Learning to generate line drawings that convey geometry and semantics](https://github.com/carolineec/informative-drawings).
4
+
5
+ From the project page [summary](https://carolineec.github.io/informative_drawings/):
6
+
7
+ *Our method approaches line drawing generation as an unsupervised image translation problem which uses various losses to assess the information communicated in a line drawing. Our key idea is to view the problem as an encoding through a line drawing and to maximize the quality of this encoding through explicit geometry, semantic, and appearance decoding objectives. This evaluation is performed by deep learning methods which decode depth, semantics, and appearance from line drawings. The aim is for the extracted depth and semantic information to match the scene geometry and semantics of the input photographs.*
8
+
9
+ ## Usage
10
+
11
+ ```python
12
+ from image_gen_aux import LineArtPreprocessor
13
+ from image_gen_aux.utils import load_image
14
+
15
+ input_image = load_image(
16
+ "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/simple_upscale/hippowaffle.png"
17
+ )
18
+
19
+ lineart_preprocessor = LineArtPreprocessor.from_pretrained("OzzyGT/lineart").to("cuda")
20
+ image = lineart_preprocessor(input_image)[0]
21
+ image.save("lineart.png")
22
+ ```
23
+
24
+ ## Additional resources
25
+
26
+ * [Project page](https://carolineec.github.io/informative_drawings/)
27
+ * [Paper](https://arxiv.org/abs/2203.12691)
image_gen_aux/preprocessors/lineart/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import IMAGE_AUX_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, is_torch_available
4
+
5
+
6
+ _import_structure = {}
7
+
8
+ try:
9
+ if not (is_torch_available()):
10
+ raise OptionalDependencyNotAvailable()
11
+ except OptionalDependencyNotAvailable:
12
+ ...
13
+ else:
14
+ _import_structure["lineart_preprocessor"] = [
15
+ "LineArtPreprocessor",
16
+ ]
17
+
18
+ if TYPE_CHECKING or IMAGE_AUX_SLOW_IMPORT:
19
+ try:
20
+ if not is_torch_available():
21
+ raise OptionalDependencyNotAvailable()
22
+ except OptionalDependencyNotAvailable:
23
+ ...
24
+ else:
25
+ from .lineart_preprocessor import (
26
+ LineArtPreprocessor,
27
+ )
28
+ else:
29
+ import sys
30
+
31
+ sys.modules[__name__] = _LazyModule(
32
+ __name__,
33
+ globals()["__file__"],
34
+ _import_structure,
35
+ module_spec=__spec__,
36
+ )
image_gen_aux/preprocessors/lineart/lineart_preprocessor.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from typing import List, Union
16
+
17
+ import numpy as np
18
+ import PIL.Image
19
+ import torch
20
+ from safetensors.torch import load_file
21
+
22
+ from ...image_processor import ImageMixin
23
+ from ...utils import SAFETENSORS_FILE_EXTENSION, get_model_path
24
+ from ..preprocessor import Preprocessor
25
+ from .model import Generator
26
+
27
+
28
+ class LineArtPreprocessor(Preprocessor, ImageMixin):
29
+ """Preprocessor specifically designed for converting images to line art.
30
+
31
+ This class inherits from both `Preprocessor` and `ImageMixin`. Please refer to each
32
+ one to get more information.
33
+ """
34
+
35
+ @classmethod
36
+ def from_pretrained(
37
+ cls,
38
+ pretrained_model_or_path: Union[str, os.PathLike],
39
+ filename: str = None,
40
+ subfolder: str = None,
41
+ weights_only: bool = True,
42
+ ) -> Generator:
43
+ model_path = get_model_path(pretrained_model_or_path, filename, subfolder)
44
+
45
+ file_extension = os.path.basename(model_path).split(".")[-1]
46
+ if file_extension == SAFETENSORS_FILE_EXTENSION:
47
+ state_dict = load_file(model_path, device="cpu")
48
+ else:
49
+ state_dict = torch.load(model_path, map_location=torch.device("cpu"), weights_only=weights_only)
50
+
51
+ model = Generator()
52
+ model.load_state_dict(state_dict)
53
+
54
+ return cls(model)
55
+
56
+ @torch.inference_mode
57
+ def __call__(
58
+ self,
59
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image]],
60
+ resolution_scale: float = 1.0,
61
+ invert: bool = True,
62
+ batch_size: int = 1,
63
+ return_type: str = "pil",
64
+ ):
65
+ """Preprocesses an image and generates line art using the pre-trained model.
66
+
67
+ Args:
68
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image]]`): Input image as PIL Image,
69
+ NumPy array, PyTorch tensor format or a list of PIL Images
70
+ resolution_scale (`float`, optional, defaults to 1.0): Scale factor for image resolution during
71
+ preprocessing and post-processing. Defaults to 1.0 for no scaling.
72
+ invert (`bool`, *optional*, defaults to True): Inverts the generated image if True (white or black background).
73
+ batch_size (`int`, *optional*, defaults to 1): The number of images to process in each batch.
74
+ return_type (`str`, *optional*, defaults to "pil"): The desired return type, either "pt" for PyTorch tensor, "np" for NumPy array,
75
+ or "pil" for PIL image.
76
+
77
+ Returns:
78
+ `Union[PIL.Image.Image, np.ndarray, torch.Tensor]`: The generated line art in the
79
+ specified output format.
80
+ """
81
+ if not isinstance(image, torch.Tensor):
82
+ image = self.convert_image_to_tensor(image)
83
+
84
+ image, resolution_scale = self.scale_image(image, resolution_scale)
85
+
86
+ processed_images = []
87
+
88
+ for i in range(0, len(image), batch_size):
89
+ batch = image[i : i + batch_size].to(self.model.device)
90
+ lineart = self.model(batch)
91
+ if invert:
92
+ lineart = 255 - lineart
93
+ processed_images.append(lineart.cpu())
94
+ lineart = torch.cat(processed_images, dim=0)
95
+
96
+ if resolution_scale != 1.0:
97
+ lineart, _ = self.scale_image(lineart, 1 / resolution_scale)
98
+
99
+ image = self.post_process_image(lineart, return_type)
100
+
101
+ return image
image_gen_aux/preprocessors/lineart/model.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Caroline Chan
2
+ #
3
+ # MIT License
4
+ import torch.nn as nn
5
+
6
+ from ...modeling_utils import ModelMixin
7
+
8
+
9
+ norm_layer = nn.InstanceNorm2d
10
+
11
+
12
+ class ResidualBlock(nn.Module):
13
+ def __init__(self, in_features: int):
14
+ super(ResidualBlock, self).__init__()
15
+
16
+ conv_block = [
17
+ nn.ReflectionPad2d(1),
18
+ nn.Conv2d(in_features, in_features, 3),
19
+ norm_layer(in_features),
20
+ nn.ReLU(inplace=True),
21
+ nn.ReflectionPad2d(1),
22
+ nn.Conv2d(in_features, in_features, 3),
23
+ norm_layer(in_features),
24
+ ]
25
+
26
+ self.conv_block = nn.Sequential(*conv_block)
27
+
28
+ def forward(self, x):
29
+ return x + self.conv_block(x)
30
+
31
+
32
+ class Generator(ModelMixin):
33
+ def __init__(self, input_nc: int = 3, output_nc: int = 1, n_residual_blocks: int = 3, sigmoid: bool = True):
34
+ super(Generator, self).__init__()
35
+
36
+ # Initial convolution block
37
+ model0 = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), norm_layer(64), nn.ReLU(inplace=True)]
38
+ self.model0 = nn.Sequential(*model0)
39
+
40
+ # Downsampling
41
+ model1 = []
42
+ in_features = 64
43
+ out_features = in_features * 2
44
+ for _ in range(2):
45
+ model1 += [
46
+ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
47
+ norm_layer(out_features),
48
+ nn.ReLU(inplace=True),
49
+ ]
50
+ in_features = out_features
51
+ out_features = in_features * 2
52
+ self.model1 = nn.Sequential(*model1)
53
+
54
+ model2 = []
55
+ # Residual blocks
56
+ for _ in range(n_residual_blocks):
57
+ model2 += [ResidualBlock(in_features)]
58
+ self.model2 = nn.Sequential(*model2)
59
+
60
+ # Upsampling
61
+ model3 = []
62
+ out_features = in_features // 2
63
+ for _ in range(2):
64
+ model3 += [
65
+ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
66
+ norm_layer(out_features),
67
+ nn.ReLU(inplace=True),
68
+ ]
69
+ in_features = out_features
70
+ out_features = in_features // 2
71
+ self.model3 = nn.Sequential(*model3)
72
+
73
+ # Output layer
74
+ model4 = [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)]
75
+ if sigmoid:
76
+ model4 += [nn.Sigmoid()]
77
+
78
+ self.model4 = nn.Sequential(*model4)
79
+
80
+ def forward(self, x):
81
+ out = self.model0(x)
82
+ out = self.model1(out)
83
+ out = self.model2(out)
84
+ out = self.model3(out)
85
+ out = self.model4(out)
86
+
87
+ return out
image_gen_aux/preprocessors/lineart_standard/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Line Art Standard
2
+
3
+ Line Art Standard was copied from the [comfyui_controlnet_aux](https://github.com/Fannovel16/comfyui_controlnet_aux) repository.
4
+ This preprocessor applies a Gaussian blur to the image to reduce noise and detail, then calculates the intensity difference between the blurred and original images to highlight edges.
5
+
6
+ ## Usage
7
+
8
+ ```python
9
+ from image_gen_aux import LineArtStandardPreprocessor
10
+ from image_gen_aux.utils import load_image
11
+
12
+ input_image = load_image(
13
+ "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/simple_upscale/hippowaffle.png"
14
+ )
15
+
16
+ lineart_standard_preprocessor = LineArtStandardPreprocessor()
17
+ image = lineart_standard_preprocessor(input_image)[0]
18
+ image.save("lineart_standard.png")
19
+ ```
20
+
21
+ ## Additional resources
22
+
23
+ * [Original implementation](https://github.com/Fannovel16/comfyui_controlnet_aux/blob/main/src/custom_controlnet_aux/lineart_standard/__init__.py)
image_gen_aux/preprocessors/lineart_standard/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import IMAGE_AUX_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, is_torch_available
4
+
5
+
6
+ _import_structure = {}
7
+
8
+ try:
9
+ if not (is_torch_available()):
10
+ raise OptionalDependencyNotAvailable()
11
+ except OptionalDependencyNotAvailable:
12
+ ...
13
+ else:
14
+ _import_structure["lineart_standard_preprocessor"] = [
15
+ "LineArtStandardPreprocessor",
16
+ ]
17
+
18
+ if TYPE_CHECKING or IMAGE_AUX_SLOW_IMPORT:
19
+ try:
20
+ if not is_torch_available():
21
+ raise OptionalDependencyNotAvailable()
22
+ except OptionalDependencyNotAvailable:
23
+ ...
24
+ else:
25
+ from .lineart_standard_preprocessor import (
26
+ LineArtStandardPreprocessor,
27
+ )
28
+ else:
29
+ import sys
30
+
31
+ sys.modules[__name__] = _LazyModule(
32
+ __name__,
33
+ globals()["__file__"],
34
+ _import_structure,
35
+ module_spec=__spec__,
36
+ )
image_gen_aux/preprocessors/lineart_standard/lineart_standard_preprocessor.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import PIL.Image
6
+
7
+ from ...image_processor import ImageMixin
8
+
9
+
10
+ class LineArtStandardPreprocessor(ImageMixin):
11
+ """Preprocessor specifically designed for converting images to line art standard.
12
+
13
+ This class inherits from both `Preprocessor` and `ImageMixin`. Please refer to each
14
+ one to get more information.
15
+ """
16
+
17
+ def __call__(
18
+ self,
19
+ image: Union[PIL.Image.Image, np.ndarray, List[PIL.Image.Image]],
20
+ resolution_scale: float = 1.0,
21
+ invert: bool = False,
22
+ gaussian_sigma=6.0,
23
+ intensity_threshold=8,
24
+ return_type: str = "pil",
25
+ ):
26
+ """Preprocesses an image and generates line art (standard) using the opencv.
27
+
28
+ Args:
29
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image]]`): Input image as PIL Image,
30
+ NumPy array, PyTorch tensor format or a list of PIL Images
31
+ resolution_scale (`float`, optional, defaults to 1.0): Scale factor for image resolution during
32
+ preprocessing and post-processing. Defaults to 1.0 for no scaling.
33
+ invert (`bool`, *optional*, defaults to True): Inverts the generated image if True (white or black background).
34
+ gaussian_sigma (float, optional): Sigma value for Gaussian blur. Defaults to 6.0.
35
+ intensity_threshold (int, optional): Threshold for intensity clipping. Defaults to 8.
36
+ return_type (`str`, *optional*, defaults to "pil"): The desired return type, either "pt" for PyTorch tensor, "np" for NumPy array,
37
+ or "pil" for PIL image.
38
+
39
+ Returns:
40
+ `Union[PIL.Image.Image, np.ndarray, torch.Tensor]`: The generated line art (standard) in the
41
+ specified output format.
42
+ """
43
+ if not isinstance(image, np.ndarray):
44
+ image = np.array(image).astype(np.float32)
45
+
46
+ # check if image has batch, if not, add it
47
+ if len(image.shape) == 3:
48
+ image = image[None, ...]
49
+
50
+ image, resolution_scale = (
51
+ self.resize_numpy_image(image, resolution_scale) if resolution_scale != 1.0 else image
52
+ )
53
+
54
+ batch_size, height, width, _channels = image.shape
55
+ processed_images = np.empty((batch_size, height, width), dtype=np.uint8)
56
+
57
+ # since we're using just cv2, we can't do batch processing
58
+ for i in range(batch_size):
59
+ gaussian = cv2.GaussianBlur(image[i], (0, 0), gaussian_sigma)
60
+ intensity = np.min(gaussian - image[i], axis=2).clip(0, 255)
61
+ intensity /= max(16, np.median(intensity[intensity > intensity_threshold]))
62
+ intensity *= 127
63
+ edges = intensity.clip(0, 255).astype(np.uint8)
64
+
65
+ processed_images[i] = edges
66
+
67
+ if invert:
68
+ processed_images = 255 - processed_images
69
+
70
+ processed_images = processed_images[..., None]
71
+
72
+ if resolution_scale != 1.0:
73
+ processed_images, _ = self.resize_numpy_image(processed_images, 1 / resolution_scale)
74
+ processed_images = processed_images[..., None] # cv2 resize removes the channel dimension if grayscale
75
+
76
+ if return_type == "np":
77
+ return processed_images
78
+
79
+ processed_images = np.transpose(processed_images, (0, 3, 1, 2))
80
+ processed_images = [PIL.Image.fromarray(image.squeeze(), mode="L") for image in processed_images]
81
+
82
+ return processed_images
image_gen_aux/preprocessors/preprocessor.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+
8
+
9
+ class Preprocessor(ABC):
10
+ """
11
+ This abstract base class defines the interface for image preprocessors.
12
+
13
+ Subclasses should implement the abstract methods `from_pretrained` and
14
+ `__call__` to provide specific loading and preprocessing logic for their
15
+ respective models.
16
+
17
+ Args:
18
+ model (`nn.Module`): The torch model to use.
19
+ """
20
+
21
+ def __init__(self, model):
22
+ self.model = model
23
+
24
+ def to(self, device):
25
+ """
26
+ Moves the underlying model to the specified device
27
+ (e.g., CPU or GPU).
28
+
29
+ Args:
30
+ device (`torch.device`): The target device.
31
+
32
+ Returns:
33
+ `Preprocessor`: The preprocessor object itself (for method chaining).
34
+ """
35
+ self.model = self.model.to(device)
36
+ return self
37
+
38
+ @abstractmethod
39
+ def from_pretrained(self):
40
+ """
41
+ This abstract method defines how the preprocessor loads pre-trained
42
+ weights or configurations specific to the model it supports. Subclasses
43
+ must implement this method to handle model-specific loading logic.
44
+
45
+ This method might download pre-trained weights from a repository or
46
+ load them from a local file depending on the model's requirements.
47
+ """
48
+ pass
49
+
50
+ @abstractmethod
51
+ def __call__(
52
+ self,
53
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
54
+ resolution_scale: float = 1.0,
55
+ invert: bool = True,
56
+ return_type: str = "pil",
57
+ ):
58
+ """
59
+ Preprocesses an image for use with the underlying model.
60
+
61
+ Args:
62
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): Input image as PIL Image,
63
+ NumPy array, or PyTorch tensor format.
64
+ resolution_scale (`float`, optional, defaults to 1.0): Scale factor for image resolution during
65
+ resolution_scale (`float`, *optional*, defaults to 1.0): Scale factor for image resolution during
66
+ preprocessing and post-processing. Defaults to 1.0 for no scaling.
67
+ invert (`bool`, *optional*, defaults to True): Inverts the generated image if True.
68
+ return_type (`str`, *optional*, defaults to "pil"): The desired return type, either "pt" for PyTorch tensor,
69
+ "np" for NumPy array, or "pil" for PIL image.
70
+
71
+ Returns:
72
+ `Union[PIL.Image.Image, torch.Tensor]`: The preprocessed image in the
73
+ specified format.
74
+ """
75
+ pass
image_gen_aux/preprocessors/teed/LICENSE.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Xavier Soria Poma
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
image_gen_aux/preprocessors/teed/README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TEED: Tiny and Efficient Model for the Edge Detection Generalization
2
+
3
+ Tiny and Efficient Edge Detector (TEED) is a light convolutional neural network with only 58K parameters, less than 0.2% of the state-of-the-art models. Training on the [BIPED](https://www.kaggle.com/datasets/xavysp/biped) dataset takes less than 30 minutes,
4
+ with each epoch requiring less than 5 minutes. Our proposed model is easy to train and it quickly converges within very first few
5
+ epochs, while the predicted edge-maps are crisp and of high quality, see image above.
6
+ [This paper has been accepted by ICCV 2023-Workshop RCV](https://arxiv.org/abs/2308.06468).
7
+
8
+ ## Usage
9
+
10
+ ```python
11
+ from image_gen_aux import TeedPreprocessor
12
+ from image_gen_aux.utils import load_image
13
+
14
+ input_image = load_image(
15
+ "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/teed/20240922043215.png"
16
+ )
17
+
18
+ teed_preprocessor = TeedPreprocessor.from_pretrained("OzzyGT/teed").to("cuda")
19
+ image = teed_preprocessor(input_image)[0]
20
+ image.save("teed.png")
21
+ ```
22
+
23
+ ## Additional resources
24
+
25
+ * [Project page](https://github.com/xavysp/TEED)
26
+ * [Paper](https://arxiv.org/abs/2308.06468)
image_gen_aux/preprocessors/teed/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import IMAGE_AUX_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, is_torch_available
4
+
5
+
6
+ _import_structure = {}
7
+
8
+ try:
9
+ if not (is_torch_available()):
10
+ raise OptionalDependencyNotAvailable()
11
+ except OptionalDependencyNotAvailable:
12
+ ...
13
+ else:
14
+ _import_structure["teed_preprocessor"] = [
15
+ "TeedPreprocessor",
16
+ ]
17
+
18
+ if TYPE_CHECKING or IMAGE_AUX_SLOW_IMPORT:
19
+ try:
20
+ if not is_torch_available():
21
+ raise OptionalDependencyNotAvailable()
22
+ except OptionalDependencyNotAvailable:
23
+ ...
24
+ else:
25
+ from .teed_preprocessor import (
26
+ TeedPreprocessor,
27
+ )
28
+ else:
29
+ import sys
30
+
31
+ sys.modules[__name__] = _LazyModule(
32
+ __name__,
33
+ globals()["__file__"],
34
+ _import_structure,
35
+ module_spec=__spec__,
36
+ )
image_gen_aux/preprocessors/teed/teed.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original from: https://github.com/xavysp/TEED
2
+ # TEED: is a Tiny but Efficient Edge Detection, it comes from the LDC-B3
3
+ # with a Slightly modification
4
+ # LDC parameters:
5
+ # 155665
6
+ # TED > 58K
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from ...modeling_utils import ModelMixin
13
+
14
+
15
+ """
16
+ smish_function and Smish script based on:
17
+ Wang, Xueliang, Honge Ren, and Achuan Wang.
18
+ "Smish: A Novel Activation Function for Deep Learning Methods.
19
+ " Electronics 11.4 (2022): 540.
20
+ smish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + sigmoid(x)))
21
+ """
22
+
23
+
24
+ @torch.jit.script
25
+ def smish_function(input):
26
+ """
27
+ Applies the mish function element-wise:
28
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(sigmoid(x))))
29
+ See additional documentation for mish class.
30
+ """
31
+ return input * torch.tanh(torch.log(1 + torch.sigmoid(input)))
32
+
33
+
34
+ class Smish(nn.Module):
35
+ """
36
+ Applies the mish function element-wise:
37
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
38
+ Shape:
39
+ - Input: (N, *) where * means, any number of additional
40
+ dimensions
41
+ - Output: (N, *), same shape as the input
42
+ Examples:
43
+ >>> m = Mish()
44
+ >>> input = torch.randn(2)
45
+ >>> output = m(input)
46
+ Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
47
+ """
48
+
49
+ def __init__(self):
50
+ """
51
+ Init method.
52
+ """
53
+ super().__init__()
54
+
55
+ def forward(self, input):
56
+ """
57
+ Forward pass of the function.
58
+ """
59
+ return smish_function(input)
60
+
61
+
62
+ def weight_init(m):
63
+ if isinstance(m, (nn.Conv2d,)):
64
+ torch.nn.init.xavier_normal_(m.weight, gain=1.0)
65
+
66
+ if m.bias is not None:
67
+ torch.nn.init.zeros_(m.bias)
68
+
69
+ # for fusion layer
70
+ if isinstance(m, (nn.ConvTranspose2d,)):
71
+ torch.nn.init.xavier_normal_(m.weight, gain=1.0)
72
+ if m.bias is not None:
73
+ torch.nn.init.zeros_(m.bias)
74
+
75
+
76
+ class CoFusion(nn.Module):
77
+ # from LDC
78
+
79
+ def __init__(self, in_ch, out_ch):
80
+ super(CoFusion, self).__init__()
81
+ self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3, stride=1, padding=1) # before 64
82
+ self.conv3 = nn.Conv2d(32, out_ch, kernel_size=3, stride=1, padding=1) # before 64 instead of 32
83
+ self.relu = nn.ReLU()
84
+ self.norm_layer1 = nn.GroupNorm(4, 32) # before 64
85
+
86
+ def forward(self, x):
87
+ # fusecat = torch.cat(x, dim=1)
88
+ attn = self.relu(self.norm_layer1(self.conv1(x)))
89
+ attn = F.softmax(self.conv3(attn), dim=1)
90
+ return ((x * attn).sum(1)).unsqueeze(1)
91
+
92
+
93
+ class CoFusion2(nn.Module):
94
+ # TEDv14-3
95
+ def __init__(self, in_ch, out_ch):
96
+ super(CoFusion2, self).__init__()
97
+ self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3, stride=1, padding=1) # before 64
98
+ self.conv3 = nn.Conv2d(32, out_ch, kernel_size=3, stride=1, padding=1) # before 64 instead of 32
99
+ self.smish = Smish() # nn.ReLU(inplace=True)
100
+
101
+ def forward(self, x):
102
+ attn = self.conv1(self.smish(x))
103
+ attn = self.conv3(self.smish(attn)) # before , )dim=1)
104
+
105
+ return ((x * attn).sum(1)).unsqueeze(1)
106
+
107
+
108
+ class DoubleFusion(nn.Module):
109
+ # TED fusion before the final edge map prediction
110
+ def __init__(self, in_ch, out_ch):
111
+ super(DoubleFusion, self).__init__()
112
+ self.DWconv1 = nn.Conv2d(in_ch, in_ch * 8, kernel_size=3, stride=1, padding=1, groups=in_ch) # before 64
113
+ self.PSconv1 = nn.PixelShuffle(1)
114
+
115
+ self.DWconv2 = nn.Conv2d(24, 24 * 1, kernel_size=3, stride=1, padding=1, groups=24) # before 64 instead of 32
116
+
117
+ self.AF = Smish() # XAF() #nn.Tanh()# XAF() # # Smish()#
118
+
119
+ def forward(self, x):
120
+ attn = self.PSconv1(self.DWconv1(self.AF(x))) # #TEED best res TEDv14 [8, 32, 352, 352]
121
+
122
+ attn2 = self.PSconv1(self.DWconv2(self.AF(attn))) # #TEED best res TEDv14[8, 3, 352, 352]
123
+
124
+ return smish_function(((attn2 + attn).sum(1)).unsqueeze(1)) # TED best res
125
+
126
+
127
+ class _DenseLayer(nn.Sequential):
128
+ def __init__(self, input_features, out_features):
129
+ super(_DenseLayer, self).__init__()
130
+
131
+ (
132
+ self.add_module(
133
+ "conv1",
134
+ nn.Conv2d(
135
+ input_features,
136
+ out_features,
137
+ kernel_size=3,
138
+ stride=1,
139
+ padding=2,
140
+ bias=True,
141
+ ),
142
+ ),
143
+ )
144
+ (self.add_module("smish1", Smish()),)
145
+ self.add_module(
146
+ "conv2",
147
+ nn.Conv2d(out_features, out_features, kernel_size=3, stride=1, bias=True),
148
+ )
149
+
150
+ def forward(self, x):
151
+ x1, x2 = x
152
+
153
+ new_features = super(_DenseLayer, self).forward(smish_function(x1)) # F.relu()
154
+
155
+ return 0.5 * (new_features + x2), x2
156
+
157
+
158
+ class _DenseBlock(nn.Sequential):
159
+ def __init__(self, num_layers, input_features, out_features):
160
+ super(_DenseBlock, self).__init__()
161
+ for i in range(num_layers):
162
+ layer = _DenseLayer(input_features, out_features)
163
+ self.add_module("denselayer%d" % (i + 1), layer)
164
+ input_features = out_features
165
+
166
+
167
+ class UpConvBlock(nn.Module):
168
+ def __init__(self, in_features, up_scale):
169
+ super(UpConvBlock, self).__init__()
170
+ self.up_factor = 2
171
+ self.constant_features = 16
172
+
173
+ layers = self.make_deconv_layers(in_features, up_scale)
174
+ assert layers is not None, layers
175
+ self.features = nn.Sequential(*layers)
176
+
177
+ def make_deconv_layers(self, in_features, up_scale):
178
+ layers = []
179
+ all_pads = [0, 0, 1, 3, 7]
180
+ for i in range(up_scale):
181
+ kernel_size = 2**up_scale
182
+ pad = all_pads[up_scale] # kernel_size-1
183
+ out_features = self.compute_out_features(i, up_scale)
184
+ layers.append(nn.Conv2d(in_features, out_features, 1))
185
+ layers.append(Smish())
186
+ layers.append(nn.ConvTranspose2d(out_features, out_features, kernel_size, stride=2, padding=pad))
187
+ in_features = out_features
188
+ return layers
189
+
190
+ def compute_out_features(self, idx, up_scale):
191
+ return 1 if idx == up_scale - 1 else self.constant_features
192
+
193
+ def forward(self, x):
194
+ return self.features(x)
195
+
196
+
197
+ class SingleConvBlock(nn.Module):
198
+ def __init__(self, in_features, out_features, stride, use_ac=False):
199
+ super(SingleConvBlock, self).__init__()
200
+ self.use_ac = use_ac
201
+ self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride, bias=True)
202
+ if self.use_ac:
203
+ self.smish = Smish()
204
+
205
+ def forward(self, x):
206
+ x = self.conv(x)
207
+ if self.use_ac:
208
+ return self.smish(x)
209
+ else:
210
+ return x
211
+
212
+
213
+ class DoubleConvBlock(nn.Module):
214
+ def __init__(self, in_features, mid_features, out_features=None, stride=1, use_act=True):
215
+ super(DoubleConvBlock, self).__init__()
216
+
217
+ self.use_act = use_act
218
+ if out_features is None:
219
+ out_features = mid_features
220
+ self.conv1 = nn.Conv2d(in_features, mid_features, 3, padding=1, stride=stride)
221
+ self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1)
222
+ self.smish = Smish() # nn.ReLU(inplace=True)
223
+
224
+ def forward(self, x):
225
+ x = self.conv1(x)
226
+ x = self.smish(x)
227
+ x = self.conv2(x)
228
+ if self.use_act:
229
+ x = self.smish(x)
230
+ return x
231
+
232
+
233
+ class TEED(ModelMixin):
234
+ """Definition of Tiny and Efficient Edge Detector
235
+ model
236
+ """
237
+
238
+ def __init__(self):
239
+ super(TEED, self).__init__()
240
+ self.block_1 = DoubleConvBlock(
241
+ 3,
242
+ 16,
243
+ 16,
244
+ stride=2,
245
+ )
246
+ self.block_2 = DoubleConvBlock(16, 32, use_act=False)
247
+ self.dblock_3 = _DenseBlock(1, 32, 48) # [32,48,100,100] before (2, 32, 64)
248
+
249
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
250
+
251
+ # skip1 connection, see fig. 2
252
+ self.side_1 = SingleConvBlock(16, 32, 2)
253
+
254
+ # skip2 connection, see fig. 2
255
+ self.pre_dense_3 = SingleConvBlock(32, 48, 1) # before (32, 64, 1)
256
+
257
+ # USNet
258
+ self.up_block_1 = UpConvBlock(16, 1)
259
+ self.up_block_2 = UpConvBlock(32, 1)
260
+ self.up_block_3 = UpConvBlock(48, 2) # (32, 64, 1)
261
+
262
+ self.block_cat = DoubleFusion(3, 3) # TEED: DoubleFusion
263
+
264
+ self.apply(weight_init)
265
+
266
+ def slice(self, tensor, slice_shape):
267
+ t_shape = tensor.shape
268
+ img_h, img_w = slice_shape
269
+ if img_w != t_shape[-1] or img_h != t_shape[2]:
270
+ new_tensor = F.interpolate(tensor, size=(img_h, img_w), mode="bicubic", align_corners=False)
271
+
272
+ else:
273
+ new_tensor = tensor
274
+ # tensor[..., :height, :width]
275
+ return new_tensor
276
+
277
+ def resize_input(self, tensor):
278
+ t_shape = tensor.shape
279
+ if t_shape[2] % 8 != 0 or t_shape[3] % 8 != 0:
280
+ img_w = ((t_shape[3] // 8) + 1) * 8
281
+ img_h = ((t_shape[2] // 8) + 1) * 8
282
+ new_tensor = F.interpolate(tensor, size=(img_h, img_w), mode="bicubic", align_corners=False)
283
+ else:
284
+ new_tensor = tensor
285
+ return new_tensor
286
+
287
+ def crop_bdcn(data1, h, w, crop_h, crop_w):
288
+ # Based on BDCN Implementation @ https://github.com/pkuCactus/BDCN
289
+ _, _, h1, w1 = data1.size()
290
+ assert h <= h1 and w <= w1
291
+ data = data1[:, :, crop_h : crop_h + h, crop_w : crop_w + w]
292
+ return data
293
+
294
+ def forward(self, x, single_test=False):
295
+ assert x.ndim == 4, x.shape
296
+ # supose the image size is 352x352
297
+
298
+ # Block 1
299
+ block_1 = self.block_1(x) # [8,16,176,176]
300
+ block_1_side = self.side_1(block_1) # 16 [8,32,88,88]
301
+
302
+ # Block 2
303
+ block_2 = self.block_2(block_1) # 32 # [8,32,176,176]
304
+ block_2_down = self.maxpool(block_2) # [8,32,88,88]
305
+ block_2_add = block_2_down + block_1_side # [8,32,88,88]
306
+
307
+ # Block 3
308
+ block_3_pre_dense = self.pre_dense_3(block_2_down) # [8,64,88,88] block 3 L connection
309
+ block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense]) # [8,64,88,88]
310
+
311
+ # upsampling blocks
312
+ out_1 = self.up_block_1(block_1)
313
+ out_2 = self.up_block_2(block_2)
314
+ out_3 = self.up_block_3(block_3)
315
+
316
+ results = [out_1, out_2, out_3]
317
+
318
+ # concatenate multiscale outputs
319
+ block_cat = torch.cat(results, dim=1) # Bx6xHxW
320
+ block_cat = self.block_cat(block_cat) # Bx1xHxW DoubleFusion
321
+
322
+ results.append(block_cat)
323
+ return results
image_gen_aux/preprocessors/teed/teed_preprocessor.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from typing import List, Union
16
+
17
+ import numpy as np
18
+ import PIL.Image
19
+ import torch
20
+ from safetensors.torch import load_file
21
+
22
+ from ...image_processor import ImageMixin
23
+ from ...utils import SAFETENSORS_FILE_EXTENSION, get_model_path
24
+ from ..preprocessor import Preprocessor
25
+ from .teed import TEED
26
+
27
+
28
+ class TeedPreprocessor(Preprocessor, ImageMixin):
29
+ """Preprocessor specifically designed for detecting edges in images.
30
+
31
+ This class inherits from both `Preprocessor` and `ImageMixin`. Please refer to each
32
+ one to get more information.
33
+ """
34
+
35
+ @classmethod
36
+ def from_pretrained(
37
+ cls,
38
+ pretrained_model_or_path: Union[str, os.PathLike],
39
+ filename: str = None,
40
+ subfolder: str = None,
41
+ weights_only: bool = True,
42
+ ) -> TEED:
43
+ model_path = get_model_path(pretrained_model_or_path, filename, subfolder)
44
+
45
+ file_extension = os.path.basename(model_path).split(".")[-1]
46
+ if file_extension == SAFETENSORS_FILE_EXTENSION:
47
+ state_dict = load_file(model_path, device="cpu")
48
+ else:
49
+ state_dict = torch.load(model_path, map_location=torch.device("cpu"), weights_only=weights_only)
50
+
51
+ model = TEED()
52
+ model.load_state_dict(state_dict)
53
+
54
+ return cls(model)
55
+
56
+ @torch.inference_mode
57
+ def __call__(
58
+ self,
59
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image]],
60
+ resolution_scale: float = 1.0,
61
+ invert: bool = False,
62
+ safe_steps: int = 2,
63
+ batch_size: int = 1,
64
+ return_type: str = "pil",
65
+ ):
66
+ """Preprocesses an image and detects the edges using the pre-trained model.
67
+
68
+ Args:
69
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image]]`): Input image as PIL Image,
70
+ NumPy array, PyTorch tensor format or a list of PIL Images
71
+ resolution_scale (`float`, optional, defaults to 1.0): Scale factor for image resolution during
72
+ preprocessing and post-processing. Defaults to 1.0 for no scaling.
73
+ invert (`bool`, *optional*, defaults to True): Inverts the generated image if True (white or black background).
74
+ safe_steps (int, optional):
75
+ Number of safe steps for the TEED model. Defaults to 2.
76
+ batch_size (`int`, *optional*, defaults to 1): The number of images to process in each batch.
77
+ return_type (`str`, *optional*, defaults to "pil"): The desired return type, either "pt" for PyTorch tensor, "np" for NumPy array,
78
+ or "pil" for PIL image.
79
+
80
+ Returns:
81
+ `Union[PIL.Image.Image, np.ndarray, torch.Tensor]`: The generated line art in the
82
+ specified output format.
83
+ """
84
+ if not isinstance(image, torch.Tensor):
85
+ image = self.convert_image_to_tensor(image, normalize=False)
86
+
87
+ image, resolution_scale = self.scale_image(image, resolution_scale)
88
+
89
+ processed_images = []
90
+
91
+ for i in range(0, len(image), batch_size):
92
+ batch = image[i : i + batch_size].to(self.model.device)
93
+
94
+ edges = self.model(batch)
95
+ edges = torch.stack([e[0, 0] for e in edges], dim=2)
96
+ mean_edges = torch.mean(edges, dim=2)
97
+ edge = torch.sigmoid(mean_edges)
98
+
99
+ if safe_steps != 0:
100
+ edge = self.safe_step(edge, safe_steps)
101
+
102
+ if invert:
103
+ edge = 1 - edge
104
+
105
+ processed_images.append(edge.unsqueeze(0).cpu())
106
+ teed = torch.cat(processed_images, dim=0)
107
+
108
+ # add missing channel
109
+ teed = teed.unsqueeze(1)
110
+
111
+ if resolution_scale != 1.0:
112
+ teed, _ = self.scale_image(teed, 1 / resolution_scale)
113
+
114
+ image = self.post_process_image(teed, return_type)
115
+
116
+ return image
117
+
118
+ def safe_step(self, x, step=2):
119
+ y = x.float() * float(step + 1)
120
+ y = y.int().float() / float(step)
121
+ return y
image_gen_aux/upscalers/README.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # UPSCALERS
2
+
3
+ ## Upscale with model
4
+
5
+ Class to upscale images with a safetensor checkpoint. We use [spandrel](https://github.com/chaiNNer-org/spandrel) for loading, and you can see the list of supported models [here](https://github.com/chaiNNer-org/spandrel?tab=readme-ov-file#model-architecture-support).
6
+
7
+ Most of the super-resolution models are provided as `pickle` checkpoints, which are considered unsafe. We promote the use of safetensor checkpoints, and for convenient use, we recommend using the Hugging Face Hub. You can still use a locally downloaded model.
8
+
9
+ ### Space
10
+
11
+ You can test the current super resolution models you can use with this [Hugging Face Space](https://huggingface.co/spaces/OzzyGT/basic_upscaler).
12
+
13
+ ### How to use
14
+
15
+ ```python
16
+ from image_gen_aux import UpscaleWithModel
17
+ from image_gen_aux.utils import load_image
18
+
19
+ original = load_image(
20
+ "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/simple_upscale/hippowaffle_small.png"
21
+ )
22
+
23
+ upscaler = UpscaleWithModel.from_pretrained("Kim2091/UltraSharp").to("cuda")
24
+ image = upscaler(original)
25
+ image.save("upscaled.png")
26
+ ```
27
+
28
+ ### Tiling
29
+
30
+ Tiling can be enabled to use less resources.
31
+
32
+ ```python
33
+ image = upscaler(original, tiling=True, tile_width=768, tile_height=768, overlap=8)
34
+ ```
35
+
36
+ ### Scale
37
+
38
+ The scale is automatically obtained from the model but can be overridden with the `scale` argument:
39
+
40
+ ```python
41
+ image = upscaler(original, scale=2)
42
+ ```
43
+
44
+ ### List of safetensors checkpoints
45
+
46
+ This is the current list of safetensor checkpoints you can use from the hub.
47
+
48
+ |Model|Scale|Repository|Owner|
49
+ |---|---|---|---|
50
+ |UltraSharp|4X|Kim2091/UltraSharp|[Kim2091](https://huggingface.co/Kim2091)|
51
+ |DAT|2X|OzzyGT/DAT_X2|[zhengchen1999](https://github.com/zhengchen1999)|
52
+ |DAT|3X|OzzyGT/DAT_X3|[zhengchen1999](https://github.com/zhengchen1999)|
53
+ |DAT|4X|OzzyGT/DAT_X4|[zhengchen1999](https://github.com/zhengchen1999)|
54
+ |RealPLKSR|4X|Phips/4xNomosWebPhoto_RealPLKSR|[Philip Hofmann](https://huggingface.co/Phips)|
55
+ |DAT-2|4X|Phips/4xRealWebPhoto_v4_dat2|[Philip Hofmann](https://huggingface.co/Phips)|
56
+ |BSRGAN|4X|OzzyGT/4xRemacri|[FoolhardyVEVO](https://openmodeldb.info/users/foolhardy)|
57
+
58
+ If you own the model and want us to change the repository to your name/organization please open an issue.
image_gen_aux/upscalers/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ..utils import IMAGE_AUX_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, is_torch_available
4
+
5
+
6
+ _import_structure = {}
7
+
8
+ try:
9
+ if not (is_torch_available()):
10
+ raise OptionalDependencyNotAvailable()
11
+ except OptionalDependencyNotAvailable:
12
+ ...
13
+ else:
14
+ _import_structure["upscale_with_model"] = [
15
+ "UpscaleWithModel",
16
+ ]
17
+
18
+ if TYPE_CHECKING or IMAGE_AUX_SLOW_IMPORT:
19
+ try:
20
+ if not is_torch_available():
21
+ raise OptionalDependencyNotAvailable()
22
+ except OptionalDependencyNotAvailable:
23
+ ...
24
+ else:
25
+ from .upscale_with_model import (
26
+ UpscaleWithModel,
27
+ )
28
+ else:
29
+ import sys
30
+
31
+ sys.modules[__name__] = _LazyModule(
32
+ __name__,
33
+ globals()["__file__"],
34
+ _import_structure,
35
+ module_spec=__spec__,
36
+ )
image_gen_aux/upscalers/upscale_with_model.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from typing import Union
16
+
17
+ import numpy as np
18
+ import PIL.Image
19
+ import torch
20
+ from spandrel import ImageModelDescriptor, ModelLoader
21
+
22
+ from ..image_processor import ImageMixin
23
+ from ..utils import get_model_path, tiled_upscale
24
+
25
+
26
+ class UpscaleWithModel(ImageMixin):
27
+ r"""
28
+ Upscaler class that uses a pytorch model.
29
+
30
+ Args:
31
+ model ([`ImageModelDescriptor`]):
32
+ Upscaler model, must be supported by spandrel.
33
+ scale (`int`, defaults to the scale of the model):
34
+ The number of times to scale the image, it is recommended to use the model default scale which
35
+ usually is what the model was trained for.
36
+ """
37
+
38
+ def __init__(self, model: ImageModelDescriptor, scale: int = None):
39
+ super().__init__()
40
+ self.model = model
41
+
42
+ def to(self, device):
43
+ self.model.to(device)
44
+ return self
45
+
46
+ @classmethod
47
+ def from_pretrained(
48
+ cls, pretrained_model_or_path: Union[str, os.PathLike], filename: str = None, subfolder: str = None
49
+ ) -> ImageModelDescriptor:
50
+ r"""
51
+ Instantiate the Upscaler class from pretrained weights.
52
+
53
+ Parameters:
54
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
55
+ Can be either:
56
+
57
+ - A string, the *repo id* (for example `OzzyGT/UltraSharp`) of a pretrained model
58
+ hosted on the Hub, must be saved in safetensors. If there's more than one checkpoint
59
+ in the repository and the filename wasn't specified, the first one found will be loaded.
60
+ - A path to a *directory* (for example `./upscaler_model/`) containing a pretrained
61
+ upscaler checkpoint.
62
+ filename (`str`, *optional*):
63
+ The name of the file in the repo.
64
+ subfolder (`str`, *optional*):
65
+ An optional value corresponding to a folder inside the model repo.
66
+ """
67
+ model_path = get_model_path(pretrained_model_or_path, filename, subfolder)
68
+ model = ModelLoader().load_from_file(model_path)
69
+
70
+ # validate that it's the correct model
71
+ assert isinstance(model, ImageModelDescriptor)
72
+
73
+ return cls(model)
74
+
75
+ @torch.inference_mode
76
+ def __call__(
77
+ self,
78
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
79
+ tiling: bool = False,
80
+ tile_width: int = 512,
81
+ tile_height: int = 512,
82
+ overlap: int = 8,
83
+ return_type: str = "pil",
84
+ ) -> Union[torch.Tensor, PIL.Image.Image, np.ndarray]:
85
+ r"""
86
+ Upscales the given image, optionally using tiling.
87
+
88
+ Args:
89
+ image (Union[PIL.Image.Image, np.ndarray, torch.Tensor]):
90
+ The image to be upscaled. Can be a PIL Image, NumPy array, or PyTorch tensor.
91
+ tiling (bool, optional):
92
+ Whether to use tiling for upscaling. Default is False.
93
+ tile_width (int, optional):
94
+ The width of each tile if tiling is used. Default is 512.
95
+ tile_height (int, optional):
96
+ The height of each tile if tiling is used. Default is 512.
97
+ overlap (int, optional):
98
+ The overlap between tiles if tiling is used. Default is 8.
99
+ return_type (str, optional):
100
+ The type of the returned image. Can be 'pil', 'numpy', or 'tensor'. Default is 'pil'.
101
+
102
+ Returns:
103
+ Union[torch.Tensor, PIL.Image.Image, np.ndarray]:
104
+ The upscaled image, in the format specified by `return_type`.
105
+ """
106
+ if not isinstance(image, torch.Tensor):
107
+ image = self.convert_image_to_tensor(image)
108
+
109
+ image = image.to(self.model.device)
110
+
111
+ if tiling:
112
+ upscaled_tensor = tiled_upscale(image, self.model, self.model.scale, tile_width, tile_height, overlap)
113
+ else:
114
+ upscaled_tensor = self.model(image)
115
+
116
+ image = self.post_process_image(upscaled_tensor, return_type)[0]
117
+
118
+ return image
image_gen_aux/utils/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .constants import SAFETENSORS_FILE_EXTENSION
16
+ from .import_utils import (
17
+ IMAGE_AUX_SLOW_IMPORT,
18
+ OptionalDependencyNotAvailable,
19
+ _LazyModule,
20
+ is_torch_available,
21
+ is_transformers_available,
22
+ )
23
+ from .loading_utils import load_image
24
+ from .model_utils import get_model_path
25
+ from .tiling_utils import create_gradient_mask, tiled_upscale
image_gen_aux/utils/constants.py ADDED
@@ -0,0 +1 @@
 
 
1
+ SAFETENSORS_FILE_EXTENSION = "safetensors"
image_gen_aux/utils/import_utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Import utilities: Utilities related to imports and our lazy inits.
16
+ """
17
+ import importlib.metadata as importlib_metadata
18
+ import importlib.util
19
+ import os
20
+ from itertools import chain
21
+ from types import ModuleType
22
+ from typing import Any
23
+
24
+ from . import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
30
+ ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
31
+
32
+ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
33
+ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
34
+ IMAGE_AUX_SLOW_IMPORT = os.environ.get("IMAGE_AUX_SLOW_IMPORT", "FALSE").upper()
35
+ IMAGE_AUX_SLOW_IMPORT = IMAGE_AUX_SLOW_IMPORT in ENV_VARS_TRUE_VALUES
36
+
37
+
38
+ _torch_version = "N/A"
39
+ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
40
+ _torch_available = importlib.util.find_spec("torch") is not None
41
+ if _torch_available:
42
+ try:
43
+ _torch_version = importlib_metadata.version("torch")
44
+ logger.info(f"PyTorch version {_torch_version} available.")
45
+ except importlib_metadata.PackageNotFoundError:
46
+ _torch_available = False
47
+ else:
48
+ logger.info("Disabling PyTorch because USE_TORCH is set")
49
+ _torch_available = False
50
+
51
+ _transformers_available = importlib.util.find_spec("transformers") is not None
52
+ try:
53
+ _transformers_version = importlib_metadata.version("transformers")
54
+ logger.debug(f"Successfully imported transformers version {_transformers_version}")
55
+ except importlib_metadata.PackageNotFoundError:
56
+ _transformers_available = False
57
+
58
+
59
+ def is_torch_available():
60
+ return _torch_available
61
+
62
+
63
+ def is_transformers_available():
64
+ return _transformers_available
65
+
66
+
67
+ class OptionalDependencyNotAvailable(BaseException):
68
+ """An error indicating that an optional dependency of Diffusers was not found in the environment."""
69
+
70
+
71
+ class _LazyModule(ModuleType):
72
+ """
73
+ Module class that surfaces all objects but only performs associated imports when the objects are requested.
74
+ """
75
+
76
+ # Very heavily inspired by optuna.integration._IntegrationModule
77
+ # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
78
+ def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
79
+ super().__init__(name)
80
+ self._modules = set(import_structure.keys())
81
+ self._class_to_module = {}
82
+ for key, values in import_structure.items():
83
+ for value in values:
84
+ self._class_to_module[value] = key
85
+ # Needed for autocompletion in an IDE
86
+ self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
87
+ self.__file__ = module_file
88
+ self.__spec__ = module_spec
89
+ self.__path__ = [os.path.dirname(module_file)]
90
+ self._objects = {} if extra_objects is None else extra_objects
91
+ self._name = name
92
+ self._import_structure = import_structure
93
+
94
+ # Needed for autocompletion in an IDE
95
+ def __dir__(self):
96
+ result = super().__dir__()
97
+ # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
98
+ # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
99
+ for attr in self.__all__:
100
+ if attr not in result:
101
+ result.append(attr)
102
+ return result
103
+
104
+ def __getattr__(self, name: str) -> Any:
105
+ if name in self._objects:
106
+ return self._objects[name]
107
+ if name in self._modules:
108
+ value = self._get_module(name)
109
+ elif name in self._class_to_module.keys():
110
+ module = self._get_module(self._class_to_module[name])
111
+ value = getattr(module, name)
112
+ else:
113
+ raise AttributeError(f"module {self.__name__} has no attribute {name}")
114
+
115
+ setattr(self, name, value)
116
+ return value
117
+
118
+ def _get_module(self, module_name: str):
119
+ try:
120
+ return importlib.import_module("." + module_name, self.__name__)
121
+ except Exception as e:
122
+ raise RuntimeError(
123
+ f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
124
+ f" traceback):\n{e}"
125
+ ) from e
126
+
127
+ def __reduce__(self):
128
+ return (self.__class__, (self._name, self.__file__, self._import_structure))
image_gen_aux/utils/loading_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from typing import Callable, Union
17
+
18
+ import PIL.Image
19
+ import PIL.ImageOps
20
+ import requests
21
+
22
+
23
+ def load_image(
24
+ image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None
25
+ ) -> PIL.Image.Image:
26
+ """
27
+ Loads `image` to a PIL Image.
28
+
29
+ Args:
30
+ image (`str` or `PIL.Image.Image`):
31
+ The image to convert to the PIL Image format.
32
+ convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional):
33
+ A conversion method to apply to the image after loading it. When set to `None` the image will be converted
34
+ "RGB".
35
+
36
+ Returns:
37
+ `PIL.Image.Image`:
38
+ A PIL Image.
39
+ """
40
+ if isinstance(image, str):
41
+ if image.startswith("http://") or image.startswith("https://"):
42
+ image = PIL.Image.open(requests.get(image, stream=True).raw)
43
+ elif os.path.isfile(image):
44
+ image = PIL.Image.open(image)
45
+ else:
46
+ raise ValueError(
47
+ f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path."
48
+ )
49
+ elif isinstance(image, PIL.Image.Image):
50
+ image = image
51
+ else:
52
+ raise ValueError(
53
+ "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image."
54
+ )
55
+
56
+ image = PIL.ImageOps.exif_transpose(image)
57
+
58
+ if convert_method is not None:
59
+ image = convert_method(image)
60
+ else:
61
+ image = image.convert("RGB")
62
+
63
+ return image
image_gen_aux/utils/logging.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Optuna, Hugging Face
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Logging utilities."""
16
+
17
+ import logging
18
+ import os
19
+ import sys
20
+ import threading
21
+ from logging import (
22
+ CRITICAL, # NOQA
23
+ DEBUG, # NOQA
24
+ ERROR, # NOQA
25
+ FATAL, # NOQA
26
+ INFO, # NOQA
27
+ NOTSET, # NOQA
28
+ WARN, # NOQA
29
+ WARNING, # NOQA
30
+ )
31
+ from typing import Dict, Optional
32
+
33
+ from tqdm import auto as tqdm_lib
34
+
35
+
36
+ _lock = threading.Lock()
37
+ _default_handler: Optional[logging.Handler] = None
38
+
39
+ log_levels = {
40
+ "debug": logging.DEBUG,
41
+ "info": logging.INFO,
42
+ "warning": logging.WARNING,
43
+ "error": logging.ERROR,
44
+ "critical": logging.CRITICAL,
45
+ }
46
+
47
+ _default_log_level = logging.WARNING
48
+
49
+ _tqdm_active = True
50
+
51
+
52
+ def _get_default_logging_level() -> int:
53
+ """
54
+ If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
55
+ not - fall back to `_default_log_level`
56
+ """
57
+ env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None)
58
+ if env_level_str:
59
+ if env_level_str in log_levels:
60
+ return log_levels[env_level_str]
61
+ else:
62
+ logging.getLogger().warning(
63
+ f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
64
+ f"has to be one of: { ', '.join(log_levels.keys()) }"
65
+ )
66
+ return _default_log_level
67
+
68
+
69
+ def _get_library_name() -> str:
70
+ return __name__.split(".")[0]
71
+
72
+
73
+ def _get_library_root_logger() -> logging.Logger:
74
+ return logging.getLogger(_get_library_name())
75
+
76
+
77
+ def _configure_library_root_logger() -> None:
78
+ global _default_handler
79
+
80
+ with _lock:
81
+ if _default_handler:
82
+ # This library has already configured the library root logger.
83
+ return
84
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
85
+
86
+ if sys.stderr: # only if sys.stderr exists, e.g. when not using pythonw in windows
87
+ _default_handler.flush = sys.stderr.flush
88
+
89
+ # Apply our default configuration to the library root logger.
90
+ library_root_logger = _get_library_root_logger()
91
+ library_root_logger.addHandler(_default_handler)
92
+ library_root_logger.setLevel(_get_default_logging_level())
93
+ library_root_logger.propagate = False
94
+
95
+
96
+ def _reset_library_root_logger() -> None:
97
+ global _default_handler
98
+
99
+ with _lock:
100
+ if not _default_handler:
101
+ return
102
+
103
+ library_root_logger = _get_library_root_logger()
104
+ library_root_logger.removeHandler(_default_handler)
105
+ library_root_logger.setLevel(logging.NOTSET)
106
+ _default_handler = None
107
+
108
+
109
+ def get_log_levels_dict() -> Dict[str, int]:
110
+ return log_levels
111
+
112
+
113
+ def get_logger(name: Optional[str] = None) -> logging.Logger:
114
+ """
115
+ Return a logger with the specified name.
116
+
117
+ This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
118
+ """
119
+
120
+ if name is None:
121
+ name = _get_library_name()
122
+
123
+ _configure_library_root_logger()
124
+ return logging.getLogger(name)
125
+
126
+
127
+ def get_verbosity() -> int:
128
+ """
129
+ Return the current level for the 🤗 Diffusers' root logger as an `int`.
130
+
131
+ Returns:
132
+ `int`:
133
+ Logging level integers which can be one of:
134
+
135
+ - `50`: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
136
+ - `40`: `diffusers.logging.ERROR`
137
+ - `30`: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
138
+ - `20`: `diffusers.logging.INFO`
139
+ - `10`: `diffusers.logging.DEBUG`
140
+
141
+ """
142
+
143
+ _configure_library_root_logger()
144
+ return _get_library_root_logger().getEffectiveLevel()
145
+
146
+
147
+ def set_verbosity(verbosity: int) -> None:
148
+ """
149
+ Set the verbosity level for the 🤗 Diffusers' root logger.
150
+
151
+ Args:
152
+ verbosity (`int`):
153
+ Logging level which can be one of:
154
+
155
+ - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
156
+ - `diffusers.logging.ERROR`
157
+ - `diffusers.logging.WARNING` or `diffusers.logging.WARN`
158
+ - `diffusers.logging.INFO`
159
+ - `diffusers.logging.DEBUG`
160
+ """
161
+
162
+ _configure_library_root_logger()
163
+ _get_library_root_logger().setLevel(verbosity)
164
+
165
+
166
+ def set_verbosity_info() -> None:
167
+ """Set the verbosity to the `INFO` level."""
168
+ return set_verbosity(INFO)
169
+
170
+
171
+ def set_verbosity_warning() -> None:
172
+ """Set the verbosity to the `WARNING` level."""
173
+ return set_verbosity(WARNING)
174
+
175
+
176
+ def set_verbosity_debug() -> None:
177
+ """Set the verbosity to the `DEBUG` level."""
178
+ return set_verbosity(DEBUG)
179
+
180
+
181
+ def set_verbosity_error() -> None:
182
+ """Set the verbosity to the `ERROR` level."""
183
+ return set_verbosity(ERROR)
184
+
185
+
186
+ def disable_default_handler() -> None:
187
+ """Disable the default handler of the 🤗 Diffusers' root logger."""
188
+
189
+ _configure_library_root_logger()
190
+
191
+ assert _default_handler is not None
192
+ _get_library_root_logger().removeHandler(_default_handler)
193
+
194
+
195
+ def enable_default_handler() -> None:
196
+ """Enable the default handler of the 🤗 Diffusers' root logger."""
197
+
198
+ _configure_library_root_logger()
199
+
200
+ assert _default_handler is not None
201
+ _get_library_root_logger().addHandler(_default_handler)
202
+
203
+
204
+ def add_handler(handler: logging.Handler) -> None:
205
+ """adds a handler to the HuggingFace Diffusers' root logger."""
206
+
207
+ _configure_library_root_logger()
208
+
209
+ assert handler is not None
210
+ _get_library_root_logger().addHandler(handler)
211
+
212
+
213
+ def remove_handler(handler: logging.Handler) -> None:
214
+ """removes given handler from the HuggingFace Diffusers' root logger."""
215
+
216
+ _configure_library_root_logger()
217
+
218
+ assert handler is not None and handler in _get_library_root_logger().handlers
219
+ _get_library_root_logger().removeHandler(handler)
220
+
221
+
222
+ def disable_propagation() -> None:
223
+ """
224
+ Disable propagation of the library log outputs. Note that log propagation is disabled by default.
225
+ """
226
+
227
+ _configure_library_root_logger()
228
+ _get_library_root_logger().propagate = False
229
+
230
+
231
+ def enable_propagation() -> None:
232
+ """
233
+ Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
234
+ double logging if the root logger has been configured.
235
+ """
236
+
237
+ _configure_library_root_logger()
238
+ _get_library_root_logger().propagate = True
239
+
240
+
241
+ def enable_explicit_format() -> None:
242
+ """
243
+ Enable explicit formatting for every 🤗 Diffusers' logger. The explicit formatter is as follows:
244
+ ```
245
+ [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
246
+ ```
247
+ All handlers currently bound to the root logger are affected by this method.
248
+ """
249
+ handlers = _get_library_root_logger().handlers
250
+
251
+ for handler in handlers:
252
+ formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
253
+ handler.setFormatter(formatter)
254
+
255
+
256
+ def reset_format() -> None:
257
+ """
258
+ Resets the formatting for 🤗 Diffusers' loggers.
259
+
260
+ All handlers currently bound to the root logger are affected by this method.
261
+ """
262
+ handlers = _get_library_root_logger().handlers
263
+
264
+ for handler in handlers:
265
+ handler.setFormatter(None)
266
+
267
+
268
+ def warning_advice(self, *args, **kwargs) -> None:
269
+ """
270
+ This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
271
+ warning will not be printed
272
+ """
273
+ no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False)
274
+ if no_advisory_warnings:
275
+ return
276
+ self.warning(*args, **kwargs)
277
+
278
+
279
+ logging.Logger.warning_advice = warning_advice
280
+
281
+
282
+ class EmptyTqdm:
283
+ """Dummy tqdm which doesn't do anything."""
284
+
285
+ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
286
+ self._iterator = args[0] if args else None
287
+
288
+ def __iter__(self):
289
+ return iter(self._iterator)
290
+
291
+ def __getattr__(self, _):
292
+ """Return empty function."""
293
+
294
+ def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
295
+ return
296
+
297
+ return empty_fn
298
+
299
+ def __enter__(self):
300
+ return self
301
+
302
+ def __exit__(self, type_, value, traceback):
303
+ return
304
+
305
+
306
+ class _tqdm_cls:
307
+ def __call__(self, *args, **kwargs):
308
+ if _tqdm_active:
309
+ return tqdm_lib.tqdm(*args, **kwargs)
310
+ else:
311
+ return EmptyTqdm(*args, **kwargs)
312
+
313
+ def set_lock(self, *args, **kwargs):
314
+ self._lock = None
315
+ if _tqdm_active:
316
+ return tqdm_lib.tqdm.set_lock(*args, **kwargs)
317
+
318
+ def get_lock(self):
319
+ if _tqdm_active:
320
+ return tqdm_lib.tqdm.get_lock()
321
+
322
+
323
+ tqdm = _tqdm_cls()
324
+
325
+
326
+ def is_progress_bar_enabled() -> bool:
327
+ """Return a boolean indicating whether tqdm progress bars are enabled."""
328
+ global _tqdm_active
329
+ return bool(_tqdm_active)
330
+
331
+
332
+ def enable_progress_bar() -> None:
333
+ """Enable tqdm progress bar."""
334
+ global _tqdm_active
335
+ _tqdm_active = True
336
+
337
+
338
+ def disable_progress_bar() -> None:
339
+ """Disable tqdm progress bar."""
340
+ global _tqdm_active
341
+ _tqdm_active = False
image_gen_aux/utils/model_utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from huggingface_hub import hf_hub_download, model_info
4
+
5
+
6
+ def get_model_path(pretrained_model_or_path, filename=None, subfolder=None):
7
+ """
8
+ Retrieves the path to the model file.
9
+
10
+ If `pretrained_model_or_path` is a file, it returns the path directly.
11
+ Otherwise, it attempts to find a `.safetensors` file associated with the given model path.
12
+ If no `.safetensors` file is found, it raises a `FileNotFoundError`.
13
+
14
+ Parameters:
15
+ - pretrained_model_or_path (str): Path to the pretrained model or directory containing the model.
16
+ - filename (str, optional): Specific filename to load. If not provided, the function will search for a `.safetensors` file.
17
+ - subfolder (str, optional): Subfolder within the model directory to look for the file.
18
+
19
+ Returns:
20
+ - str: Path to the model file.
21
+
22
+ Raises:
23
+ - FileNotFoundError: If no `.safetensors` file is found when `filename` is not provided.
24
+ """
25
+ if os.path.isfile(pretrained_model_or_path):
26
+ return pretrained_model_or_path
27
+
28
+ if filename is None:
29
+ # If the filename is not passed, we only try to load a safetensor
30
+ info = model_info(pretrained_model_or_path)
31
+ filename = next(
32
+ (sibling.rfilename for sibling in info.siblings if sibling.rfilename.endswith(".safetensors")), None
33
+ )
34
+ if filename is None:
35
+ raise FileNotFoundError("No safetensors checkpoint found.")
36
+
37
+ return hf_hub_download(pretrained_model_or_path, filename, subfolder=subfolder)
image_gen_aux/utils/tiling_utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Tuple
2
+
3
+ import torch
4
+
5
+
6
+ def create_gradient_mask(shape: Tuple, feather: int, device="cpu") -> torch.Tensor:
7
+ """
8
+ Create a gradient mask for smooth blending of tiles.
9
+
10
+ Args:
11
+ shape (tuple): Shape of the mask (batch, channels, height, width)
12
+ feather (int): Width of the feathered edge
13
+
14
+ Returns:
15
+ torch.Tensor: Gradient mask
16
+ """
17
+ mask = torch.ones(shape).to(device)
18
+ _, _, h, w = shape
19
+ for feather_step in range(feather):
20
+ factor = (feather_step + 1) / feather
21
+ mask[:, :, feather_step, :] *= factor
22
+ mask[:, :, h - 1 - feather_step, :] *= factor
23
+ mask[:, :, :, feather_step] *= factor
24
+ mask[:, :, :, w - 1 - feather_step] *= factor
25
+ return mask
26
+
27
+
28
+ def tiled_upscale(
29
+ samples: torch.Tensor,
30
+ function: Callable,
31
+ scale: int,
32
+ tile_width: int = 512,
33
+ tile_height: int = 512,
34
+ overlap: int = 8,
35
+ ) -> torch.Tensor:
36
+ """
37
+ Apply a scaling function to image samples in a tiled manner.
38
+
39
+ Args:
40
+ samples (torch.Tensor): Input tensor of shape (batch_size, channels, height, width)
41
+ function (Callable): The scaling function to apply to each tile
42
+ scale (int): Factor by which to upscale the image
43
+ tile_width (int): Width of each tile
44
+ tile_height (int): Height of each tile
45
+ overlap (int): Overlap between tiles
46
+
47
+ Returns:
48
+ torch.Tensor: Upscaled and processed output tensor
49
+ """
50
+ _batch, _channels, height, width = samples.shape
51
+ out_height, out_width = round(height * scale), round(width * scale)
52
+ output_device = samples.device
53
+
54
+ # Initialize output tensors
55
+ output = torch.empty((1, 3, out_height, out_width), device=output_device)
56
+ out = torch.zeros((1, 3, out_height, out_width), device=output_device)
57
+ out_div = torch.zeros_like(output)
58
+
59
+ # Process the image in tiles
60
+ for y in range(0, height, tile_height - overlap):
61
+ for x in range(0, width, tile_width - overlap):
62
+ # Ensure we don't go out of bounds
63
+ x_end = min(x + tile_width, width)
64
+ y_end = min(y + tile_height, height)
65
+ x = max(0, x_end - tile_width)
66
+ y = max(0, y_end - tile_height)
67
+
68
+ # Extract and process the tile
69
+ tile = samples[:, :, y:y_end, x:x_end]
70
+ processed_tile = function(tile).to(output_device)
71
+
72
+ # Calculate the position in the output tensor
73
+ out_y, out_x = round(y * scale), round(x * scale)
74
+ out_h, out_w = processed_tile.shape[2:]
75
+
76
+ # Create a feathered mask for smooth blending
77
+ mask = create_gradient_mask(processed_tile.shape, overlap * scale, device=output_device)
78
+
79
+ # Add the processed tile to the output
80
+ out[:, :, out_y : out_y + out_h, out_x : out_x + out_w] += processed_tile * mask
81
+ out_div[:, :, out_y : out_y + out_h, out_x : out_x + out_w] += mask
82
+
83
+ # Normalize the output
84
+ output = out / out_div
85
+
86
+ return output
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  einops
2
- diffusers==0.31.0
3
  transformers
4
  accelerate
5
  gradio
@@ -8,4 +7,5 @@ spaces
8
  pillow
9
  peft
10
  openai
11
- torch
 
 
1
  einops
 
2
  transformers
3
  accelerate
4
  gradio
 
7
  pillow
8
  peft
9
  openai
10
+ torch
11
+ git+https://github.com/huggingface/diffusers