Spaces:
Paused
Paused
uuu
Browse files- .idea/vcs.xml +1 -1
- app.py +5 -2
- image_gen_aux/__init__.py +81 -0
- image_gen_aux/image_processor.py +233 -0
- image_gen_aux/modeling_utils.py +90 -0
- image_gen_aux/preprocessors/README.md +54 -0
- image_gen_aux/preprocessors/__init__.py +66 -0
- image_gen_aux/preprocessors/depth/README.md +45 -0
- image_gen_aux/preprocessors/depth/__init__.py +42 -0
- image_gen_aux/preprocessors/depth/depth_preprocessor.py +70 -0
- image_gen_aux/preprocessors/lineart/LICENSE.txt +21 -0
- image_gen_aux/preprocessors/lineart/README.md +27 -0
- image_gen_aux/preprocessors/lineart/__init__.py +36 -0
- image_gen_aux/preprocessors/lineart/lineart_preprocessor.py +101 -0
- image_gen_aux/preprocessors/lineart/model.py +87 -0
- image_gen_aux/preprocessors/lineart_standard/README.md +23 -0
- image_gen_aux/preprocessors/lineart_standard/__init__.py +36 -0
- image_gen_aux/preprocessors/lineart_standard/lineart_standard_preprocessor.py +82 -0
- image_gen_aux/preprocessors/preprocessor.py +75 -0
- image_gen_aux/preprocessors/teed/LICENSE.txt +21 -0
- image_gen_aux/preprocessors/teed/README.md +26 -0
- image_gen_aux/preprocessors/teed/__init__.py +36 -0
- image_gen_aux/preprocessors/teed/teed.py +323 -0
- image_gen_aux/preprocessors/teed/teed_preprocessor.py +121 -0
- image_gen_aux/upscalers/README.md +58 -0
- image_gen_aux/upscalers/__init__.py +36 -0
- image_gen_aux/upscalers/upscale_with_model.py +118 -0
- image_gen_aux/utils/__init__.py +25 -0
- image_gen_aux/utils/constants.py +1 -0
- image_gen_aux/utils/import_utils.py +128 -0
- image_gen_aux/utils/loading_utils.py +63 -0
- image_gen_aux/utils/logging.py +341 -0
- image_gen_aux/utils/model_utils.py +37 -0
- image_gen_aux/utils/tiling_utils.py +86 -0
- requirements.txt +2 -2
.idea/vcs.xml
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
<project version="4">
|
3 |
<component name="VcsDirectoryMappings">
|
4 |
<mapping directory="" vcs="Git" />
|
5 |
-
<mapping directory="$PROJECT_DIR$/
|
6 |
<mapping directory="$PROJECT_DIR$/test_gradio" vcs="Git" />
|
7 |
</component>
|
8 |
</project>
|
|
|
2 |
<project version="4">
|
3 |
<component name="VcsDirectoryMappings">
|
4 |
<mapping directory="" vcs="Git" />
|
5 |
+
<mapping directory="$PROJECT_DIR$/_image_gen_aux" vcs="Git" />
|
6 |
<mapping directory="$PROJECT_DIR$/test_gradio" vcs="Git" />
|
7 |
</component>
|
8 |
</project>
|
app.py
CHANGED
@@ -11,6 +11,7 @@ from huggingface_hub import login
|
|
11 |
import torch
|
12 |
from diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel
|
13 |
from diffusers.utils import load_image
|
|
|
14 |
|
15 |
# ----------------------------
|
16 |
# Step 1: Download IP Adapter if not exists
|
@@ -70,9 +71,11 @@ pipe.init_ipadapter(
|
|
70 |
@spaces.GPU
|
71 |
def gui_generation(prompt,negative_prompt, ref_img, guidance_scale, ipadapter_scale):
|
72 |
ref_img = load_image(ref_img.name).convert('RGB')
|
|
|
|
|
|
|
|
|
73 |
|
74 |
-
control_image = load_image(
|
75 |
-
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_lcm_depth.png")
|
76 |
generator = torch.Generator(device="cpu").manual_seed(0)
|
77 |
image = pipe(
|
78 |
width=1024,
|
|
|
11 |
import torch
|
12 |
from diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel
|
13 |
from diffusers.utils import load_image
|
14 |
+
from image_gen_aux import DepthPreprocessor
|
15 |
|
16 |
# ----------------------------
|
17 |
# Step 1: Download IP Adapter if not exists
|
|
|
71 |
@spaces.GPU
|
72 |
def gui_generation(prompt,negative_prompt, ref_img, guidance_scale, ipadapter_scale):
|
73 |
ref_img = load_image(ref_img.name).convert('RGB')
|
74 |
+
image = load_image(ref_img.name)
|
75 |
+
|
76 |
+
depth_preprocessor = DepthPreprocessor.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf").to("cuda")
|
77 |
+
control_image = depth_preprocessor(image, invert=True)[0].convert("RGB")
|
78 |
|
|
|
|
|
79 |
generator = torch.Generator(device="cpu").manual_seed(0)
|
80 |
image = pipe(
|
81 |
width=1024,
|
image_gen_aux/__init__.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import TYPE_CHECKING
|
16 |
+
|
17 |
+
from .utils import (
|
18 |
+
IMAGE_AUX_SLOW_IMPORT,
|
19 |
+
OptionalDependencyNotAvailable,
|
20 |
+
_LazyModule,
|
21 |
+
is_torch_available,
|
22 |
+
is_transformers_available,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
# Lazy Import based on
|
27 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py
|
28 |
+
|
29 |
+
# When adding a new object to this init, please add it to `_import_structure`. The `_import_structure` is a dictionary submodule to list of object names,
|
30 |
+
# and is used to defer the actual importing for when the objects are requested.
|
31 |
+
# This way `import image_gen_aux` provides the names in the namespace without actually importing anything (and especially none of the backends).
|
32 |
+
|
33 |
+
_import_structure = {
|
34 |
+
"upscalers": [],
|
35 |
+
"preprocessors": ["LineArtStandardPreprocessor"],
|
36 |
+
"utils": [
|
37 |
+
"OptionalDependencyNotAvailable",
|
38 |
+
"is_torch_available",
|
39 |
+
"logging",
|
40 |
+
],
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
try:
|
45 |
+
if not is_torch_available():
|
46 |
+
raise OptionalDependencyNotAvailable()
|
47 |
+
except OptionalDependencyNotAvailable:
|
48 |
+
...
|
49 |
+
else:
|
50 |
+
_import_structure["upscalers"].extend(["UpscaleWithModel"])
|
51 |
+
|
52 |
+
_import_structure["preprocessors"].extend(["DepthPreprocessor", "LineArtPreprocessor", "TeedPreprocessor"])
|
53 |
+
|
54 |
+
if TYPE_CHECKING or IMAGE_AUX_SLOW_IMPORT:
|
55 |
+
from .preprocessors import LineArtStandardPreprocessor
|
56 |
+
|
57 |
+
try:
|
58 |
+
if not is_torch_available():
|
59 |
+
raise OptionalDependencyNotAvailable()
|
60 |
+
except OptionalDependencyNotAvailable:
|
61 |
+
...
|
62 |
+
else:
|
63 |
+
from .preprocessors import LineArtPreprocessor, TeedPreprocessor
|
64 |
+
from .upscalers import UpscaleWithModel
|
65 |
+
|
66 |
+
try:
|
67 |
+
if not (is_torch_available() and is_transformers_available()):
|
68 |
+
raise OptionalDependencyNotAvailable()
|
69 |
+
except OptionalDependencyNotAvailable:
|
70 |
+
...
|
71 |
+
else:
|
72 |
+
from .preprocessors import DepthPreprocessor
|
73 |
+
else:
|
74 |
+
import sys
|
75 |
+
|
76 |
+
sys.modules[__name__] = _LazyModule(
|
77 |
+
__name__,
|
78 |
+
globals()["__file__"],
|
79 |
+
_import_structure,
|
80 |
+
module_spec=__spec__,
|
81 |
+
)
|
image_gen_aux/image_processor.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import List, Union
|
16 |
+
|
17 |
+
import cv2
|
18 |
+
import numpy as np
|
19 |
+
import PIL.Image
|
20 |
+
import torch
|
21 |
+
from PIL import Image
|
22 |
+
|
23 |
+
|
24 |
+
class ImageMixin:
|
25 |
+
"""
|
26 |
+
A mixin class for converting images between different formats: PIL, NumPy, and PyTorch tensors.
|
27 |
+
|
28 |
+
This class provides methods to:
|
29 |
+
- Convert a PIL image or a NumPy array to a PyTorch tensor.
|
30 |
+
- Post-process a PyTorch tensor image and convert it to the specified return type.
|
31 |
+
- Convert a PIL image or a list of PIL images to NumPy arrays.
|
32 |
+
- Convert a NumPy image to a PyTorch tensor.
|
33 |
+
- Convert a PyTorch tensor to a NumPy image.
|
34 |
+
- Convert a NumPy image or a batch of images to a PIL image.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def convert_image_to_tensor(
|
38 |
+
self, image: Union[PIL.Image.Image, np.ndarray, List[PIL.Image.Image]], normalize: bool = True
|
39 |
+
) -> torch.Tensor:
|
40 |
+
"""
|
41 |
+
Convert a PIL image or a NumPy array to a PyTorch tensor.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
image (Union[PIL.Image.Image, np.ndarray]): The input image, either as a PIL image, a NumPy array or a list of
|
45 |
+
PIL images.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
torch.Tensor: The converted image as a PyTorch tensor.
|
49 |
+
"""
|
50 |
+
if isinstance(image, (PIL.Image.Image, list)):
|
51 |
+
# We expect that if it is a list, it only should contain pillow images
|
52 |
+
if isinstance(image, list):
|
53 |
+
for single_image in image:
|
54 |
+
if not isinstance(single_image, PIL.Image.Image):
|
55 |
+
raise ValueError("All images in the list must be Pillow images.")
|
56 |
+
|
57 |
+
image = self.pil_to_numpy(image, normalize)
|
58 |
+
|
59 |
+
return self.numpy_to_pt(image)
|
60 |
+
|
61 |
+
def post_process_image(self, image: torch.Tensor, return_type: str):
|
62 |
+
"""
|
63 |
+
Post-process a PyTorch tensor image and convert it to the specified return type.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
image (torch.Tensor): The input image as a PyTorch tensor.
|
67 |
+
return_type (str): The desired return type, either "pt" for PyTorch tensor, "np" for NumPy array, or "pil" for PIL image.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
Union[torch.Tensor, np.ndarray, List[PIL.Image.Image]]: The post-processed image in the specified return type.
|
71 |
+
"""
|
72 |
+
if return_type == "pt":
|
73 |
+
return image
|
74 |
+
|
75 |
+
image = self.pt_to_numpy(image)
|
76 |
+
if return_type == "np":
|
77 |
+
return image
|
78 |
+
|
79 |
+
image = self.numpy_to_pil(image)
|
80 |
+
return image
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image], normalize: bool = True) -> np.ndarray:
|
84 |
+
"""
|
85 |
+
Convert a PIL image or a list of PIL images to NumPy arrays.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
images (Union[List[PIL.Image.Image], PIL.Image.Image]): The input image(s) as PIL image(s).
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
np.ndarray: The converted image(s) as a NumPy array.
|
92 |
+
"""
|
93 |
+
if not isinstance(images, list):
|
94 |
+
images = [images]
|
95 |
+
|
96 |
+
if normalize:
|
97 |
+
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
|
98 |
+
else:
|
99 |
+
images = [np.array(image).astype(np.float32) for image in images]
|
100 |
+
|
101 |
+
images = np.stack(images, axis=0)
|
102 |
+
|
103 |
+
return images
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
|
107 |
+
"""
|
108 |
+
Convert a NumPy image to a PyTorch tensor.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
images (np.ndarray): The input image(s) as a NumPy array.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
torch.Tensor: The converted image(s) as a PyTorch tensor.
|
115 |
+
"""
|
116 |
+
if images.ndim == 3:
|
117 |
+
images = images[..., None]
|
118 |
+
images = torch.from_numpy(images.transpose(0, 3, 1, 2)).float()
|
119 |
+
|
120 |
+
return images
|
121 |
+
|
122 |
+
@staticmethod
|
123 |
+
def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
|
124 |
+
"""
|
125 |
+
Convert a PyTorch tensor to a NumPy image.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
images (torch.Tensor): The input image(s) as a PyTorch tensor.
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
np.ndarray: The converted image(s) as a NumPy array.
|
132 |
+
"""
|
133 |
+
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
134 |
+
return images
|
135 |
+
|
136 |
+
@staticmethod
|
137 |
+
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
138 |
+
"""
|
139 |
+
Convert a NumPy image or a batch of images to a PIL image.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
images (np.ndarray): The input image(s) as a NumPy array.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
List[PIL.Image.Image]: The converted image(s) as PIL images.
|
146 |
+
"""
|
147 |
+
if images.ndim == 3:
|
148 |
+
images = images[None, ...]
|
149 |
+
images = (images * 255).round().astype("uint8")
|
150 |
+
if images.shape[-1] == 1:
|
151 |
+
# special case for grayscale (single channel) images
|
152 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
153 |
+
else:
|
154 |
+
pil_images = [Image.fromarray(image) for image in images]
|
155 |
+
|
156 |
+
return pil_images
|
157 |
+
|
158 |
+
@staticmethod
|
159 |
+
def scale_image(image: torch.Tensor, scale: float, mutiple_factor: int = 8) -> torch.Tensor:
|
160 |
+
"""
|
161 |
+
Scales an image while maintaining aspect ratio and ensuring dimensions are multiples of `multiple_factor`.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
image (`torch.Tensor`): The input image tensor of shape (batch, channels, height, width).
|
165 |
+
scale (`float`): The scaling factor applied to the image dimensions.
|
166 |
+
multiple_factor (`int`, *optional*, defaults to 8): The factor by which the new dimensions should be divisible.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
`torch.Tensor`: The scaled image tensor.
|
170 |
+
"""
|
171 |
+
|
172 |
+
if scale == 1.0:
|
173 |
+
return image, scale
|
174 |
+
|
175 |
+
_batch, _channels, height, width = image.shape
|
176 |
+
|
177 |
+
# Calculate new dimensions while maintaining aspect ratio
|
178 |
+
new_height = int(height * scale)
|
179 |
+
new_width = int(width * scale)
|
180 |
+
|
181 |
+
# Ensure new dimensions are multiples of mutiple_factor
|
182 |
+
new_height = (new_height // mutiple_factor) * mutiple_factor
|
183 |
+
new_width = (new_width // mutiple_factor) * mutiple_factor
|
184 |
+
|
185 |
+
# if the final height and widht changed because of the multiple_factor, we need to set the scale too
|
186 |
+
scale = new_height / height
|
187 |
+
|
188 |
+
# Resize the image using the calculated dimensions
|
189 |
+
resized_image = torch.nn.functional.interpolate(
|
190 |
+
image, size=(new_height, new_width), mode="bilinear", align_corners=False
|
191 |
+
)
|
192 |
+
|
193 |
+
return resized_image, scale
|
194 |
+
|
195 |
+
@staticmethod
|
196 |
+
def resize_numpy_image(image: np.ndarray, scale: float, multiple_factor: int = 8) -> np.ndarray:
|
197 |
+
"""
|
198 |
+
Resizes a NumPy image while maintaining aspect ratio and ensuring dimensions are multiples of `multiple_factor`.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
image (`np.ndarray`): The input image array of shape (height, width, channels) or (height, width) for grayscale.
|
202 |
+
scale (`float`): The scaling factor applied to the image dimensions.
|
203 |
+
multiple_factor (`int`, *optional*, defaults to 8): The factor by which the new dimensions should be divisible.
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
`np.ndarray`: The resized image array.
|
207 |
+
"""
|
208 |
+
if len(image.shape) == 3: # Single image without batch dimension
|
209 |
+
image = np.expand_dims(image, axis=0)
|
210 |
+
|
211 |
+
batch_size, height, width, channels = image.shape
|
212 |
+
|
213 |
+
# Calculate new dimensions while maintaining aspect ratio
|
214 |
+
new_height = int(height * scale)
|
215 |
+
new_width = int(width * scale)
|
216 |
+
|
217 |
+
# Ensure new dimensions are multiples of multiple_factor
|
218 |
+
new_height = (new_height // multiple_factor) * multiple_factor
|
219 |
+
new_width = (new_width // multiple_factor) * multiple_factor
|
220 |
+
|
221 |
+
# if the final height and widht changed because of the multiple_factor, we need to set the scale too
|
222 |
+
scale = new_height / height
|
223 |
+
|
224 |
+
# Resize each image in the batch
|
225 |
+
resized_images = []
|
226 |
+
for i in range(batch_size):
|
227 |
+
resized_image = cv2.resize(image[i], (new_width, new_height), interpolation=cv2.INTER_LINEAR)
|
228 |
+
resized_images.append(resized_image)
|
229 |
+
|
230 |
+
# Stack resized images back into a single array
|
231 |
+
resized_images = np.stack(resized_images, axis=0)
|
232 |
+
|
233 |
+
return resized_images, scale
|
image_gen_aux/modeling_utils.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
|
7 |
+
|
8 |
+
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
9 |
+
"""
|
10 |
+
Gets the device of a PyTorch module's parameters or buffers.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
parameter (`torch.nn.Module`): The PyTorch module from which to get the device.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
`torch.device`: The device of the module's parameters or buffers.
|
17 |
+
"""
|
18 |
+
try:
|
19 |
+
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
20 |
+
return next(parameters_and_buffers).device
|
21 |
+
except StopIteration:
|
22 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
23 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
24 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
25 |
+
return tuples
|
26 |
+
|
27 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
28 |
+
first_tuple = next(gen)
|
29 |
+
return first_tuple[1].device
|
30 |
+
|
31 |
+
|
32 |
+
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
33 |
+
"""
|
34 |
+
Gets the data type of a PyTorch module's parameters or buffers.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
parameter (`torch.nn.Module`): The PyTorch module from which to get the data type.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
`torch.dtype`: The data type of the module's parameters or buffers.
|
41 |
+
"""
|
42 |
+
try:
|
43 |
+
params = tuple(parameter.parameters())
|
44 |
+
if len(params) > 0:
|
45 |
+
return params[0].dtype
|
46 |
+
|
47 |
+
buffers = tuple(parameter.buffers())
|
48 |
+
if len(buffers) > 0:
|
49 |
+
return buffers[0].dtype
|
50 |
+
|
51 |
+
except StopIteration:
|
52 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
53 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
54 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
55 |
+
return tuples
|
56 |
+
|
57 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
58 |
+
first_tuple = next(gen)
|
59 |
+
return first_tuple[1].dtype
|
60 |
+
|
61 |
+
|
62 |
+
class ModelMixin(torch.nn.Module):
|
63 |
+
"""
|
64 |
+
Provides convenient properties to access the device and data type
|
65 |
+
of a PyTorch module.
|
66 |
+
|
67 |
+
By inheriting from this class, your custom PyTorch modules can access these properties
|
68 |
+
without manual retrieval of device and data type information.
|
69 |
+
|
70 |
+
These properties assume that all module parameters and buffers reside
|
71 |
+
on the same device and have the same data type, respectively.
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self):
|
75 |
+
super().__init__()
|
76 |
+
|
77 |
+
@property
|
78 |
+
def device(self) -> torch.device:
|
79 |
+
"""
|
80 |
+
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
81 |
+
device).
|
82 |
+
"""
|
83 |
+
return get_parameter_device(self)
|
84 |
+
|
85 |
+
@property
|
86 |
+
def dtype(self) -> torch.dtype:
|
87 |
+
"""
|
88 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
89 |
+
"""
|
90 |
+
return get_parameter_dtype(self)
|
image_gen_aux/preprocessors/README.md
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Preprocessors
|
2 |
+
|
3 |
+
Preprocessors in this context refer to the application of machine learning models to images or groups of images as a preliminary step for other tasks. A common use case is to feed preprocessed images into a controlnet to guide the generation of a diffusion model. Examples of preprocessors include depth maps, normals, and edge detection.
|
4 |
+
|
5 |
+
## Supported preprocessors
|
6 |
+
|
7 |
+
This is a list of the currently supported preprocessors.
|
8 |
+
|
9 |
+
* [DepthPreprocessor](https://github.com/asomoza/image_gen_aux/blob/main/src/image_gen_aux/preprocessors/depth/README.md)
|
10 |
+
* [LineArtPreprocessor](https://github.com/asomoza/image_gen_aux/blob/main/src/image_gen_aux/preprocessors/lineart/README.md)
|
11 |
+
* [LineArtStandardPreprocessor](https://github.com/asomoza/image_gen_aux/blob/main/src/image_gen_aux/preprocessors/lineart_standard/README.md)
|
12 |
+
* [TeedPreprocessor](https://github.com/asomoza/image_gen_aux/blob/main/src/image_gen_aux/preprocessors/teed/README.md)
|
13 |
+
|
14 |
+
## General preprocessor usage
|
15 |
+
|
16 |
+
```python
|
17 |
+
from image_gen_aux import LineArtPreprocessor
|
18 |
+
from image_gen_aux.utils import load_image
|
19 |
+
|
20 |
+
input_image = load_image(
|
21 |
+
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/simple_upscale/hippowaffle.png"
|
22 |
+
)
|
23 |
+
|
24 |
+
lineart_preprocessor = LineArtPreprocessor.from_pretrained("OzzyGT/lineart").to("cuda")
|
25 |
+
image = lineart_preprocessor(input_image)
|
26 |
+
image.save("lineart.png")
|
27 |
+
```
|
28 |
+
|
29 |
+
Use the Hugging Face Hub model id, a URL, or a local path to load a model. You can also load pickle checkpoints, but it's not recommended due to the [vulnerabilities](https://docs.python.org/3/library/pickle.html) of pickled files. Loading [safetensor](https://hf.co/docs/safetensors/index) files is a more secure option.
|
30 |
+
|
31 |
+
If the checkpoint has custom imports (high risk!), you'll need to set the `weights_only` argument to `False`.
|
32 |
+
|
33 |
+
You can also specify a `filename` and a `subfolder` if needed.
|
34 |
+
|
35 |
+
```python
|
36 |
+
lineart_preprocessor = LineArtPreprocessor.from_pretrained("lllyasviel/Annotators", filename="sk_model.pth", weights_only=False).to("cuda")
|
37 |
+
```
|
38 |
+
|
39 |
+
### List of safetensors checkpoints
|
40 |
+
|
41 |
+
This is the current list of safetensor checkpoints available on the Hub.
|
42 |
+
|
43 |
+
|Preprocessor|Repository|Author|
|
44 |
+
|---|---|---|
|
45 |
+
|Depth Anything V2 Small|depth-anything/Depth-Anything-V2-Small-hf|<https://depth-anything-v2.github.io/>|
|
46 |
+
|Depth Anything V2 Base|depth-anything/Depth-Anything-V2-Base-hf|<https://depth-anything-v2.github.io/>|
|
47 |
+
|Depth Anything V2 Large|depth-anything/Depth-Anything-V2-Large-hf|<https://depth-anything-v2.github.io/>|
|
48 |
+
|LineArt|OzzyGT/lineart|[Caroline Chan](https://github.com/carolineec)|
|
49 |
+
|Teed|OzzyGT/teed|<https://github.com/xavysp/TEED>|
|
50 |
+
|ZoeDepth NYU|Intel/zoedepth-nyu|<https://github.com/isl-org/ZoeDepth>|
|
51 |
+
|ZoeDepth KITTI|Intel/zoedepth-kitti|<https://github.com/isl-org/ZoeDepth>|
|
52 |
+
|ZoeDepth NYU and KITTI|Intel/zoedepth-nyu-kitti|<https://github.com/isl-org/ZoeDepth>|
|
53 |
+
|
54 |
+
If you own the model and want us to change the repository to your name/organization please open an issue.
|
image_gen_aux/preprocessors/__init__.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TYPE_CHECKING
|
2 |
+
|
3 |
+
from ..utils import (
|
4 |
+
IMAGE_AUX_SLOW_IMPORT,
|
5 |
+
OptionalDependencyNotAvailable,
|
6 |
+
_LazyModule,
|
7 |
+
is_torch_available,
|
8 |
+
is_transformers_available,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
_import_structure = {
|
13 |
+
"lineart": [],
|
14 |
+
"depth": [],
|
15 |
+
"teed": [],
|
16 |
+
"lineart_standard": ["LineArtStandardPreprocessor"],
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
try:
|
21 |
+
if not (is_torch_available()):
|
22 |
+
raise OptionalDependencyNotAvailable()
|
23 |
+
except OptionalDependencyNotAvailable:
|
24 |
+
...
|
25 |
+
else:
|
26 |
+
_import_structure["lineart"] = ["LineArtPreprocessor"]
|
27 |
+
_import_structure["teed"] = ["TeedPreprocessor"]
|
28 |
+
|
29 |
+
try:
|
30 |
+
if not (is_torch_available() and is_transformers_available()):
|
31 |
+
raise OptionalDependencyNotAvailable()
|
32 |
+
except OptionalDependencyNotAvailable:
|
33 |
+
...
|
34 |
+
else:
|
35 |
+
_import_structure["depth"] = [
|
36 |
+
"DepthPreprocessor",
|
37 |
+
]
|
38 |
+
|
39 |
+
if TYPE_CHECKING or IMAGE_AUX_SLOW_IMPORT:
|
40 |
+
from .lineart_standard import LineArtStandardPreprocessor
|
41 |
+
|
42 |
+
try:
|
43 |
+
if not is_torch_available():
|
44 |
+
raise OptionalDependencyNotAvailable()
|
45 |
+
except OptionalDependencyNotAvailable:
|
46 |
+
...
|
47 |
+
else:
|
48 |
+
from .lineart import LineArtPreprocessor
|
49 |
+
from .teed import TeedPreprocessor
|
50 |
+
|
51 |
+
try:
|
52 |
+
if not (is_torch_available() and is_transformers_available()):
|
53 |
+
raise OptionalDependencyNotAvailable()
|
54 |
+
except OptionalDependencyNotAvailable:
|
55 |
+
...
|
56 |
+
else:
|
57 |
+
from .depth import DepthPreprocessor
|
58 |
+
else:
|
59 |
+
import sys
|
60 |
+
|
61 |
+
sys.modules[__name__] = _LazyModule(
|
62 |
+
__name__,
|
63 |
+
globals()["__file__"],
|
64 |
+
_import_structure,
|
65 |
+
module_spec=__spec__,
|
66 |
+
)
|
image_gen_aux/preprocessors/depth/README.md
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DEPTH
|
2 |
+
|
3 |
+
Monocular depth estimation is a computer vision task that involves predicting the depth information of a scene from a single image. In other words, it is the process of estimating the distance of objects in a scene from a single camera viewpoint.
|
4 |
+
|
5 |
+
Monocular depth estimation has various applications, including 3D reconstruction, augmented reality, autonomous driving, and robotics. It is a challenging task as it requires the model to understand the complex relationships between objects in the scene and the corresponding depth information, which can be affected by factors such as lighting conditions, occlusion, and texture.
|
6 |
+
|
7 |
+
## Usage
|
8 |
+
|
9 |
+
```python
|
10 |
+
from image_gen_aux import DepthPreprocessor
|
11 |
+
from image_gen_aux.utils import load_image
|
12 |
+
|
13 |
+
input_image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/depth/coffee_ship.png")
|
14 |
+
|
15 |
+
depth_preprocessor = DepthPreprocessor.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf").to("cuda")
|
16 |
+
image = depth_preprocessor(input_image)[0]
|
17 |
+
image.save("depth.png")
|
18 |
+
```
|
19 |
+
|
20 |
+
## Models
|
21 |
+
|
22 |
+
The Depth Preprocessor supports any depth estimation model that the `transformers` library supports that doesn't have a fixed image size restriction, but we mainly recommend and ensure the correct functionality for these models:
|
23 |
+
|
24 |
+
|Model|License|Project Page|
|
25 |
+
|---|---|---|
|
26 |
+
|Depth Anything V2|CC-BY-NC-4.0|<https://depth-anything-v2.github.io/>|
|
27 |
+
|ZoeDepth|MIT|<https://github.com/isl-org/ZoeDepth>|
|
28 |
+
|
29 |
+
Each model has different variations:
|
30 |
+
|
31 |
+
### Depth Anything V2
|
32 |
+
|
33 |
+
|Variation|Repo ID|
|
34 |
+
|---|---|
|
35 |
+
|Small|depth-anything/Depth-Anything-V2-Small-hf|
|
36 |
+
|Base|depth-anything/Depth-Anything-V2-Base-hf|
|
37 |
+
|Large|depth-anything/Depth-Anything-V2-Large-hf|
|
38 |
+
|
39 |
+
### ZoeDepth
|
40 |
+
|
41 |
+
|Variation|Repo ID|
|
42 |
+
|---|---|
|
43 |
+
|NYU|Intel/zoedepth-nyu|
|
44 |
+
|KITTI|Intel/zoedepth-kitti|
|
45 |
+
|NYU and KITTI|Intel/zoedepth-nyu-kitti|
|
image_gen_aux/preprocessors/depth/__init__.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TYPE_CHECKING
|
2 |
+
|
3 |
+
from ...utils import (
|
4 |
+
IMAGE_AUX_SLOW_IMPORT,
|
5 |
+
OptionalDependencyNotAvailable,
|
6 |
+
_LazyModule,
|
7 |
+
is_torch_available,
|
8 |
+
is_transformers_available,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
_import_structure = {}
|
13 |
+
|
14 |
+
try:
|
15 |
+
if not (is_torch_available() and is_transformers_available()):
|
16 |
+
raise OptionalDependencyNotAvailable()
|
17 |
+
except OptionalDependencyNotAvailable:
|
18 |
+
...
|
19 |
+
else:
|
20 |
+
_import_structure["depth_preprocessor"] = [
|
21 |
+
"DepthPreprocessor",
|
22 |
+
]
|
23 |
+
|
24 |
+
if TYPE_CHECKING or IMAGE_AUX_SLOW_IMPORT:
|
25 |
+
try:
|
26 |
+
if not (is_torch_available() and is_transformers_available()):
|
27 |
+
raise OptionalDependencyNotAvailable()
|
28 |
+
except OptionalDependencyNotAvailable:
|
29 |
+
...
|
30 |
+
else:
|
31 |
+
from .depth_preprocessor import (
|
32 |
+
DepthPreprocessor,
|
33 |
+
)
|
34 |
+
else:
|
35 |
+
import sys
|
36 |
+
|
37 |
+
sys.modules[__name__] = _LazyModule(
|
38 |
+
__name__,
|
39 |
+
globals()["__file__"],
|
40 |
+
_import_structure,
|
41 |
+
module_spec=__spec__,
|
42 |
+
)
|
image_gen_aux/preprocessors/depth/depth_preprocessor.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import PIL.Image
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from transformers import AutoModelForDepthEstimation
|
9 |
+
|
10 |
+
from ...image_processor import ImageMixin
|
11 |
+
from ..preprocessor import Preprocessor
|
12 |
+
|
13 |
+
|
14 |
+
class DepthPreprocessor(Preprocessor, ImageMixin):
|
15 |
+
"""Preprocessor specifically designed for monocular depth estimation.
|
16 |
+
|
17 |
+
This class inherits from both `Preprocessor` and `ImageMixin`. Please refer to each
|
18 |
+
one to get more information.
|
19 |
+
"""
|
20 |
+
|
21 |
+
@classmethod
|
22 |
+
def from_pretrained(cls, pretrained_model_or_path: Union[str, os.PathLike]):
|
23 |
+
model = AutoModelForDepthEstimation.from_pretrained(pretrained_model_or_path)
|
24 |
+
|
25 |
+
return cls(model)
|
26 |
+
|
27 |
+
@torch.inference_mode
|
28 |
+
def __call__(
|
29 |
+
self,
|
30 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image]],
|
31 |
+
resolution_scale: float = 1.0,
|
32 |
+
invert: bool = False,
|
33 |
+
batch_size: int = 1,
|
34 |
+
return_type: str = "pil",
|
35 |
+
):
|
36 |
+
if not isinstance(image, torch.Tensor):
|
37 |
+
image = self.convert_image_to_tensor(image).to(self.model.device)
|
38 |
+
|
39 |
+
image, resolution_scale = self.scale_image(image, resolution_scale)
|
40 |
+
|
41 |
+
processed_images = []
|
42 |
+
|
43 |
+
for i in range(0, len(image), batch_size):
|
44 |
+
batch = image[i : i + batch_size].to(self.model.device)
|
45 |
+
|
46 |
+
predicted_depth = self.model(batch).predicted_depth
|
47 |
+
|
48 |
+
# depth models returns only batch, height and width, so we add the channel
|
49 |
+
predicted_depth = predicted_depth.unsqueeze(1)
|
50 |
+
|
51 |
+
# models like depth anything can return a different size image
|
52 |
+
if batch.shape[2] != predicted_depth.shape[2] or batch.shape[3] != predicted_depth.shape[3]:
|
53 |
+
predicted_depth = F.interpolate(
|
54 |
+
predicted_depth, size=(batch.shape[2], batch.shape[3]), mode="bilinear", align_corners=False
|
55 |
+
)
|
56 |
+
|
57 |
+
if invert:
|
58 |
+
predicted_depth = 255 - predicted_depth
|
59 |
+
processed_images.append(predicted_depth.cpu())
|
60 |
+
predicted_depth = torch.cat(processed_images, dim=0)
|
61 |
+
|
62 |
+
if resolution_scale != 1.0:
|
63 |
+
predicted_depth, _ = self.scale_image(predicted_depth, 1 / resolution_scale)
|
64 |
+
|
65 |
+
predicted_depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
|
66 |
+
predicted_depth = predicted_depth.clamp(0, 1)
|
67 |
+
|
68 |
+
image = self.post_process_image(predicted_depth, return_type)
|
69 |
+
|
70 |
+
return image
|
image_gen_aux/preprocessors/lineart/LICENSE.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Caroline Chan
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
image_gen_aux/preprocessors/lineart/README.md
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Line Art
|
2 |
+
|
3 |
+
Line Art is based on the original repository, [Informative Drawings: Learning to generate line drawings that convey geometry and semantics](https://github.com/carolineec/informative-drawings).
|
4 |
+
|
5 |
+
From the project page [summary](https://carolineec.github.io/informative_drawings/):
|
6 |
+
|
7 |
+
*Our method approaches line drawing generation as an unsupervised image translation problem which uses various losses to assess the information communicated in a line drawing. Our key idea is to view the problem as an encoding through a line drawing and to maximize the quality of this encoding through explicit geometry, semantic, and appearance decoding objectives. This evaluation is performed by deep learning methods which decode depth, semantics, and appearance from line drawings. The aim is for the extracted depth and semantic information to match the scene geometry and semantics of the input photographs.*
|
8 |
+
|
9 |
+
## Usage
|
10 |
+
|
11 |
+
```python
|
12 |
+
from image_gen_aux import LineArtPreprocessor
|
13 |
+
from image_gen_aux.utils import load_image
|
14 |
+
|
15 |
+
input_image = load_image(
|
16 |
+
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/simple_upscale/hippowaffle.png"
|
17 |
+
)
|
18 |
+
|
19 |
+
lineart_preprocessor = LineArtPreprocessor.from_pretrained("OzzyGT/lineart").to("cuda")
|
20 |
+
image = lineart_preprocessor(input_image)[0]
|
21 |
+
image.save("lineart.png")
|
22 |
+
```
|
23 |
+
|
24 |
+
## Additional resources
|
25 |
+
|
26 |
+
* [Project page](https://carolineec.github.io/informative_drawings/)
|
27 |
+
* [Paper](https://arxiv.org/abs/2203.12691)
|
image_gen_aux/preprocessors/lineart/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TYPE_CHECKING
|
2 |
+
|
3 |
+
from ...utils import IMAGE_AUX_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
4 |
+
|
5 |
+
|
6 |
+
_import_structure = {}
|
7 |
+
|
8 |
+
try:
|
9 |
+
if not (is_torch_available()):
|
10 |
+
raise OptionalDependencyNotAvailable()
|
11 |
+
except OptionalDependencyNotAvailable:
|
12 |
+
...
|
13 |
+
else:
|
14 |
+
_import_structure["lineart_preprocessor"] = [
|
15 |
+
"LineArtPreprocessor",
|
16 |
+
]
|
17 |
+
|
18 |
+
if TYPE_CHECKING or IMAGE_AUX_SLOW_IMPORT:
|
19 |
+
try:
|
20 |
+
if not is_torch_available():
|
21 |
+
raise OptionalDependencyNotAvailable()
|
22 |
+
except OptionalDependencyNotAvailable:
|
23 |
+
...
|
24 |
+
else:
|
25 |
+
from .lineart_preprocessor import (
|
26 |
+
LineArtPreprocessor,
|
27 |
+
)
|
28 |
+
else:
|
29 |
+
import sys
|
30 |
+
|
31 |
+
sys.modules[__name__] = _LazyModule(
|
32 |
+
__name__,
|
33 |
+
globals()["__file__"],
|
34 |
+
_import_structure,
|
35 |
+
module_spec=__spec__,
|
36 |
+
)
|
image_gen_aux/preprocessors/lineart/lineart_preprocessor.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
from typing import List, Union
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import PIL.Image
|
19 |
+
import torch
|
20 |
+
from safetensors.torch import load_file
|
21 |
+
|
22 |
+
from ...image_processor import ImageMixin
|
23 |
+
from ...utils import SAFETENSORS_FILE_EXTENSION, get_model_path
|
24 |
+
from ..preprocessor import Preprocessor
|
25 |
+
from .model import Generator
|
26 |
+
|
27 |
+
|
28 |
+
class LineArtPreprocessor(Preprocessor, ImageMixin):
|
29 |
+
"""Preprocessor specifically designed for converting images to line art.
|
30 |
+
|
31 |
+
This class inherits from both `Preprocessor` and `ImageMixin`. Please refer to each
|
32 |
+
one to get more information.
|
33 |
+
"""
|
34 |
+
|
35 |
+
@classmethod
|
36 |
+
def from_pretrained(
|
37 |
+
cls,
|
38 |
+
pretrained_model_or_path: Union[str, os.PathLike],
|
39 |
+
filename: str = None,
|
40 |
+
subfolder: str = None,
|
41 |
+
weights_only: bool = True,
|
42 |
+
) -> Generator:
|
43 |
+
model_path = get_model_path(pretrained_model_or_path, filename, subfolder)
|
44 |
+
|
45 |
+
file_extension = os.path.basename(model_path).split(".")[-1]
|
46 |
+
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
47 |
+
state_dict = load_file(model_path, device="cpu")
|
48 |
+
else:
|
49 |
+
state_dict = torch.load(model_path, map_location=torch.device("cpu"), weights_only=weights_only)
|
50 |
+
|
51 |
+
model = Generator()
|
52 |
+
model.load_state_dict(state_dict)
|
53 |
+
|
54 |
+
return cls(model)
|
55 |
+
|
56 |
+
@torch.inference_mode
|
57 |
+
def __call__(
|
58 |
+
self,
|
59 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image]],
|
60 |
+
resolution_scale: float = 1.0,
|
61 |
+
invert: bool = True,
|
62 |
+
batch_size: int = 1,
|
63 |
+
return_type: str = "pil",
|
64 |
+
):
|
65 |
+
"""Preprocesses an image and generates line art using the pre-trained model.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image]]`): Input image as PIL Image,
|
69 |
+
NumPy array, PyTorch tensor format or a list of PIL Images
|
70 |
+
resolution_scale (`float`, optional, defaults to 1.0): Scale factor for image resolution during
|
71 |
+
preprocessing and post-processing. Defaults to 1.0 for no scaling.
|
72 |
+
invert (`bool`, *optional*, defaults to True): Inverts the generated image if True (white or black background).
|
73 |
+
batch_size (`int`, *optional*, defaults to 1): The number of images to process in each batch.
|
74 |
+
return_type (`str`, *optional*, defaults to "pil"): The desired return type, either "pt" for PyTorch tensor, "np" for NumPy array,
|
75 |
+
or "pil" for PIL image.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`: The generated line art in the
|
79 |
+
specified output format.
|
80 |
+
"""
|
81 |
+
if not isinstance(image, torch.Tensor):
|
82 |
+
image = self.convert_image_to_tensor(image)
|
83 |
+
|
84 |
+
image, resolution_scale = self.scale_image(image, resolution_scale)
|
85 |
+
|
86 |
+
processed_images = []
|
87 |
+
|
88 |
+
for i in range(0, len(image), batch_size):
|
89 |
+
batch = image[i : i + batch_size].to(self.model.device)
|
90 |
+
lineart = self.model(batch)
|
91 |
+
if invert:
|
92 |
+
lineart = 255 - lineart
|
93 |
+
processed_images.append(lineart.cpu())
|
94 |
+
lineart = torch.cat(processed_images, dim=0)
|
95 |
+
|
96 |
+
if resolution_scale != 1.0:
|
97 |
+
lineart, _ = self.scale_image(lineart, 1 / resolution_scale)
|
98 |
+
|
99 |
+
image = self.post_process_image(lineart, return_type)
|
100 |
+
|
101 |
+
return image
|
image_gen_aux/preprocessors/lineart/model.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Caroline Chan
|
2 |
+
#
|
3 |
+
# MIT License
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from ...modeling_utils import ModelMixin
|
7 |
+
|
8 |
+
|
9 |
+
norm_layer = nn.InstanceNorm2d
|
10 |
+
|
11 |
+
|
12 |
+
class ResidualBlock(nn.Module):
|
13 |
+
def __init__(self, in_features: int):
|
14 |
+
super(ResidualBlock, self).__init__()
|
15 |
+
|
16 |
+
conv_block = [
|
17 |
+
nn.ReflectionPad2d(1),
|
18 |
+
nn.Conv2d(in_features, in_features, 3),
|
19 |
+
norm_layer(in_features),
|
20 |
+
nn.ReLU(inplace=True),
|
21 |
+
nn.ReflectionPad2d(1),
|
22 |
+
nn.Conv2d(in_features, in_features, 3),
|
23 |
+
norm_layer(in_features),
|
24 |
+
]
|
25 |
+
|
26 |
+
self.conv_block = nn.Sequential(*conv_block)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
return x + self.conv_block(x)
|
30 |
+
|
31 |
+
|
32 |
+
class Generator(ModelMixin):
|
33 |
+
def __init__(self, input_nc: int = 3, output_nc: int = 1, n_residual_blocks: int = 3, sigmoid: bool = True):
|
34 |
+
super(Generator, self).__init__()
|
35 |
+
|
36 |
+
# Initial convolution block
|
37 |
+
model0 = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), norm_layer(64), nn.ReLU(inplace=True)]
|
38 |
+
self.model0 = nn.Sequential(*model0)
|
39 |
+
|
40 |
+
# Downsampling
|
41 |
+
model1 = []
|
42 |
+
in_features = 64
|
43 |
+
out_features = in_features * 2
|
44 |
+
for _ in range(2):
|
45 |
+
model1 += [
|
46 |
+
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
47 |
+
norm_layer(out_features),
|
48 |
+
nn.ReLU(inplace=True),
|
49 |
+
]
|
50 |
+
in_features = out_features
|
51 |
+
out_features = in_features * 2
|
52 |
+
self.model1 = nn.Sequential(*model1)
|
53 |
+
|
54 |
+
model2 = []
|
55 |
+
# Residual blocks
|
56 |
+
for _ in range(n_residual_blocks):
|
57 |
+
model2 += [ResidualBlock(in_features)]
|
58 |
+
self.model2 = nn.Sequential(*model2)
|
59 |
+
|
60 |
+
# Upsampling
|
61 |
+
model3 = []
|
62 |
+
out_features = in_features // 2
|
63 |
+
for _ in range(2):
|
64 |
+
model3 += [
|
65 |
+
nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
|
66 |
+
norm_layer(out_features),
|
67 |
+
nn.ReLU(inplace=True),
|
68 |
+
]
|
69 |
+
in_features = out_features
|
70 |
+
out_features = in_features // 2
|
71 |
+
self.model3 = nn.Sequential(*model3)
|
72 |
+
|
73 |
+
# Output layer
|
74 |
+
model4 = [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)]
|
75 |
+
if sigmoid:
|
76 |
+
model4 += [nn.Sigmoid()]
|
77 |
+
|
78 |
+
self.model4 = nn.Sequential(*model4)
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
out = self.model0(x)
|
82 |
+
out = self.model1(out)
|
83 |
+
out = self.model2(out)
|
84 |
+
out = self.model3(out)
|
85 |
+
out = self.model4(out)
|
86 |
+
|
87 |
+
return out
|
image_gen_aux/preprocessors/lineart_standard/README.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Line Art Standard
|
2 |
+
|
3 |
+
Line Art Standard was copied from the [comfyui_controlnet_aux](https://github.com/Fannovel16/comfyui_controlnet_aux) repository.
|
4 |
+
This preprocessor applies a Gaussian blur to the image to reduce noise and detail, then calculates the intensity difference between the blurred and original images to highlight edges.
|
5 |
+
|
6 |
+
## Usage
|
7 |
+
|
8 |
+
```python
|
9 |
+
from image_gen_aux import LineArtStandardPreprocessor
|
10 |
+
from image_gen_aux.utils import load_image
|
11 |
+
|
12 |
+
input_image = load_image(
|
13 |
+
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/simple_upscale/hippowaffle.png"
|
14 |
+
)
|
15 |
+
|
16 |
+
lineart_standard_preprocessor = LineArtStandardPreprocessor()
|
17 |
+
image = lineart_standard_preprocessor(input_image)[0]
|
18 |
+
image.save("lineart_standard.png")
|
19 |
+
```
|
20 |
+
|
21 |
+
## Additional resources
|
22 |
+
|
23 |
+
* [Original implementation](https://github.com/Fannovel16/comfyui_controlnet_aux/blob/main/src/custom_controlnet_aux/lineart_standard/__init__.py)
|
image_gen_aux/preprocessors/lineart_standard/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TYPE_CHECKING
|
2 |
+
|
3 |
+
from ...utils import IMAGE_AUX_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
4 |
+
|
5 |
+
|
6 |
+
_import_structure = {}
|
7 |
+
|
8 |
+
try:
|
9 |
+
if not (is_torch_available()):
|
10 |
+
raise OptionalDependencyNotAvailable()
|
11 |
+
except OptionalDependencyNotAvailable:
|
12 |
+
...
|
13 |
+
else:
|
14 |
+
_import_structure["lineart_standard_preprocessor"] = [
|
15 |
+
"LineArtStandardPreprocessor",
|
16 |
+
]
|
17 |
+
|
18 |
+
if TYPE_CHECKING or IMAGE_AUX_SLOW_IMPORT:
|
19 |
+
try:
|
20 |
+
if not is_torch_available():
|
21 |
+
raise OptionalDependencyNotAvailable()
|
22 |
+
except OptionalDependencyNotAvailable:
|
23 |
+
...
|
24 |
+
else:
|
25 |
+
from .lineart_standard_preprocessor import (
|
26 |
+
LineArtStandardPreprocessor,
|
27 |
+
)
|
28 |
+
else:
|
29 |
+
import sys
|
30 |
+
|
31 |
+
sys.modules[__name__] = _LazyModule(
|
32 |
+
__name__,
|
33 |
+
globals()["__file__"],
|
34 |
+
_import_structure,
|
35 |
+
module_spec=__spec__,
|
36 |
+
)
|
image_gen_aux/preprocessors/lineart_standard/lineart_standard_preprocessor.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import PIL.Image
|
6 |
+
|
7 |
+
from ...image_processor import ImageMixin
|
8 |
+
|
9 |
+
|
10 |
+
class LineArtStandardPreprocessor(ImageMixin):
|
11 |
+
"""Preprocessor specifically designed for converting images to line art standard.
|
12 |
+
|
13 |
+
This class inherits from both `Preprocessor` and `ImageMixin`. Please refer to each
|
14 |
+
one to get more information.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __call__(
|
18 |
+
self,
|
19 |
+
image: Union[PIL.Image.Image, np.ndarray, List[PIL.Image.Image]],
|
20 |
+
resolution_scale: float = 1.0,
|
21 |
+
invert: bool = False,
|
22 |
+
gaussian_sigma=6.0,
|
23 |
+
intensity_threshold=8,
|
24 |
+
return_type: str = "pil",
|
25 |
+
):
|
26 |
+
"""Preprocesses an image and generates line art (standard) using the opencv.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image]]`): Input image as PIL Image,
|
30 |
+
NumPy array, PyTorch tensor format or a list of PIL Images
|
31 |
+
resolution_scale (`float`, optional, defaults to 1.0): Scale factor for image resolution during
|
32 |
+
preprocessing and post-processing. Defaults to 1.0 for no scaling.
|
33 |
+
invert (`bool`, *optional*, defaults to True): Inverts the generated image if True (white or black background).
|
34 |
+
gaussian_sigma (float, optional): Sigma value for Gaussian blur. Defaults to 6.0.
|
35 |
+
intensity_threshold (int, optional): Threshold for intensity clipping. Defaults to 8.
|
36 |
+
return_type (`str`, *optional*, defaults to "pil"): The desired return type, either "pt" for PyTorch tensor, "np" for NumPy array,
|
37 |
+
or "pil" for PIL image.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`: The generated line art (standard) in the
|
41 |
+
specified output format.
|
42 |
+
"""
|
43 |
+
if not isinstance(image, np.ndarray):
|
44 |
+
image = np.array(image).astype(np.float32)
|
45 |
+
|
46 |
+
# check if image has batch, if not, add it
|
47 |
+
if len(image.shape) == 3:
|
48 |
+
image = image[None, ...]
|
49 |
+
|
50 |
+
image, resolution_scale = (
|
51 |
+
self.resize_numpy_image(image, resolution_scale) if resolution_scale != 1.0 else image
|
52 |
+
)
|
53 |
+
|
54 |
+
batch_size, height, width, _channels = image.shape
|
55 |
+
processed_images = np.empty((batch_size, height, width), dtype=np.uint8)
|
56 |
+
|
57 |
+
# since we're using just cv2, we can't do batch processing
|
58 |
+
for i in range(batch_size):
|
59 |
+
gaussian = cv2.GaussianBlur(image[i], (0, 0), gaussian_sigma)
|
60 |
+
intensity = np.min(gaussian - image[i], axis=2).clip(0, 255)
|
61 |
+
intensity /= max(16, np.median(intensity[intensity > intensity_threshold]))
|
62 |
+
intensity *= 127
|
63 |
+
edges = intensity.clip(0, 255).astype(np.uint8)
|
64 |
+
|
65 |
+
processed_images[i] = edges
|
66 |
+
|
67 |
+
if invert:
|
68 |
+
processed_images = 255 - processed_images
|
69 |
+
|
70 |
+
processed_images = processed_images[..., None]
|
71 |
+
|
72 |
+
if resolution_scale != 1.0:
|
73 |
+
processed_images, _ = self.resize_numpy_image(processed_images, 1 / resolution_scale)
|
74 |
+
processed_images = processed_images[..., None] # cv2 resize removes the channel dimension if grayscale
|
75 |
+
|
76 |
+
if return_type == "np":
|
77 |
+
return processed_images
|
78 |
+
|
79 |
+
processed_images = np.transpose(processed_images, (0, 3, 1, 2))
|
80 |
+
processed_images = [PIL.Image.fromarray(image.squeeze(), mode="L") for image in processed_images]
|
81 |
+
|
82 |
+
return processed_images
|
image_gen_aux/preprocessors/preprocessor.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import PIL.Image
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class Preprocessor(ABC):
|
10 |
+
"""
|
11 |
+
This abstract base class defines the interface for image preprocessors.
|
12 |
+
|
13 |
+
Subclasses should implement the abstract methods `from_pretrained` and
|
14 |
+
`__call__` to provide specific loading and preprocessing logic for their
|
15 |
+
respective models.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
model (`nn.Module`): The torch model to use.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, model):
|
22 |
+
self.model = model
|
23 |
+
|
24 |
+
def to(self, device):
|
25 |
+
"""
|
26 |
+
Moves the underlying model to the specified device
|
27 |
+
(e.g., CPU or GPU).
|
28 |
+
|
29 |
+
Args:
|
30 |
+
device (`torch.device`): The target device.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
`Preprocessor`: The preprocessor object itself (for method chaining).
|
34 |
+
"""
|
35 |
+
self.model = self.model.to(device)
|
36 |
+
return self
|
37 |
+
|
38 |
+
@abstractmethod
|
39 |
+
def from_pretrained(self):
|
40 |
+
"""
|
41 |
+
This abstract method defines how the preprocessor loads pre-trained
|
42 |
+
weights or configurations specific to the model it supports. Subclasses
|
43 |
+
must implement this method to handle model-specific loading logic.
|
44 |
+
|
45 |
+
This method might download pre-trained weights from a repository or
|
46 |
+
load them from a local file depending on the model's requirements.
|
47 |
+
"""
|
48 |
+
pass
|
49 |
+
|
50 |
+
@abstractmethod
|
51 |
+
def __call__(
|
52 |
+
self,
|
53 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
54 |
+
resolution_scale: float = 1.0,
|
55 |
+
invert: bool = True,
|
56 |
+
return_type: str = "pil",
|
57 |
+
):
|
58 |
+
"""
|
59 |
+
Preprocesses an image for use with the underlying model.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): Input image as PIL Image,
|
63 |
+
NumPy array, or PyTorch tensor format.
|
64 |
+
resolution_scale (`float`, optional, defaults to 1.0): Scale factor for image resolution during
|
65 |
+
resolution_scale (`float`, *optional*, defaults to 1.0): Scale factor for image resolution during
|
66 |
+
preprocessing and post-processing. Defaults to 1.0 for no scaling.
|
67 |
+
invert (`bool`, *optional*, defaults to True): Inverts the generated image if True.
|
68 |
+
return_type (`str`, *optional*, defaults to "pil"): The desired return type, either "pt" for PyTorch tensor,
|
69 |
+
"np" for NumPy array, or "pil" for PIL image.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
`Union[PIL.Image.Image, torch.Tensor]`: The preprocessed image in the
|
73 |
+
specified format.
|
74 |
+
"""
|
75 |
+
pass
|
image_gen_aux/preprocessors/teed/LICENSE.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Xavier Soria Poma
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
image_gen_aux/preprocessors/teed/README.md
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TEED: Tiny and Efficient Model for the Edge Detection Generalization
|
2 |
+
|
3 |
+
Tiny and Efficient Edge Detector (TEED) is a light convolutional neural network with only 58K parameters, less than 0.2% of the state-of-the-art models. Training on the [BIPED](https://www.kaggle.com/datasets/xavysp/biped) dataset takes less than 30 minutes,
|
4 |
+
with each epoch requiring less than 5 minutes. Our proposed model is easy to train and it quickly converges within very first few
|
5 |
+
epochs, while the predicted edge-maps are crisp and of high quality, see image above.
|
6 |
+
[This paper has been accepted by ICCV 2023-Workshop RCV](https://arxiv.org/abs/2308.06468).
|
7 |
+
|
8 |
+
## Usage
|
9 |
+
|
10 |
+
```python
|
11 |
+
from image_gen_aux import TeedPreprocessor
|
12 |
+
from image_gen_aux.utils import load_image
|
13 |
+
|
14 |
+
input_image = load_image(
|
15 |
+
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/teed/20240922043215.png"
|
16 |
+
)
|
17 |
+
|
18 |
+
teed_preprocessor = TeedPreprocessor.from_pretrained("OzzyGT/teed").to("cuda")
|
19 |
+
image = teed_preprocessor(input_image)[0]
|
20 |
+
image.save("teed.png")
|
21 |
+
```
|
22 |
+
|
23 |
+
## Additional resources
|
24 |
+
|
25 |
+
* [Project page](https://github.com/xavysp/TEED)
|
26 |
+
* [Paper](https://arxiv.org/abs/2308.06468)
|
image_gen_aux/preprocessors/teed/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TYPE_CHECKING
|
2 |
+
|
3 |
+
from ...utils import IMAGE_AUX_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
4 |
+
|
5 |
+
|
6 |
+
_import_structure = {}
|
7 |
+
|
8 |
+
try:
|
9 |
+
if not (is_torch_available()):
|
10 |
+
raise OptionalDependencyNotAvailable()
|
11 |
+
except OptionalDependencyNotAvailable:
|
12 |
+
...
|
13 |
+
else:
|
14 |
+
_import_structure["teed_preprocessor"] = [
|
15 |
+
"TeedPreprocessor",
|
16 |
+
]
|
17 |
+
|
18 |
+
if TYPE_CHECKING or IMAGE_AUX_SLOW_IMPORT:
|
19 |
+
try:
|
20 |
+
if not is_torch_available():
|
21 |
+
raise OptionalDependencyNotAvailable()
|
22 |
+
except OptionalDependencyNotAvailable:
|
23 |
+
...
|
24 |
+
else:
|
25 |
+
from .teed_preprocessor import (
|
26 |
+
TeedPreprocessor,
|
27 |
+
)
|
28 |
+
else:
|
29 |
+
import sys
|
30 |
+
|
31 |
+
sys.modules[__name__] = _LazyModule(
|
32 |
+
__name__,
|
33 |
+
globals()["__file__"],
|
34 |
+
_import_structure,
|
35 |
+
module_spec=__spec__,
|
36 |
+
)
|
image_gen_aux/preprocessors/teed/teed.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original from: https://github.com/xavysp/TEED
|
2 |
+
# TEED: is a Tiny but Efficient Edge Detection, it comes from the LDC-B3
|
3 |
+
# with a Slightly modification
|
4 |
+
# LDC parameters:
|
5 |
+
# 155665
|
6 |
+
# TED > 58K
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from ...modeling_utils import ModelMixin
|
13 |
+
|
14 |
+
|
15 |
+
"""
|
16 |
+
smish_function and Smish script based on:
|
17 |
+
Wang, Xueliang, Honge Ren, and Achuan Wang.
|
18 |
+
"Smish: A Novel Activation Function for Deep Learning Methods.
|
19 |
+
" Electronics 11.4 (2022): 540.
|
20 |
+
smish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + sigmoid(x)))
|
21 |
+
"""
|
22 |
+
|
23 |
+
|
24 |
+
@torch.jit.script
|
25 |
+
def smish_function(input):
|
26 |
+
"""
|
27 |
+
Applies the mish function element-wise:
|
28 |
+
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(sigmoid(x))))
|
29 |
+
See additional documentation for mish class.
|
30 |
+
"""
|
31 |
+
return input * torch.tanh(torch.log(1 + torch.sigmoid(input)))
|
32 |
+
|
33 |
+
|
34 |
+
class Smish(nn.Module):
|
35 |
+
"""
|
36 |
+
Applies the mish function element-wise:
|
37 |
+
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
|
38 |
+
Shape:
|
39 |
+
- Input: (N, *) where * means, any number of additional
|
40 |
+
dimensions
|
41 |
+
- Output: (N, *), same shape as the input
|
42 |
+
Examples:
|
43 |
+
>>> m = Mish()
|
44 |
+
>>> input = torch.randn(2)
|
45 |
+
>>> output = m(input)
|
46 |
+
Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self):
|
50 |
+
"""
|
51 |
+
Init method.
|
52 |
+
"""
|
53 |
+
super().__init__()
|
54 |
+
|
55 |
+
def forward(self, input):
|
56 |
+
"""
|
57 |
+
Forward pass of the function.
|
58 |
+
"""
|
59 |
+
return smish_function(input)
|
60 |
+
|
61 |
+
|
62 |
+
def weight_init(m):
|
63 |
+
if isinstance(m, (nn.Conv2d,)):
|
64 |
+
torch.nn.init.xavier_normal_(m.weight, gain=1.0)
|
65 |
+
|
66 |
+
if m.bias is not None:
|
67 |
+
torch.nn.init.zeros_(m.bias)
|
68 |
+
|
69 |
+
# for fusion layer
|
70 |
+
if isinstance(m, (nn.ConvTranspose2d,)):
|
71 |
+
torch.nn.init.xavier_normal_(m.weight, gain=1.0)
|
72 |
+
if m.bias is not None:
|
73 |
+
torch.nn.init.zeros_(m.bias)
|
74 |
+
|
75 |
+
|
76 |
+
class CoFusion(nn.Module):
|
77 |
+
# from LDC
|
78 |
+
|
79 |
+
def __init__(self, in_ch, out_ch):
|
80 |
+
super(CoFusion, self).__init__()
|
81 |
+
self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3, stride=1, padding=1) # before 64
|
82 |
+
self.conv3 = nn.Conv2d(32, out_ch, kernel_size=3, stride=1, padding=1) # before 64 instead of 32
|
83 |
+
self.relu = nn.ReLU()
|
84 |
+
self.norm_layer1 = nn.GroupNorm(4, 32) # before 64
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
# fusecat = torch.cat(x, dim=1)
|
88 |
+
attn = self.relu(self.norm_layer1(self.conv1(x)))
|
89 |
+
attn = F.softmax(self.conv3(attn), dim=1)
|
90 |
+
return ((x * attn).sum(1)).unsqueeze(1)
|
91 |
+
|
92 |
+
|
93 |
+
class CoFusion2(nn.Module):
|
94 |
+
# TEDv14-3
|
95 |
+
def __init__(self, in_ch, out_ch):
|
96 |
+
super(CoFusion2, self).__init__()
|
97 |
+
self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3, stride=1, padding=1) # before 64
|
98 |
+
self.conv3 = nn.Conv2d(32, out_ch, kernel_size=3, stride=1, padding=1) # before 64 instead of 32
|
99 |
+
self.smish = Smish() # nn.ReLU(inplace=True)
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
attn = self.conv1(self.smish(x))
|
103 |
+
attn = self.conv3(self.smish(attn)) # before , )dim=1)
|
104 |
+
|
105 |
+
return ((x * attn).sum(1)).unsqueeze(1)
|
106 |
+
|
107 |
+
|
108 |
+
class DoubleFusion(nn.Module):
|
109 |
+
# TED fusion before the final edge map prediction
|
110 |
+
def __init__(self, in_ch, out_ch):
|
111 |
+
super(DoubleFusion, self).__init__()
|
112 |
+
self.DWconv1 = nn.Conv2d(in_ch, in_ch * 8, kernel_size=3, stride=1, padding=1, groups=in_ch) # before 64
|
113 |
+
self.PSconv1 = nn.PixelShuffle(1)
|
114 |
+
|
115 |
+
self.DWconv2 = nn.Conv2d(24, 24 * 1, kernel_size=3, stride=1, padding=1, groups=24) # before 64 instead of 32
|
116 |
+
|
117 |
+
self.AF = Smish() # XAF() #nn.Tanh()# XAF() # # Smish()#
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
attn = self.PSconv1(self.DWconv1(self.AF(x))) # #TEED best res TEDv14 [8, 32, 352, 352]
|
121 |
+
|
122 |
+
attn2 = self.PSconv1(self.DWconv2(self.AF(attn))) # #TEED best res TEDv14[8, 3, 352, 352]
|
123 |
+
|
124 |
+
return smish_function(((attn2 + attn).sum(1)).unsqueeze(1)) # TED best res
|
125 |
+
|
126 |
+
|
127 |
+
class _DenseLayer(nn.Sequential):
|
128 |
+
def __init__(self, input_features, out_features):
|
129 |
+
super(_DenseLayer, self).__init__()
|
130 |
+
|
131 |
+
(
|
132 |
+
self.add_module(
|
133 |
+
"conv1",
|
134 |
+
nn.Conv2d(
|
135 |
+
input_features,
|
136 |
+
out_features,
|
137 |
+
kernel_size=3,
|
138 |
+
stride=1,
|
139 |
+
padding=2,
|
140 |
+
bias=True,
|
141 |
+
),
|
142 |
+
),
|
143 |
+
)
|
144 |
+
(self.add_module("smish1", Smish()),)
|
145 |
+
self.add_module(
|
146 |
+
"conv2",
|
147 |
+
nn.Conv2d(out_features, out_features, kernel_size=3, stride=1, bias=True),
|
148 |
+
)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
x1, x2 = x
|
152 |
+
|
153 |
+
new_features = super(_DenseLayer, self).forward(smish_function(x1)) # F.relu()
|
154 |
+
|
155 |
+
return 0.5 * (new_features + x2), x2
|
156 |
+
|
157 |
+
|
158 |
+
class _DenseBlock(nn.Sequential):
|
159 |
+
def __init__(self, num_layers, input_features, out_features):
|
160 |
+
super(_DenseBlock, self).__init__()
|
161 |
+
for i in range(num_layers):
|
162 |
+
layer = _DenseLayer(input_features, out_features)
|
163 |
+
self.add_module("denselayer%d" % (i + 1), layer)
|
164 |
+
input_features = out_features
|
165 |
+
|
166 |
+
|
167 |
+
class UpConvBlock(nn.Module):
|
168 |
+
def __init__(self, in_features, up_scale):
|
169 |
+
super(UpConvBlock, self).__init__()
|
170 |
+
self.up_factor = 2
|
171 |
+
self.constant_features = 16
|
172 |
+
|
173 |
+
layers = self.make_deconv_layers(in_features, up_scale)
|
174 |
+
assert layers is not None, layers
|
175 |
+
self.features = nn.Sequential(*layers)
|
176 |
+
|
177 |
+
def make_deconv_layers(self, in_features, up_scale):
|
178 |
+
layers = []
|
179 |
+
all_pads = [0, 0, 1, 3, 7]
|
180 |
+
for i in range(up_scale):
|
181 |
+
kernel_size = 2**up_scale
|
182 |
+
pad = all_pads[up_scale] # kernel_size-1
|
183 |
+
out_features = self.compute_out_features(i, up_scale)
|
184 |
+
layers.append(nn.Conv2d(in_features, out_features, 1))
|
185 |
+
layers.append(Smish())
|
186 |
+
layers.append(nn.ConvTranspose2d(out_features, out_features, kernel_size, stride=2, padding=pad))
|
187 |
+
in_features = out_features
|
188 |
+
return layers
|
189 |
+
|
190 |
+
def compute_out_features(self, idx, up_scale):
|
191 |
+
return 1 if idx == up_scale - 1 else self.constant_features
|
192 |
+
|
193 |
+
def forward(self, x):
|
194 |
+
return self.features(x)
|
195 |
+
|
196 |
+
|
197 |
+
class SingleConvBlock(nn.Module):
|
198 |
+
def __init__(self, in_features, out_features, stride, use_ac=False):
|
199 |
+
super(SingleConvBlock, self).__init__()
|
200 |
+
self.use_ac = use_ac
|
201 |
+
self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride, bias=True)
|
202 |
+
if self.use_ac:
|
203 |
+
self.smish = Smish()
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
x = self.conv(x)
|
207 |
+
if self.use_ac:
|
208 |
+
return self.smish(x)
|
209 |
+
else:
|
210 |
+
return x
|
211 |
+
|
212 |
+
|
213 |
+
class DoubleConvBlock(nn.Module):
|
214 |
+
def __init__(self, in_features, mid_features, out_features=None, stride=1, use_act=True):
|
215 |
+
super(DoubleConvBlock, self).__init__()
|
216 |
+
|
217 |
+
self.use_act = use_act
|
218 |
+
if out_features is None:
|
219 |
+
out_features = mid_features
|
220 |
+
self.conv1 = nn.Conv2d(in_features, mid_features, 3, padding=1, stride=stride)
|
221 |
+
self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1)
|
222 |
+
self.smish = Smish() # nn.ReLU(inplace=True)
|
223 |
+
|
224 |
+
def forward(self, x):
|
225 |
+
x = self.conv1(x)
|
226 |
+
x = self.smish(x)
|
227 |
+
x = self.conv2(x)
|
228 |
+
if self.use_act:
|
229 |
+
x = self.smish(x)
|
230 |
+
return x
|
231 |
+
|
232 |
+
|
233 |
+
class TEED(ModelMixin):
|
234 |
+
"""Definition of Tiny and Efficient Edge Detector
|
235 |
+
model
|
236 |
+
"""
|
237 |
+
|
238 |
+
def __init__(self):
|
239 |
+
super(TEED, self).__init__()
|
240 |
+
self.block_1 = DoubleConvBlock(
|
241 |
+
3,
|
242 |
+
16,
|
243 |
+
16,
|
244 |
+
stride=2,
|
245 |
+
)
|
246 |
+
self.block_2 = DoubleConvBlock(16, 32, use_act=False)
|
247 |
+
self.dblock_3 = _DenseBlock(1, 32, 48) # [32,48,100,100] before (2, 32, 64)
|
248 |
+
|
249 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
250 |
+
|
251 |
+
# skip1 connection, see fig. 2
|
252 |
+
self.side_1 = SingleConvBlock(16, 32, 2)
|
253 |
+
|
254 |
+
# skip2 connection, see fig. 2
|
255 |
+
self.pre_dense_3 = SingleConvBlock(32, 48, 1) # before (32, 64, 1)
|
256 |
+
|
257 |
+
# USNet
|
258 |
+
self.up_block_1 = UpConvBlock(16, 1)
|
259 |
+
self.up_block_2 = UpConvBlock(32, 1)
|
260 |
+
self.up_block_3 = UpConvBlock(48, 2) # (32, 64, 1)
|
261 |
+
|
262 |
+
self.block_cat = DoubleFusion(3, 3) # TEED: DoubleFusion
|
263 |
+
|
264 |
+
self.apply(weight_init)
|
265 |
+
|
266 |
+
def slice(self, tensor, slice_shape):
|
267 |
+
t_shape = tensor.shape
|
268 |
+
img_h, img_w = slice_shape
|
269 |
+
if img_w != t_shape[-1] or img_h != t_shape[2]:
|
270 |
+
new_tensor = F.interpolate(tensor, size=(img_h, img_w), mode="bicubic", align_corners=False)
|
271 |
+
|
272 |
+
else:
|
273 |
+
new_tensor = tensor
|
274 |
+
# tensor[..., :height, :width]
|
275 |
+
return new_tensor
|
276 |
+
|
277 |
+
def resize_input(self, tensor):
|
278 |
+
t_shape = tensor.shape
|
279 |
+
if t_shape[2] % 8 != 0 or t_shape[3] % 8 != 0:
|
280 |
+
img_w = ((t_shape[3] // 8) + 1) * 8
|
281 |
+
img_h = ((t_shape[2] // 8) + 1) * 8
|
282 |
+
new_tensor = F.interpolate(tensor, size=(img_h, img_w), mode="bicubic", align_corners=False)
|
283 |
+
else:
|
284 |
+
new_tensor = tensor
|
285 |
+
return new_tensor
|
286 |
+
|
287 |
+
def crop_bdcn(data1, h, w, crop_h, crop_w):
|
288 |
+
# Based on BDCN Implementation @ https://github.com/pkuCactus/BDCN
|
289 |
+
_, _, h1, w1 = data1.size()
|
290 |
+
assert h <= h1 and w <= w1
|
291 |
+
data = data1[:, :, crop_h : crop_h + h, crop_w : crop_w + w]
|
292 |
+
return data
|
293 |
+
|
294 |
+
def forward(self, x, single_test=False):
|
295 |
+
assert x.ndim == 4, x.shape
|
296 |
+
# supose the image size is 352x352
|
297 |
+
|
298 |
+
# Block 1
|
299 |
+
block_1 = self.block_1(x) # [8,16,176,176]
|
300 |
+
block_1_side = self.side_1(block_1) # 16 [8,32,88,88]
|
301 |
+
|
302 |
+
# Block 2
|
303 |
+
block_2 = self.block_2(block_1) # 32 # [8,32,176,176]
|
304 |
+
block_2_down = self.maxpool(block_2) # [8,32,88,88]
|
305 |
+
block_2_add = block_2_down + block_1_side # [8,32,88,88]
|
306 |
+
|
307 |
+
# Block 3
|
308 |
+
block_3_pre_dense = self.pre_dense_3(block_2_down) # [8,64,88,88] block 3 L connection
|
309 |
+
block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense]) # [8,64,88,88]
|
310 |
+
|
311 |
+
# upsampling blocks
|
312 |
+
out_1 = self.up_block_1(block_1)
|
313 |
+
out_2 = self.up_block_2(block_2)
|
314 |
+
out_3 = self.up_block_3(block_3)
|
315 |
+
|
316 |
+
results = [out_1, out_2, out_3]
|
317 |
+
|
318 |
+
# concatenate multiscale outputs
|
319 |
+
block_cat = torch.cat(results, dim=1) # Bx6xHxW
|
320 |
+
block_cat = self.block_cat(block_cat) # Bx1xHxW DoubleFusion
|
321 |
+
|
322 |
+
results.append(block_cat)
|
323 |
+
return results
|
image_gen_aux/preprocessors/teed/teed_preprocessor.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
from typing import List, Union
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import PIL.Image
|
19 |
+
import torch
|
20 |
+
from safetensors.torch import load_file
|
21 |
+
|
22 |
+
from ...image_processor import ImageMixin
|
23 |
+
from ...utils import SAFETENSORS_FILE_EXTENSION, get_model_path
|
24 |
+
from ..preprocessor import Preprocessor
|
25 |
+
from .teed import TEED
|
26 |
+
|
27 |
+
|
28 |
+
class TeedPreprocessor(Preprocessor, ImageMixin):
|
29 |
+
"""Preprocessor specifically designed for detecting edges in images.
|
30 |
+
|
31 |
+
This class inherits from both `Preprocessor` and `ImageMixin`. Please refer to each
|
32 |
+
one to get more information.
|
33 |
+
"""
|
34 |
+
|
35 |
+
@classmethod
|
36 |
+
def from_pretrained(
|
37 |
+
cls,
|
38 |
+
pretrained_model_or_path: Union[str, os.PathLike],
|
39 |
+
filename: str = None,
|
40 |
+
subfolder: str = None,
|
41 |
+
weights_only: bool = True,
|
42 |
+
) -> TEED:
|
43 |
+
model_path = get_model_path(pretrained_model_or_path, filename, subfolder)
|
44 |
+
|
45 |
+
file_extension = os.path.basename(model_path).split(".")[-1]
|
46 |
+
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
47 |
+
state_dict = load_file(model_path, device="cpu")
|
48 |
+
else:
|
49 |
+
state_dict = torch.load(model_path, map_location=torch.device("cpu"), weights_only=weights_only)
|
50 |
+
|
51 |
+
model = TEED()
|
52 |
+
model.load_state_dict(state_dict)
|
53 |
+
|
54 |
+
return cls(model)
|
55 |
+
|
56 |
+
@torch.inference_mode
|
57 |
+
def __call__(
|
58 |
+
self,
|
59 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image]],
|
60 |
+
resolution_scale: float = 1.0,
|
61 |
+
invert: bool = False,
|
62 |
+
safe_steps: int = 2,
|
63 |
+
batch_size: int = 1,
|
64 |
+
return_type: str = "pil",
|
65 |
+
):
|
66 |
+
"""Preprocesses an image and detects the edges using the pre-trained model.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image]]`): Input image as PIL Image,
|
70 |
+
NumPy array, PyTorch tensor format or a list of PIL Images
|
71 |
+
resolution_scale (`float`, optional, defaults to 1.0): Scale factor for image resolution during
|
72 |
+
preprocessing and post-processing. Defaults to 1.0 for no scaling.
|
73 |
+
invert (`bool`, *optional*, defaults to True): Inverts the generated image if True (white or black background).
|
74 |
+
safe_steps (int, optional):
|
75 |
+
Number of safe steps for the TEED model. Defaults to 2.
|
76 |
+
batch_size (`int`, *optional*, defaults to 1): The number of images to process in each batch.
|
77 |
+
return_type (`str`, *optional*, defaults to "pil"): The desired return type, either "pt" for PyTorch tensor, "np" for NumPy array,
|
78 |
+
or "pil" for PIL image.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`: The generated line art in the
|
82 |
+
specified output format.
|
83 |
+
"""
|
84 |
+
if not isinstance(image, torch.Tensor):
|
85 |
+
image = self.convert_image_to_tensor(image, normalize=False)
|
86 |
+
|
87 |
+
image, resolution_scale = self.scale_image(image, resolution_scale)
|
88 |
+
|
89 |
+
processed_images = []
|
90 |
+
|
91 |
+
for i in range(0, len(image), batch_size):
|
92 |
+
batch = image[i : i + batch_size].to(self.model.device)
|
93 |
+
|
94 |
+
edges = self.model(batch)
|
95 |
+
edges = torch.stack([e[0, 0] for e in edges], dim=2)
|
96 |
+
mean_edges = torch.mean(edges, dim=2)
|
97 |
+
edge = torch.sigmoid(mean_edges)
|
98 |
+
|
99 |
+
if safe_steps != 0:
|
100 |
+
edge = self.safe_step(edge, safe_steps)
|
101 |
+
|
102 |
+
if invert:
|
103 |
+
edge = 1 - edge
|
104 |
+
|
105 |
+
processed_images.append(edge.unsqueeze(0).cpu())
|
106 |
+
teed = torch.cat(processed_images, dim=0)
|
107 |
+
|
108 |
+
# add missing channel
|
109 |
+
teed = teed.unsqueeze(1)
|
110 |
+
|
111 |
+
if resolution_scale != 1.0:
|
112 |
+
teed, _ = self.scale_image(teed, 1 / resolution_scale)
|
113 |
+
|
114 |
+
image = self.post_process_image(teed, return_type)
|
115 |
+
|
116 |
+
return image
|
117 |
+
|
118 |
+
def safe_step(self, x, step=2):
|
119 |
+
y = x.float() * float(step + 1)
|
120 |
+
y = y.int().float() / float(step)
|
121 |
+
return y
|
image_gen_aux/upscalers/README.md
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# UPSCALERS
|
2 |
+
|
3 |
+
## Upscale with model
|
4 |
+
|
5 |
+
Class to upscale images with a safetensor checkpoint. We use [spandrel](https://github.com/chaiNNer-org/spandrel) for loading, and you can see the list of supported models [here](https://github.com/chaiNNer-org/spandrel?tab=readme-ov-file#model-architecture-support).
|
6 |
+
|
7 |
+
Most of the super-resolution models are provided as `pickle` checkpoints, which are considered unsafe. We promote the use of safetensor checkpoints, and for convenient use, we recommend using the Hugging Face Hub. You can still use a locally downloaded model.
|
8 |
+
|
9 |
+
### Space
|
10 |
+
|
11 |
+
You can test the current super resolution models you can use with this [Hugging Face Space](https://huggingface.co/spaces/OzzyGT/basic_upscaler).
|
12 |
+
|
13 |
+
### How to use
|
14 |
+
|
15 |
+
```python
|
16 |
+
from image_gen_aux import UpscaleWithModel
|
17 |
+
from image_gen_aux.utils import load_image
|
18 |
+
|
19 |
+
original = load_image(
|
20 |
+
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/simple_upscale/hippowaffle_small.png"
|
21 |
+
)
|
22 |
+
|
23 |
+
upscaler = UpscaleWithModel.from_pretrained("Kim2091/UltraSharp").to("cuda")
|
24 |
+
image = upscaler(original)
|
25 |
+
image.save("upscaled.png")
|
26 |
+
```
|
27 |
+
|
28 |
+
### Tiling
|
29 |
+
|
30 |
+
Tiling can be enabled to use less resources.
|
31 |
+
|
32 |
+
```python
|
33 |
+
image = upscaler(original, tiling=True, tile_width=768, tile_height=768, overlap=8)
|
34 |
+
```
|
35 |
+
|
36 |
+
### Scale
|
37 |
+
|
38 |
+
The scale is automatically obtained from the model but can be overridden with the `scale` argument:
|
39 |
+
|
40 |
+
```python
|
41 |
+
image = upscaler(original, scale=2)
|
42 |
+
```
|
43 |
+
|
44 |
+
### List of safetensors checkpoints
|
45 |
+
|
46 |
+
This is the current list of safetensor checkpoints you can use from the hub.
|
47 |
+
|
48 |
+
|Model|Scale|Repository|Owner|
|
49 |
+
|---|---|---|---|
|
50 |
+
|UltraSharp|4X|Kim2091/UltraSharp|[Kim2091](https://huggingface.co/Kim2091)|
|
51 |
+
|DAT|2X|OzzyGT/DAT_X2|[zhengchen1999](https://github.com/zhengchen1999)|
|
52 |
+
|DAT|3X|OzzyGT/DAT_X3|[zhengchen1999](https://github.com/zhengchen1999)|
|
53 |
+
|DAT|4X|OzzyGT/DAT_X4|[zhengchen1999](https://github.com/zhengchen1999)|
|
54 |
+
|RealPLKSR|4X|Phips/4xNomosWebPhoto_RealPLKSR|[Philip Hofmann](https://huggingface.co/Phips)|
|
55 |
+
|DAT-2|4X|Phips/4xRealWebPhoto_v4_dat2|[Philip Hofmann](https://huggingface.co/Phips)|
|
56 |
+
|BSRGAN|4X|OzzyGT/4xRemacri|[FoolhardyVEVO](https://openmodeldb.info/users/foolhardy)|
|
57 |
+
|
58 |
+
If you own the model and want us to change the repository to your name/organization please open an issue.
|
image_gen_aux/upscalers/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TYPE_CHECKING
|
2 |
+
|
3 |
+
from ..utils import IMAGE_AUX_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
4 |
+
|
5 |
+
|
6 |
+
_import_structure = {}
|
7 |
+
|
8 |
+
try:
|
9 |
+
if not (is_torch_available()):
|
10 |
+
raise OptionalDependencyNotAvailable()
|
11 |
+
except OptionalDependencyNotAvailable:
|
12 |
+
...
|
13 |
+
else:
|
14 |
+
_import_structure["upscale_with_model"] = [
|
15 |
+
"UpscaleWithModel",
|
16 |
+
]
|
17 |
+
|
18 |
+
if TYPE_CHECKING or IMAGE_AUX_SLOW_IMPORT:
|
19 |
+
try:
|
20 |
+
if not is_torch_available():
|
21 |
+
raise OptionalDependencyNotAvailable()
|
22 |
+
except OptionalDependencyNotAvailable:
|
23 |
+
...
|
24 |
+
else:
|
25 |
+
from .upscale_with_model import (
|
26 |
+
UpscaleWithModel,
|
27 |
+
)
|
28 |
+
else:
|
29 |
+
import sys
|
30 |
+
|
31 |
+
sys.modules[__name__] = _LazyModule(
|
32 |
+
__name__,
|
33 |
+
globals()["__file__"],
|
34 |
+
_import_structure,
|
35 |
+
module_spec=__spec__,
|
36 |
+
)
|
image_gen_aux/upscalers/upscale_with_model.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
from typing import Union
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import PIL.Image
|
19 |
+
import torch
|
20 |
+
from spandrel import ImageModelDescriptor, ModelLoader
|
21 |
+
|
22 |
+
from ..image_processor import ImageMixin
|
23 |
+
from ..utils import get_model_path, tiled_upscale
|
24 |
+
|
25 |
+
|
26 |
+
class UpscaleWithModel(ImageMixin):
|
27 |
+
r"""
|
28 |
+
Upscaler class that uses a pytorch model.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
model ([`ImageModelDescriptor`]):
|
32 |
+
Upscaler model, must be supported by spandrel.
|
33 |
+
scale (`int`, defaults to the scale of the model):
|
34 |
+
The number of times to scale the image, it is recommended to use the model default scale which
|
35 |
+
usually is what the model was trained for.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, model: ImageModelDescriptor, scale: int = None):
|
39 |
+
super().__init__()
|
40 |
+
self.model = model
|
41 |
+
|
42 |
+
def to(self, device):
|
43 |
+
self.model.to(device)
|
44 |
+
return self
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def from_pretrained(
|
48 |
+
cls, pretrained_model_or_path: Union[str, os.PathLike], filename: str = None, subfolder: str = None
|
49 |
+
) -> ImageModelDescriptor:
|
50 |
+
r"""
|
51 |
+
Instantiate the Upscaler class from pretrained weights.
|
52 |
+
|
53 |
+
Parameters:
|
54 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
55 |
+
Can be either:
|
56 |
+
|
57 |
+
- A string, the *repo id* (for example `OzzyGT/UltraSharp`) of a pretrained model
|
58 |
+
hosted on the Hub, must be saved in safetensors. If there's more than one checkpoint
|
59 |
+
in the repository and the filename wasn't specified, the first one found will be loaded.
|
60 |
+
- A path to a *directory* (for example `./upscaler_model/`) containing a pretrained
|
61 |
+
upscaler checkpoint.
|
62 |
+
filename (`str`, *optional*):
|
63 |
+
The name of the file in the repo.
|
64 |
+
subfolder (`str`, *optional*):
|
65 |
+
An optional value corresponding to a folder inside the model repo.
|
66 |
+
"""
|
67 |
+
model_path = get_model_path(pretrained_model_or_path, filename, subfolder)
|
68 |
+
model = ModelLoader().load_from_file(model_path)
|
69 |
+
|
70 |
+
# validate that it's the correct model
|
71 |
+
assert isinstance(model, ImageModelDescriptor)
|
72 |
+
|
73 |
+
return cls(model)
|
74 |
+
|
75 |
+
@torch.inference_mode
|
76 |
+
def __call__(
|
77 |
+
self,
|
78 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
79 |
+
tiling: bool = False,
|
80 |
+
tile_width: int = 512,
|
81 |
+
tile_height: int = 512,
|
82 |
+
overlap: int = 8,
|
83 |
+
return_type: str = "pil",
|
84 |
+
) -> Union[torch.Tensor, PIL.Image.Image, np.ndarray]:
|
85 |
+
r"""
|
86 |
+
Upscales the given image, optionally using tiling.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
image (Union[PIL.Image.Image, np.ndarray, torch.Tensor]):
|
90 |
+
The image to be upscaled. Can be a PIL Image, NumPy array, or PyTorch tensor.
|
91 |
+
tiling (bool, optional):
|
92 |
+
Whether to use tiling for upscaling. Default is False.
|
93 |
+
tile_width (int, optional):
|
94 |
+
The width of each tile if tiling is used. Default is 512.
|
95 |
+
tile_height (int, optional):
|
96 |
+
The height of each tile if tiling is used. Default is 512.
|
97 |
+
overlap (int, optional):
|
98 |
+
The overlap between tiles if tiling is used. Default is 8.
|
99 |
+
return_type (str, optional):
|
100 |
+
The type of the returned image. Can be 'pil', 'numpy', or 'tensor'. Default is 'pil'.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
Union[torch.Tensor, PIL.Image.Image, np.ndarray]:
|
104 |
+
The upscaled image, in the format specified by `return_type`.
|
105 |
+
"""
|
106 |
+
if not isinstance(image, torch.Tensor):
|
107 |
+
image = self.convert_image_to_tensor(image)
|
108 |
+
|
109 |
+
image = image.to(self.model.device)
|
110 |
+
|
111 |
+
if tiling:
|
112 |
+
upscaled_tensor = tiled_upscale(image, self.model, self.model.scale, tile_width, tile_height, overlap)
|
113 |
+
else:
|
114 |
+
upscaled_tensor = self.model(image)
|
115 |
+
|
116 |
+
image = self.post_process_image(upscaled_tensor, return_type)[0]
|
117 |
+
|
118 |
+
return image
|
image_gen_aux/utils/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .constants import SAFETENSORS_FILE_EXTENSION
|
16 |
+
from .import_utils import (
|
17 |
+
IMAGE_AUX_SLOW_IMPORT,
|
18 |
+
OptionalDependencyNotAvailable,
|
19 |
+
_LazyModule,
|
20 |
+
is_torch_available,
|
21 |
+
is_transformers_available,
|
22 |
+
)
|
23 |
+
from .loading_utils import load_image
|
24 |
+
from .model_utils import get_model_path
|
25 |
+
from .tiling_utils import create_gradient_mask, tiled_upscale
|
image_gen_aux/utils/constants.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
SAFETENSORS_FILE_EXTENSION = "safetensors"
|
image_gen_aux/utils/import_utils.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Import utilities: Utilities related to imports and our lazy inits.
|
16 |
+
"""
|
17 |
+
import importlib.metadata as importlib_metadata
|
18 |
+
import importlib.util
|
19 |
+
import os
|
20 |
+
from itertools import chain
|
21 |
+
from types import ModuleType
|
22 |
+
from typing import Any
|
23 |
+
|
24 |
+
from . import logging
|
25 |
+
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__)
|
28 |
+
|
29 |
+
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
30 |
+
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
31 |
+
|
32 |
+
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
33 |
+
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
34 |
+
IMAGE_AUX_SLOW_IMPORT = os.environ.get("IMAGE_AUX_SLOW_IMPORT", "FALSE").upper()
|
35 |
+
IMAGE_AUX_SLOW_IMPORT = IMAGE_AUX_SLOW_IMPORT in ENV_VARS_TRUE_VALUES
|
36 |
+
|
37 |
+
|
38 |
+
_torch_version = "N/A"
|
39 |
+
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
40 |
+
_torch_available = importlib.util.find_spec("torch") is not None
|
41 |
+
if _torch_available:
|
42 |
+
try:
|
43 |
+
_torch_version = importlib_metadata.version("torch")
|
44 |
+
logger.info(f"PyTorch version {_torch_version} available.")
|
45 |
+
except importlib_metadata.PackageNotFoundError:
|
46 |
+
_torch_available = False
|
47 |
+
else:
|
48 |
+
logger.info("Disabling PyTorch because USE_TORCH is set")
|
49 |
+
_torch_available = False
|
50 |
+
|
51 |
+
_transformers_available = importlib.util.find_spec("transformers") is not None
|
52 |
+
try:
|
53 |
+
_transformers_version = importlib_metadata.version("transformers")
|
54 |
+
logger.debug(f"Successfully imported transformers version {_transformers_version}")
|
55 |
+
except importlib_metadata.PackageNotFoundError:
|
56 |
+
_transformers_available = False
|
57 |
+
|
58 |
+
|
59 |
+
def is_torch_available():
|
60 |
+
return _torch_available
|
61 |
+
|
62 |
+
|
63 |
+
def is_transformers_available():
|
64 |
+
return _transformers_available
|
65 |
+
|
66 |
+
|
67 |
+
class OptionalDependencyNotAvailable(BaseException):
|
68 |
+
"""An error indicating that an optional dependency of Diffusers was not found in the environment."""
|
69 |
+
|
70 |
+
|
71 |
+
class _LazyModule(ModuleType):
|
72 |
+
"""
|
73 |
+
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
74 |
+
"""
|
75 |
+
|
76 |
+
# Very heavily inspired by optuna.integration._IntegrationModule
|
77 |
+
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
|
78 |
+
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
|
79 |
+
super().__init__(name)
|
80 |
+
self._modules = set(import_structure.keys())
|
81 |
+
self._class_to_module = {}
|
82 |
+
for key, values in import_structure.items():
|
83 |
+
for value in values:
|
84 |
+
self._class_to_module[value] = key
|
85 |
+
# Needed for autocompletion in an IDE
|
86 |
+
self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
|
87 |
+
self.__file__ = module_file
|
88 |
+
self.__spec__ = module_spec
|
89 |
+
self.__path__ = [os.path.dirname(module_file)]
|
90 |
+
self._objects = {} if extra_objects is None else extra_objects
|
91 |
+
self._name = name
|
92 |
+
self._import_structure = import_structure
|
93 |
+
|
94 |
+
# Needed for autocompletion in an IDE
|
95 |
+
def __dir__(self):
|
96 |
+
result = super().__dir__()
|
97 |
+
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
|
98 |
+
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
|
99 |
+
for attr in self.__all__:
|
100 |
+
if attr not in result:
|
101 |
+
result.append(attr)
|
102 |
+
return result
|
103 |
+
|
104 |
+
def __getattr__(self, name: str) -> Any:
|
105 |
+
if name in self._objects:
|
106 |
+
return self._objects[name]
|
107 |
+
if name in self._modules:
|
108 |
+
value = self._get_module(name)
|
109 |
+
elif name in self._class_to_module.keys():
|
110 |
+
module = self._get_module(self._class_to_module[name])
|
111 |
+
value = getattr(module, name)
|
112 |
+
else:
|
113 |
+
raise AttributeError(f"module {self.__name__} has no attribute {name}")
|
114 |
+
|
115 |
+
setattr(self, name, value)
|
116 |
+
return value
|
117 |
+
|
118 |
+
def _get_module(self, module_name: str):
|
119 |
+
try:
|
120 |
+
return importlib.import_module("." + module_name, self.__name__)
|
121 |
+
except Exception as e:
|
122 |
+
raise RuntimeError(
|
123 |
+
f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
|
124 |
+
f" traceback):\n{e}"
|
125 |
+
) from e
|
126 |
+
|
127 |
+
def __reduce__(self):
|
128 |
+
return (self.__class__, (self._name, self.__file__, self._import_structure))
|
image_gen_aux/utils/loading_utils.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import Callable, Union
|
17 |
+
|
18 |
+
import PIL.Image
|
19 |
+
import PIL.ImageOps
|
20 |
+
import requests
|
21 |
+
|
22 |
+
|
23 |
+
def load_image(
|
24 |
+
image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None
|
25 |
+
) -> PIL.Image.Image:
|
26 |
+
"""
|
27 |
+
Loads `image` to a PIL Image.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
image (`str` or `PIL.Image.Image`):
|
31 |
+
The image to convert to the PIL Image format.
|
32 |
+
convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional):
|
33 |
+
A conversion method to apply to the image after loading it. When set to `None` the image will be converted
|
34 |
+
"RGB".
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
`PIL.Image.Image`:
|
38 |
+
A PIL Image.
|
39 |
+
"""
|
40 |
+
if isinstance(image, str):
|
41 |
+
if image.startswith("http://") or image.startswith("https://"):
|
42 |
+
image = PIL.Image.open(requests.get(image, stream=True).raw)
|
43 |
+
elif os.path.isfile(image):
|
44 |
+
image = PIL.Image.open(image)
|
45 |
+
else:
|
46 |
+
raise ValueError(
|
47 |
+
f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path."
|
48 |
+
)
|
49 |
+
elif isinstance(image, PIL.Image.Image):
|
50 |
+
image = image
|
51 |
+
else:
|
52 |
+
raise ValueError(
|
53 |
+
"Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image."
|
54 |
+
)
|
55 |
+
|
56 |
+
image = PIL.ImageOps.exif_transpose(image)
|
57 |
+
|
58 |
+
if convert_method is not None:
|
59 |
+
image = convert_method(image)
|
60 |
+
else:
|
61 |
+
image = image.convert("RGB")
|
62 |
+
|
63 |
+
return image
|
image_gen_aux/utils/logging.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 Optuna, Hugging Face
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Logging utilities."""
|
16 |
+
|
17 |
+
import logging
|
18 |
+
import os
|
19 |
+
import sys
|
20 |
+
import threading
|
21 |
+
from logging import (
|
22 |
+
CRITICAL, # NOQA
|
23 |
+
DEBUG, # NOQA
|
24 |
+
ERROR, # NOQA
|
25 |
+
FATAL, # NOQA
|
26 |
+
INFO, # NOQA
|
27 |
+
NOTSET, # NOQA
|
28 |
+
WARN, # NOQA
|
29 |
+
WARNING, # NOQA
|
30 |
+
)
|
31 |
+
from typing import Dict, Optional
|
32 |
+
|
33 |
+
from tqdm import auto as tqdm_lib
|
34 |
+
|
35 |
+
|
36 |
+
_lock = threading.Lock()
|
37 |
+
_default_handler: Optional[logging.Handler] = None
|
38 |
+
|
39 |
+
log_levels = {
|
40 |
+
"debug": logging.DEBUG,
|
41 |
+
"info": logging.INFO,
|
42 |
+
"warning": logging.WARNING,
|
43 |
+
"error": logging.ERROR,
|
44 |
+
"critical": logging.CRITICAL,
|
45 |
+
}
|
46 |
+
|
47 |
+
_default_log_level = logging.WARNING
|
48 |
+
|
49 |
+
_tqdm_active = True
|
50 |
+
|
51 |
+
|
52 |
+
def _get_default_logging_level() -> int:
|
53 |
+
"""
|
54 |
+
If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
|
55 |
+
not - fall back to `_default_log_level`
|
56 |
+
"""
|
57 |
+
env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None)
|
58 |
+
if env_level_str:
|
59 |
+
if env_level_str in log_levels:
|
60 |
+
return log_levels[env_level_str]
|
61 |
+
else:
|
62 |
+
logging.getLogger().warning(
|
63 |
+
f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
|
64 |
+
f"has to be one of: { ', '.join(log_levels.keys()) }"
|
65 |
+
)
|
66 |
+
return _default_log_level
|
67 |
+
|
68 |
+
|
69 |
+
def _get_library_name() -> str:
|
70 |
+
return __name__.split(".")[0]
|
71 |
+
|
72 |
+
|
73 |
+
def _get_library_root_logger() -> logging.Logger:
|
74 |
+
return logging.getLogger(_get_library_name())
|
75 |
+
|
76 |
+
|
77 |
+
def _configure_library_root_logger() -> None:
|
78 |
+
global _default_handler
|
79 |
+
|
80 |
+
with _lock:
|
81 |
+
if _default_handler:
|
82 |
+
# This library has already configured the library root logger.
|
83 |
+
return
|
84 |
+
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
|
85 |
+
|
86 |
+
if sys.stderr: # only if sys.stderr exists, e.g. when not using pythonw in windows
|
87 |
+
_default_handler.flush = sys.stderr.flush
|
88 |
+
|
89 |
+
# Apply our default configuration to the library root logger.
|
90 |
+
library_root_logger = _get_library_root_logger()
|
91 |
+
library_root_logger.addHandler(_default_handler)
|
92 |
+
library_root_logger.setLevel(_get_default_logging_level())
|
93 |
+
library_root_logger.propagate = False
|
94 |
+
|
95 |
+
|
96 |
+
def _reset_library_root_logger() -> None:
|
97 |
+
global _default_handler
|
98 |
+
|
99 |
+
with _lock:
|
100 |
+
if not _default_handler:
|
101 |
+
return
|
102 |
+
|
103 |
+
library_root_logger = _get_library_root_logger()
|
104 |
+
library_root_logger.removeHandler(_default_handler)
|
105 |
+
library_root_logger.setLevel(logging.NOTSET)
|
106 |
+
_default_handler = None
|
107 |
+
|
108 |
+
|
109 |
+
def get_log_levels_dict() -> Dict[str, int]:
|
110 |
+
return log_levels
|
111 |
+
|
112 |
+
|
113 |
+
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
114 |
+
"""
|
115 |
+
Return a logger with the specified name.
|
116 |
+
|
117 |
+
This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
|
118 |
+
"""
|
119 |
+
|
120 |
+
if name is None:
|
121 |
+
name = _get_library_name()
|
122 |
+
|
123 |
+
_configure_library_root_logger()
|
124 |
+
return logging.getLogger(name)
|
125 |
+
|
126 |
+
|
127 |
+
def get_verbosity() -> int:
|
128 |
+
"""
|
129 |
+
Return the current level for the 🤗 Diffusers' root logger as an `int`.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
`int`:
|
133 |
+
Logging level integers which can be one of:
|
134 |
+
|
135 |
+
- `50`: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
|
136 |
+
- `40`: `diffusers.logging.ERROR`
|
137 |
+
- `30`: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
|
138 |
+
- `20`: `diffusers.logging.INFO`
|
139 |
+
- `10`: `diffusers.logging.DEBUG`
|
140 |
+
|
141 |
+
"""
|
142 |
+
|
143 |
+
_configure_library_root_logger()
|
144 |
+
return _get_library_root_logger().getEffectiveLevel()
|
145 |
+
|
146 |
+
|
147 |
+
def set_verbosity(verbosity: int) -> None:
|
148 |
+
"""
|
149 |
+
Set the verbosity level for the 🤗 Diffusers' root logger.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
verbosity (`int`):
|
153 |
+
Logging level which can be one of:
|
154 |
+
|
155 |
+
- `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
|
156 |
+
- `diffusers.logging.ERROR`
|
157 |
+
- `diffusers.logging.WARNING` or `diffusers.logging.WARN`
|
158 |
+
- `diffusers.logging.INFO`
|
159 |
+
- `diffusers.logging.DEBUG`
|
160 |
+
"""
|
161 |
+
|
162 |
+
_configure_library_root_logger()
|
163 |
+
_get_library_root_logger().setLevel(verbosity)
|
164 |
+
|
165 |
+
|
166 |
+
def set_verbosity_info() -> None:
|
167 |
+
"""Set the verbosity to the `INFO` level."""
|
168 |
+
return set_verbosity(INFO)
|
169 |
+
|
170 |
+
|
171 |
+
def set_verbosity_warning() -> None:
|
172 |
+
"""Set the verbosity to the `WARNING` level."""
|
173 |
+
return set_verbosity(WARNING)
|
174 |
+
|
175 |
+
|
176 |
+
def set_verbosity_debug() -> None:
|
177 |
+
"""Set the verbosity to the `DEBUG` level."""
|
178 |
+
return set_verbosity(DEBUG)
|
179 |
+
|
180 |
+
|
181 |
+
def set_verbosity_error() -> None:
|
182 |
+
"""Set the verbosity to the `ERROR` level."""
|
183 |
+
return set_verbosity(ERROR)
|
184 |
+
|
185 |
+
|
186 |
+
def disable_default_handler() -> None:
|
187 |
+
"""Disable the default handler of the 🤗 Diffusers' root logger."""
|
188 |
+
|
189 |
+
_configure_library_root_logger()
|
190 |
+
|
191 |
+
assert _default_handler is not None
|
192 |
+
_get_library_root_logger().removeHandler(_default_handler)
|
193 |
+
|
194 |
+
|
195 |
+
def enable_default_handler() -> None:
|
196 |
+
"""Enable the default handler of the 🤗 Diffusers' root logger."""
|
197 |
+
|
198 |
+
_configure_library_root_logger()
|
199 |
+
|
200 |
+
assert _default_handler is not None
|
201 |
+
_get_library_root_logger().addHandler(_default_handler)
|
202 |
+
|
203 |
+
|
204 |
+
def add_handler(handler: logging.Handler) -> None:
|
205 |
+
"""adds a handler to the HuggingFace Diffusers' root logger."""
|
206 |
+
|
207 |
+
_configure_library_root_logger()
|
208 |
+
|
209 |
+
assert handler is not None
|
210 |
+
_get_library_root_logger().addHandler(handler)
|
211 |
+
|
212 |
+
|
213 |
+
def remove_handler(handler: logging.Handler) -> None:
|
214 |
+
"""removes given handler from the HuggingFace Diffusers' root logger."""
|
215 |
+
|
216 |
+
_configure_library_root_logger()
|
217 |
+
|
218 |
+
assert handler is not None and handler in _get_library_root_logger().handlers
|
219 |
+
_get_library_root_logger().removeHandler(handler)
|
220 |
+
|
221 |
+
|
222 |
+
def disable_propagation() -> None:
|
223 |
+
"""
|
224 |
+
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
|
225 |
+
"""
|
226 |
+
|
227 |
+
_configure_library_root_logger()
|
228 |
+
_get_library_root_logger().propagate = False
|
229 |
+
|
230 |
+
|
231 |
+
def enable_propagation() -> None:
|
232 |
+
"""
|
233 |
+
Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
|
234 |
+
double logging if the root logger has been configured.
|
235 |
+
"""
|
236 |
+
|
237 |
+
_configure_library_root_logger()
|
238 |
+
_get_library_root_logger().propagate = True
|
239 |
+
|
240 |
+
|
241 |
+
def enable_explicit_format() -> None:
|
242 |
+
"""
|
243 |
+
Enable explicit formatting for every 🤗 Diffusers' logger. The explicit formatter is as follows:
|
244 |
+
```
|
245 |
+
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
|
246 |
+
```
|
247 |
+
All handlers currently bound to the root logger are affected by this method.
|
248 |
+
"""
|
249 |
+
handlers = _get_library_root_logger().handlers
|
250 |
+
|
251 |
+
for handler in handlers:
|
252 |
+
formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
|
253 |
+
handler.setFormatter(formatter)
|
254 |
+
|
255 |
+
|
256 |
+
def reset_format() -> None:
|
257 |
+
"""
|
258 |
+
Resets the formatting for 🤗 Diffusers' loggers.
|
259 |
+
|
260 |
+
All handlers currently bound to the root logger are affected by this method.
|
261 |
+
"""
|
262 |
+
handlers = _get_library_root_logger().handlers
|
263 |
+
|
264 |
+
for handler in handlers:
|
265 |
+
handler.setFormatter(None)
|
266 |
+
|
267 |
+
|
268 |
+
def warning_advice(self, *args, **kwargs) -> None:
|
269 |
+
"""
|
270 |
+
This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
|
271 |
+
warning will not be printed
|
272 |
+
"""
|
273 |
+
no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False)
|
274 |
+
if no_advisory_warnings:
|
275 |
+
return
|
276 |
+
self.warning(*args, **kwargs)
|
277 |
+
|
278 |
+
|
279 |
+
logging.Logger.warning_advice = warning_advice
|
280 |
+
|
281 |
+
|
282 |
+
class EmptyTqdm:
|
283 |
+
"""Dummy tqdm which doesn't do anything."""
|
284 |
+
|
285 |
+
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
|
286 |
+
self._iterator = args[0] if args else None
|
287 |
+
|
288 |
+
def __iter__(self):
|
289 |
+
return iter(self._iterator)
|
290 |
+
|
291 |
+
def __getattr__(self, _):
|
292 |
+
"""Return empty function."""
|
293 |
+
|
294 |
+
def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
|
295 |
+
return
|
296 |
+
|
297 |
+
return empty_fn
|
298 |
+
|
299 |
+
def __enter__(self):
|
300 |
+
return self
|
301 |
+
|
302 |
+
def __exit__(self, type_, value, traceback):
|
303 |
+
return
|
304 |
+
|
305 |
+
|
306 |
+
class _tqdm_cls:
|
307 |
+
def __call__(self, *args, **kwargs):
|
308 |
+
if _tqdm_active:
|
309 |
+
return tqdm_lib.tqdm(*args, **kwargs)
|
310 |
+
else:
|
311 |
+
return EmptyTqdm(*args, **kwargs)
|
312 |
+
|
313 |
+
def set_lock(self, *args, **kwargs):
|
314 |
+
self._lock = None
|
315 |
+
if _tqdm_active:
|
316 |
+
return tqdm_lib.tqdm.set_lock(*args, **kwargs)
|
317 |
+
|
318 |
+
def get_lock(self):
|
319 |
+
if _tqdm_active:
|
320 |
+
return tqdm_lib.tqdm.get_lock()
|
321 |
+
|
322 |
+
|
323 |
+
tqdm = _tqdm_cls()
|
324 |
+
|
325 |
+
|
326 |
+
def is_progress_bar_enabled() -> bool:
|
327 |
+
"""Return a boolean indicating whether tqdm progress bars are enabled."""
|
328 |
+
global _tqdm_active
|
329 |
+
return bool(_tqdm_active)
|
330 |
+
|
331 |
+
|
332 |
+
def enable_progress_bar() -> None:
|
333 |
+
"""Enable tqdm progress bar."""
|
334 |
+
global _tqdm_active
|
335 |
+
_tqdm_active = True
|
336 |
+
|
337 |
+
|
338 |
+
def disable_progress_bar() -> None:
|
339 |
+
"""Disable tqdm progress bar."""
|
340 |
+
global _tqdm_active
|
341 |
+
_tqdm_active = False
|
image_gen_aux/utils/model_utils.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from huggingface_hub import hf_hub_download, model_info
|
4 |
+
|
5 |
+
|
6 |
+
def get_model_path(pretrained_model_or_path, filename=None, subfolder=None):
|
7 |
+
"""
|
8 |
+
Retrieves the path to the model file.
|
9 |
+
|
10 |
+
If `pretrained_model_or_path` is a file, it returns the path directly.
|
11 |
+
Otherwise, it attempts to find a `.safetensors` file associated with the given model path.
|
12 |
+
If no `.safetensors` file is found, it raises a `FileNotFoundError`.
|
13 |
+
|
14 |
+
Parameters:
|
15 |
+
- pretrained_model_or_path (str): Path to the pretrained model or directory containing the model.
|
16 |
+
- filename (str, optional): Specific filename to load. If not provided, the function will search for a `.safetensors` file.
|
17 |
+
- subfolder (str, optional): Subfolder within the model directory to look for the file.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
- str: Path to the model file.
|
21 |
+
|
22 |
+
Raises:
|
23 |
+
- FileNotFoundError: If no `.safetensors` file is found when `filename` is not provided.
|
24 |
+
"""
|
25 |
+
if os.path.isfile(pretrained_model_or_path):
|
26 |
+
return pretrained_model_or_path
|
27 |
+
|
28 |
+
if filename is None:
|
29 |
+
# If the filename is not passed, we only try to load a safetensor
|
30 |
+
info = model_info(pretrained_model_or_path)
|
31 |
+
filename = next(
|
32 |
+
(sibling.rfilename for sibling in info.siblings if sibling.rfilename.endswith(".safetensors")), None
|
33 |
+
)
|
34 |
+
if filename is None:
|
35 |
+
raise FileNotFoundError("No safetensors checkpoint found.")
|
36 |
+
|
37 |
+
return hf_hub_download(pretrained_model_or_path, filename, subfolder=subfolder)
|
image_gen_aux/utils/tiling_utils.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def create_gradient_mask(shape: Tuple, feather: int, device="cpu") -> torch.Tensor:
|
7 |
+
"""
|
8 |
+
Create a gradient mask for smooth blending of tiles.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
shape (tuple): Shape of the mask (batch, channels, height, width)
|
12 |
+
feather (int): Width of the feathered edge
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
torch.Tensor: Gradient mask
|
16 |
+
"""
|
17 |
+
mask = torch.ones(shape).to(device)
|
18 |
+
_, _, h, w = shape
|
19 |
+
for feather_step in range(feather):
|
20 |
+
factor = (feather_step + 1) / feather
|
21 |
+
mask[:, :, feather_step, :] *= factor
|
22 |
+
mask[:, :, h - 1 - feather_step, :] *= factor
|
23 |
+
mask[:, :, :, feather_step] *= factor
|
24 |
+
mask[:, :, :, w - 1 - feather_step] *= factor
|
25 |
+
return mask
|
26 |
+
|
27 |
+
|
28 |
+
def tiled_upscale(
|
29 |
+
samples: torch.Tensor,
|
30 |
+
function: Callable,
|
31 |
+
scale: int,
|
32 |
+
tile_width: int = 512,
|
33 |
+
tile_height: int = 512,
|
34 |
+
overlap: int = 8,
|
35 |
+
) -> torch.Tensor:
|
36 |
+
"""
|
37 |
+
Apply a scaling function to image samples in a tiled manner.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
samples (torch.Tensor): Input tensor of shape (batch_size, channels, height, width)
|
41 |
+
function (Callable): The scaling function to apply to each tile
|
42 |
+
scale (int): Factor by which to upscale the image
|
43 |
+
tile_width (int): Width of each tile
|
44 |
+
tile_height (int): Height of each tile
|
45 |
+
overlap (int): Overlap between tiles
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
torch.Tensor: Upscaled and processed output tensor
|
49 |
+
"""
|
50 |
+
_batch, _channels, height, width = samples.shape
|
51 |
+
out_height, out_width = round(height * scale), round(width * scale)
|
52 |
+
output_device = samples.device
|
53 |
+
|
54 |
+
# Initialize output tensors
|
55 |
+
output = torch.empty((1, 3, out_height, out_width), device=output_device)
|
56 |
+
out = torch.zeros((1, 3, out_height, out_width), device=output_device)
|
57 |
+
out_div = torch.zeros_like(output)
|
58 |
+
|
59 |
+
# Process the image in tiles
|
60 |
+
for y in range(0, height, tile_height - overlap):
|
61 |
+
for x in range(0, width, tile_width - overlap):
|
62 |
+
# Ensure we don't go out of bounds
|
63 |
+
x_end = min(x + tile_width, width)
|
64 |
+
y_end = min(y + tile_height, height)
|
65 |
+
x = max(0, x_end - tile_width)
|
66 |
+
y = max(0, y_end - tile_height)
|
67 |
+
|
68 |
+
# Extract and process the tile
|
69 |
+
tile = samples[:, :, y:y_end, x:x_end]
|
70 |
+
processed_tile = function(tile).to(output_device)
|
71 |
+
|
72 |
+
# Calculate the position in the output tensor
|
73 |
+
out_y, out_x = round(y * scale), round(x * scale)
|
74 |
+
out_h, out_w = processed_tile.shape[2:]
|
75 |
+
|
76 |
+
# Create a feathered mask for smooth blending
|
77 |
+
mask = create_gradient_mask(processed_tile.shape, overlap * scale, device=output_device)
|
78 |
+
|
79 |
+
# Add the processed tile to the output
|
80 |
+
out[:, :, out_y : out_y + out_h, out_x : out_x + out_w] += processed_tile * mask
|
81 |
+
out_div[:, :, out_y : out_y + out_h, out_x : out_x + out_w] += mask
|
82 |
+
|
83 |
+
# Normalize the output
|
84 |
+
output = out / out_div
|
85 |
+
|
86 |
+
return output
|
requirements.txt
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
einops
|
2 |
-
diffusers==0.31.0
|
3 |
transformers
|
4 |
accelerate
|
5 |
gradio
|
@@ -8,4 +7,5 @@ spaces
|
|
8 |
pillow
|
9 |
peft
|
10 |
openai
|
11 |
-
torch
|
|
|
|
1 |
einops
|
|
|
2 |
transformers
|
3 |
accelerate
|
4 |
gradio
|
|
|
7 |
pillow
|
8 |
peft
|
9 |
openai
|
10 |
+
torch
|
11 |
+
git+https://github.com/huggingface/diffusers
|