Spaces:
Runtime error
Runtime error
krunakuamar
commited on
Commit
•
252e766
1
Parent(s):
762cf51
Upload 75 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +44 -0
- app.py +316 -0
- app/big-lama.pt +3 -0
- app/u2net.onnx +3 -0
- app/yolov8x-seg.pt +3 -0
- lama_cleaner/__init__.py +11 -0
- lama_cleaner/__pycache__/__init__.cpython-38.pyc +0 -0
- lama_cleaner/__pycache__/const.cpython-38.pyc +0 -0
- lama_cleaner/__pycache__/helper.cpython-38.pyc +0 -0
- lama_cleaner/__pycache__/interactive_seg.cpython-38.pyc +0 -0
- lama_cleaner/__pycache__/model_manager.cpython-38.pyc +0 -0
- lama_cleaner/__pycache__/parse_args.cpython-38.pyc +0 -0
- lama_cleaner/__pycache__/runtime.cpython-38.pyc +0 -0
- lama_cleaner/__pycache__/schema.cpython-38.pyc +0 -0
- lama_cleaner/__pycache__/server2.cpython-38.pyc +0 -0
- lama_cleaner/benchmark.py +109 -0
- lama_cleaner/const.py +68 -0
- lama_cleaner/file_manager/__init__.py +1 -0
- lama_cleaner/file_manager/__pycache__/__init__.cpython-38.pyc +0 -0
- lama_cleaner/file_manager/__pycache__/file_manager.cpython-38.pyc +0 -0
- lama_cleaner/file_manager/__pycache__/storage_backends.cpython-38.pyc +0 -0
- lama_cleaner/file_manager/__pycache__/utils.cpython-38.pyc +0 -0
- lama_cleaner/file_manager/file_manager.py +252 -0
- lama_cleaner/file_manager/storage_backends.py +46 -0
- lama_cleaner/file_manager/utils.py +66 -0
- lama_cleaner/helper.py +218 -0
- lama_cleaner/interactive_seg.py +202 -0
- lama_cleaner/model/__init__.py +0 -0
- lama_cleaner/model/__pycache__/__init__.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/base.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/ddim_sampler.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/fcf.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/lama.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/ldm.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/manga.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/mat.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/opencv2.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/paint_by_example.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/plms_sampler.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/sd.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/utils.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/zits.cpython-38.pyc +0 -0
- lama_cleaner/model/base.py +247 -0
- lama_cleaner/model/ddim_sampler.py +192 -0
- lama_cleaner/model/fcf.py +1212 -0
- lama_cleaner/model/lama.py +61 -0
- lama_cleaner/model/ldm.py +310 -0
- lama_cleaner/model/manga.py +130 -0
- lama_cleaner/model/mat.py +1444 -0
- lama_cleaner/model/opencv2.py +25 -0
Dockerfile
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.8
|
2 |
+
|
3 |
+
RUN mkdir /app
|
4 |
+
RUN mkdir /.cache/
|
5 |
+
RUN mkdir /.cache/matplotlib
|
6 |
+
RUN mkdir /.cache/huggingface
|
7 |
+
RUN mkdir /.cache/huggingface/hub/
|
8 |
+
RUN mkdir /.cache/torch/
|
9 |
+
RUN mkdir /.config
|
10 |
+
RUN mkdir /.config/matplotlib/
|
11 |
+
|
12 |
+
RUN chmod -R 777 /.cache
|
13 |
+
RUN chmod -R 777 /.cache/matplotlib
|
14 |
+
RUN chmod -R 777 /.cache/huggingface/hub
|
15 |
+
RUN chmod -R 777 /.cache/torch
|
16 |
+
RUN chmod -R 777 /.config/
|
17 |
+
RUN chmod -R 777 /.config/matplotlib
|
18 |
+
RUN chmod -R 777 /app
|
19 |
+
|
20 |
+
|
21 |
+
COPY lama_cleaner ./lama_cleaner
|
22 |
+
COPY ./app.py ./app.py
|
23 |
+
|
24 |
+
|
25 |
+
COPY app/yolov8x-seg.pt /app
|
26 |
+
COPY big-lama.pt /app
|
27 |
+
# COPY clickseg_pplnet.pt /app
|
28 |
+
COPY u2net.onnx /app
|
29 |
+
COPY u2net.onnx /tmp
|
30 |
+
|
31 |
+
RUN chmod -R a+r /app/yolov8x-seg.pt
|
32 |
+
RUN chmod -R a+r /app/big-lama.pt
|
33 |
+
#RUN chmod -R a+r /app/clickseg_pplnet.pt
|
34 |
+
RUN chmod -R a+r /app/u2net.onnx
|
35 |
+
RUN chmod -R a+r /tmp/u2net.onnx
|
36 |
+
|
37 |
+
|
38 |
+
COPY ./requirements.txt ./requirements.txt
|
39 |
+
RUN pip install -r ./requirements.txt
|
40 |
+
|
41 |
+
RUN --mount=type=secret,id=SECRET,mode=0444,required=true \
|
42 |
+
git clone $(cat /run/secrets/SECRET)
|
43 |
+
|
44 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import imghdr
|
3 |
+
import os
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from ultralytics import YOLO
|
9 |
+
from ultralytics.yolo.utils.ops import scale_image
|
10 |
+
import asyncio
|
11 |
+
from fastapi import FastAPI, File, UploadFile, Request, Response
|
12 |
+
from fastapi.responses import JSONResponse
|
13 |
+
from fastapi.middleware.cors import CORSMiddleware
|
14 |
+
import uvicorn
|
15 |
+
# from mangum import Mangum
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
|
18 |
+
import lama_cleaner.server2 as server
|
19 |
+
from lama_cleaner.helper import (
|
20 |
+
load_img,
|
21 |
+
)
|
22 |
+
|
23 |
+
# os.environ["TRANSFORMERS_CACHE"] = "/path/to/writable/directory"
|
24 |
+
|
25 |
+
app = FastAPI()
|
26 |
+
|
27 |
+
# handler = Mangum(app)
|
28 |
+
origins = ["*"]
|
29 |
+
|
30 |
+
app.add_middleware(
|
31 |
+
CORSMiddleware,
|
32 |
+
allow_origins=origins,
|
33 |
+
allow_credentials=True,
|
34 |
+
allow_methods=["*"],
|
35 |
+
allow_headers=["*"],
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
40 |
+
"""
|
41 |
+
Args:
|
42 |
+
image_numpy: numpy image
|
43 |
+
ext: image extension
|
44 |
+
Returns:
|
45 |
+
image bytes
|
46 |
+
"""
|
47 |
+
data = cv2.imencode(
|
48 |
+
f".{ext}",
|
49 |
+
image_numpy,
|
50 |
+
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
51 |
+
)[1].tobytes()
|
52 |
+
return data
|
53 |
+
|
54 |
+
|
55 |
+
def get_image_ext(img_bytes):
|
56 |
+
"""
|
57 |
+
Args:
|
58 |
+
img_bytes: image bytes
|
59 |
+
Returns:
|
60 |
+
image extension
|
61 |
+
"""
|
62 |
+
if not img_bytes:
|
63 |
+
raise ValueError("Empty input")
|
64 |
+
header = img_bytes[:32]
|
65 |
+
w = imghdr.what("", header)
|
66 |
+
if w is None:
|
67 |
+
w = "jpeg"
|
68 |
+
return w
|
69 |
+
|
70 |
+
|
71 |
+
def predict_on_image(model, img, conf, retina_masks):
|
72 |
+
"""
|
73 |
+
Args:
|
74 |
+
model: YOLOv8 model
|
75 |
+
img: image (C, H, W)
|
76 |
+
conf: confidence threshold
|
77 |
+
retina_masks: use retina masks or not
|
78 |
+
Returns:
|
79 |
+
boxes: box with xyxy format, (N, 4)
|
80 |
+
masks: masks, (N, H, W)
|
81 |
+
cls: class of masks, (N, )
|
82 |
+
probs: confidence score, (N, 1)
|
83 |
+
"""
|
84 |
+
with torch.no_grad():
|
85 |
+
result = model(img, conf=conf, retina_masks=retina_masks, scale=1)[0]
|
86 |
+
|
87 |
+
boxes, masks, cls, probs = None, None, None, None
|
88 |
+
|
89 |
+
if result.boxes.cls.size(0) > 0:
|
90 |
+
# detection
|
91 |
+
cls = result.boxes.cls.cpu().numpy().astype(np.int32)
|
92 |
+
probs = result.boxes.conf.cpu().numpy() # confidence score, (N, 1)
|
93 |
+
boxes = result.boxes.xyxy.cpu().numpy() # box with xyxy format, (N, 4)
|
94 |
+
|
95 |
+
# segmentation
|
96 |
+
masks = result.masks.masks.cpu().numpy() # masks, (N, H, W)
|
97 |
+
masks = np.transpose(masks, (1, 2, 0)) # masks, (H, W, N)
|
98 |
+
# rescale masks to original image
|
99 |
+
masks = scale_image(masks.shape[:2], masks, result.masks.orig_shape)
|
100 |
+
masks = np.transpose(masks, (2, 0, 1)) # masks, (N, H, W)
|
101 |
+
|
102 |
+
return boxes, masks, cls, probs
|
103 |
+
|
104 |
+
|
105 |
+
def overlay(image, mask, color, alpha, id, resize=None):
|
106 |
+
"""Overlays a binary mask on an image.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
image: Image to be overlayed on.
|
110 |
+
mask: Binary mask to overlay.
|
111 |
+
color: Color to use for the mask.
|
112 |
+
alpha: Opacity of the mask.
|
113 |
+
id: id of the mask
|
114 |
+
resize: Resize the image to this size. If None, no resizing is performed.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
The overlayed image.
|
118 |
+
"""
|
119 |
+
color = color[::-1]
|
120 |
+
colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0)
|
121 |
+
colored_mask = np.moveaxis(colored_mask, 0, -1)
|
122 |
+
masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
|
123 |
+
image_overlay = masked.filled()
|
124 |
+
|
125 |
+
imgray = cv2.cvtColor(image_overlay, cv2.COLOR_BGR2GRAY)
|
126 |
+
|
127 |
+
contour_thickness = 8
|
128 |
+
_, thresh = cv2.threshold(imgray, 255, 255, 255)
|
129 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
|
130 |
+
imgray = cv2.cvtColor(imgray, cv2.COLOR_GRAY2BGR)
|
131 |
+
imgray = cv2.drawContours(imgray, contours, -1, (255, 255, 255), contour_thickness)
|
132 |
+
|
133 |
+
imgray = np.where(imgray.any(-1, keepdims=True), (46, 36, 225), 0)
|
134 |
+
|
135 |
+
if resize is not None:
|
136 |
+
image = cv2.resize(image.transpose(1, 2, 0), resize)
|
137 |
+
image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize)
|
138 |
+
|
139 |
+
return imgray
|
140 |
+
|
141 |
+
|
142 |
+
async def process_mask(idx, mask_i, boxes, probs, yolo_model, blank_image, cls):
|
143 |
+
"""Process the mask of the image.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
idx: index of the mask
|
147 |
+
mask_i: mask of the image
|
148 |
+
boxes: box with xyxy format, (N, 4)
|
149 |
+
probs: confidence score, (N, 1)
|
150 |
+
yolo_model: YOLOv8 model
|
151 |
+
blank_image: blank image
|
152 |
+
cls: class of masks, (N, )
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
dictionary_seg: dictionary of the mask of the image
|
156 |
+
"""
|
157 |
+
dictionary_seg = {}
|
158 |
+
maskwith_back = overlay(blank_image, mask_i, color=(255, 155, 155), alpha=0.5, id=idx)
|
159 |
+
|
160 |
+
alpha = np.sum(maskwith_back, axis=-1) > 0
|
161 |
+
alpha = np.uint8(alpha * 255)
|
162 |
+
maskwith_back = np.dstack((maskwith_back, alpha))
|
163 |
+
|
164 |
+
imgencode = await asyncio.get_running_loop().run_in_executor(None, cv2.imencode, '.png', maskwith_back)
|
165 |
+
mask = base64.b64encode(imgencode[1]).decode('utf-8')
|
166 |
+
|
167 |
+
dictionary_seg["confi"] = f'{probs[idx] * 100:.2f}'
|
168 |
+
dictionary_seg["boxe"] = [int(item) for item in list(boxes[idx])]
|
169 |
+
dictionary_seg["mask"] = mask
|
170 |
+
dictionary_seg["cls"] = str(yolo_model.names[cls[idx]])
|
171 |
+
|
172 |
+
return dictionary_seg
|
173 |
+
|
174 |
+
|
175 |
+
@app.middleware("http")
|
176 |
+
async def check_auth_header(request: Request, call_next):
|
177 |
+
token = request.headers.get('Authorization')
|
178 |
+
if token != os.environ.get("SECRET"):
|
179 |
+
return JSONResponse(content={'error': 'Authorization header missing or incorrect.'}, status_code=403)
|
180 |
+
else:
|
181 |
+
response = await call_next(request)
|
182 |
+
return response
|
183 |
+
|
184 |
+
|
185 |
+
@app.post("/api/mask")
|
186 |
+
async def detect_mask(file: UploadFile = File()):
|
187 |
+
"""
|
188 |
+
Detects masks in an image uploaded via a POST request and returns a JSON response containing the details of the detected masks.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
None
|
192 |
+
|
193 |
+
Parameters:
|
194 |
+
- file: a file object containing the input image
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
A JSON response containing the details of the detected masks:
|
198 |
+
- code: 200 if objects were detected, 500 if no objects were detected
|
199 |
+
- msg: a message indicating whether objects were detected or not
|
200 |
+
- data: a list of dictionaries, where each dictionary contains the following keys:
|
201 |
+
- confi: the confidence level of the detected object
|
202 |
+
- boxe: a list containing the coordinates of the bounding box of the detected object
|
203 |
+
- mask: the mask of the detected object encoded in base64
|
204 |
+
- cls: the class of the detected object
|
205 |
+
|
206 |
+
Raises:
|
207 |
+
500: No objects detected
|
208 |
+
"""
|
209 |
+
file = await file.read()
|
210 |
+
|
211 |
+
img, _ = load_img(file)
|
212 |
+
|
213 |
+
# predict by YOLOv8
|
214 |
+
boxes, masks, cls, probs = predict_on_image(yolo_model, img, conf=0.55, retina_masks=True)
|
215 |
+
|
216 |
+
if boxes is None:
|
217 |
+
return {'code': 500, 'msg': 'No objects detected'}
|
218 |
+
|
219 |
+
# overlay masks on original image
|
220 |
+
blank_image = np.zeros(img.shape, dtype=np.uint8)
|
221 |
+
|
222 |
+
data = []
|
223 |
+
|
224 |
+
coroutines = [process_mask(idx, mask_i, boxes, probs, yolo_model, blank_image, cls) for idx, mask_i in
|
225 |
+
enumerate(masks)]
|
226 |
+
results = await asyncio.gather(*coroutines)
|
227 |
+
|
228 |
+
for result in results:
|
229 |
+
data.append(result)
|
230 |
+
|
231 |
+
return {'code': 200, 'msg': "object detected", 'data': data}
|
232 |
+
|
233 |
+
|
234 |
+
@app.post("/api/lama/paint")
|
235 |
+
async def paint(img: UploadFile = File(), mask: UploadFile = File()):
|
236 |
+
"""
|
237 |
+
Endpoint to process an image with a given mask using the server's process function.
|
238 |
+
|
239 |
+
Route: '/api/lama/paint'
|
240 |
+
Method: POST
|
241 |
+
|
242 |
+
Parameters:
|
243 |
+
img: The input image file (JPEG or PNG format).
|
244 |
+
mask: The mask file (JPEG or PNG format).
|
245 |
+
Returns:
|
246 |
+
A JSON object containing the processed image in base64 format under the "image" key.
|
247 |
+
"""
|
248 |
+
img = await img.read()
|
249 |
+
mask = await mask.read()
|
250 |
+
return {"image": server.process(img, mask)}
|
251 |
+
|
252 |
+
|
253 |
+
@app.post("/api/remove")
|
254 |
+
async def remove(img: UploadFile = File()):
|
255 |
+
x = await img.read()
|
256 |
+
return {"image": server.remove(x)}
|
257 |
+
|
258 |
+
@app.post("/api/lama/model")
|
259 |
+
def switch_model(new_name: str):
|
260 |
+
return server.switch_model(new_name)
|
261 |
+
|
262 |
+
|
263 |
+
@app.get("/api/lama/model")
|
264 |
+
def current_model():
|
265 |
+
return server.current_model()
|
266 |
+
|
267 |
+
|
268 |
+
@app.get("/api/lama/switchmode")
|
269 |
+
def get_is_disable_model_switch():
|
270 |
+
return server.get_is_disable_model_switch()
|
271 |
+
|
272 |
+
|
273 |
+
@app.on_event("startup")
|
274 |
+
def init_data():
|
275 |
+
model_device = "cpu"
|
276 |
+
global yolo_model
|
277 |
+
# TODO Update for local development
|
278 |
+
yolo_model = YOLO('yolov8x-seg.pt')
|
279 |
+
# yolo_model = YOLO('/app/yolov8x-seg.pt')
|
280 |
+
yolo_model.to(model_device)
|
281 |
+
print(f"YOLO model yolov8x-seg.pt loaded.")
|
282 |
+
server.initModel()
|
283 |
+
|
284 |
+
|
285 |
+
def create_app(args):
|
286 |
+
"""
|
287 |
+
Creates the FastAPI app and adds the endpoints.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
args: The arguments.
|
291 |
+
"""
|
292 |
+
uvicorn.run("app:app", host=args.host, port=args.port, reload=args.reload)
|
293 |
+
|
294 |
+
|
295 |
+
if __name__ == "__main__":
|
296 |
+
parser = ArgumentParser()
|
297 |
+
parser.add_argument('--model_name', type=str, default='lama', help='Model name')
|
298 |
+
parser.add_argument('--host', type=str, default="0.0.0.0")
|
299 |
+
parser.add_argument('--port', type=int, default=5000)
|
300 |
+
parser.add_argument('--reload', type=bool, default=True)
|
301 |
+
parser.add_argument('--model_device', type=str, default='cpu', help='Model device')
|
302 |
+
parser.add_argument('--disable_model_switch', type=bool, default=False, help='Disable model switch')
|
303 |
+
parser.add_argument('--gui', type=bool, default=False, help='Enable GUI')
|
304 |
+
parser.add_argument('--cpu_offload', type=bool, default=False, help='Enable CPU offload')
|
305 |
+
parser.add_argument('--disable_nsfw', type=bool, default=False, help='Disable NSFW')
|
306 |
+
parser.add_argument('--enable_xformers', type=bool, default=False, help='Enable xformers')
|
307 |
+
parser.add_argument('--hf_access_token', type=str, default='', help='Hugging Face access token')
|
308 |
+
parser.add_argument('--local_files_only', type=bool, default=False, help='Enable local files only')
|
309 |
+
parser.add_argument('--no_half', type=bool, default=False, help='Disable half')
|
310 |
+
parser.add_argument('--sd_cpu_textencoder', type=bool, default=False, help='Enable CPU text encoder')
|
311 |
+
parser.add_argument('--sd_disable_nsfw', type=bool, default=False, help='Disable NSFW')
|
312 |
+
parser.add_argument('--sd_enable_xformers', type=bool, default=False, help='Enable xformers')
|
313 |
+
parser.add_argument('--sd_run_local', type=bool, default=False, help='Enable local files only')
|
314 |
+
|
315 |
+
args = parser.parse_args()
|
316 |
+
create_app(args)
|
app/big-lama.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:344c77bbcb158f17dd143070d1e789f38a66c04202311ae3a258ef66667a9ea9
|
3 |
+
size 205669692
|
app/u2net.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d10d2f3bb75ae3b6d527c77944fc5e7dcd94b29809d47a739a7a728a912b491
|
3 |
+
size 175997641
|
app/yolov8x-seg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d63cbfa5764867c0066bedfa43cf2dcd90a412a1de44b2e238c43978a9d28ea6
|
3 |
+
size 144076467
|
lama_cleaner/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
warnings.simplefilter("ignore", UserWarning)
|
3 |
+
|
4 |
+
from lama_cleaner.parse_args import parse_args
|
5 |
+
|
6 |
+
def entry_point():
|
7 |
+
args = parse_args()
|
8 |
+
# To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
|
9 |
+
# https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
|
10 |
+
from lama_cleaner.server import main
|
11 |
+
main(args)
|
lama_cleaner/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (478 Bytes). View file
|
|
lama_cleaner/__pycache__/const.cpython-38.pyc
ADDED
Binary file (1.79 kB). View file
|
|
lama_cleaner/__pycache__/helper.cpython-38.pyc
ADDED
Binary file (5.43 kB). View file
|
|
lama_cleaner/__pycache__/interactive_seg.cpython-38.pyc
ADDED
Binary file (6.77 kB). View file
|
|
lama_cleaner/__pycache__/model_manager.cpython-38.pyc
ADDED
Binary file (2.27 kB). View file
|
|
lama_cleaner/__pycache__/parse_args.cpython-38.pyc
ADDED
Binary file (4.28 kB). View file
|
|
lama_cleaner/__pycache__/runtime.cpython-38.pyc
ADDED
Binary file (1.35 kB). View file
|
|
lama_cleaner/__pycache__/schema.cpython-38.pyc
ADDED
Binary file (2.42 kB). View file
|
|
lama_cleaner/__pycache__/server2.cpython-38.pyc
ADDED
Binary file (6.31 kB). View file
|
|
lama_cleaner/benchmark.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import nvidia_smi
|
9 |
+
import psutil
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from lama_cleaner.model_manager import ModelManager
|
13 |
+
from lama_cleaner.schema import Config, HDStrategy, SDSampler
|
14 |
+
|
15 |
+
try:
|
16 |
+
torch._C._jit_override_can_fuse_on_cpu(False)
|
17 |
+
torch._C._jit_override_can_fuse_on_gpu(False)
|
18 |
+
torch._C._jit_set_texpr_fuser_enabled(False)
|
19 |
+
torch._C._jit_set_nvfuser_enabled(False)
|
20 |
+
except:
|
21 |
+
pass
|
22 |
+
|
23 |
+
NUM_THREADS = str(4)
|
24 |
+
|
25 |
+
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
26 |
+
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
|
27 |
+
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
|
28 |
+
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
|
29 |
+
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
30 |
+
if os.environ.get("CACHE_DIR"):
|
31 |
+
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
|
32 |
+
|
33 |
+
|
34 |
+
def run_model(model, size):
|
35 |
+
# RGB
|
36 |
+
image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
|
37 |
+
mask = np.random.randint(0, 255, size).astype(np.uint8)
|
38 |
+
|
39 |
+
config = Config(
|
40 |
+
ldm_steps=2,
|
41 |
+
hd_strategy=HDStrategy.ORIGINAL,
|
42 |
+
hd_strategy_crop_margin=128,
|
43 |
+
hd_strategy_crop_trigger_size=128,
|
44 |
+
hd_strategy_resize_limit=128,
|
45 |
+
prompt="a fox is sitting on a bench",
|
46 |
+
sd_steps=5,
|
47 |
+
sd_sampler=SDSampler.ddim
|
48 |
+
)
|
49 |
+
model(image, mask, config)
|
50 |
+
|
51 |
+
|
52 |
+
def benchmark(model, times: int, empty_cache: bool):
|
53 |
+
sizes = [(512, 512)]
|
54 |
+
|
55 |
+
nvidia_smi.nvmlInit()
|
56 |
+
device_id = 0
|
57 |
+
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id)
|
58 |
+
|
59 |
+
def format(metrics):
|
60 |
+
return f"{np.mean(metrics):.2f} ± {np.std(metrics):.2f}"
|
61 |
+
|
62 |
+
process = psutil.Process(os.getpid())
|
63 |
+
# 每个 size 给出显存和内存占用的指标
|
64 |
+
for size in sizes:
|
65 |
+
torch.cuda.empty_cache()
|
66 |
+
time_metrics = []
|
67 |
+
cpu_metrics = []
|
68 |
+
memory_metrics = []
|
69 |
+
gpu_memory_metrics = []
|
70 |
+
for _ in range(times):
|
71 |
+
start = time.time()
|
72 |
+
run_model(model, size)
|
73 |
+
torch.cuda.synchronize()
|
74 |
+
|
75 |
+
# cpu_metrics.append(process.cpu_percent())
|
76 |
+
time_metrics.append((time.time() - start) * 1000)
|
77 |
+
memory_metrics.append(process.memory_info().rss / 1024 / 1024)
|
78 |
+
gpu_memory_metrics.append(nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024)
|
79 |
+
|
80 |
+
print(f"size: {size}".center(80, "-"))
|
81 |
+
# print(f"cpu: {format(cpu_metrics)}")
|
82 |
+
print(f"latency: {format(time_metrics)}ms")
|
83 |
+
print(f"memory: {format(memory_metrics)} MB")
|
84 |
+
print(f"gpu memory: {format(gpu_memory_metrics)} MB")
|
85 |
+
|
86 |
+
nvidia_smi.nvmlShutdown()
|
87 |
+
|
88 |
+
|
89 |
+
def get_args_parser():
|
90 |
+
parser = argparse.ArgumentParser()
|
91 |
+
parser.add_argument("--name")
|
92 |
+
parser.add_argument("--device", default="cuda", type=str)
|
93 |
+
parser.add_argument("--times", default=10, type=int)
|
94 |
+
parser.add_argument("--empty-cache", action="store_true")
|
95 |
+
return parser.parse_args()
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == "__main__":
|
99 |
+
args = get_args_parser()
|
100 |
+
device = torch.device(args.device)
|
101 |
+
model = ModelManager(
|
102 |
+
name=args.name,
|
103 |
+
device=device,
|
104 |
+
sd_run_local=True,
|
105 |
+
disable_nsfw=True,
|
106 |
+
sd_cpu_textencoder=True,
|
107 |
+
hf_access_token="123"
|
108 |
+
)
|
109 |
+
benchmark(model, args.times, args.empty_cache)
|
lama_cleaner/const.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
DEFAULT_MODEL = "lama"
|
4 |
+
AVAILABLE_MODELS = [
|
5 |
+
"lama",
|
6 |
+
"ldm",
|
7 |
+
"zits",
|
8 |
+
"mat",
|
9 |
+
"fcf",
|
10 |
+
"sd1.5",
|
11 |
+
"cv2",
|
12 |
+
"manga",
|
13 |
+
"sd2",
|
14 |
+
"paint_by_example"
|
15 |
+
]
|
16 |
+
|
17 |
+
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
|
18 |
+
DEFAULT_DEVICE = 'cuda'
|
19 |
+
|
20 |
+
NO_HALF_HELP = """
|
21 |
+
Using full precision model.
|
22 |
+
If your generate result is always black or green, use this argument. (sd/paint_by_exmaple)
|
23 |
+
"""
|
24 |
+
|
25 |
+
CPU_OFFLOAD_HELP = """
|
26 |
+
Offloads all models to CPU, significantly reducing vRAM usage. (sd/paint_by_example)
|
27 |
+
"""
|
28 |
+
|
29 |
+
DISABLE_NSFW_HELP = """
|
30 |
+
Disable NSFW checker. (sd/paint_by_example)
|
31 |
+
"""
|
32 |
+
|
33 |
+
SD_CPU_TEXTENCODER_HELP = """
|
34 |
+
Run Stable Diffusion text encoder model on CPU to save GPU memory.
|
35 |
+
"""
|
36 |
+
|
37 |
+
LOCAL_FILES_ONLY_HELP = """
|
38 |
+
Use local files only, not connect to Hugging Face server. (sd/paint_by_example)
|
39 |
+
"""
|
40 |
+
|
41 |
+
ENABLE_XFORMERS_HELP = """
|
42 |
+
Enable xFormers optimizations. Requires xformers package has been installed. See: https://github.com/facebookresearch/xformers (sd/paint_by_example)
|
43 |
+
"""
|
44 |
+
|
45 |
+
DEFAULT_MODEL_DIR = os.getenv(
|
46 |
+
"XDG_CACHE_HOME",
|
47 |
+
os.path.join(os.path.expanduser("~"), ".cache")
|
48 |
+
)
|
49 |
+
MODEL_DIR_HELP = """
|
50 |
+
Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache
|
51 |
+
"""
|
52 |
+
|
53 |
+
OUTPUT_DIR_HELP = """
|
54 |
+
Only required when --input is directory. Result images will be saved to output directory automatically.
|
55 |
+
"""
|
56 |
+
|
57 |
+
INPUT_HELP = """
|
58 |
+
If input is image, it will be loaded by default.
|
59 |
+
If input is directory, you can browse and select image in file manager.
|
60 |
+
"""
|
61 |
+
|
62 |
+
GUI_HELP = """
|
63 |
+
Launch Lama Cleaner as desktop app
|
64 |
+
"""
|
65 |
+
|
66 |
+
NO_GUI_AUTO_CLOSE_HELP = """
|
67 |
+
Prevent backend auto close after the GUI window closed.
|
68 |
+
"""
|
lama_cleaner/file_manager/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .file_manager import FileManager
|
lama_cleaner/file_manager/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (227 Bytes). View file
|
|
lama_cleaner/file_manager/__pycache__/file_manager.cpython-38.pyc
ADDED
Binary file (7.68 kB). View file
|
|
lama_cleaner/file_manager/__pycache__/storage_backends.cpython-38.pyc
ADDED
Binary file (2.01 kB). View file
|
|
lama_cleaner/file_manager/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (1.64 kB). View file
|
|
lama_cleaner/file_manager/file_manager.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/thumbnail.py
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from datetime import datetime
|
5 |
+
from io import BytesIO
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image, ImageOps, PngImagePlugin
|
11 |
+
from loguru import logger
|
12 |
+
from watchdog.events import FileSystemEventHandler
|
13 |
+
from watchdog.observers import Observer
|
14 |
+
|
15 |
+
LARGE_ENOUGH_NUMBER = 100
|
16 |
+
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024 ** 2)
|
17 |
+
from .storage_backends import FilesystemStorageBackend
|
18 |
+
from .utils import aspect_to_string, generate_filename, glob_img
|
19 |
+
|
20 |
+
|
21 |
+
class FileManager(FileSystemEventHandler):
|
22 |
+
def __init__(self, app=None):
|
23 |
+
self.app = app
|
24 |
+
self._default_root_directory = "media"
|
25 |
+
self._default_thumbnail_directory = "media"
|
26 |
+
self._default_root_url = "/"
|
27 |
+
self._default_thumbnail_root_url = "/"
|
28 |
+
self._default_format = "JPEG"
|
29 |
+
self.output_dir: Path = None
|
30 |
+
|
31 |
+
if app is not None:
|
32 |
+
self.init_app(app)
|
33 |
+
|
34 |
+
self.image_dir_filenames = []
|
35 |
+
self.output_dir_filenames = []
|
36 |
+
|
37 |
+
self.image_dir_observer = None
|
38 |
+
self.output_dir_observer = None
|
39 |
+
|
40 |
+
self.modified_time = {
|
41 |
+
"image": datetime.utcnow(),
|
42 |
+
"output": datetime.utcnow(),
|
43 |
+
}
|
44 |
+
|
45 |
+
def start(self):
|
46 |
+
self.image_dir_filenames = self._media_names(self.root_directory)
|
47 |
+
self.output_dir_filenames = self._media_names(self.output_dir)
|
48 |
+
|
49 |
+
logger.info(f"Start watching image directory: {self.root_directory}")
|
50 |
+
self.image_dir_observer = Observer()
|
51 |
+
self.image_dir_observer.schedule(self, self.root_directory, recursive=False)
|
52 |
+
self.image_dir_observer.start()
|
53 |
+
|
54 |
+
logger.info(f"Start watching output directory: {self.output_dir}")
|
55 |
+
self.output_dir_observer = Observer()
|
56 |
+
self.output_dir_observer.schedule(self, self.output_dir, recursive=False)
|
57 |
+
self.output_dir_observer.start()
|
58 |
+
|
59 |
+
def on_modified(self, event):
|
60 |
+
if not os.path.isdir(event.src_path):
|
61 |
+
return
|
62 |
+
if event.src_path == str(self.root_directory):
|
63 |
+
logger.info(f"Image directory {event.src_path} modified")
|
64 |
+
self.image_dir_filenames = self._media_names(self.root_directory)
|
65 |
+
self.modified_time['image'] = datetime.utcnow()
|
66 |
+
elif event.src_path == str(self.output_dir):
|
67 |
+
logger.info(f"Output directory {event.src_path} modified")
|
68 |
+
self.output_dir_filenames = self._media_names(self.output_dir)
|
69 |
+
self.modified_time['output'] = datetime.utcnow()
|
70 |
+
|
71 |
+
def init_app(self, app):
|
72 |
+
if self.app is None:
|
73 |
+
self.app = app
|
74 |
+
app.thumbnail_instance = self
|
75 |
+
|
76 |
+
if not hasattr(app, "extensions"):
|
77 |
+
app.extensions = {}
|
78 |
+
|
79 |
+
if "thumbnail" in app.extensions:
|
80 |
+
raise RuntimeError("Flask-thumbnail extension already initialized")
|
81 |
+
|
82 |
+
app.extensions["thumbnail"] = self
|
83 |
+
|
84 |
+
app.config.setdefault("THUMBNAIL_MEDIA_ROOT", self._default_root_directory)
|
85 |
+
app.config.setdefault("THUMBNAIL_MEDIA_THUMBNAIL_ROOT", self._default_thumbnail_directory)
|
86 |
+
app.config.setdefault("THUMBNAIL_MEDIA_URL", self._default_root_url)
|
87 |
+
app.config.setdefault("THUMBNAIL_MEDIA_THUMBNAIL_URL", self._default_thumbnail_root_url)
|
88 |
+
app.config.setdefault("THUMBNAIL_DEFAULT_FORMAT", self._default_format)
|
89 |
+
|
90 |
+
def save_to_output_directory(self, image: np.ndarray, filename: str):
|
91 |
+
fp = Path(filename)
|
92 |
+
new_name = fp.stem + f"_{int(time.time())}" + fp.suffix
|
93 |
+
if image.shape[2] == 3:
|
94 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
95 |
+
elif image.shape[2] == 4:
|
96 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA)
|
97 |
+
|
98 |
+
cv2.imwrite(str(self.output_dir / new_name), image)
|
99 |
+
|
100 |
+
@property
|
101 |
+
def root_directory(self):
|
102 |
+
path = self.app.config["THUMBNAIL_MEDIA_ROOT"]
|
103 |
+
|
104 |
+
if os.path.isabs(path):
|
105 |
+
return path
|
106 |
+
else:
|
107 |
+
return os.path.join(self.app.root_path, path)
|
108 |
+
|
109 |
+
@property
|
110 |
+
def thumbnail_directory(self):
|
111 |
+
path = self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"]
|
112 |
+
|
113 |
+
if os.path.isabs(path):
|
114 |
+
return path
|
115 |
+
else:
|
116 |
+
return os.path.join(self.app.root_path, path)
|
117 |
+
|
118 |
+
@property
|
119 |
+
def root_url(self):
|
120 |
+
return self.app.config["THUMBNAIL_MEDIA_URL"]
|
121 |
+
|
122 |
+
@property
|
123 |
+
def media_names(self):
|
124 |
+
# return self.image_dir_filenames
|
125 |
+
return self._media_names(self.root_directory)
|
126 |
+
|
127 |
+
@property
|
128 |
+
def output_media_names(self):
|
129 |
+
return self._media_names(self.output_dir)
|
130 |
+
# return self.output_dir_filenames
|
131 |
+
|
132 |
+
@staticmethod
|
133 |
+
def _media_names(directory: Path):
|
134 |
+
names = sorted([it.name for it in glob_img(directory)])
|
135 |
+
res = []
|
136 |
+
for name in names:
|
137 |
+
path = os.path.join(directory, name)
|
138 |
+
img = Image.open(path)
|
139 |
+
res.append({"name": name, "height": img.height, "width": img.width, "ctime": os.path.getctime(path)})
|
140 |
+
return res
|
141 |
+
|
142 |
+
@property
|
143 |
+
def thumbnail_url(self):
|
144 |
+
return self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_URL"]
|
145 |
+
|
146 |
+
def get_thumbnail(self, directory: Path, original_filename: str, width, height, **options):
|
147 |
+
storage = FilesystemStorageBackend(self.app)
|
148 |
+
crop = options.get("crop", "fit")
|
149 |
+
background = options.get("background")
|
150 |
+
quality = options.get("quality", 90)
|
151 |
+
|
152 |
+
original_path, original_filename = os.path.split(original_filename)
|
153 |
+
original_filepath = os.path.join(directory, original_path, original_filename)
|
154 |
+
image = Image.open(BytesIO(storage.read(original_filepath)))
|
155 |
+
|
156 |
+
# keep ratio resize
|
157 |
+
if width is not None:
|
158 |
+
height = int(image.height * width / image.width)
|
159 |
+
else:
|
160 |
+
width = int(image.width * height / image.height)
|
161 |
+
|
162 |
+
thumbnail_size = (width, height)
|
163 |
+
|
164 |
+
thumbnail_filename = generate_filename(
|
165 |
+
original_filename, aspect_to_string(thumbnail_size), crop, background, quality
|
166 |
+
)
|
167 |
+
|
168 |
+
thumbnail_filepath = os.path.join(
|
169 |
+
self.thumbnail_directory, original_path, thumbnail_filename
|
170 |
+
)
|
171 |
+
thumbnail_url = os.path.join(self.thumbnail_url, original_path, thumbnail_filename)
|
172 |
+
|
173 |
+
if storage.exists(thumbnail_filepath):
|
174 |
+
return thumbnail_url, (width, height)
|
175 |
+
|
176 |
+
try:
|
177 |
+
image.load()
|
178 |
+
except (IOError, OSError):
|
179 |
+
self.app.logger.warning("Thumbnail not load image: %s", original_filepath)
|
180 |
+
return thumbnail_url, (width, height)
|
181 |
+
|
182 |
+
# get original image format
|
183 |
+
options["format"] = options.get("format", image.format)
|
184 |
+
|
185 |
+
image = self._create_thumbnail(image, thumbnail_size, crop, background=background)
|
186 |
+
|
187 |
+
raw_data = self.get_raw_data(image, **options)
|
188 |
+
storage.save(thumbnail_filepath, raw_data)
|
189 |
+
|
190 |
+
return thumbnail_url, (width, height)
|
191 |
+
|
192 |
+
def get_raw_data(self, image, **options):
|
193 |
+
data = {
|
194 |
+
"format": self._get_format(image, **options),
|
195 |
+
"quality": options.get("quality", 90),
|
196 |
+
}
|
197 |
+
|
198 |
+
_file = BytesIO()
|
199 |
+
image.save(_file, **data)
|
200 |
+
return _file.getvalue()
|
201 |
+
|
202 |
+
@staticmethod
|
203 |
+
def colormode(image, colormode="RGB"):
|
204 |
+
if colormode == "RGB" or colormode == "RGBA":
|
205 |
+
if image.mode == "RGBA":
|
206 |
+
return image
|
207 |
+
if image.mode == "LA":
|
208 |
+
return image.convert("RGBA")
|
209 |
+
return image.convert(colormode)
|
210 |
+
|
211 |
+
if colormode == "GRAY":
|
212 |
+
return image.convert("L")
|
213 |
+
|
214 |
+
return image.convert(colormode)
|
215 |
+
|
216 |
+
@staticmethod
|
217 |
+
def background(original_image, color=0xFF):
|
218 |
+
size = (max(original_image.size),) * 2
|
219 |
+
image = Image.new("L", size, color)
|
220 |
+
image.paste(
|
221 |
+
original_image,
|
222 |
+
tuple(map(lambda x: (x[0] - x[1]) / 2, zip(size, original_image.size))),
|
223 |
+
)
|
224 |
+
|
225 |
+
return image
|
226 |
+
|
227 |
+
def _get_format(self, image, **options):
|
228 |
+
if options.get("format"):
|
229 |
+
return options.get("format")
|
230 |
+
if image.format:
|
231 |
+
return image.format
|
232 |
+
|
233 |
+
return self.app.config["THUMBNAIL_DEFAULT_FORMAT"]
|
234 |
+
|
235 |
+
def _create_thumbnail(self, image, size, crop="fit", background=None):
|
236 |
+
try:
|
237 |
+
resample = Image.Resampling.LANCZOS
|
238 |
+
except AttributeError: # pylint: disable=raise-missing-from
|
239 |
+
resample = Image.ANTIALIAS
|
240 |
+
|
241 |
+
if crop == "fit":
|
242 |
+
image = ImageOps.fit(image, size, resample)
|
243 |
+
else:
|
244 |
+
image = image.copy()
|
245 |
+
image.thumbnail(size, resample=resample)
|
246 |
+
|
247 |
+
if background is not None:
|
248 |
+
image = self.background(image)
|
249 |
+
|
250 |
+
image = self.colormode(image)
|
251 |
+
|
252 |
+
return image
|
lama_cleaner/file_manager/storage_backends.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/storage_backends.py
|
2 |
+
import errno
|
3 |
+
import os
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
|
6 |
+
|
7 |
+
class BaseStorageBackend(ABC):
|
8 |
+
def __init__(self, app=None):
|
9 |
+
self.app = app
|
10 |
+
|
11 |
+
@abstractmethod
|
12 |
+
def read(self, filepath, mode="rb", **kwargs):
|
13 |
+
raise NotImplementedError
|
14 |
+
|
15 |
+
@abstractmethod
|
16 |
+
def exists(self, filepath):
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
@abstractmethod
|
20 |
+
def save(self, filepath, data):
|
21 |
+
raise NotImplementedError
|
22 |
+
|
23 |
+
|
24 |
+
class FilesystemStorageBackend(BaseStorageBackend):
|
25 |
+
def read(self, filepath, mode="rb", **kwargs):
|
26 |
+
with open(filepath, mode) as f: # pylint: disable=unspecified-encoding
|
27 |
+
return f.read()
|
28 |
+
|
29 |
+
def exists(self, filepath):
|
30 |
+
return os.path.exists(filepath)
|
31 |
+
|
32 |
+
def save(self, filepath, data):
|
33 |
+
directory = os.path.dirname(filepath)
|
34 |
+
|
35 |
+
if not os.path.exists(directory):
|
36 |
+
try:
|
37 |
+
os.makedirs(directory)
|
38 |
+
except OSError as e:
|
39 |
+
if e.errno != errno.EEXIST:
|
40 |
+
raise
|
41 |
+
|
42 |
+
if not os.path.isdir(directory):
|
43 |
+
raise IOError("{} is not a directory".format(directory))
|
44 |
+
|
45 |
+
with open(filepath, "wb") as f:
|
46 |
+
f.write(data)
|
lama_cleaner/file_manager/utils.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from typing import Union
|
6 |
+
|
7 |
+
|
8 |
+
def generate_filename(original_filename, *options):
|
9 |
+
name, ext = os.path.splitext(original_filename)
|
10 |
+
for v in options:
|
11 |
+
if v:
|
12 |
+
name += "_%s" % v
|
13 |
+
name += ext
|
14 |
+
|
15 |
+
return name
|
16 |
+
|
17 |
+
|
18 |
+
def parse_size(size):
|
19 |
+
if isinstance(size, int):
|
20 |
+
# If the size parameter is a single number, assume square aspect.
|
21 |
+
return [size, size]
|
22 |
+
|
23 |
+
if isinstance(size, (tuple, list)):
|
24 |
+
if len(size) == 1:
|
25 |
+
# If single value tuple/list is provided, exand it to two elements
|
26 |
+
return size + type(size)(size)
|
27 |
+
return size
|
28 |
+
|
29 |
+
try:
|
30 |
+
thumbnail_size = [int(x) for x in size.lower().split("x", 1)]
|
31 |
+
except ValueError:
|
32 |
+
raise ValueError( # pylint: disable=raise-missing-from
|
33 |
+
"Bad thumbnail size format. Valid format is INTxINT."
|
34 |
+
)
|
35 |
+
|
36 |
+
if len(thumbnail_size) == 1:
|
37 |
+
# If the size parameter only contains a single integer, assume square aspect.
|
38 |
+
thumbnail_size.append(thumbnail_size[0])
|
39 |
+
|
40 |
+
return thumbnail_size
|
41 |
+
|
42 |
+
|
43 |
+
def aspect_to_string(size):
|
44 |
+
if isinstance(size, str):
|
45 |
+
return size
|
46 |
+
|
47 |
+
return "x".join(map(str, size))
|
48 |
+
|
49 |
+
|
50 |
+
IMG_SUFFIX = {'.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG'}
|
51 |
+
|
52 |
+
|
53 |
+
def glob_img(p: Union[Path, str], recursive: bool = False):
|
54 |
+
p = Path(p)
|
55 |
+
if p.is_file() and p.suffix in IMG_SUFFIX:
|
56 |
+
yield p
|
57 |
+
else:
|
58 |
+
if recursive:
|
59 |
+
files = Path(p).glob("**/*.*")
|
60 |
+
else:
|
61 |
+
files = Path(p).glob("*.*")
|
62 |
+
|
63 |
+
for it in files:
|
64 |
+
if it.suffix not in IMG_SUFFIX:
|
65 |
+
continue
|
66 |
+
yield it
|
lama_cleaner/helper.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from typing import List, Optional
|
5 |
+
from urllib.parse import urlparse
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from PIL import Image, ImageOps
|
11 |
+
from loguru import logger
|
12 |
+
from torch.hub import download_url_to_file, get_dir
|
13 |
+
|
14 |
+
|
15 |
+
def get_cache_path_by_url(url):
|
16 |
+
parts = urlparse(url)
|
17 |
+
hub_dir = get_dir()
|
18 |
+
model_dir = os.path.join(hub_dir, "checkpoints")
|
19 |
+
if not os.path.isdir(model_dir):
|
20 |
+
os.makedirs(model_dir)
|
21 |
+
filename = os.path.basename(parts.path)
|
22 |
+
cached_file = os.path.join(model_dir, filename)
|
23 |
+
return cached_file
|
24 |
+
|
25 |
+
|
26 |
+
def download_model(url):
|
27 |
+
cached_file = get_cache_path_by_url(url)
|
28 |
+
if not os.path.exists(cached_file):
|
29 |
+
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
30 |
+
hash_prefix = None
|
31 |
+
download_url_to_file(url, cached_file, hash_prefix, progress=True)
|
32 |
+
return cached_file
|
33 |
+
|
34 |
+
|
35 |
+
def ceil_modulo(x, mod):
|
36 |
+
if x % mod == 0:
|
37 |
+
return x
|
38 |
+
return (x // mod + 1) * mod
|
39 |
+
|
40 |
+
|
41 |
+
def load_jit_model(url_or_path, device):
|
42 |
+
# if os.path.exists(url_or_path):
|
43 |
+
# model_path = url_or_path
|
44 |
+
# else:
|
45 |
+
# model_path = download_model(url_or_path)
|
46 |
+
model_path = os.getcwd()
|
47 |
+
logger.info(f"Load model from: {model_path}")
|
48 |
+
try:
|
49 |
+
model = torch.jit.load(model_path).to(device)
|
50 |
+
except:
|
51 |
+
logger.error(
|
52 |
+
f"Failed to load {model_path}, delete model and restart lama-cleaner"
|
53 |
+
)
|
54 |
+
exit(-1)
|
55 |
+
model.eval()
|
56 |
+
return model
|
57 |
+
|
58 |
+
|
59 |
+
def load_model(model: torch.nn.Module, url_or_path, device):
|
60 |
+
if os.path.exists(url_or_path):
|
61 |
+
model_path = url_or_path
|
62 |
+
else:
|
63 |
+
model_path = download_model(url_or_path)
|
64 |
+
|
65 |
+
try:
|
66 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
67 |
+
model.load_state_dict(state_dict, strict=True)
|
68 |
+
model.to(device)
|
69 |
+
logger.info(f"Load model from: {model_path}")
|
70 |
+
except:
|
71 |
+
logger.error(
|
72 |
+
f"Failed to load {model_path}, delete model and restart lama-cleaner"
|
73 |
+
)
|
74 |
+
exit(-1)
|
75 |
+
model.eval()
|
76 |
+
return model
|
77 |
+
|
78 |
+
|
79 |
+
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
80 |
+
data = cv2.imencode(
|
81 |
+
f".{ext}",
|
82 |
+
image_numpy,
|
83 |
+
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
84 |
+
)[1]
|
85 |
+
image_bytes = data.tobytes()
|
86 |
+
return image_bytes
|
87 |
+
|
88 |
+
|
89 |
+
def load_img(img_bytes, gray: bool = False):
|
90 |
+
alpha_channel = None
|
91 |
+
image = Image.open(io.BytesIO(img_bytes))
|
92 |
+
try:
|
93 |
+
image = ImageOps.exif_transpose(image)
|
94 |
+
except:
|
95 |
+
pass
|
96 |
+
|
97 |
+
if gray:
|
98 |
+
image = image.convert('L')
|
99 |
+
np_img = np.array(image)
|
100 |
+
else:
|
101 |
+
if image.mode == 'RGBA':
|
102 |
+
np_img = np.array(image)
|
103 |
+
alpha_channel = np_img[:, :, -1]
|
104 |
+
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
|
105 |
+
else:
|
106 |
+
image = image.convert('RGB')
|
107 |
+
np_img = np.array(image)
|
108 |
+
|
109 |
+
return np_img, alpha_channel
|
110 |
+
|
111 |
+
|
112 |
+
def norm_img(np_img):
|
113 |
+
if len(np_img.shape) == 2:
|
114 |
+
np_img = np_img[:, :, np.newaxis]
|
115 |
+
np_img = np.transpose(np_img, (2, 0, 1))
|
116 |
+
np_img = np_img.astype("float32") / 255
|
117 |
+
return np_img
|
118 |
+
|
119 |
+
|
120 |
+
def resize_max_size(
|
121 |
+
np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
|
122 |
+
) -> np.ndarray:
|
123 |
+
# Resize image's longer size to size_limit if longer size larger than size_limit
|
124 |
+
h, w = np_img.shape[:2]
|
125 |
+
if max(h, w) > size_limit:
|
126 |
+
ratio = size_limit / max(h, w)
|
127 |
+
new_w = int(w * ratio + 0.5)
|
128 |
+
new_h = int(h * ratio + 0.5)
|
129 |
+
return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
|
130 |
+
else:
|
131 |
+
return np_img
|
132 |
+
|
133 |
+
|
134 |
+
def pad_img_to_modulo(
|
135 |
+
img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
|
136 |
+
):
|
137 |
+
"""
|
138 |
+
|
139 |
+
Args:
|
140 |
+
img: [H, W, C]
|
141 |
+
mod:
|
142 |
+
square: 是否为正方形
|
143 |
+
min_size:
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
|
147 |
+
"""
|
148 |
+
if len(img.shape) == 2:
|
149 |
+
img = img[:, :, np.newaxis]
|
150 |
+
height, width = img.shape[:2]
|
151 |
+
out_height = ceil_modulo(height, mod)
|
152 |
+
out_width = ceil_modulo(width, mod)
|
153 |
+
|
154 |
+
if min_size is not None:
|
155 |
+
assert min_size % mod == 0
|
156 |
+
out_width = max(min_size, out_width)
|
157 |
+
out_height = max(min_size, out_height)
|
158 |
+
|
159 |
+
if square:
|
160 |
+
max_size = max(out_height, out_width)
|
161 |
+
out_height = max_size
|
162 |
+
out_width = max_size
|
163 |
+
|
164 |
+
return np.pad(
|
165 |
+
img,
|
166 |
+
((0, out_height - height), (0, out_width - width), (0, 0)),
|
167 |
+
mode="symmetric",
|
168 |
+
)
|
169 |
+
|
170 |
+
|
171 |
+
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
|
172 |
+
"""
|
173 |
+
Args:
|
174 |
+
mask: (h, w, 1) 0~255
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
|
178 |
+
"""
|
179 |
+
height, width = mask.shape[:2]
|
180 |
+
_, thresh = cv2.threshold(mask, 127, 255, 0)
|
181 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
182 |
+
|
183 |
+
boxes = []
|
184 |
+
for cnt in contours:
|
185 |
+
x, y, w, h = cv2.boundingRect(cnt)
|
186 |
+
box = np.array([x, y, x + w, y + h]).astype(int)
|
187 |
+
|
188 |
+
box[::2] = np.clip(box[::2], 0, width)
|
189 |
+
box[1::2] = np.clip(box[1::2], 0, height)
|
190 |
+
boxes.append(box)
|
191 |
+
|
192 |
+
return boxes
|
193 |
+
|
194 |
+
|
195 |
+
def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
|
196 |
+
"""
|
197 |
+
Args:
|
198 |
+
mask: (h, w) 0~255
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
|
202 |
+
"""
|
203 |
+
_, thresh = cv2.threshold(mask, 127, 255, 0)
|
204 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
205 |
+
|
206 |
+
max_area = 0
|
207 |
+
max_index = -1
|
208 |
+
for i, cnt in enumerate(contours):
|
209 |
+
area = cv2.contourArea(cnt)
|
210 |
+
if area > max_area:
|
211 |
+
max_area = area
|
212 |
+
max_index = i
|
213 |
+
|
214 |
+
if max_index != -1:
|
215 |
+
new_mask = np.zeros_like(mask)
|
216 |
+
return cv2.drawContours(new_mask, contours, max_index, 255, -1)
|
217 |
+
else:
|
218 |
+
return mask
|
lama_cleaner/interactive_seg.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Tuple, List
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from loguru import logger
|
9 |
+
from pydantic import BaseModel
|
10 |
+
|
11 |
+
from lama_cleaner.helper import load_jit_model
|
12 |
+
|
13 |
+
|
14 |
+
class Click(BaseModel):
|
15 |
+
# [y, x]
|
16 |
+
coords: Tuple[float, float]
|
17 |
+
is_positive: bool
|
18 |
+
indx: int
|
19 |
+
|
20 |
+
@property
|
21 |
+
def coords_and_indx(self):
|
22 |
+
return (*self.coords, self.indx)
|
23 |
+
|
24 |
+
def scale(self, x_ratio: float, y_ratio: float) -> 'Click':
|
25 |
+
return Click(
|
26 |
+
coords=(self.coords[0] * x_ratio, self.coords[1] * y_ratio),
|
27 |
+
is_positive=self.is_positive,
|
28 |
+
indx=self.indx
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
class ResizeTrans:
|
33 |
+
def __init__(self, size=480):
|
34 |
+
super().__init__()
|
35 |
+
self.crop_height = size
|
36 |
+
self.crop_width = size
|
37 |
+
|
38 |
+
def transform(self, image_nd, clicks_lists):
|
39 |
+
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
|
40 |
+
image_height, image_width = image_nd.shape[2:4]
|
41 |
+
self.image_height = image_height
|
42 |
+
self.image_width = image_width
|
43 |
+
image_nd_r = F.interpolate(image_nd, (self.crop_height, self.crop_width), mode='bilinear', align_corners=True)
|
44 |
+
|
45 |
+
y_ratio = self.crop_height / image_height
|
46 |
+
x_ratio = self.crop_width / image_width
|
47 |
+
|
48 |
+
clicks_lists_resized = []
|
49 |
+
for clicks_list in clicks_lists:
|
50 |
+
clicks_list_resized = [click.scale(y_ratio, x_ratio) for click in clicks_list]
|
51 |
+
clicks_lists_resized.append(clicks_list_resized)
|
52 |
+
|
53 |
+
return image_nd_r, clicks_lists_resized
|
54 |
+
|
55 |
+
def inv_transform(self, prob_map):
|
56 |
+
new_prob_map = F.interpolate(prob_map, (self.image_height, self.image_width), mode='bilinear',
|
57 |
+
align_corners=True)
|
58 |
+
|
59 |
+
return new_prob_map
|
60 |
+
|
61 |
+
|
62 |
+
class ISPredictor(object):
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
model,
|
66 |
+
device,
|
67 |
+
open_kernel_size: int,
|
68 |
+
dilate_kernel_size: int,
|
69 |
+
net_clicks_limit=None,
|
70 |
+
zoom_in=None,
|
71 |
+
infer_size=384,
|
72 |
+
):
|
73 |
+
self.model = model
|
74 |
+
self.open_kernel_size = open_kernel_size
|
75 |
+
self.dilate_kernel_size = dilate_kernel_size
|
76 |
+
self.net_clicks_limit = net_clicks_limit
|
77 |
+
self.device = device
|
78 |
+
self.zoom_in = zoom_in
|
79 |
+
self.infer_size = infer_size
|
80 |
+
|
81 |
+
# self.transforms = [zoom_in] if zoom_in is not None else []
|
82 |
+
|
83 |
+
def __call__(self, input_image: torch.Tensor, clicks: List[Click], prev_mask):
|
84 |
+
"""
|
85 |
+
|
86 |
+
Args:
|
87 |
+
input_image: [1, 3, H, W] [0~1]
|
88 |
+
clicks: List[Click]
|
89 |
+
prev_mask: [1, 1, H, W]
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
|
93 |
+
"""
|
94 |
+
transforms = [ResizeTrans(self.infer_size)]
|
95 |
+
input_image = torch.cat((input_image, prev_mask), dim=1)
|
96 |
+
|
97 |
+
# image_nd resized to infer_size
|
98 |
+
for t in transforms:
|
99 |
+
image_nd, clicks_lists = t.transform(input_image, [clicks])
|
100 |
+
|
101 |
+
# image_nd.shape = [1, 4, 256, 256]
|
102 |
+
# points_nd.sha[e = [1, 2, 3]
|
103 |
+
# clicks_lists[0][0] Click 类
|
104 |
+
points_nd = self.get_points_nd(clicks_lists)
|
105 |
+
pred_logits = self.model(image_nd, points_nd)
|
106 |
+
pred = torch.sigmoid(pred_logits)
|
107 |
+
pred = self.post_process(pred)
|
108 |
+
|
109 |
+
prediction = F.interpolate(pred, mode='bilinear', align_corners=True,
|
110 |
+
size=image_nd.size()[2:])
|
111 |
+
|
112 |
+
for t in reversed(transforms):
|
113 |
+
prediction = t.inv_transform(prediction)
|
114 |
+
|
115 |
+
# if self.zoom_in is not None and self.zoom_in.check_possible_recalculation():
|
116 |
+
# return self.get_prediction(clicker)
|
117 |
+
|
118 |
+
return prediction.cpu().numpy()[0, 0]
|
119 |
+
|
120 |
+
def post_process(self, pred: torch.Tensor) -> torch.Tensor:
|
121 |
+
pred_mask = pred.cpu().numpy()[0][0]
|
122 |
+
# morph_open to remove small noise
|
123 |
+
kernel_size = self.open_kernel_size
|
124 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
125 |
+
pred_mask = cv2.morphologyEx(pred_mask, cv2.MORPH_OPEN, kernel, iterations=1)
|
126 |
+
|
127 |
+
# Why dilate: make region slightly larger to avoid missing some pixels, this generally works better
|
128 |
+
dilate_kernel_size = self.dilate_kernel_size
|
129 |
+
if dilate_kernel_size > 1:
|
130 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_DILATE, (dilate_kernel_size, dilate_kernel_size))
|
131 |
+
pred_mask = cv2.dilate(pred_mask, kernel, 1)
|
132 |
+
return torch.from_numpy(pred_mask).unsqueeze(0).unsqueeze(0)
|
133 |
+
|
134 |
+
def get_points_nd(self, clicks_lists):
|
135 |
+
total_clicks = []
|
136 |
+
num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
|
137 |
+
num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
|
138 |
+
num_max_points = max(num_pos_clicks + num_neg_clicks)
|
139 |
+
if self.net_clicks_limit is not None:
|
140 |
+
num_max_points = min(self.net_clicks_limit, num_max_points)
|
141 |
+
num_max_points = max(1, num_max_points)
|
142 |
+
|
143 |
+
for clicks_list in clicks_lists:
|
144 |
+
clicks_list = clicks_list[:self.net_clicks_limit]
|
145 |
+
pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive]
|
146 |
+
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)]
|
147 |
+
|
148 |
+
neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive]
|
149 |
+
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)]
|
150 |
+
total_clicks.append(pos_clicks + neg_clicks)
|
151 |
+
|
152 |
+
return torch.tensor(total_clicks, device=self.device)
|
153 |
+
|
154 |
+
|
155 |
+
INTERACTIVE_SEG_MODEL_URL = os.environ.get(
|
156 |
+
"INTERACTIVE_SEG_MODEL_URL",
|
157 |
+
"https://github.com/Sanster/models/releases/download/clickseg_pplnet/clickseg_pplnet.pt",
|
158 |
+
)
|
159 |
+
|
160 |
+
|
161 |
+
class InteractiveSeg:
|
162 |
+
def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3):
|
163 |
+
device = torch.device('cpu')
|
164 |
+
model = load_jit_model(INTERACTIVE_SEG_MODEL_URL, device).eval()
|
165 |
+
self.predictor = ISPredictor(model, device,
|
166 |
+
infer_size=infer_size,
|
167 |
+
open_kernel_size=open_kernel_size,
|
168 |
+
dilate_kernel_size=dilate_kernel_size)
|
169 |
+
|
170 |
+
def __call__(self, image, clicks, prev_mask=None):
|
171 |
+
"""
|
172 |
+
|
173 |
+
Args:
|
174 |
+
image: [H,W,C] RGB
|
175 |
+
clicks:
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
|
179 |
+
"""
|
180 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
181 |
+
image = torch.from_numpy((image / 255).transpose(2, 0, 1)).unsqueeze(0).float()
|
182 |
+
if prev_mask is None:
|
183 |
+
mask = torch.zeros_like(image[:, :1, :, :])
|
184 |
+
else:
|
185 |
+
logger.info('InteractiveSeg run with prev_mask')
|
186 |
+
mask = torch.from_numpy(prev_mask / 255).unsqueeze(0).unsqueeze(0).float()
|
187 |
+
|
188 |
+
pred_probs = self.predictor(image, clicks, mask)
|
189 |
+
pred_mask = pred_probs > 0.5
|
190 |
+
pred_mask = (pred_mask * 255).astype(np.uint8)
|
191 |
+
|
192 |
+
# Find largest contour
|
193 |
+
# pred_mask = only_keep_largest_contour(pred_mask)
|
194 |
+
# To simplify frontend process, add mask brush color here
|
195 |
+
fg = pred_mask == 255
|
196 |
+
bg = pred_mask != 255
|
197 |
+
pred_mask = cv2.cvtColor(pred_mask, cv2.COLOR_GRAY2BGRA)
|
198 |
+
# frontend brush color "ffcc00bb"
|
199 |
+
pred_mask[bg] = 0
|
200 |
+
pred_mask[fg] = [255, 203, 0, int(255 * 0.73)]
|
201 |
+
pred_mask = cv2.cvtColor(pred_mask, cv2.COLOR_BGRA2RGBA)
|
202 |
+
return pred_mask
|
lama_cleaner/model/__init__.py
ADDED
File without changes
|
lama_cleaner/model/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (172 Bytes). View file
|
|
lama_cleaner/model/__pycache__/base.cpython-38.pyc
ADDED
Binary file (6.7 kB). View file
|
|
lama_cleaner/model/__pycache__/ddim_sampler.cpython-38.pyc
ADDED
Binary file (4.74 kB). View file
|
|
lama_cleaner/model/__pycache__/fcf.cpython-38.pyc
ADDED
Binary file (33.4 kB). View file
|
|
lama_cleaner/model/__pycache__/lama.cpython-38.pyc
ADDED
Binary file (2.12 kB). View file
|
|
lama_cleaner/model/__pycache__/ldm.cpython-38.pyc
ADDED
Binary file (7.79 kB). View file
|
|
lama_cleaner/model/__pycache__/manga.cpython-38.pyc
ADDED
Binary file (2.72 kB). View file
|
|
lama_cleaner/model/__pycache__/mat.cpython-38.pyc
ADDED
Binary file (38.8 kB). View file
|
|
lama_cleaner/model/__pycache__/opencv2.cpython-38.pyc
ADDED
Binary file (1.13 kB). View file
|
|
lama_cleaner/model/__pycache__/paint_by_example.cpython-38.pyc
ADDED
Binary file (4.25 kB). View file
|
|
lama_cleaner/model/__pycache__/plms_sampler.cpython-38.pyc
ADDED
Binary file (7.09 kB). View file
|
|
lama_cleaner/model/__pycache__/sd.cpython-38.pyc
ADDED
Binary file (6.26 kB). View file
|
|
lama_cleaner/model/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (26.3 kB). View file
|
|
lama_cleaner/model/__pycache__/zits.cpython-38.pyc
ADDED
Binary file (10.5 kB). View file
|
|
lama_cleaner/model/base.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo
|
10 |
+
from lama_cleaner.schema import Config, HDStrategy
|
11 |
+
|
12 |
+
|
13 |
+
class InpaintModel:
|
14 |
+
min_size: Optional[int] = None
|
15 |
+
pad_mod = 8
|
16 |
+
pad_to_square = False
|
17 |
+
|
18 |
+
def __init__(self, device, **kwargs):
|
19 |
+
"""
|
20 |
+
|
21 |
+
Args:
|
22 |
+
device:
|
23 |
+
"""
|
24 |
+
self.device = device
|
25 |
+
self.init_model(device, **kwargs)
|
26 |
+
|
27 |
+
@abc.abstractmethod
|
28 |
+
def init_model(self, device, **kwargs):
|
29 |
+
...
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
@abc.abstractmethod
|
33 |
+
def is_downloaded() -> bool:
|
34 |
+
...
|
35 |
+
|
36 |
+
@abc.abstractmethod
|
37 |
+
def forward(self, image, mask, config: Config):
|
38 |
+
"""Input images and output images have same size
|
39 |
+
images: [H, W, C] RGB
|
40 |
+
masks: [H, W, 1] 255 为 masks 区域
|
41 |
+
return: BGR IMAGE
|
42 |
+
"""
|
43 |
+
...
|
44 |
+
|
45 |
+
def _pad_forward(self, image, mask, config: Config):
|
46 |
+
origin_height, origin_width = image.shape[:2]
|
47 |
+
pad_image = pad_img_to_modulo(
|
48 |
+
image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
49 |
+
)
|
50 |
+
pad_mask = pad_img_to_modulo(
|
51 |
+
mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
52 |
+
)
|
53 |
+
|
54 |
+
logger.info(f"final forward pad size: {pad_image.shape}")
|
55 |
+
|
56 |
+
result = self.forward(pad_image, pad_mask, config)
|
57 |
+
result = result[0:origin_height, 0:origin_width, :]
|
58 |
+
|
59 |
+
result, image, mask = self.forward_post_process(result, image, mask, config)
|
60 |
+
|
61 |
+
mask = mask[:, :, np.newaxis]
|
62 |
+
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
|
63 |
+
return result
|
64 |
+
|
65 |
+
def forward_post_process(self, result, image, mask, config):
|
66 |
+
return result, image, mask
|
67 |
+
|
68 |
+
@torch.no_grad()
|
69 |
+
def __call__(self, image, mask, config: Config):
|
70 |
+
"""
|
71 |
+
images: [H, W, C] RGB, not normalized
|
72 |
+
masks: [H, W]
|
73 |
+
return: BGR IMAGE
|
74 |
+
"""
|
75 |
+
inpaint_result = None
|
76 |
+
logger.info(f"hd_strategy: {config.hd_strategy}")
|
77 |
+
if config.hd_strategy == HDStrategy.CROP:
|
78 |
+
if max(image.shape) > config.hd_strategy_crop_trigger_size:
|
79 |
+
logger.info(f"Run crop strategy")
|
80 |
+
boxes = boxes_from_mask(mask)
|
81 |
+
crop_result = []
|
82 |
+
for box in boxes:
|
83 |
+
crop_image, crop_box = self._run_box(image, mask, box, config)
|
84 |
+
crop_result.append((crop_image, crop_box))
|
85 |
+
|
86 |
+
inpaint_result = image[:, :, ::-1]
|
87 |
+
for crop_image, crop_box in crop_result:
|
88 |
+
x1, y1, x2, y2 = crop_box
|
89 |
+
inpaint_result[y1:y2, x1:x2, :] = crop_image
|
90 |
+
|
91 |
+
elif config.hd_strategy == HDStrategy.RESIZE:
|
92 |
+
if max(image.shape) > config.hd_strategy_resize_limit:
|
93 |
+
origin_size = image.shape[:2]
|
94 |
+
downsize_image = resize_max_size(
|
95 |
+
image, size_limit=config.hd_strategy_resize_limit
|
96 |
+
)
|
97 |
+
downsize_mask = resize_max_size(
|
98 |
+
mask, size_limit=config.hd_strategy_resize_limit
|
99 |
+
)
|
100 |
+
|
101 |
+
logger.info(
|
102 |
+
f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}"
|
103 |
+
)
|
104 |
+
inpaint_result = self._pad_forward(
|
105 |
+
downsize_image, downsize_mask, config
|
106 |
+
)
|
107 |
+
|
108 |
+
# only paste masked area result
|
109 |
+
inpaint_result = cv2.resize(
|
110 |
+
inpaint_result,
|
111 |
+
(origin_size[1], origin_size[0]),
|
112 |
+
interpolation=cv2.INTER_CUBIC,
|
113 |
+
)
|
114 |
+
original_pixel_indices = mask < 127
|
115 |
+
inpaint_result[original_pixel_indices] = image[:, :, ::-1][
|
116 |
+
original_pixel_indices
|
117 |
+
]
|
118 |
+
|
119 |
+
if inpaint_result is None:
|
120 |
+
inpaint_result = self._pad_forward(image, mask, config)
|
121 |
+
|
122 |
+
return inpaint_result
|
123 |
+
|
124 |
+
def _crop_box(self, image, mask, box, config: Config):
|
125 |
+
"""
|
126 |
+
|
127 |
+
Args:
|
128 |
+
image: [H, W, C] RGB
|
129 |
+
mask: [H, W, 1]
|
130 |
+
box: [left,top,right,bottom]
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
BGR IMAGE, (l, r, r, b)
|
134 |
+
"""
|
135 |
+
box_h = box[3] - box[1]
|
136 |
+
box_w = box[2] - box[0]
|
137 |
+
cx = (box[0] + box[2]) // 2
|
138 |
+
cy = (box[1] + box[3]) // 2
|
139 |
+
img_h, img_w = image.shape[:2]
|
140 |
+
|
141 |
+
w = box_w + config.hd_strategy_crop_margin * 2
|
142 |
+
h = box_h + config.hd_strategy_crop_margin * 2
|
143 |
+
|
144 |
+
_l = cx - w // 2
|
145 |
+
_r = cx + w // 2
|
146 |
+
_t = cy - h // 2
|
147 |
+
_b = cy + h // 2
|
148 |
+
|
149 |
+
l = max(_l, 0)
|
150 |
+
r = min(_r, img_w)
|
151 |
+
t = max(_t, 0)
|
152 |
+
b = min(_b, img_h)
|
153 |
+
|
154 |
+
# try to get more context when crop around image edge
|
155 |
+
if _l < 0:
|
156 |
+
r += abs(_l)
|
157 |
+
if _r > img_w:
|
158 |
+
l -= _r - img_w
|
159 |
+
if _t < 0:
|
160 |
+
b += abs(_t)
|
161 |
+
if _b > img_h:
|
162 |
+
t -= _b - img_h
|
163 |
+
|
164 |
+
l = max(l, 0)
|
165 |
+
r = min(r, img_w)
|
166 |
+
t = max(t, 0)
|
167 |
+
b = min(b, img_h)
|
168 |
+
|
169 |
+
crop_img = image[t:b, l:r, :]
|
170 |
+
crop_mask = mask[t:b, l:r]
|
171 |
+
|
172 |
+
logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
|
173 |
+
|
174 |
+
return crop_img, crop_mask, [l, t, r, b]
|
175 |
+
|
176 |
+
def _calculate_cdf(self, histogram):
|
177 |
+
cdf = histogram.cumsum()
|
178 |
+
normalized_cdf = cdf / float(cdf.max())
|
179 |
+
return normalized_cdf
|
180 |
+
|
181 |
+
def _calculate_lookup(self, source_cdf, reference_cdf):
|
182 |
+
lookup_table = np.zeros(256)
|
183 |
+
lookup_val = 0
|
184 |
+
for source_index, source_val in enumerate(source_cdf):
|
185 |
+
for reference_index, reference_val in enumerate(reference_cdf):
|
186 |
+
if reference_val >= source_val:
|
187 |
+
lookup_val = reference_index
|
188 |
+
break
|
189 |
+
lookup_table[source_index] = lookup_val
|
190 |
+
return lookup_table
|
191 |
+
|
192 |
+
def _match_histograms(self, source, reference, mask):
|
193 |
+
transformed_channels = []
|
194 |
+
for channel in range(source.shape[-1]):
|
195 |
+
source_channel = source[:, :, channel]
|
196 |
+
reference_channel = reference[:, :, channel]
|
197 |
+
|
198 |
+
# only calculate histograms for non-masked parts
|
199 |
+
source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256])
|
200 |
+
reference_histogram, _ = np.histogram(reference_channel[mask == 0], 256, [0, 256])
|
201 |
+
|
202 |
+
source_cdf = self._calculate_cdf(source_histogram)
|
203 |
+
reference_cdf = self._calculate_cdf(reference_histogram)
|
204 |
+
|
205 |
+
lookup = self._calculate_lookup(source_cdf, reference_cdf)
|
206 |
+
|
207 |
+
transformed_channels.append(cv2.LUT(source_channel, lookup))
|
208 |
+
|
209 |
+
result = cv2.merge(transformed_channels)
|
210 |
+
result = cv2.convertScaleAbs(result)
|
211 |
+
|
212 |
+
return result
|
213 |
+
|
214 |
+
def _apply_cropper(self, image, mask, config: Config):
|
215 |
+
img_h, img_w = image.shape[:2]
|
216 |
+
l, t, w, h = (
|
217 |
+
config.croper_x,
|
218 |
+
config.croper_y,
|
219 |
+
config.croper_width,
|
220 |
+
config.croper_height,
|
221 |
+
)
|
222 |
+
r = l + w
|
223 |
+
b = t + h
|
224 |
+
|
225 |
+
l = max(l, 0)
|
226 |
+
r = min(r, img_w)
|
227 |
+
t = max(t, 0)
|
228 |
+
b = min(b, img_h)
|
229 |
+
|
230 |
+
crop_img = image[t:b, l:r, :]
|
231 |
+
crop_mask = mask[t:b, l:r]
|
232 |
+
return crop_img, crop_mask, (l, t, r, b)
|
233 |
+
|
234 |
+
def _run_box(self, image, mask, box, config: Config):
|
235 |
+
"""
|
236 |
+
|
237 |
+
Args:
|
238 |
+
image: [H, W, C] RGB
|
239 |
+
mask: [H, W, 1]
|
240 |
+
box: [left,top,right,bottom]
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
BGR IMAGE
|
244 |
+
"""
|
245 |
+
crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
|
246 |
+
|
247 |
+
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
|
lama_cleaner/model/ddim_sampler.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from loguru import logger
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
|
7 |
+
|
8 |
+
|
9 |
+
class DDIMSampler(object):
|
10 |
+
def __init__(self, model, schedule="linear"):
|
11 |
+
super().__init__()
|
12 |
+
self.model = model
|
13 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
14 |
+
self.schedule = schedule
|
15 |
+
|
16 |
+
def register_buffer(self, name, attr):
|
17 |
+
setattr(self, name, attr)
|
18 |
+
|
19 |
+
def make_schedule(
|
20 |
+
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
21 |
+
):
|
22 |
+
self.ddim_timesteps = make_ddim_timesteps(
|
23 |
+
ddim_discr_method=ddim_discretize,
|
24 |
+
num_ddim_timesteps=ddim_num_steps,
|
25 |
+
# array([1])
|
26 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
27 |
+
verbose=verbose,
|
28 |
+
)
|
29 |
+
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
|
30 |
+
assert (
|
31 |
+
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
32 |
+
), "alphas have to be defined for each timestep"
|
33 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
34 |
+
|
35 |
+
self.register_buffer("betas", to_torch(self.model.betas))
|
36 |
+
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
37 |
+
self.register_buffer(
|
38 |
+
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
39 |
+
)
|
40 |
+
|
41 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
42 |
+
self.register_buffer(
|
43 |
+
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
44 |
+
)
|
45 |
+
self.register_buffer(
|
46 |
+
"sqrt_one_minus_alphas_cumprod",
|
47 |
+
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
48 |
+
)
|
49 |
+
self.register_buffer(
|
50 |
+
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
51 |
+
)
|
52 |
+
self.register_buffer(
|
53 |
+
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
54 |
+
)
|
55 |
+
self.register_buffer(
|
56 |
+
"sqrt_recipm1_alphas_cumprod",
|
57 |
+
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
58 |
+
)
|
59 |
+
|
60 |
+
# ddim sampling parameters
|
61 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
62 |
+
alphacums=alphas_cumprod.cpu(),
|
63 |
+
ddim_timesteps=self.ddim_timesteps,
|
64 |
+
eta=ddim_eta,
|
65 |
+
verbose=verbose,
|
66 |
+
)
|
67 |
+
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
68 |
+
self.register_buffer("ddim_alphas", ddim_alphas)
|
69 |
+
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
70 |
+
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
71 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
72 |
+
(1 - self.alphas_cumprod_prev)
|
73 |
+
/ (1 - self.alphas_cumprod)
|
74 |
+
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
75 |
+
)
|
76 |
+
self.register_buffer(
|
77 |
+
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
78 |
+
)
|
79 |
+
|
80 |
+
@torch.no_grad()
|
81 |
+
def sample(self, steps, conditioning, batch_size, shape):
|
82 |
+
self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
|
83 |
+
# sampling
|
84 |
+
C, H, W = shape
|
85 |
+
size = (batch_size, C, H, W)
|
86 |
+
|
87 |
+
# samples: 1,3,128,128
|
88 |
+
return self.ddim_sampling(
|
89 |
+
conditioning,
|
90 |
+
size,
|
91 |
+
quantize_denoised=False,
|
92 |
+
ddim_use_original_steps=False,
|
93 |
+
noise_dropout=0,
|
94 |
+
temperature=1.0,
|
95 |
+
)
|
96 |
+
|
97 |
+
@torch.no_grad()
|
98 |
+
def ddim_sampling(
|
99 |
+
self,
|
100 |
+
cond,
|
101 |
+
shape,
|
102 |
+
ddim_use_original_steps=False,
|
103 |
+
quantize_denoised=False,
|
104 |
+
temperature=1.0,
|
105 |
+
noise_dropout=0.0,
|
106 |
+
):
|
107 |
+
device = self.model.betas.device
|
108 |
+
b = shape[0]
|
109 |
+
img = torch.randn(shape, device=device, dtype=cond.dtype)
|
110 |
+
timesteps = (
|
111 |
+
self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
112 |
+
)
|
113 |
+
|
114 |
+
time_range = (
|
115 |
+
reversed(range(0, timesteps))
|
116 |
+
if ddim_use_original_steps
|
117 |
+
else np.flip(timesteps)
|
118 |
+
)
|
119 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
120 |
+
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
|
121 |
+
|
122 |
+
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
|
123 |
+
|
124 |
+
for i, step in enumerate(iterator):
|
125 |
+
index = total_steps - i - 1
|
126 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
127 |
+
|
128 |
+
outs = self.p_sample_ddim(
|
129 |
+
img,
|
130 |
+
cond,
|
131 |
+
ts,
|
132 |
+
index=index,
|
133 |
+
use_original_steps=ddim_use_original_steps,
|
134 |
+
quantize_denoised=quantize_denoised,
|
135 |
+
temperature=temperature,
|
136 |
+
noise_dropout=noise_dropout,
|
137 |
+
)
|
138 |
+
img, _ = outs
|
139 |
+
|
140 |
+
return img
|
141 |
+
|
142 |
+
@torch.no_grad()
|
143 |
+
def p_sample_ddim(
|
144 |
+
self,
|
145 |
+
x,
|
146 |
+
c,
|
147 |
+
t,
|
148 |
+
index,
|
149 |
+
repeat_noise=False,
|
150 |
+
use_original_steps=False,
|
151 |
+
quantize_denoised=False,
|
152 |
+
temperature=1.0,
|
153 |
+
noise_dropout=0.0,
|
154 |
+
):
|
155 |
+
b, *_, device = *x.shape, x.device
|
156 |
+
e_t = self.model.apply_model(x, t, c)
|
157 |
+
|
158 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
159 |
+
alphas_prev = (
|
160 |
+
self.model.alphas_cumprod_prev
|
161 |
+
if use_original_steps
|
162 |
+
else self.ddim_alphas_prev
|
163 |
+
)
|
164 |
+
sqrt_one_minus_alphas = (
|
165 |
+
self.model.sqrt_one_minus_alphas_cumprod
|
166 |
+
if use_original_steps
|
167 |
+
else self.ddim_sqrt_one_minus_alphas
|
168 |
+
)
|
169 |
+
sigmas = (
|
170 |
+
self.model.ddim_sigmas_for_original_num_steps
|
171 |
+
if use_original_steps
|
172 |
+
else self.ddim_sigmas
|
173 |
+
)
|
174 |
+
# select parameters corresponding to the currently considered timestep
|
175 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
176 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
177 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
178 |
+
sqrt_one_minus_at = torch.full(
|
179 |
+
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
180 |
+
)
|
181 |
+
|
182 |
+
# current prediction for x_0
|
183 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
184 |
+
if quantize_denoised: # 没用
|
185 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
186 |
+
# direction pointing to x_t
|
187 |
+
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
|
188 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
189 |
+
if noise_dropout > 0.0: # 没用
|
190 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
191 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
192 |
+
return x_prev, pred_x0
|
lama_cleaner/model/fcf.py
ADDED
@@ -0,0 +1,1212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.fft as fft
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import conv2d, nn
|
10 |
+
|
11 |
+
from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img, boxes_from_mask, resize_max_size
|
12 |
+
from lama_cleaner.model.base import InpaintModel
|
13 |
+
from lama_cleaner.model.utils import setup_filter, _parse_scaling, _parse_padding, Conv2dLayer, FullyConnectedLayer, \
|
14 |
+
MinibatchStdLayer, activation_funcs, conv2d_resample, bias_act, upsample2d, normalize_2nd_moment, downsample2d
|
15 |
+
from lama_cleaner.schema import Config
|
16 |
+
|
17 |
+
|
18 |
+
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
19 |
+
assert isinstance(x, torch.Tensor)
|
20 |
+
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
21 |
+
|
22 |
+
|
23 |
+
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
24 |
+
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
25 |
+
"""
|
26 |
+
# Validate arguments.
|
27 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
28 |
+
if f is None:
|
29 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
30 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
31 |
+
assert f.dtype == torch.float32 and not f.requires_grad
|
32 |
+
batch_size, num_channels, in_height, in_width = x.shape
|
33 |
+
upx, upy = _parse_scaling(up)
|
34 |
+
downx, downy = _parse_scaling(down)
|
35 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
36 |
+
|
37 |
+
# Upsample by inserting zeros.
|
38 |
+
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
39 |
+
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
40 |
+
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
41 |
+
|
42 |
+
# Pad or crop.
|
43 |
+
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
44 |
+
x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)]
|
45 |
+
|
46 |
+
# Setup filter.
|
47 |
+
f = f * (gain ** (f.ndim / 2))
|
48 |
+
f = f.to(x.dtype)
|
49 |
+
if not flip_filter:
|
50 |
+
f = f.flip(list(range(f.ndim)))
|
51 |
+
|
52 |
+
# Convolve with the filter.
|
53 |
+
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
54 |
+
if f.ndim == 4:
|
55 |
+
x = conv2d(input=x, weight=f, groups=num_channels)
|
56 |
+
else:
|
57 |
+
x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
58 |
+
x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
59 |
+
|
60 |
+
# Downsample by throwing away pixels.
|
61 |
+
x = x[:, :, ::downy, ::downx]
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class EncoderEpilogue(torch.nn.Module):
|
66 |
+
def __init__(self,
|
67 |
+
in_channels, # Number of input channels.
|
68 |
+
cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
|
69 |
+
z_dim, # Output Latent (Z) dimensionality.
|
70 |
+
resolution, # Resolution of this block.
|
71 |
+
img_channels, # Number of input color channels.
|
72 |
+
architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
73 |
+
mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
|
74 |
+
mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
|
75 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
76 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
77 |
+
):
|
78 |
+
assert architecture in ['orig', 'skip', 'resnet']
|
79 |
+
super().__init__()
|
80 |
+
self.in_channels = in_channels
|
81 |
+
self.cmap_dim = cmap_dim
|
82 |
+
self.resolution = resolution
|
83 |
+
self.img_channels = img_channels
|
84 |
+
self.architecture = architecture
|
85 |
+
|
86 |
+
if architecture == 'skip':
|
87 |
+
self.fromrgb = Conv2dLayer(self.img_channels, in_channels, kernel_size=1, activation=activation)
|
88 |
+
self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size,
|
89 |
+
num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
|
90 |
+
self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation,
|
91 |
+
conv_clamp=conv_clamp)
|
92 |
+
self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), z_dim, activation=activation)
|
93 |
+
self.dropout = torch.nn.Dropout(p=0.5)
|
94 |
+
|
95 |
+
def forward(self, x, cmap, force_fp32=False):
|
96 |
+
_ = force_fp32 # unused
|
97 |
+
dtype = torch.float32
|
98 |
+
memory_format = torch.contiguous_format
|
99 |
+
|
100 |
+
# FromRGB.
|
101 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
102 |
+
|
103 |
+
# Main layers.
|
104 |
+
if self.mbstd is not None:
|
105 |
+
x = self.mbstd(x)
|
106 |
+
const_e = self.conv(x)
|
107 |
+
x = self.fc(const_e.flatten(1))
|
108 |
+
x = self.dropout(x)
|
109 |
+
|
110 |
+
# Conditioning.
|
111 |
+
if self.cmap_dim > 0:
|
112 |
+
x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
113 |
+
|
114 |
+
assert x.dtype == dtype
|
115 |
+
return x, const_e
|
116 |
+
|
117 |
+
|
118 |
+
class EncoderBlock(torch.nn.Module):
|
119 |
+
def __init__(self,
|
120 |
+
in_channels, # Number of input channels, 0 = first block.
|
121 |
+
tmp_channels, # Number of intermediate channels.
|
122 |
+
out_channels, # Number of output channels.
|
123 |
+
resolution, # Resolution of this block.
|
124 |
+
img_channels, # Number of input color channels.
|
125 |
+
first_layer_idx, # Index of the first layer.
|
126 |
+
architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
|
127 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
128 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
129 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
130 |
+
use_fp16=False, # Use FP16 for this block?
|
131 |
+
fp16_channels_last=False, # Use channels-last memory format with FP16?
|
132 |
+
freeze_layers=0, # Freeze-D: Number of layers to freeze.
|
133 |
+
):
|
134 |
+
assert in_channels in [0, tmp_channels]
|
135 |
+
assert architecture in ['orig', 'skip', 'resnet']
|
136 |
+
super().__init__()
|
137 |
+
self.in_channels = in_channels
|
138 |
+
self.resolution = resolution
|
139 |
+
self.img_channels = img_channels + 1
|
140 |
+
self.first_layer_idx = first_layer_idx
|
141 |
+
self.architecture = architecture
|
142 |
+
self.use_fp16 = use_fp16
|
143 |
+
self.channels_last = (use_fp16 and fp16_channels_last)
|
144 |
+
self.register_buffer('resample_filter', setup_filter(resample_filter))
|
145 |
+
|
146 |
+
self.num_layers = 0
|
147 |
+
|
148 |
+
def trainable_gen():
|
149 |
+
while True:
|
150 |
+
layer_idx = self.first_layer_idx + self.num_layers
|
151 |
+
trainable = (layer_idx >= freeze_layers)
|
152 |
+
self.num_layers += 1
|
153 |
+
yield trainable
|
154 |
+
|
155 |
+
trainable_iter = trainable_gen()
|
156 |
+
|
157 |
+
if in_channels == 0:
|
158 |
+
self.fromrgb = Conv2dLayer(self.img_channels, tmp_channels, kernel_size=1, activation=activation,
|
159 |
+
trainable=next(trainable_iter), conv_clamp=conv_clamp,
|
160 |
+
channels_last=self.channels_last)
|
161 |
+
|
162 |
+
self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
|
163 |
+
trainable=next(trainable_iter), conv_clamp=conv_clamp,
|
164 |
+
channels_last=self.channels_last)
|
165 |
+
|
166 |
+
self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
|
167 |
+
trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp,
|
168 |
+
channels_last=self.channels_last)
|
169 |
+
|
170 |
+
if architecture == 'resnet':
|
171 |
+
self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
|
172 |
+
trainable=next(trainable_iter), resample_filter=resample_filter,
|
173 |
+
channels_last=self.channels_last)
|
174 |
+
|
175 |
+
def forward(self, x, img, force_fp32=False):
|
176 |
+
# dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
177 |
+
dtype = torch.float32
|
178 |
+
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
179 |
+
|
180 |
+
# Input.
|
181 |
+
if x is not None:
|
182 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
183 |
+
|
184 |
+
# FromRGB.
|
185 |
+
if self.in_channels == 0:
|
186 |
+
img = img.to(dtype=dtype, memory_format=memory_format)
|
187 |
+
y = self.fromrgb(img)
|
188 |
+
x = x + y if x is not None else y
|
189 |
+
img = downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
|
190 |
+
|
191 |
+
# Main layers.
|
192 |
+
if self.architecture == 'resnet':
|
193 |
+
y = self.skip(x, gain=np.sqrt(0.5))
|
194 |
+
x = self.conv0(x)
|
195 |
+
feat = x.clone()
|
196 |
+
x = self.conv1(x, gain=np.sqrt(0.5))
|
197 |
+
x = y.add_(x)
|
198 |
+
else:
|
199 |
+
x = self.conv0(x)
|
200 |
+
feat = x.clone()
|
201 |
+
x = self.conv1(x)
|
202 |
+
|
203 |
+
assert x.dtype == dtype
|
204 |
+
return x, img, feat
|
205 |
+
|
206 |
+
|
207 |
+
class EncoderNetwork(torch.nn.Module):
|
208 |
+
def __init__(self,
|
209 |
+
c_dim, # Conditioning label (C) dimensionality.
|
210 |
+
z_dim, # Input latent (Z) dimensionality.
|
211 |
+
img_resolution, # Input resolution.
|
212 |
+
img_channels, # Number of input color channels.
|
213 |
+
architecture='orig', # Architecture: 'orig', 'skip', 'resnet'.
|
214 |
+
channel_base=16384, # Overall multiplier for the number of channels.
|
215 |
+
channel_max=512, # Maximum number of channels in any layer.
|
216 |
+
num_fp16_res=0, # Use FP16 for the N highest resolutions.
|
217 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
218 |
+
cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
|
219 |
+
block_kwargs={}, # Arguments for DiscriminatorBlock.
|
220 |
+
mapping_kwargs={}, # Arguments for MappingNetwork.
|
221 |
+
epilogue_kwargs={}, # Arguments for EncoderEpilogue.
|
222 |
+
):
|
223 |
+
super().__init__()
|
224 |
+
self.c_dim = c_dim
|
225 |
+
self.z_dim = z_dim
|
226 |
+
self.img_resolution = img_resolution
|
227 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
228 |
+
self.img_channels = img_channels
|
229 |
+
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
|
230 |
+
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
|
231 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
232 |
+
|
233 |
+
if cmap_dim is None:
|
234 |
+
cmap_dim = channels_dict[4]
|
235 |
+
if c_dim == 0:
|
236 |
+
cmap_dim = 0
|
237 |
+
|
238 |
+
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
|
239 |
+
cur_layer_idx = 0
|
240 |
+
for res in self.block_resolutions:
|
241 |
+
in_channels = channels_dict[res] if res < img_resolution else 0
|
242 |
+
tmp_channels = channels_dict[res]
|
243 |
+
out_channels = channels_dict[res // 2]
|
244 |
+
use_fp16 = (res >= fp16_resolution)
|
245 |
+
use_fp16 = False
|
246 |
+
block = EncoderBlock(in_channels, tmp_channels, out_channels, resolution=res,
|
247 |
+
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
|
248 |
+
setattr(self, f'b{res}', block)
|
249 |
+
cur_layer_idx += block.num_layers
|
250 |
+
if c_dim > 0:
|
251 |
+
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None,
|
252 |
+
**mapping_kwargs)
|
253 |
+
self.b4 = EncoderEpilogue(channels_dict[4], cmap_dim=cmap_dim, z_dim=z_dim * 2, resolution=4, **epilogue_kwargs,
|
254 |
+
**common_kwargs)
|
255 |
+
|
256 |
+
def forward(self, img, c, **block_kwargs):
|
257 |
+
x = None
|
258 |
+
feats = {}
|
259 |
+
for res in self.block_resolutions:
|
260 |
+
block = getattr(self, f'b{res}')
|
261 |
+
x, img, feat = block(x, img, **block_kwargs)
|
262 |
+
feats[res] = feat
|
263 |
+
|
264 |
+
cmap = None
|
265 |
+
if self.c_dim > 0:
|
266 |
+
cmap = self.mapping(None, c)
|
267 |
+
x, const_e = self.b4(x, cmap)
|
268 |
+
feats[4] = const_e
|
269 |
+
|
270 |
+
B, _ = x.shape
|
271 |
+
z = torch.zeros((B, self.z_dim), requires_grad=False, dtype=x.dtype,
|
272 |
+
device=x.device) ## Noise for Co-Modulation
|
273 |
+
return x, z, feats
|
274 |
+
|
275 |
+
|
276 |
+
def fma(a, b, c): # => a * b + c
|
277 |
+
return _FusedMultiplyAdd.apply(a, b, c)
|
278 |
+
|
279 |
+
|
280 |
+
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
|
281 |
+
@staticmethod
|
282 |
+
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
|
283 |
+
out = torch.addcmul(c, a, b)
|
284 |
+
ctx.save_for_backward(a, b)
|
285 |
+
ctx.c_shape = c.shape
|
286 |
+
return out
|
287 |
+
|
288 |
+
@staticmethod
|
289 |
+
def backward(ctx, dout): # pylint: disable=arguments-differ
|
290 |
+
a, b = ctx.saved_tensors
|
291 |
+
c_shape = ctx.c_shape
|
292 |
+
da = None
|
293 |
+
db = None
|
294 |
+
dc = None
|
295 |
+
|
296 |
+
if ctx.needs_input_grad[0]:
|
297 |
+
da = _unbroadcast(dout * b, a.shape)
|
298 |
+
|
299 |
+
if ctx.needs_input_grad[1]:
|
300 |
+
db = _unbroadcast(dout * a, b.shape)
|
301 |
+
|
302 |
+
if ctx.needs_input_grad[2]:
|
303 |
+
dc = _unbroadcast(dout, c_shape)
|
304 |
+
|
305 |
+
return da, db, dc
|
306 |
+
|
307 |
+
|
308 |
+
def _unbroadcast(x, shape):
|
309 |
+
extra_dims = x.ndim - len(shape)
|
310 |
+
assert extra_dims >= 0
|
311 |
+
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
|
312 |
+
if len(dim):
|
313 |
+
x = x.sum(dim=dim, keepdim=True)
|
314 |
+
if extra_dims:
|
315 |
+
x = x.reshape(-1, *x.shape[extra_dims + 1:])
|
316 |
+
assert x.shape == shape
|
317 |
+
return x
|
318 |
+
|
319 |
+
|
320 |
+
def modulated_conv2d(
|
321 |
+
x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
|
322 |
+
weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
|
323 |
+
styles, # Modulation coefficients of shape [batch_size, in_channels].
|
324 |
+
noise=None, # Optional noise tensor to add to the output activations.
|
325 |
+
up=1, # Integer upsampling factor.
|
326 |
+
down=1, # Integer downsampling factor.
|
327 |
+
padding=0, # Padding with respect to the upsampled image.
|
328 |
+
resample_filter=None,
|
329 |
+
# Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
|
330 |
+
demodulate=True, # Apply weight demodulation?
|
331 |
+
flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
|
332 |
+
fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation?
|
333 |
+
):
|
334 |
+
batch_size = x.shape[0]
|
335 |
+
out_channels, in_channels, kh, kw = weight.shape
|
336 |
+
|
337 |
+
# Pre-normalize inputs to avoid FP16 overflow.
|
338 |
+
if x.dtype == torch.float16 and demodulate:
|
339 |
+
weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1, 2, 3],
|
340 |
+
keepdim=True)) # max_Ikk
|
341 |
+
styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
|
342 |
+
|
343 |
+
# Calculate per-sample weights and demodulation coefficients.
|
344 |
+
w = None
|
345 |
+
dcoefs = None
|
346 |
+
if demodulate or fused_modconv:
|
347 |
+
w = weight.unsqueeze(0) # [NOIkk]
|
348 |
+
w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
|
349 |
+
if demodulate:
|
350 |
+
dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO]
|
351 |
+
if demodulate and fused_modconv:
|
352 |
+
w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
|
353 |
+
# Execute by scaling the activations before and after the convolution.
|
354 |
+
if not fused_modconv:
|
355 |
+
x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
|
356 |
+
x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down,
|
357 |
+
padding=padding, flip_weight=flip_weight)
|
358 |
+
if demodulate and noise is not None:
|
359 |
+
x = fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
|
360 |
+
elif demodulate:
|
361 |
+
x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
|
362 |
+
elif noise is not None:
|
363 |
+
x = x.add_(noise.to(x.dtype))
|
364 |
+
return x
|
365 |
+
|
366 |
+
# Execute as one fused op using grouped convolution.
|
367 |
+
batch_size = int(batch_size)
|
368 |
+
x = x.reshape(1, -1, *x.shape[2:])
|
369 |
+
w = w.reshape(-1, in_channels, kh, kw)
|
370 |
+
x = conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding,
|
371 |
+
groups=batch_size, flip_weight=flip_weight)
|
372 |
+
x = x.reshape(batch_size, -1, *x.shape[2:])
|
373 |
+
if noise is not None:
|
374 |
+
x = x.add_(noise)
|
375 |
+
return x
|
376 |
+
|
377 |
+
|
378 |
+
class SynthesisLayer(torch.nn.Module):
|
379 |
+
def __init__(self,
|
380 |
+
in_channels, # Number of input channels.
|
381 |
+
out_channels, # Number of output channels.
|
382 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
383 |
+
resolution, # Resolution of this layer.
|
384 |
+
kernel_size=3, # Convolution kernel size.
|
385 |
+
up=1, # Integer upsampling factor.
|
386 |
+
use_noise=True, # Enable noise input?
|
387 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
388 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
389 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
390 |
+
channels_last=False, # Use channels_last format for the weights?
|
391 |
+
):
|
392 |
+
super().__init__()
|
393 |
+
self.resolution = resolution
|
394 |
+
self.up = up
|
395 |
+
self.use_noise = use_noise
|
396 |
+
self.activation = activation
|
397 |
+
self.conv_clamp = conv_clamp
|
398 |
+
self.register_buffer('resample_filter', setup_filter(resample_filter))
|
399 |
+
self.padding = kernel_size // 2
|
400 |
+
self.act_gain = activation_funcs[activation].def_gain
|
401 |
+
|
402 |
+
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
403 |
+
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
404 |
+
self.weight = torch.nn.Parameter(
|
405 |
+
torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
|
406 |
+
if use_noise:
|
407 |
+
self.register_buffer('noise_const', torch.randn([resolution, resolution]))
|
408 |
+
self.noise_strength = torch.nn.Parameter(torch.zeros([]))
|
409 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
410 |
+
|
411 |
+
def forward(self, x, w, noise_mode='none', fused_modconv=True, gain=1):
|
412 |
+
assert noise_mode in ['random', 'const', 'none']
|
413 |
+
in_resolution = self.resolution // self.up
|
414 |
+
styles = self.affine(w)
|
415 |
+
|
416 |
+
noise = None
|
417 |
+
if self.use_noise and noise_mode == 'random':
|
418 |
+
noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution],
|
419 |
+
device=x.device) * self.noise_strength
|
420 |
+
if self.use_noise and noise_mode == 'const':
|
421 |
+
noise = self.noise_const * self.noise_strength
|
422 |
+
|
423 |
+
flip_weight = (self.up == 1) # slightly faster
|
424 |
+
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
|
425 |
+
padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight,
|
426 |
+
fused_modconv=fused_modconv)
|
427 |
+
|
428 |
+
act_gain = self.act_gain * gain
|
429 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
430 |
+
x = F.leaky_relu(x, negative_slope=0.2, inplace=False)
|
431 |
+
if act_gain != 1:
|
432 |
+
x = x * act_gain
|
433 |
+
if act_clamp is not None:
|
434 |
+
x = x.clamp(-act_clamp, act_clamp)
|
435 |
+
return x
|
436 |
+
|
437 |
+
|
438 |
+
class ToRGBLayer(torch.nn.Module):
|
439 |
+
def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
|
440 |
+
super().__init__()
|
441 |
+
self.conv_clamp = conv_clamp
|
442 |
+
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
443 |
+
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
444 |
+
self.weight = torch.nn.Parameter(
|
445 |
+
torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
|
446 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
447 |
+
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
448 |
+
|
449 |
+
def forward(self, x, w, fused_modconv=True):
|
450 |
+
styles = self.affine(w) * self.weight_gain
|
451 |
+
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
|
452 |
+
x = bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
|
453 |
+
return x
|
454 |
+
|
455 |
+
|
456 |
+
class SynthesisForeword(torch.nn.Module):
|
457 |
+
def __init__(self,
|
458 |
+
z_dim, # Output Latent (Z) dimensionality.
|
459 |
+
resolution, # Resolution of this block.
|
460 |
+
in_channels,
|
461 |
+
img_channels, # Number of input color channels.
|
462 |
+
architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
|
463 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
464 |
+
|
465 |
+
):
|
466 |
+
super().__init__()
|
467 |
+
self.in_channels = in_channels
|
468 |
+
self.z_dim = z_dim
|
469 |
+
self.resolution = resolution
|
470 |
+
self.img_channels = img_channels
|
471 |
+
self.architecture = architecture
|
472 |
+
|
473 |
+
self.fc = FullyConnectedLayer(self.z_dim, (self.z_dim // 2) * 4 * 4, activation=activation)
|
474 |
+
self.conv = SynthesisLayer(self.in_channels, self.in_channels, w_dim=(z_dim // 2) * 3, resolution=4)
|
475 |
+
|
476 |
+
if architecture == 'skip':
|
477 |
+
self.torgb = ToRGBLayer(self.in_channels, self.img_channels, kernel_size=1, w_dim=(z_dim // 2) * 3)
|
478 |
+
|
479 |
+
def forward(self, x, ws, feats, img, force_fp32=False):
|
480 |
+
_ = force_fp32 # unused
|
481 |
+
dtype = torch.float32
|
482 |
+
memory_format = torch.contiguous_format
|
483 |
+
|
484 |
+
x_global = x.clone()
|
485 |
+
# ToRGB.
|
486 |
+
x = self.fc(x)
|
487 |
+
x = x.view(-1, self.z_dim // 2, 4, 4)
|
488 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
489 |
+
|
490 |
+
# Main layers.
|
491 |
+
x_skip = feats[4].clone()
|
492 |
+
x = x + x_skip
|
493 |
+
|
494 |
+
mod_vector = []
|
495 |
+
mod_vector.append(ws[:, 0])
|
496 |
+
mod_vector.append(x_global.clone())
|
497 |
+
mod_vector = torch.cat(mod_vector, dim=1)
|
498 |
+
|
499 |
+
x = self.conv(x, mod_vector)
|
500 |
+
|
501 |
+
mod_vector = []
|
502 |
+
mod_vector.append(ws[:, 2 * 2 - 3])
|
503 |
+
mod_vector.append(x_global.clone())
|
504 |
+
mod_vector = torch.cat(mod_vector, dim=1)
|
505 |
+
|
506 |
+
if self.architecture == 'skip':
|
507 |
+
img = self.torgb(x, mod_vector)
|
508 |
+
img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format)
|
509 |
+
|
510 |
+
assert x.dtype == dtype
|
511 |
+
return x, img
|
512 |
+
|
513 |
+
|
514 |
+
class SELayer(nn.Module):
|
515 |
+
def __init__(self, channel, reduction=16):
|
516 |
+
super(SELayer, self).__init__()
|
517 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
518 |
+
self.fc = nn.Sequential(
|
519 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
520 |
+
nn.ReLU(inplace=False),
|
521 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
522 |
+
nn.Sigmoid()
|
523 |
+
)
|
524 |
+
|
525 |
+
def forward(self, x):
|
526 |
+
b, c, _, _ = x.size()
|
527 |
+
y = self.avg_pool(x).view(b, c)
|
528 |
+
y = self.fc(y).view(b, c, 1, 1)
|
529 |
+
res = x * y.expand_as(x)
|
530 |
+
return res
|
531 |
+
|
532 |
+
|
533 |
+
class FourierUnit(nn.Module):
|
534 |
+
|
535 |
+
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
|
536 |
+
spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
|
537 |
+
# bn_layer not used
|
538 |
+
super(FourierUnit, self).__init__()
|
539 |
+
self.groups = groups
|
540 |
+
|
541 |
+
self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
|
542 |
+
out_channels=out_channels * 2,
|
543 |
+
kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
|
544 |
+
self.relu = torch.nn.ReLU(inplace=False)
|
545 |
+
|
546 |
+
# squeeze and excitation block
|
547 |
+
self.use_se = use_se
|
548 |
+
if use_se:
|
549 |
+
if se_kwargs is None:
|
550 |
+
se_kwargs = {}
|
551 |
+
self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
|
552 |
+
|
553 |
+
self.spatial_scale_factor = spatial_scale_factor
|
554 |
+
self.spatial_scale_mode = spatial_scale_mode
|
555 |
+
self.spectral_pos_encoding = spectral_pos_encoding
|
556 |
+
self.ffc3d = ffc3d
|
557 |
+
self.fft_norm = fft_norm
|
558 |
+
|
559 |
+
def forward(self, x):
|
560 |
+
batch = x.shape[0]
|
561 |
+
|
562 |
+
if self.spatial_scale_factor is not None:
|
563 |
+
orig_size = x.shape[-2:]
|
564 |
+
x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode,
|
565 |
+
align_corners=False)
|
566 |
+
|
567 |
+
r_size = x.size()
|
568 |
+
# (batch, c, h, w/2+1, 2)
|
569 |
+
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
|
570 |
+
ffted = fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
|
571 |
+
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
|
572 |
+
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
|
573 |
+
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
|
574 |
+
|
575 |
+
if self.spectral_pos_encoding:
|
576 |
+
height, width = ffted.shape[-2:]
|
577 |
+
coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
|
578 |
+
coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
|
579 |
+
ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
|
580 |
+
|
581 |
+
if self.use_se:
|
582 |
+
ffted = self.se(ffted)
|
583 |
+
|
584 |
+
ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
|
585 |
+
ffted = self.relu(ffted)
|
586 |
+
|
587 |
+
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
|
588 |
+
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
|
589 |
+
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
|
590 |
+
|
591 |
+
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
|
592 |
+
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
|
593 |
+
|
594 |
+
if self.spatial_scale_factor is not None:
|
595 |
+
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
|
596 |
+
|
597 |
+
return output
|
598 |
+
|
599 |
+
|
600 |
+
class SpectralTransform(nn.Module):
|
601 |
+
|
602 |
+
def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs):
|
603 |
+
# bn_layer not used
|
604 |
+
super(SpectralTransform, self).__init__()
|
605 |
+
self.enable_lfu = enable_lfu
|
606 |
+
if stride == 2:
|
607 |
+
self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
|
608 |
+
else:
|
609 |
+
self.downsample = nn.Identity()
|
610 |
+
|
611 |
+
self.stride = stride
|
612 |
+
self.conv1 = nn.Sequential(
|
613 |
+
nn.Conv2d(in_channels, out_channels //
|
614 |
+
2, kernel_size=1, groups=groups, bias=False),
|
615 |
+
# nn.BatchNorm2d(out_channels // 2),
|
616 |
+
nn.ReLU(inplace=True)
|
617 |
+
)
|
618 |
+
self.fu = FourierUnit(
|
619 |
+
out_channels // 2, out_channels // 2, groups, **fu_kwargs)
|
620 |
+
if self.enable_lfu:
|
621 |
+
self.lfu = FourierUnit(
|
622 |
+
out_channels // 2, out_channels // 2, groups)
|
623 |
+
self.conv2 = torch.nn.Conv2d(
|
624 |
+
out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
|
625 |
+
|
626 |
+
def forward(self, x):
|
627 |
+
|
628 |
+
x = self.downsample(x)
|
629 |
+
x = self.conv1(x)
|
630 |
+
output = self.fu(x)
|
631 |
+
|
632 |
+
if self.enable_lfu:
|
633 |
+
n, c, h, w = x.shape
|
634 |
+
split_no = 2
|
635 |
+
split_s = h // split_no
|
636 |
+
xs = torch.cat(torch.split(
|
637 |
+
x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
|
638 |
+
xs = torch.cat(torch.split(xs, split_s, dim=-1),
|
639 |
+
dim=1).contiguous()
|
640 |
+
xs = self.lfu(xs)
|
641 |
+
xs = xs.repeat(1, 1, split_no, split_no).contiguous()
|
642 |
+
else:
|
643 |
+
xs = 0
|
644 |
+
|
645 |
+
output = self.conv2(x + output + xs)
|
646 |
+
|
647 |
+
return output
|
648 |
+
|
649 |
+
|
650 |
+
class FFC(nn.Module):
|
651 |
+
|
652 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
653 |
+
ratio_gin, ratio_gout, stride=1, padding=0,
|
654 |
+
dilation=1, groups=1, bias=False, enable_lfu=True,
|
655 |
+
padding_type='reflect', gated=False, **spectral_kwargs):
|
656 |
+
super(FFC, self).__init__()
|
657 |
+
|
658 |
+
assert stride == 1 or stride == 2, "Stride should be 1 or 2."
|
659 |
+
self.stride = stride
|
660 |
+
|
661 |
+
in_cg = int(in_channels * ratio_gin)
|
662 |
+
in_cl = in_channels - in_cg
|
663 |
+
out_cg = int(out_channels * ratio_gout)
|
664 |
+
out_cl = out_channels - out_cg
|
665 |
+
# groups_g = 1 if groups == 1 else int(groups * ratio_gout)
|
666 |
+
# groups_l = 1 if groups == 1 else groups - groups_g
|
667 |
+
|
668 |
+
self.ratio_gin = ratio_gin
|
669 |
+
self.ratio_gout = ratio_gout
|
670 |
+
self.global_in_num = in_cg
|
671 |
+
|
672 |
+
module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
|
673 |
+
self.convl2l = module(in_cl, out_cl, kernel_size,
|
674 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
675 |
+
module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
|
676 |
+
self.convl2g = module(in_cl, out_cg, kernel_size,
|
677 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
678 |
+
module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
|
679 |
+
self.convg2l = module(in_cg, out_cl, kernel_size,
|
680 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
681 |
+
module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
|
682 |
+
self.convg2g = module(
|
683 |
+
in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
|
684 |
+
|
685 |
+
self.gated = gated
|
686 |
+
module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
|
687 |
+
self.gate = module(in_channels, 2, 1)
|
688 |
+
|
689 |
+
def forward(self, x, fname=None):
|
690 |
+
x_l, x_g = x if type(x) is tuple else (x, 0)
|
691 |
+
out_xl, out_xg = 0, 0
|
692 |
+
|
693 |
+
if self.gated:
|
694 |
+
total_input_parts = [x_l]
|
695 |
+
if torch.is_tensor(x_g):
|
696 |
+
total_input_parts.append(x_g)
|
697 |
+
total_input = torch.cat(total_input_parts, dim=1)
|
698 |
+
|
699 |
+
gates = torch.sigmoid(self.gate(total_input))
|
700 |
+
g2l_gate, l2g_gate = gates.chunk(2, dim=1)
|
701 |
+
else:
|
702 |
+
g2l_gate, l2g_gate = 1, 1
|
703 |
+
|
704 |
+
spec_x = self.convg2g(x_g)
|
705 |
+
|
706 |
+
if self.ratio_gout != 1:
|
707 |
+
out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
|
708 |
+
if self.ratio_gout != 0:
|
709 |
+
out_xg = self.convl2g(x_l) * l2g_gate + spec_x
|
710 |
+
|
711 |
+
return out_xl, out_xg
|
712 |
+
|
713 |
+
|
714 |
+
class FFC_BN_ACT(nn.Module):
|
715 |
+
|
716 |
+
def __init__(self, in_channels, out_channels,
|
717 |
+
kernel_size, ratio_gin, ratio_gout,
|
718 |
+
stride=1, padding=0, dilation=1, groups=1, bias=False,
|
719 |
+
norm_layer=nn.SyncBatchNorm, activation_layer=nn.Identity,
|
720 |
+
padding_type='reflect',
|
721 |
+
enable_lfu=True, **kwargs):
|
722 |
+
super(FFC_BN_ACT, self).__init__()
|
723 |
+
self.ffc = FFC(in_channels, out_channels, kernel_size,
|
724 |
+
ratio_gin, ratio_gout, stride, padding, dilation,
|
725 |
+
groups, bias, enable_lfu, padding_type=padding_type, **kwargs)
|
726 |
+
lnorm = nn.Identity if ratio_gout == 1 else norm_layer
|
727 |
+
gnorm = nn.Identity if ratio_gout == 0 else norm_layer
|
728 |
+
global_channels = int(out_channels * ratio_gout)
|
729 |
+
# self.bn_l = lnorm(out_channels - global_channels)
|
730 |
+
# self.bn_g = gnorm(global_channels)
|
731 |
+
|
732 |
+
lact = nn.Identity if ratio_gout == 1 else activation_layer
|
733 |
+
gact = nn.Identity if ratio_gout == 0 else activation_layer
|
734 |
+
self.act_l = lact(inplace=True)
|
735 |
+
self.act_g = gact(inplace=True)
|
736 |
+
|
737 |
+
def forward(self, x, fname=None):
|
738 |
+
x_l, x_g = self.ffc(x, fname=fname, )
|
739 |
+
x_l = self.act_l(x_l)
|
740 |
+
x_g = self.act_g(x_g)
|
741 |
+
return x_l, x_g
|
742 |
+
|
743 |
+
|
744 |
+
class FFCResnetBlock(nn.Module):
|
745 |
+
def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1,
|
746 |
+
spatial_transform_kwargs=None, inline=False, ratio_gin=0.75, ratio_gout=0.75):
|
747 |
+
super().__init__()
|
748 |
+
self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
|
749 |
+
norm_layer=norm_layer,
|
750 |
+
activation_layer=activation_layer,
|
751 |
+
padding_type=padding_type,
|
752 |
+
ratio_gin=ratio_gin, ratio_gout=ratio_gout)
|
753 |
+
self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
|
754 |
+
norm_layer=norm_layer,
|
755 |
+
activation_layer=activation_layer,
|
756 |
+
padding_type=padding_type,
|
757 |
+
ratio_gin=ratio_gin, ratio_gout=ratio_gout)
|
758 |
+
self.inline = inline
|
759 |
+
|
760 |
+
def forward(self, x, fname=None):
|
761 |
+
if self.inline:
|
762 |
+
x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
|
763 |
+
else:
|
764 |
+
x_l, x_g = x if type(x) is tuple else (x, 0)
|
765 |
+
|
766 |
+
id_l, id_g = x_l, x_g
|
767 |
+
|
768 |
+
x_l, x_g = self.conv1((x_l, x_g), fname=fname)
|
769 |
+
x_l, x_g = self.conv2((x_l, x_g), fname=fname)
|
770 |
+
|
771 |
+
x_l, x_g = id_l + x_l, id_g + x_g
|
772 |
+
out = x_l, x_g
|
773 |
+
if self.inline:
|
774 |
+
out = torch.cat(out, dim=1)
|
775 |
+
return out
|
776 |
+
|
777 |
+
|
778 |
+
class ConcatTupleLayer(nn.Module):
|
779 |
+
def forward(self, x):
|
780 |
+
assert isinstance(x, tuple)
|
781 |
+
x_l, x_g = x
|
782 |
+
assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
|
783 |
+
if not torch.is_tensor(x_g):
|
784 |
+
return x_l
|
785 |
+
return torch.cat(x, dim=1)
|
786 |
+
|
787 |
+
|
788 |
+
class FFCBlock(torch.nn.Module):
|
789 |
+
def __init__(self,
|
790 |
+
dim, # Number of output/input channels.
|
791 |
+
kernel_size, # Width and height of the convolution kernel.
|
792 |
+
padding,
|
793 |
+
ratio_gin=0.75,
|
794 |
+
ratio_gout=0.75,
|
795 |
+
activation='linear', # Activation function: 'relu', 'lrelu', etc.
|
796 |
+
):
|
797 |
+
super().__init__()
|
798 |
+
if activation == 'linear':
|
799 |
+
self.activation = nn.Identity
|
800 |
+
else:
|
801 |
+
self.activation = nn.ReLU
|
802 |
+
self.padding = padding
|
803 |
+
self.kernel_size = kernel_size
|
804 |
+
self.ffc_block = FFCResnetBlock(dim=dim,
|
805 |
+
padding_type='reflect',
|
806 |
+
norm_layer=nn.SyncBatchNorm,
|
807 |
+
activation_layer=self.activation,
|
808 |
+
dilation=1,
|
809 |
+
ratio_gin=ratio_gin,
|
810 |
+
ratio_gout=ratio_gout)
|
811 |
+
|
812 |
+
self.concat_layer = ConcatTupleLayer()
|
813 |
+
|
814 |
+
def forward(self, gen_ft, mask, fname=None):
|
815 |
+
x = gen_ft.float()
|
816 |
+
|
817 |
+
x_l, x_g = x[:, :-self.ffc_block.conv1.ffc.global_in_num], x[:, -self.ffc_block.conv1.ffc.global_in_num:]
|
818 |
+
id_l, id_g = x_l, x_g
|
819 |
+
|
820 |
+
x_l, x_g = self.ffc_block((x_l, x_g), fname=fname)
|
821 |
+
x_l, x_g = id_l + x_l, id_g + x_g
|
822 |
+
x = self.concat_layer((x_l, x_g))
|
823 |
+
|
824 |
+
return x + gen_ft.float()
|
825 |
+
|
826 |
+
|
827 |
+
class FFCSkipLayer(torch.nn.Module):
|
828 |
+
def __init__(self,
|
829 |
+
dim, # Number of input/output channels.
|
830 |
+
kernel_size=3, # Convolution kernel size.
|
831 |
+
ratio_gin=0.75,
|
832 |
+
ratio_gout=0.75,
|
833 |
+
):
|
834 |
+
super().__init__()
|
835 |
+
self.padding = kernel_size // 2
|
836 |
+
|
837 |
+
self.ffc_act = FFCBlock(dim=dim, kernel_size=kernel_size, activation=nn.ReLU,
|
838 |
+
padding=self.padding, ratio_gin=ratio_gin, ratio_gout=ratio_gout)
|
839 |
+
|
840 |
+
def forward(self, gen_ft, mask, fname=None):
|
841 |
+
x = self.ffc_act(gen_ft, mask, fname=fname)
|
842 |
+
return x
|
843 |
+
|
844 |
+
|
845 |
+
class SynthesisBlock(torch.nn.Module):
|
846 |
+
def __init__(self,
|
847 |
+
in_channels, # Number of input channels, 0 = first block.
|
848 |
+
out_channels, # Number of output channels.
|
849 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
850 |
+
resolution, # Resolution of this block.
|
851 |
+
img_channels, # Number of output color channels.
|
852 |
+
is_last, # Is this the last block?
|
853 |
+
architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
|
854 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
855 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
856 |
+
use_fp16=False, # Use FP16 for this block?
|
857 |
+
fp16_channels_last=False, # Use channels-last memory format with FP16?
|
858 |
+
**layer_kwargs, # Arguments for SynthesisLayer.
|
859 |
+
):
|
860 |
+
assert architecture in ['orig', 'skip', 'resnet']
|
861 |
+
super().__init__()
|
862 |
+
self.in_channels = in_channels
|
863 |
+
self.w_dim = w_dim
|
864 |
+
self.resolution = resolution
|
865 |
+
self.img_channels = img_channels
|
866 |
+
self.is_last = is_last
|
867 |
+
self.architecture = architecture
|
868 |
+
self.use_fp16 = use_fp16
|
869 |
+
self.channels_last = (use_fp16 and fp16_channels_last)
|
870 |
+
self.register_buffer('resample_filter', setup_filter(resample_filter))
|
871 |
+
self.num_conv = 0
|
872 |
+
self.num_torgb = 0
|
873 |
+
self.res_ffc = {4: 0, 8: 0, 16: 0, 32: 1, 64: 1, 128: 1, 256: 1, 512: 1}
|
874 |
+
|
875 |
+
if in_channels != 0 and resolution >= 8:
|
876 |
+
self.ffc_skip = nn.ModuleList()
|
877 |
+
for _ in range(self.res_ffc[resolution]):
|
878 |
+
self.ffc_skip.append(FFCSkipLayer(dim=out_channels))
|
879 |
+
|
880 |
+
if in_channels == 0:
|
881 |
+
self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
|
882 |
+
|
883 |
+
if in_channels != 0:
|
884 |
+
self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim * 3, resolution=resolution, up=2,
|
885 |
+
resample_filter=resample_filter, conv_clamp=conv_clamp,
|
886 |
+
channels_last=self.channels_last, **layer_kwargs)
|
887 |
+
self.num_conv += 1
|
888 |
+
|
889 |
+
self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim * 3, resolution=resolution,
|
890 |
+
conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
|
891 |
+
self.num_conv += 1
|
892 |
+
|
893 |
+
if is_last or architecture == 'skip':
|
894 |
+
self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim * 3,
|
895 |
+
conv_clamp=conv_clamp, channels_last=self.channels_last)
|
896 |
+
self.num_torgb += 1
|
897 |
+
|
898 |
+
if in_channels != 0 and architecture == 'resnet':
|
899 |
+
self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
|
900 |
+
resample_filter=resample_filter, channels_last=self.channels_last)
|
901 |
+
|
902 |
+
def forward(self, x, mask, feats, img, ws, fname=None, force_fp32=False, fused_modconv=None, **layer_kwargs):
|
903 |
+
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
904 |
+
dtype = torch.float32
|
905 |
+
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
906 |
+
if fused_modconv is None:
|
907 |
+
fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
|
908 |
+
|
909 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
910 |
+
x_skip = feats[self.resolution].clone().to(dtype=dtype, memory_format=memory_format)
|
911 |
+
|
912 |
+
# Main layers.
|
913 |
+
if self.in_channels == 0:
|
914 |
+
x = self.conv1(x, ws[1], fused_modconv=fused_modconv, **layer_kwargs)
|
915 |
+
elif self.architecture == 'resnet':
|
916 |
+
y = self.skip(x, gain=np.sqrt(0.5))
|
917 |
+
x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs)
|
918 |
+
if len(self.ffc_skip) > 0:
|
919 |
+
mask = F.interpolate(mask, size=x_skip.shape[2:], )
|
920 |
+
z = x + x_skip
|
921 |
+
for fres in self.ffc_skip:
|
922 |
+
z = fres(z, mask)
|
923 |
+
x = x + z
|
924 |
+
else:
|
925 |
+
x = x + x_skip
|
926 |
+
x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
|
927 |
+
x = y.add_(x)
|
928 |
+
else:
|
929 |
+
x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs)
|
930 |
+
if len(self.ffc_skip) > 0:
|
931 |
+
mask = F.interpolate(mask, size=x_skip.shape[2:], )
|
932 |
+
z = x + x_skip
|
933 |
+
for fres in self.ffc_skip:
|
934 |
+
z = fres(z, mask)
|
935 |
+
x = x + z
|
936 |
+
else:
|
937 |
+
x = x + x_skip
|
938 |
+
x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, **layer_kwargs)
|
939 |
+
# ToRGB.
|
940 |
+
if img is not None:
|
941 |
+
img = upsample2d(img, self.resample_filter)
|
942 |
+
if self.is_last or self.architecture == 'skip':
|
943 |
+
y = self.torgb(x, ws[2].clone(), fused_modconv=fused_modconv)
|
944 |
+
y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
|
945 |
+
img = img.add_(y) if img is not None else y
|
946 |
+
|
947 |
+
x = x.to(dtype=dtype)
|
948 |
+
assert x.dtype == dtype
|
949 |
+
assert img is None or img.dtype == torch.float32
|
950 |
+
return x, img
|
951 |
+
|
952 |
+
|
953 |
+
class SynthesisNetwork(torch.nn.Module):
|
954 |
+
def __init__(self,
|
955 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
956 |
+
z_dim, # Output Latent (Z) dimensionality.
|
957 |
+
img_resolution, # Output image resolution.
|
958 |
+
img_channels, # Number of color channels.
|
959 |
+
channel_base=16384, # Overall multiplier for the number of channels.
|
960 |
+
channel_max=512, # Maximum number of channels in any layer.
|
961 |
+
num_fp16_res=0, # Use FP16 for the N highest resolutions.
|
962 |
+
**block_kwargs, # Arguments for SynthesisBlock.
|
963 |
+
):
|
964 |
+
assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
|
965 |
+
super().__init__()
|
966 |
+
self.w_dim = w_dim
|
967 |
+
self.img_resolution = img_resolution
|
968 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
969 |
+
self.img_channels = img_channels
|
970 |
+
self.block_resolutions = [2 ** i for i in range(3, self.img_resolution_log2 + 1)]
|
971 |
+
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
|
972 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
973 |
+
|
974 |
+
self.foreword = SynthesisForeword(img_channels=img_channels, in_channels=min(channel_base // 4, channel_max),
|
975 |
+
z_dim=z_dim * 2, resolution=4)
|
976 |
+
|
977 |
+
self.num_ws = self.img_resolution_log2 * 2 - 2
|
978 |
+
for res in self.block_resolutions:
|
979 |
+
if res // 2 in channels_dict.keys():
|
980 |
+
in_channels = channels_dict[res // 2] if res > 4 else 0
|
981 |
+
else:
|
982 |
+
in_channels = min(channel_base // (res // 2), channel_max)
|
983 |
+
out_channels = channels_dict[res]
|
984 |
+
use_fp16 = (res >= fp16_resolution)
|
985 |
+
use_fp16 = False
|
986 |
+
is_last = (res == self.img_resolution)
|
987 |
+
block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
|
988 |
+
img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
|
989 |
+
setattr(self, f'b{res}', block)
|
990 |
+
|
991 |
+
def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs):
|
992 |
+
|
993 |
+
img = None
|
994 |
+
|
995 |
+
x, img = self.foreword(x_global, ws, feats, img)
|
996 |
+
|
997 |
+
for res in self.block_resolutions:
|
998 |
+
block = getattr(self, f'b{res}')
|
999 |
+
mod_vector0 = []
|
1000 |
+
mod_vector0.append(ws[:, int(np.log2(res)) * 2 - 5])
|
1001 |
+
mod_vector0.append(x_global.clone())
|
1002 |
+
mod_vector0 = torch.cat(mod_vector0, dim=1)
|
1003 |
+
|
1004 |
+
mod_vector1 = []
|
1005 |
+
mod_vector1.append(ws[:, int(np.log2(res)) * 2 - 4])
|
1006 |
+
mod_vector1.append(x_global.clone())
|
1007 |
+
mod_vector1 = torch.cat(mod_vector1, dim=1)
|
1008 |
+
|
1009 |
+
mod_vector_rgb = []
|
1010 |
+
mod_vector_rgb.append(ws[:, int(np.log2(res)) * 2 - 3])
|
1011 |
+
mod_vector_rgb.append(x_global.clone())
|
1012 |
+
mod_vector_rgb = torch.cat(mod_vector_rgb, dim=1)
|
1013 |
+
x, img = block(x, mask, feats, img, (mod_vector0, mod_vector1, mod_vector_rgb), fname=fname, **block_kwargs)
|
1014 |
+
return img
|
1015 |
+
|
1016 |
+
|
1017 |
+
class MappingNetwork(torch.nn.Module):
|
1018 |
+
def __init__(self,
|
1019 |
+
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
|
1020 |
+
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
|
1021 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
1022 |
+
num_ws, # Number of intermediate latents to output, None = do not broadcast.
|
1023 |
+
num_layers=8, # Number of mapping layers.
|
1024 |
+
embed_features=None, # Label embedding dimensionality, None = same as w_dim.
|
1025 |
+
layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
|
1026 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
1027 |
+
lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
|
1028 |
+
w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
|
1029 |
+
):
|
1030 |
+
super().__init__()
|
1031 |
+
self.z_dim = z_dim
|
1032 |
+
self.c_dim = c_dim
|
1033 |
+
self.w_dim = w_dim
|
1034 |
+
self.num_ws = num_ws
|
1035 |
+
self.num_layers = num_layers
|
1036 |
+
self.w_avg_beta = w_avg_beta
|
1037 |
+
|
1038 |
+
if embed_features is None:
|
1039 |
+
embed_features = w_dim
|
1040 |
+
if c_dim == 0:
|
1041 |
+
embed_features = 0
|
1042 |
+
if layer_features is None:
|
1043 |
+
layer_features = w_dim
|
1044 |
+
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
|
1045 |
+
|
1046 |
+
if c_dim > 0:
|
1047 |
+
self.embed = FullyConnectedLayer(c_dim, embed_features)
|
1048 |
+
for idx in range(num_layers):
|
1049 |
+
in_features = features_list[idx]
|
1050 |
+
out_features = features_list[idx + 1]
|
1051 |
+
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
|
1052 |
+
setattr(self, f'fc{idx}', layer)
|
1053 |
+
|
1054 |
+
if num_ws is not None and w_avg_beta is not None:
|
1055 |
+
self.register_buffer('w_avg', torch.zeros([w_dim]))
|
1056 |
+
|
1057 |
+
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
|
1058 |
+
# Embed, normalize, and concat inputs.
|
1059 |
+
x = None
|
1060 |
+
with torch.autograd.profiler.record_function('input'):
|
1061 |
+
if self.z_dim > 0:
|
1062 |
+
x = normalize_2nd_moment(z.to(torch.float32))
|
1063 |
+
if self.c_dim > 0:
|
1064 |
+
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
|
1065 |
+
x = torch.cat([x, y], dim=1) if x is not None else y
|
1066 |
+
|
1067 |
+
# Main layers.
|
1068 |
+
for idx in range(self.num_layers):
|
1069 |
+
layer = getattr(self, f'fc{idx}')
|
1070 |
+
x = layer(x)
|
1071 |
+
|
1072 |
+
# Update moving average of W.
|
1073 |
+
if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
|
1074 |
+
with torch.autograd.profiler.record_function('update_w_avg'):
|
1075 |
+
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
|
1076 |
+
|
1077 |
+
# Broadcast.
|
1078 |
+
if self.num_ws is not None:
|
1079 |
+
with torch.autograd.profiler.record_function('broadcast'):
|
1080 |
+
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
|
1081 |
+
|
1082 |
+
# Apply truncation.
|
1083 |
+
if truncation_psi != 1:
|
1084 |
+
with torch.autograd.profiler.record_function('truncate'):
|
1085 |
+
assert self.w_avg_beta is not None
|
1086 |
+
if self.num_ws is None or truncation_cutoff is None:
|
1087 |
+
x = self.w_avg.lerp(x, truncation_psi)
|
1088 |
+
else:
|
1089 |
+
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
|
1090 |
+
return x
|
1091 |
+
|
1092 |
+
|
1093 |
+
class Generator(torch.nn.Module):
|
1094 |
+
def __init__(self,
|
1095 |
+
z_dim, # Input latent (Z) dimensionality.
|
1096 |
+
c_dim, # Conditioning label (C) dimensionality.
|
1097 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
1098 |
+
img_resolution, # Output resolution.
|
1099 |
+
img_channels, # Number of output color channels.
|
1100 |
+
encoder_kwargs={}, # Arguments for EncoderNetwork.
|
1101 |
+
mapping_kwargs={}, # Arguments for MappingNetwork.
|
1102 |
+
synthesis_kwargs={}, # Arguments for SynthesisNetwork.
|
1103 |
+
):
|
1104 |
+
super().__init__()
|
1105 |
+
self.z_dim = z_dim
|
1106 |
+
self.c_dim = c_dim
|
1107 |
+
self.w_dim = w_dim
|
1108 |
+
self.img_resolution = img_resolution
|
1109 |
+
self.img_channels = img_channels
|
1110 |
+
self.encoder = EncoderNetwork(c_dim=c_dim, z_dim=z_dim, img_resolution=img_resolution,
|
1111 |
+
img_channels=img_channels, **encoder_kwargs)
|
1112 |
+
self.synthesis = SynthesisNetwork(z_dim=z_dim, w_dim=w_dim, img_resolution=img_resolution,
|
1113 |
+
img_channels=img_channels, **synthesis_kwargs)
|
1114 |
+
self.num_ws = self.synthesis.num_ws
|
1115 |
+
self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
|
1116 |
+
|
1117 |
+
def forward(self, img, c, fname=None, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs):
|
1118 |
+
mask = img[:, -1].unsqueeze(1)
|
1119 |
+
x_global, z, feats = self.encoder(img, c)
|
1120 |
+
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
|
1121 |
+
img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs)
|
1122 |
+
return img
|
1123 |
+
|
1124 |
+
|
1125 |
+
FCF_MODEL_URL = os.environ.get(
|
1126 |
+
"FCF_MODEL_URL",
|
1127 |
+
"https://github.com/Sanster/models/releases/download/add_fcf/places_512_G.pth",
|
1128 |
+
)
|
1129 |
+
|
1130 |
+
|
1131 |
+
class FcF(InpaintModel):
|
1132 |
+
min_size = 512
|
1133 |
+
pad_mod = 512
|
1134 |
+
pad_to_square = True
|
1135 |
+
|
1136 |
+
def init_model(self, device, **kwargs):
|
1137 |
+
seed = 0
|
1138 |
+
random.seed(seed)
|
1139 |
+
np.random.seed(seed)
|
1140 |
+
torch.manual_seed(seed)
|
1141 |
+
torch.cuda.manual_seed_all(seed)
|
1142 |
+
torch.backends.cudnn.deterministic = True
|
1143 |
+
torch.backends.cudnn.benchmark = False
|
1144 |
+
|
1145 |
+
kwargs = {'channel_base': 1 * 32768, 'channel_max': 512, 'num_fp16_res': 4, 'conv_clamp': 256}
|
1146 |
+
G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3,
|
1147 |
+
synthesis_kwargs=kwargs, encoder_kwargs=kwargs, mapping_kwargs={'num_layers': 2})
|
1148 |
+
self.model = load_model(G, FCF_MODEL_URL, device)
|
1149 |
+
self.label = torch.zeros([1, self.model.c_dim], device=device)
|
1150 |
+
|
1151 |
+
@staticmethod
|
1152 |
+
def is_downloaded() -> bool:
|
1153 |
+
return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL))
|
1154 |
+
|
1155 |
+
@torch.no_grad()
|
1156 |
+
def __call__(self, image, mask, config: Config):
|
1157 |
+
"""
|
1158 |
+
images: [H, W, C] RGB, not normalized
|
1159 |
+
masks: [H, W]
|
1160 |
+
return: BGR IMAGE
|
1161 |
+
"""
|
1162 |
+
if image.shape[0] == 512 and image.shape[1] == 512:
|
1163 |
+
return self._pad_forward(image, mask, config)
|
1164 |
+
|
1165 |
+
boxes = boxes_from_mask(mask)
|
1166 |
+
crop_result = []
|
1167 |
+
config.hd_strategy_crop_margin = 128
|
1168 |
+
for box in boxes:
|
1169 |
+
crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config)
|
1170 |
+
origin_size = crop_image.shape[:2]
|
1171 |
+
resize_image = resize_max_size(crop_image, size_limit=512)
|
1172 |
+
resize_mask = resize_max_size(crop_mask, size_limit=512)
|
1173 |
+
inpaint_result = self._pad_forward(resize_image, resize_mask, config)
|
1174 |
+
|
1175 |
+
# only paste masked area result
|
1176 |
+
inpaint_result = cv2.resize(inpaint_result, (origin_size[1], origin_size[0]), interpolation=cv2.INTER_CUBIC)
|
1177 |
+
|
1178 |
+
original_pixel_indices = crop_mask < 127
|
1179 |
+
inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][original_pixel_indices]
|
1180 |
+
|
1181 |
+
crop_result.append((inpaint_result, crop_box))
|
1182 |
+
|
1183 |
+
inpaint_result = image[:, :, ::-1]
|
1184 |
+
for crop_image, crop_box in crop_result:
|
1185 |
+
x1, y1, x2, y2 = crop_box
|
1186 |
+
inpaint_result[y1:y2, x1:x2, :] = crop_image
|
1187 |
+
|
1188 |
+
return inpaint_result
|
1189 |
+
|
1190 |
+
def forward(self, image, mask, config: Config):
|
1191 |
+
"""Input images and output images have same size
|
1192 |
+
images: [H, W, C] RGB
|
1193 |
+
masks: [H, W] mask area == 255
|
1194 |
+
return: BGR IMAGE
|
1195 |
+
"""
|
1196 |
+
|
1197 |
+
image = norm_img(image) # [0, 1]
|
1198 |
+
image = image * 2 - 1 # [0, 1] -> [-1, 1]
|
1199 |
+
mask = (mask > 120) * 255
|
1200 |
+
mask = norm_img(mask)
|
1201 |
+
|
1202 |
+
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
1203 |
+
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
1204 |
+
|
1205 |
+
erased_img = image * (1 - mask)
|
1206 |
+
input_image = torch.cat([0.5 - mask, erased_img], dim=1)
|
1207 |
+
|
1208 |
+
output = self.model(input_image, self.label, truncation_psi=0.1, noise_mode='none')
|
1209 |
+
output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8)
|
1210 |
+
output = output[0].cpu().numpy()
|
1211 |
+
cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
1212 |
+
return cur_res
|
lama_cleaner/model/lama.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from loguru import logger
|
7 |
+
|
8 |
+
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
9 |
+
from lama_cleaner.model.base import InpaintModel
|
10 |
+
from lama_cleaner.schema import Config
|
11 |
+
|
12 |
+
LAMA_MODEL_URL = os.environ.get(
|
13 |
+
"LAMA_MODEL_URL",
|
14 |
+
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class LaMa(InpaintModel):
|
19 |
+
pad_mod = 8
|
20 |
+
|
21 |
+
def init_model(self, device, **kwargs):
|
22 |
+
if os.environ.get("LAMA_MODEL"):
|
23 |
+
model_path = os.environ.get("LAMA_MODEL")
|
24 |
+
if not os.path.exists(model_path):
|
25 |
+
raise FileNotFoundError(
|
26 |
+
f"lama torchscript model not found: {model_path}"
|
27 |
+
)
|
28 |
+
else:
|
29 |
+
model_path = download_model(LAMA_MODEL_URL)
|
30 |
+
# TODO used to create a lambda docker image
|
31 |
+
# model_path = '../app/big-lama.pt'
|
32 |
+
logger.info(f"Load LaMa model from: {model_path}")
|
33 |
+
model = torch.jit.load(model_path, map_location="cpu")
|
34 |
+
model = model.to(device)
|
35 |
+
model.eval()
|
36 |
+
self.model = model
|
37 |
+
self.model_path = model_path
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def is_downloaded() -> bool:
|
41 |
+
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
|
42 |
+
|
43 |
+
def forward(self, image, mask, config: Config):
|
44 |
+
"""Input image and output image have same size
|
45 |
+
image: [H, W, C] RGB
|
46 |
+
mask: [H, W]
|
47 |
+
return: BGR IMAGE
|
48 |
+
"""
|
49 |
+
image = norm_img(image)
|
50 |
+
mask = norm_img(mask)
|
51 |
+
|
52 |
+
mask = (mask > 0) * 1
|
53 |
+
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
54 |
+
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
55 |
+
|
56 |
+
inpainted_image = self.model(image, mask)
|
57 |
+
|
58 |
+
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
59 |
+
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
60 |
+
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
61 |
+
return cur_res
|
lama_cleaner/model/ldm.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from lama_cleaner.model.base import InpaintModel
|
7 |
+
from lama_cleaner.model.ddim_sampler import DDIMSampler
|
8 |
+
from lama_cleaner.model.plms_sampler import PLMSSampler
|
9 |
+
from lama_cleaner.schema import Config, LDMSampler
|
10 |
+
|
11 |
+
torch.manual_seed(42)
|
12 |
+
import torch.nn as nn
|
13 |
+
from lama_cleaner.helper import (
|
14 |
+
norm_img,
|
15 |
+
get_cache_path_by_url,
|
16 |
+
load_jit_model,
|
17 |
+
)
|
18 |
+
from lama_cleaner.model.utils import (
|
19 |
+
make_beta_schedule,
|
20 |
+
timestep_embedding,
|
21 |
+
)
|
22 |
+
|
23 |
+
LDM_ENCODE_MODEL_URL = os.environ.get(
|
24 |
+
"LDM_ENCODE_MODEL_URL",
|
25 |
+
"https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt",
|
26 |
+
)
|
27 |
+
|
28 |
+
LDM_DECODE_MODEL_URL = os.environ.get(
|
29 |
+
"LDM_DECODE_MODEL_URL",
|
30 |
+
"https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt",
|
31 |
+
)
|
32 |
+
|
33 |
+
LDM_DIFFUSION_MODEL_URL = os.environ.get(
|
34 |
+
"LDM_DIFFUSION_MODEL_URL",
|
35 |
+
"https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt",
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
class DDPM(nn.Module):
|
40 |
+
# classic DDPM with Gaussian diffusion, in image space
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
device,
|
44 |
+
timesteps=1000,
|
45 |
+
beta_schedule="linear",
|
46 |
+
linear_start=0.0015,
|
47 |
+
linear_end=0.0205,
|
48 |
+
cosine_s=0.008,
|
49 |
+
original_elbo_weight=0.0,
|
50 |
+
v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
51 |
+
l_simple_weight=1.0,
|
52 |
+
parameterization="eps", # all assuming fixed variance schedules
|
53 |
+
use_positional_encodings=False,
|
54 |
+
):
|
55 |
+
super().__init__()
|
56 |
+
self.device = device
|
57 |
+
self.parameterization = parameterization
|
58 |
+
self.use_positional_encodings = use_positional_encodings
|
59 |
+
|
60 |
+
self.v_posterior = v_posterior
|
61 |
+
self.original_elbo_weight = original_elbo_weight
|
62 |
+
self.l_simple_weight = l_simple_weight
|
63 |
+
|
64 |
+
self.register_schedule(
|
65 |
+
beta_schedule=beta_schedule,
|
66 |
+
timesteps=timesteps,
|
67 |
+
linear_start=linear_start,
|
68 |
+
linear_end=linear_end,
|
69 |
+
cosine_s=cosine_s,
|
70 |
+
)
|
71 |
+
|
72 |
+
def register_schedule(
|
73 |
+
self,
|
74 |
+
given_betas=None,
|
75 |
+
beta_schedule="linear",
|
76 |
+
timesteps=1000,
|
77 |
+
linear_start=1e-4,
|
78 |
+
linear_end=2e-2,
|
79 |
+
cosine_s=8e-3,
|
80 |
+
):
|
81 |
+
betas = make_beta_schedule(
|
82 |
+
self.device,
|
83 |
+
beta_schedule,
|
84 |
+
timesteps,
|
85 |
+
linear_start=linear_start,
|
86 |
+
linear_end=linear_end,
|
87 |
+
cosine_s=cosine_s,
|
88 |
+
)
|
89 |
+
alphas = 1.0 - betas
|
90 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
91 |
+
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
92 |
+
|
93 |
+
(timesteps,) = betas.shape
|
94 |
+
self.num_timesteps = int(timesteps)
|
95 |
+
self.linear_start = linear_start
|
96 |
+
self.linear_end = linear_end
|
97 |
+
assert (
|
98 |
+
alphas_cumprod.shape[0] == self.num_timesteps
|
99 |
+
), "alphas have to be defined for each timestep"
|
100 |
+
|
101 |
+
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
|
102 |
+
|
103 |
+
self.register_buffer("betas", to_torch(betas))
|
104 |
+
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
105 |
+
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
|
106 |
+
|
107 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
108 |
+
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
109 |
+
self.register_buffer(
|
110 |
+
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
|
111 |
+
)
|
112 |
+
self.register_buffer(
|
113 |
+
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
|
114 |
+
)
|
115 |
+
self.register_buffer(
|
116 |
+
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
|
117 |
+
)
|
118 |
+
self.register_buffer(
|
119 |
+
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
|
120 |
+
)
|
121 |
+
|
122 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
123 |
+
posterior_variance = (1 - self.v_posterior) * betas * (
|
124 |
+
1.0 - alphas_cumprod_prev
|
125 |
+
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
|
126 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
127 |
+
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
128 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
129 |
+
self.register_buffer(
|
130 |
+
"posterior_log_variance_clipped",
|
131 |
+
to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
|
132 |
+
)
|
133 |
+
self.register_buffer(
|
134 |
+
"posterior_mean_coef1",
|
135 |
+
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
|
136 |
+
)
|
137 |
+
self.register_buffer(
|
138 |
+
"posterior_mean_coef2",
|
139 |
+
to_torch(
|
140 |
+
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
|
141 |
+
),
|
142 |
+
)
|
143 |
+
|
144 |
+
if self.parameterization == "eps":
|
145 |
+
lvlb_weights = self.betas**2 / (
|
146 |
+
2
|
147 |
+
* self.posterior_variance
|
148 |
+
* to_torch(alphas)
|
149 |
+
* (1 - self.alphas_cumprod)
|
150 |
+
)
|
151 |
+
elif self.parameterization == "x0":
|
152 |
+
lvlb_weights = (
|
153 |
+
0.5
|
154 |
+
* np.sqrt(torch.Tensor(alphas_cumprod))
|
155 |
+
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
156 |
+
)
|
157 |
+
else:
|
158 |
+
raise NotImplementedError("mu not supported")
|
159 |
+
# TODO how to choose this term
|
160 |
+
lvlb_weights[0] = lvlb_weights[1]
|
161 |
+
self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
|
162 |
+
assert not torch.isnan(self.lvlb_weights).all()
|
163 |
+
|
164 |
+
|
165 |
+
class LatentDiffusion(DDPM):
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
diffusion_model,
|
169 |
+
device,
|
170 |
+
cond_stage_key="image",
|
171 |
+
cond_stage_trainable=False,
|
172 |
+
concat_mode=True,
|
173 |
+
scale_factor=1.0,
|
174 |
+
scale_by_std=False,
|
175 |
+
*args,
|
176 |
+
**kwargs,
|
177 |
+
):
|
178 |
+
self.num_timesteps_cond = 1
|
179 |
+
self.scale_by_std = scale_by_std
|
180 |
+
super().__init__(device, *args, **kwargs)
|
181 |
+
self.diffusion_model = diffusion_model
|
182 |
+
self.concat_mode = concat_mode
|
183 |
+
self.cond_stage_trainable = cond_stage_trainable
|
184 |
+
self.cond_stage_key = cond_stage_key
|
185 |
+
self.num_downs = 2
|
186 |
+
self.scale_factor = scale_factor
|
187 |
+
|
188 |
+
def make_cond_schedule(
|
189 |
+
self,
|
190 |
+
):
|
191 |
+
self.cond_ids = torch.full(
|
192 |
+
size=(self.num_timesteps,),
|
193 |
+
fill_value=self.num_timesteps - 1,
|
194 |
+
dtype=torch.long,
|
195 |
+
)
|
196 |
+
ids = torch.round(
|
197 |
+
torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
|
198 |
+
).long()
|
199 |
+
self.cond_ids[: self.num_timesteps_cond] = ids
|
200 |
+
|
201 |
+
def register_schedule(
|
202 |
+
self,
|
203 |
+
given_betas=None,
|
204 |
+
beta_schedule="linear",
|
205 |
+
timesteps=1000,
|
206 |
+
linear_start=1e-4,
|
207 |
+
linear_end=2e-2,
|
208 |
+
cosine_s=8e-3,
|
209 |
+
):
|
210 |
+
super().register_schedule(
|
211 |
+
given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
|
212 |
+
)
|
213 |
+
|
214 |
+
self.shorten_cond_schedule = self.num_timesteps_cond > 1
|
215 |
+
if self.shorten_cond_schedule:
|
216 |
+
self.make_cond_schedule()
|
217 |
+
|
218 |
+
def apply_model(self, x_noisy, t, cond):
|
219 |
+
# x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128
|
220 |
+
t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False)
|
221 |
+
x_recon = self.diffusion_model(x_noisy, t_emb, cond)
|
222 |
+
return x_recon
|
223 |
+
|
224 |
+
|
225 |
+
class LDM(InpaintModel):
|
226 |
+
pad_mod = 32
|
227 |
+
|
228 |
+
def __init__(self, device, fp16: bool = True, **kwargs):
|
229 |
+
self.fp16 = fp16
|
230 |
+
super().__init__(device)
|
231 |
+
self.device = device
|
232 |
+
|
233 |
+
def init_model(self, device, **kwargs):
|
234 |
+
self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
|
235 |
+
self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
|
236 |
+
self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
|
237 |
+
if self.fp16 and "cuda" in str(device):
|
238 |
+
self.diffusion_model = self.diffusion_model.half()
|
239 |
+
self.cond_stage_model_decode = self.cond_stage_model_decode.half()
|
240 |
+
self.cond_stage_model_encode = self.cond_stage_model_encode.half()
|
241 |
+
|
242 |
+
self.model = LatentDiffusion(self.diffusion_model, device)
|
243 |
+
|
244 |
+
@staticmethod
|
245 |
+
def is_downloaded() -> bool:
|
246 |
+
model_paths = [
|
247 |
+
get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
|
248 |
+
get_cache_path_by_url(LDM_DECODE_MODEL_URL),
|
249 |
+
get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
|
250 |
+
]
|
251 |
+
return all([os.path.exists(it) for it in model_paths])
|
252 |
+
|
253 |
+
@torch.cuda.amp.autocast()
|
254 |
+
def forward(self, image, mask, config: Config):
|
255 |
+
"""
|
256 |
+
image: [H, W, C] RGB
|
257 |
+
mask: [H, W, 1]
|
258 |
+
return: BGR IMAGE
|
259 |
+
"""
|
260 |
+
# image [1,3,512,512] float32
|
261 |
+
# mask: [1,1,512,512] float32
|
262 |
+
# masked_image: [1,3,512,512] float32
|
263 |
+
if config.ldm_sampler == LDMSampler.ddim:
|
264 |
+
sampler = DDIMSampler(self.model)
|
265 |
+
elif config.ldm_sampler == LDMSampler.plms:
|
266 |
+
sampler = PLMSSampler(self.model)
|
267 |
+
else:
|
268 |
+
raise ValueError()
|
269 |
+
|
270 |
+
steps = config.ldm_steps
|
271 |
+
image = norm_img(image)
|
272 |
+
mask = norm_img(mask)
|
273 |
+
|
274 |
+
mask[mask < 0.5] = 0
|
275 |
+
mask[mask >= 0.5] = 1
|
276 |
+
|
277 |
+
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
278 |
+
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
279 |
+
masked_image = (1 - mask) * image
|
280 |
+
|
281 |
+
mask = self._norm(mask)
|
282 |
+
masked_image = self._norm(masked_image)
|
283 |
+
|
284 |
+
c = self.cond_stage_model_encode(masked_image)
|
285 |
+
torch.cuda.empty_cache()
|
286 |
+
|
287 |
+
cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128
|
288 |
+
c = torch.cat((c, cc), dim=1) # 1,4,128,128
|
289 |
+
|
290 |
+
shape = (c.shape[1] - 1,) + c.shape[2:]
|
291 |
+
samples_ddim = sampler.sample(
|
292 |
+
steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
|
293 |
+
)
|
294 |
+
torch.cuda.empty_cache()
|
295 |
+
x_samples_ddim = self.cond_stage_model_decode(
|
296 |
+
samples_ddim
|
297 |
+
) # samples_ddim: 1, 3, 128, 128 float32
|
298 |
+
torch.cuda.empty_cache()
|
299 |
+
|
300 |
+
# image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
301 |
+
# mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
|
302 |
+
inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
303 |
+
|
304 |
+
# inpainted = (1 - mask) * image + mask * predicted_image
|
305 |
+
inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
|
306 |
+
inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
|
307 |
+
return inpainted_image
|
308 |
+
|
309 |
+
def _norm(self, tensor):
|
310 |
+
return tensor * 2.0 - 1.0
|
lama_cleaner/model/manga.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from loguru import logger
|
9 |
+
|
10 |
+
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model
|
11 |
+
from lama_cleaner.model.base import InpaintModel
|
12 |
+
from lama_cleaner.schema import Config
|
13 |
+
|
14 |
+
# def norm(np_img):
|
15 |
+
# return np_img / 255 * 2 - 1.0
|
16 |
+
#
|
17 |
+
#
|
18 |
+
# @torch.no_grad()
|
19 |
+
# def run():
|
20 |
+
# name = 'manga_1080x740.jpg'
|
21 |
+
# img_p = f'/Users/qing/code/github/MangaInpainting/examples/test/imgs/{name}'
|
22 |
+
# mask_p = f'/Users/qing/code/github/MangaInpainting/examples/test/masks/mask_{name}'
|
23 |
+
# erika_model = torch.jit.load('erika.jit')
|
24 |
+
# manga_inpaintor_model = torch.jit.load('manga_inpaintor.jit')
|
25 |
+
#
|
26 |
+
# img = cv2.imread(img_p)
|
27 |
+
# gray_img = cv2.imread(img_p, cv2.IMREAD_GRAYSCALE)
|
28 |
+
# mask = cv2.imread(mask_p, cv2.IMREAD_GRAYSCALE)
|
29 |
+
#
|
30 |
+
# kernel = np.ones((9, 9), dtype=np.uint8)
|
31 |
+
# mask = cv2.dilate(mask, kernel, 2)
|
32 |
+
# # cv2.imwrite("mask.jpg", mask)
|
33 |
+
# # cv2.imshow('dilated_mask', cv2.hconcat([mask, dilated_mask]))
|
34 |
+
# # cv2.waitKey(0)
|
35 |
+
# # exit()
|
36 |
+
#
|
37 |
+
# # img = pad(img)
|
38 |
+
# gray_img = pad(gray_img).astype(np.float32)
|
39 |
+
# mask = pad(mask)
|
40 |
+
#
|
41 |
+
# # pad_mod = 16
|
42 |
+
# import time
|
43 |
+
# start = time.time()
|
44 |
+
# y = erika_model(torch.from_numpy(gray_img[np.newaxis, np.newaxis, :, :]))
|
45 |
+
# y = torch.clamp(y, 0, 255)
|
46 |
+
# lines = y.cpu().numpy()
|
47 |
+
# print(f"erika_model time: {time.time() - start}")
|
48 |
+
#
|
49 |
+
# cv2.imwrite('lines.png', lines[0][0])
|
50 |
+
#
|
51 |
+
# start = time.time()
|
52 |
+
# masks = torch.from_numpy(mask[np.newaxis, np.newaxis, :, :])
|
53 |
+
# masks = torch.where(masks > 0.5, torch.tensor(1.0), torch.tensor(0.0))
|
54 |
+
# noise = torch.randn_like(masks)
|
55 |
+
#
|
56 |
+
# images = torch.from_numpy(norm(gray_img)[np.newaxis, np.newaxis, :, :])
|
57 |
+
# lines = torch.from_numpy(norm(lines))
|
58 |
+
#
|
59 |
+
# outputs = manga_inpaintor_model(images, lines, masks, noise)
|
60 |
+
# print(f"manga_inpaintor_model time: {time.time() - start}")
|
61 |
+
#
|
62 |
+
# outputs_merged = (outputs * masks) + (images * (1 - masks))
|
63 |
+
# outputs_merged = outputs_merged * 127.5 + 127.5
|
64 |
+
# outputs_merged = outputs_merged.permute(0, 2, 3, 1)[0].detach().cpu().numpy().astype(np.uint8)
|
65 |
+
# cv2.imwrite(f'output_{name}', outputs_merged)
|
66 |
+
|
67 |
+
|
68 |
+
MANGA_INPAINTOR_MODEL_URL = os.environ.get(
|
69 |
+
"MANGA_INPAINTOR_MODEL_URL",
|
70 |
+
"https://github.com/Sanster/models/releases/download/manga/manga_inpaintor.jit"
|
71 |
+
)
|
72 |
+
MANGA_LINE_MODEL_URL = os.environ.get(
|
73 |
+
"MANGA_LINE_MODEL_URL",
|
74 |
+
"https://github.com/Sanster/models/releases/download/manga/erika.jit"
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
class Manga(InpaintModel):
|
79 |
+
pad_mod = 16
|
80 |
+
|
81 |
+
def init_model(self, device, **kwargs):
|
82 |
+
self.inpaintor_model = load_jit_model(MANGA_INPAINTOR_MODEL_URL, device)
|
83 |
+
self.line_model = load_jit_model(MANGA_LINE_MODEL_URL, device)
|
84 |
+
self.seed = 42
|
85 |
+
|
86 |
+
@staticmethod
|
87 |
+
def is_downloaded() -> bool:
|
88 |
+
model_paths = [
|
89 |
+
get_cache_path_by_url(MANGA_INPAINTOR_MODEL_URL),
|
90 |
+
get_cache_path_by_url(MANGA_LINE_MODEL_URL),
|
91 |
+
]
|
92 |
+
return all([os.path.exists(it) for it in model_paths])
|
93 |
+
|
94 |
+
def forward(self, image, mask, config: Config):
|
95 |
+
"""
|
96 |
+
image: [H, W, C] RGB
|
97 |
+
mask: [H, W, 1]
|
98 |
+
return: BGR IMAGE
|
99 |
+
"""
|
100 |
+
seed = self.seed
|
101 |
+
random.seed(seed)
|
102 |
+
np.random.seed(seed)
|
103 |
+
torch.manual_seed(seed)
|
104 |
+
torch.cuda.manual_seed_all(seed)
|
105 |
+
|
106 |
+
gray_img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
107 |
+
gray_img = torch.from_numpy(gray_img[np.newaxis, np.newaxis, :, :].astype(np.float32)).to(self.device)
|
108 |
+
start = time.time()
|
109 |
+
lines = self.line_model(gray_img)
|
110 |
+
torch.cuda.empty_cache()
|
111 |
+
lines = torch.clamp(lines, 0, 255)
|
112 |
+
logger.info(f"erika_model time: {time.time() - start}")
|
113 |
+
|
114 |
+
mask = torch.from_numpy(mask[np.newaxis, :, :, :]).to(self.device)
|
115 |
+
mask = mask.permute(0, 3, 1, 2)
|
116 |
+
mask = torch.where(mask > 0.5, 1.0, 0.0)
|
117 |
+
noise = torch.randn_like(mask)
|
118 |
+
ones = torch.ones_like(mask)
|
119 |
+
|
120 |
+
gray_img = gray_img / 255 * 2 - 1.0
|
121 |
+
lines = lines / 255 * 2 - 1.0
|
122 |
+
|
123 |
+
start = time.time()
|
124 |
+
inpainted_image = self.inpaintor_model(gray_img, lines, mask, noise, ones)
|
125 |
+
logger.info(f"image_inpaintor_model time: {time.time() - start}")
|
126 |
+
|
127 |
+
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
128 |
+
cur_res = (cur_res * 127.5 + 127.5).astype(np.uint8)
|
129 |
+
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_GRAY2BGR)
|
130 |
+
return cur_res
|
lama_cleaner/model/mat.py
ADDED
@@ -0,0 +1,1444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.utils.checkpoint as checkpoint
|
10 |
+
|
11 |
+
from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img
|
12 |
+
from lama_cleaner.model.base import InpaintModel
|
13 |
+
from lama_cleaner.model.utils import setup_filter, Conv2dLayer, FullyConnectedLayer, conv2d_resample, bias_act, \
|
14 |
+
upsample2d, activation_funcs, MinibatchStdLayer, to_2tuple, normalize_2nd_moment
|
15 |
+
from lama_cleaner.schema import Config
|
16 |
+
|
17 |
+
|
18 |
+
class ModulatedConv2d(nn.Module):
|
19 |
+
def __init__(self,
|
20 |
+
in_channels, # Number of input channels.
|
21 |
+
out_channels, # Number of output channels.
|
22 |
+
kernel_size, # Width and height of the convolution kernel.
|
23 |
+
style_dim, # dimension of the style code
|
24 |
+
demodulate=True, # perfrom demodulation
|
25 |
+
up=1, # Integer upsampling factor.
|
26 |
+
down=1, # Integer downsampling factor.
|
27 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
28 |
+
conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
self.demodulate = demodulate
|
32 |
+
|
33 |
+
self.weight = torch.nn.Parameter(torch.randn([1, out_channels, in_channels, kernel_size, kernel_size]))
|
34 |
+
self.out_channels = out_channels
|
35 |
+
self.kernel_size = kernel_size
|
36 |
+
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
37 |
+
self.padding = self.kernel_size // 2
|
38 |
+
self.up = up
|
39 |
+
self.down = down
|
40 |
+
self.register_buffer('resample_filter', setup_filter(resample_filter))
|
41 |
+
self.conv_clamp = conv_clamp
|
42 |
+
|
43 |
+
self.affine = FullyConnectedLayer(style_dim, in_channels, bias_init=1)
|
44 |
+
|
45 |
+
def forward(self, x, style):
|
46 |
+
batch, in_channels, height, width = x.shape
|
47 |
+
style = self.affine(style).view(batch, 1, in_channels, 1, 1)
|
48 |
+
weight = self.weight * self.weight_gain * style
|
49 |
+
|
50 |
+
if self.demodulate:
|
51 |
+
decoefs = (weight.pow(2).sum(dim=[2, 3, 4]) + 1e-8).rsqrt()
|
52 |
+
weight = weight * decoefs.view(batch, self.out_channels, 1, 1, 1)
|
53 |
+
|
54 |
+
weight = weight.view(batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size)
|
55 |
+
x = x.view(1, batch * in_channels, height, width)
|
56 |
+
x = conv2d_resample(x=x, w=weight, f=self.resample_filter, up=self.up, down=self.down,
|
57 |
+
padding=self.padding, groups=batch)
|
58 |
+
out = x.view(batch, self.out_channels, *x.shape[2:])
|
59 |
+
|
60 |
+
return out
|
61 |
+
|
62 |
+
|
63 |
+
class StyleConv(torch.nn.Module):
|
64 |
+
def __init__(self,
|
65 |
+
in_channels, # Number of input channels.
|
66 |
+
out_channels, # Number of output channels.
|
67 |
+
style_dim, # Intermediate latent (W) dimensionality.
|
68 |
+
resolution, # Resolution of this layer.
|
69 |
+
kernel_size=3, # Convolution kernel size.
|
70 |
+
up=1, # Integer upsampling factor.
|
71 |
+
use_noise=False, # Enable noise input?
|
72 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
73 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
74 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
75 |
+
demodulate=True, # perform demodulation
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
self.conv = ModulatedConv2d(in_channels=in_channels,
|
80 |
+
out_channels=out_channels,
|
81 |
+
kernel_size=kernel_size,
|
82 |
+
style_dim=style_dim,
|
83 |
+
demodulate=demodulate,
|
84 |
+
up=up,
|
85 |
+
resample_filter=resample_filter,
|
86 |
+
conv_clamp=conv_clamp)
|
87 |
+
|
88 |
+
self.use_noise = use_noise
|
89 |
+
self.resolution = resolution
|
90 |
+
if use_noise:
|
91 |
+
self.register_buffer('noise_const', torch.randn([resolution, resolution]))
|
92 |
+
self.noise_strength = torch.nn.Parameter(torch.zeros([]))
|
93 |
+
|
94 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
95 |
+
self.activation = activation
|
96 |
+
self.act_gain = activation_funcs[activation].def_gain
|
97 |
+
self.conv_clamp = conv_clamp
|
98 |
+
|
99 |
+
def forward(self, x, style, noise_mode='random', gain=1):
|
100 |
+
x = self.conv(x, style)
|
101 |
+
|
102 |
+
assert noise_mode in ['random', 'const', 'none']
|
103 |
+
|
104 |
+
if self.use_noise:
|
105 |
+
if noise_mode == 'random':
|
106 |
+
xh, xw = x.size()[-2:]
|
107 |
+
noise = torch.randn([x.shape[0], 1, xh, xw], device=x.device) \
|
108 |
+
* self.noise_strength
|
109 |
+
if noise_mode == 'const':
|
110 |
+
noise = self.noise_const * self.noise_strength
|
111 |
+
x = x + noise
|
112 |
+
|
113 |
+
act_gain = self.act_gain * gain
|
114 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
115 |
+
out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp)
|
116 |
+
|
117 |
+
return out
|
118 |
+
|
119 |
+
|
120 |
+
class ToRGB(torch.nn.Module):
|
121 |
+
def __init__(self,
|
122 |
+
in_channels,
|
123 |
+
out_channels,
|
124 |
+
style_dim,
|
125 |
+
kernel_size=1,
|
126 |
+
resample_filter=[1, 3, 3, 1],
|
127 |
+
conv_clamp=None,
|
128 |
+
demodulate=False):
|
129 |
+
super().__init__()
|
130 |
+
|
131 |
+
self.conv = ModulatedConv2d(in_channels=in_channels,
|
132 |
+
out_channels=out_channels,
|
133 |
+
kernel_size=kernel_size,
|
134 |
+
style_dim=style_dim,
|
135 |
+
demodulate=demodulate,
|
136 |
+
resample_filter=resample_filter,
|
137 |
+
conv_clamp=conv_clamp)
|
138 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
139 |
+
self.register_buffer('resample_filter', setup_filter(resample_filter))
|
140 |
+
self.conv_clamp = conv_clamp
|
141 |
+
|
142 |
+
def forward(self, x, style, skip=None):
|
143 |
+
x = self.conv(x, style)
|
144 |
+
out = bias_act(x, self.bias, clamp=self.conv_clamp)
|
145 |
+
|
146 |
+
if skip is not None:
|
147 |
+
if skip.shape != out.shape:
|
148 |
+
skip = upsample2d(skip, self.resample_filter)
|
149 |
+
out = out + skip
|
150 |
+
|
151 |
+
return out
|
152 |
+
|
153 |
+
|
154 |
+
def get_style_code(a, b):
|
155 |
+
return torch.cat([a, b], dim=1)
|
156 |
+
|
157 |
+
|
158 |
+
class DecBlockFirst(nn.Module):
|
159 |
+
def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
|
160 |
+
super().__init__()
|
161 |
+
self.fc = FullyConnectedLayer(in_features=in_channels * 2,
|
162 |
+
out_features=in_channels * 4 ** 2,
|
163 |
+
activation=activation)
|
164 |
+
self.conv = StyleConv(in_channels=in_channels,
|
165 |
+
out_channels=out_channels,
|
166 |
+
style_dim=style_dim,
|
167 |
+
resolution=4,
|
168 |
+
kernel_size=3,
|
169 |
+
use_noise=use_noise,
|
170 |
+
activation=activation,
|
171 |
+
demodulate=demodulate,
|
172 |
+
)
|
173 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
174 |
+
out_channels=img_channels,
|
175 |
+
style_dim=style_dim,
|
176 |
+
kernel_size=1,
|
177 |
+
demodulate=False,
|
178 |
+
)
|
179 |
+
|
180 |
+
def forward(self, x, ws, gs, E_features, noise_mode='random'):
|
181 |
+
x = self.fc(x).view(x.shape[0], -1, 4, 4)
|
182 |
+
x = x + E_features[2]
|
183 |
+
style = get_style_code(ws[:, 0], gs)
|
184 |
+
x = self.conv(x, style, noise_mode=noise_mode)
|
185 |
+
style = get_style_code(ws[:, 1], gs)
|
186 |
+
img = self.toRGB(x, style, skip=None)
|
187 |
+
|
188 |
+
return x, img
|
189 |
+
|
190 |
+
|
191 |
+
class DecBlockFirstV2(nn.Module):
|
192 |
+
def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
|
193 |
+
super().__init__()
|
194 |
+
self.conv0 = Conv2dLayer(in_channels=in_channels,
|
195 |
+
out_channels=in_channels,
|
196 |
+
kernel_size=3,
|
197 |
+
activation=activation,
|
198 |
+
)
|
199 |
+
self.conv1 = StyleConv(in_channels=in_channels,
|
200 |
+
out_channels=out_channels,
|
201 |
+
style_dim=style_dim,
|
202 |
+
resolution=4,
|
203 |
+
kernel_size=3,
|
204 |
+
use_noise=use_noise,
|
205 |
+
activation=activation,
|
206 |
+
demodulate=demodulate,
|
207 |
+
)
|
208 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
209 |
+
out_channels=img_channels,
|
210 |
+
style_dim=style_dim,
|
211 |
+
kernel_size=1,
|
212 |
+
demodulate=False,
|
213 |
+
)
|
214 |
+
|
215 |
+
def forward(self, x, ws, gs, E_features, noise_mode='random'):
|
216 |
+
# x = self.fc(x).view(x.shape[0], -1, 4, 4)
|
217 |
+
x = self.conv0(x)
|
218 |
+
x = x + E_features[2]
|
219 |
+
style = get_style_code(ws[:, 0], gs)
|
220 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
221 |
+
style = get_style_code(ws[:, 1], gs)
|
222 |
+
img = self.toRGB(x, style, skip=None)
|
223 |
+
|
224 |
+
return x, img
|
225 |
+
|
226 |
+
|
227 |
+
class DecBlock(nn.Module):
|
228 |
+
def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate,
|
229 |
+
img_channels): # res = 2, ..., resolution_log2
|
230 |
+
super().__init__()
|
231 |
+
self.res = res
|
232 |
+
|
233 |
+
self.conv0 = StyleConv(in_channels=in_channels,
|
234 |
+
out_channels=out_channels,
|
235 |
+
style_dim=style_dim,
|
236 |
+
resolution=2 ** res,
|
237 |
+
kernel_size=3,
|
238 |
+
up=2,
|
239 |
+
use_noise=use_noise,
|
240 |
+
activation=activation,
|
241 |
+
demodulate=demodulate,
|
242 |
+
)
|
243 |
+
self.conv1 = StyleConv(in_channels=out_channels,
|
244 |
+
out_channels=out_channels,
|
245 |
+
style_dim=style_dim,
|
246 |
+
resolution=2 ** res,
|
247 |
+
kernel_size=3,
|
248 |
+
use_noise=use_noise,
|
249 |
+
activation=activation,
|
250 |
+
demodulate=demodulate,
|
251 |
+
)
|
252 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
253 |
+
out_channels=img_channels,
|
254 |
+
style_dim=style_dim,
|
255 |
+
kernel_size=1,
|
256 |
+
demodulate=False,
|
257 |
+
)
|
258 |
+
|
259 |
+
def forward(self, x, img, ws, gs, E_features, noise_mode='random'):
|
260 |
+
style = get_style_code(ws[:, self.res * 2 - 5], gs)
|
261 |
+
x = self.conv0(x, style, noise_mode=noise_mode)
|
262 |
+
x = x + E_features[self.res]
|
263 |
+
style = get_style_code(ws[:, self.res * 2 - 4], gs)
|
264 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
265 |
+
style = get_style_code(ws[:, self.res * 2 - 3], gs)
|
266 |
+
img = self.toRGB(x, style, skip=img)
|
267 |
+
|
268 |
+
return x, img
|
269 |
+
|
270 |
+
|
271 |
+
class MappingNet(torch.nn.Module):
|
272 |
+
def __init__(self,
|
273 |
+
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
|
274 |
+
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
|
275 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
276 |
+
num_ws, # Number of intermediate latents to output, None = do not broadcast.
|
277 |
+
num_layers=8, # Number of mapping layers.
|
278 |
+
embed_features=None, # Label embedding dimensionality, None = same as w_dim.
|
279 |
+
layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
|
280 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
281 |
+
lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
|
282 |
+
w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
|
283 |
+
):
|
284 |
+
super().__init__()
|
285 |
+
self.z_dim = z_dim
|
286 |
+
self.c_dim = c_dim
|
287 |
+
self.w_dim = w_dim
|
288 |
+
self.num_ws = num_ws
|
289 |
+
self.num_layers = num_layers
|
290 |
+
self.w_avg_beta = w_avg_beta
|
291 |
+
|
292 |
+
if embed_features is None:
|
293 |
+
embed_features = w_dim
|
294 |
+
if c_dim == 0:
|
295 |
+
embed_features = 0
|
296 |
+
if layer_features is None:
|
297 |
+
layer_features = w_dim
|
298 |
+
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
|
299 |
+
|
300 |
+
if c_dim > 0:
|
301 |
+
self.embed = FullyConnectedLayer(c_dim, embed_features)
|
302 |
+
for idx in range(num_layers):
|
303 |
+
in_features = features_list[idx]
|
304 |
+
out_features = features_list[idx + 1]
|
305 |
+
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
|
306 |
+
setattr(self, f'fc{idx}', layer)
|
307 |
+
|
308 |
+
if num_ws is not None and w_avg_beta is not None:
|
309 |
+
self.register_buffer('w_avg', torch.zeros([w_dim]))
|
310 |
+
|
311 |
+
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
|
312 |
+
# Embed, normalize, and concat inputs.
|
313 |
+
x = None
|
314 |
+
with torch.autograd.profiler.record_function('input'):
|
315 |
+
if self.z_dim > 0:
|
316 |
+
x = normalize_2nd_moment(z.to(torch.float32))
|
317 |
+
if self.c_dim > 0:
|
318 |
+
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
|
319 |
+
x = torch.cat([x, y], dim=1) if x is not None else y
|
320 |
+
|
321 |
+
# Main layers.
|
322 |
+
for idx in range(self.num_layers):
|
323 |
+
layer = getattr(self, f'fc{idx}')
|
324 |
+
x = layer(x)
|
325 |
+
|
326 |
+
# Update moving average of W.
|
327 |
+
if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
|
328 |
+
with torch.autograd.profiler.record_function('update_w_avg'):
|
329 |
+
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
|
330 |
+
|
331 |
+
# Broadcast.
|
332 |
+
if self.num_ws is not None:
|
333 |
+
with torch.autograd.profiler.record_function('broadcast'):
|
334 |
+
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
|
335 |
+
|
336 |
+
# Apply truncation.
|
337 |
+
if truncation_psi != 1:
|
338 |
+
with torch.autograd.profiler.record_function('truncate'):
|
339 |
+
assert self.w_avg_beta is not None
|
340 |
+
if self.num_ws is None or truncation_cutoff is None:
|
341 |
+
x = self.w_avg.lerp(x, truncation_psi)
|
342 |
+
else:
|
343 |
+
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
|
344 |
+
|
345 |
+
return x
|
346 |
+
|
347 |
+
|
348 |
+
class DisFromRGB(nn.Module):
|
349 |
+
def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
|
350 |
+
super().__init__()
|
351 |
+
self.conv = Conv2dLayer(in_channels=in_channels,
|
352 |
+
out_channels=out_channels,
|
353 |
+
kernel_size=1,
|
354 |
+
activation=activation,
|
355 |
+
)
|
356 |
+
|
357 |
+
def forward(self, x):
|
358 |
+
return self.conv(x)
|
359 |
+
|
360 |
+
|
361 |
+
class DisBlock(nn.Module):
|
362 |
+
def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
|
363 |
+
super().__init__()
|
364 |
+
self.conv0 = Conv2dLayer(in_channels=in_channels,
|
365 |
+
out_channels=in_channels,
|
366 |
+
kernel_size=3,
|
367 |
+
activation=activation,
|
368 |
+
)
|
369 |
+
self.conv1 = Conv2dLayer(in_channels=in_channels,
|
370 |
+
out_channels=out_channels,
|
371 |
+
kernel_size=3,
|
372 |
+
down=2,
|
373 |
+
activation=activation,
|
374 |
+
)
|
375 |
+
self.skip = Conv2dLayer(in_channels=in_channels,
|
376 |
+
out_channels=out_channels,
|
377 |
+
kernel_size=1,
|
378 |
+
down=2,
|
379 |
+
bias=False,
|
380 |
+
)
|
381 |
+
|
382 |
+
def forward(self, x):
|
383 |
+
skip = self.skip(x, gain=np.sqrt(0.5))
|
384 |
+
x = self.conv0(x)
|
385 |
+
x = self.conv1(x, gain=np.sqrt(0.5))
|
386 |
+
out = skip + x
|
387 |
+
|
388 |
+
return out
|
389 |
+
|
390 |
+
|
391 |
+
class Discriminator(torch.nn.Module):
|
392 |
+
def __init__(self,
|
393 |
+
c_dim, # Conditioning label (C) dimensionality.
|
394 |
+
img_resolution, # Input resolution.
|
395 |
+
img_channels, # Number of input color channels.
|
396 |
+
channel_base=32768, # Overall multiplier for the number of channels.
|
397 |
+
channel_max=512, # Maximum number of channels in any layer.
|
398 |
+
channel_decay=1,
|
399 |
+
cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
|
400 |
+
activation='lrelu',
|
401 |
+
mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
|
402 |
+
mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
|
403 |
+
):
|
404 |
+
super().__init__()
|
405 |
+
self.c_dim = c_dim
|
406 |
+
self.img_resolution = img_resolution
|
407 |
+
self.img_channels = img_channels
|
408 |
+
|
409 |
+
resolution_log2 = int(np.log2(img_resolution))
|
410 |
+
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
|
411 |
+
self.resolution_log2 = resolution_log2
|
412 |
+
|
413 |
+
def nf(stage):
|
414 |
+
return np.clip(int(channel_base / 2 ** (stage * channel_decay)), 1, channel_max)
|
415 |
+
|
416 |
+
if cmap_dim == None:
|
417 |
+
cmap_dim = nf(2)
|
418 |
+
if c_dim == 0:
|
419 |
+
cmap_dim = 0
|
420 |
+
self.cmap_dim = cmap_dim
|
421 |
+
|
422 |
+
if c_dim > 0:
|
423 |
+
self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None)
|
424 |
+
|
425 |
+
Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)]
|
426 |
+
for res in range(resolution_log2, 2, -1):
|
427 |
+
Dis.append(DisBlock(nf(res), nf(res - 1), activation))
|
428 |
+
|
429 |
+
if mbstd_num_channels > 0:
|
430 |
+
Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
|
431 |
+
Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation))
|
432 |
+
self.Dis = nn.Sequential(*Dis)
|
433 |
+
|
434 |
+
self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
|
435 |
+
self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
|
436 |
+
|
437 |
+
def forward(self, images_in, masks_in, c):
|
438 |
+
x = torch.cat([masks_in - 0.5, images_in], dim=1)
|
439 |
+
x = self.Dis(x)
|
440 |
+
x = self.fc1(self.fc0(x.flatten(start_dim=1)))
|
441 |
+
|
442 |
+
if self.c_dim > 0:
|
443 |
+
cmap = self.mapping(None, c)
|
444 |
+
|
445 |
+
if self.cmap_dim > 0:
|
446 |
+
x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
447 |
+
|
448 |
+
return x
|
449 |
+
|
450 |
+
|
451 |
+
def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512):
|
452 |
+
NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512}
|
453 |
+
return NF[2 ** stage]
|
454 |
+
|
455 |
+
|
456 |
+
class Mlp(nn.Module):
|
457 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
458 |
+
super().__init__()
|
459 |
+
out_features = out_features or in_features
|
460 |
+
hidden_features = hidden_features or in_features
|
461 |
+
self.fc1 = FullyConnectedLayer(in_features=in_features, out_features=hidden_features, activation='lrelu')
|
462 |
+
self.fc2 = FullyConnectedLayer(in_features=hidden_features, out_features=out_features)
|
463 |
+
|
464 |
+
def forward(self, x):
|
465 |
+
x = self.fc1(x)
|
466 |
+
x = self.fc2(x)
|
467 |
+
return x
|
468 |
+
|
469 |
+
|
470 |
+
def window_partition(x, window_size):
|
471 |
+
"""
|
472 |
+
Args:
|
473 |
+
x: (B, H, W, C)
|
474 |
+
window_size (int): window size
|
475 |
+
Returns:
|
476 |
+
windows: (num_windows*B, window_size, window_size, C)
|
477 |
+
"""
|
478 |
+
B, H, W, C = x.shape
|
479 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
480 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
481 |
+
return windows
|
482 |
+
|
483 |
+
|
484 |
+
def window_reverse(windows, window_size: int, H: int, W: int):
|
485 |
+
"""
|
486 |
+
Args:
|
487 |
+
windows: (num_windows*B, window_size, window_size, C)
|
488 |
+
window_size (int): Window size
|
489 |
+
H (int): Height of image
|
490 |
+
W (int): Width of image
|
491 |
+
Returns:
|
492 |
+
x: (B, H, W, C)
|
493 |
+
"""
|
494 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
495 |
+
# B = windows.shape[0] / (H * W / window_size / window_size)
|
496 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
497 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
498 |
+
return x
|
499 |
+
|
500 |
+
|
501 |
+
class Conv2dLayerPartial(nn.Module):
|
502 |
+
def __init__(self,
|
503 |
+
in_channels, # Number of input channels.
|
504 |
+
out_channels, # Number of output channels.
|
505 |
+
kernel_size, # Width and height of the convolution kernel.
|
506 |
+
bias=True, # Apply additive bias before the activation function?
|
507 |
+
activation='linear', # Activation function: 'relu', 'lrelu', etc.
|
508 |
+
up=1, # Integer upsampling factor.
|
509 |
+
down=1, # Integer downsampling factor.
|
510 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
511 |
+
conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
|
512 |
+
trainable=True, # Update the weights of this layer during training?
|
513 |
+
):
|
514 |
+
super().__init__()
|
515 |
+
self.conv = Conv2dLayer(in_channels, out_channels, kernel_size, bias, activation, up, down, resample_filter,
|
516 |
+
conv_clamp, trainable)
|
517 |
+
|
518 |
+
self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size)
|
519 |
+
self.slide_winsize = kernel_size ** 2
|
520 |
+
self.stride = down
|
521 |
+
self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0
|
522 |
+
|
523 |
+
def forward(self, x, mask=None):
|
524 |
+
if mask is not None:
|
525 |
+
with torch.no_grad():
|
526 |
+
if self.weight_maskUpdater.type() != x.type():
|
527 |
+
self.weight_maskUpdater = self.weight_maskUpdater.to(x)
|
528 |
+
update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride,
|
529 |
+
padding=self.padding)
|
530 |
+
mask_ratio = self.slide_winsize / (update_mask + 1e-8)
|
531 |
+
update_mask = torch.clamp(update_mask, 0, 1) # 0 or 1
|
532 |
+
mask_ratio = torch.mul(mask_ratio, update_mask)
|
533 |
+
x = self.conv(x)
|
534 |
+
x = torch.mul(x, mask_ratio)
|
535 |
+
return x, update_mask
|
536 |
+
else:
|
537 |
+
x = self.conv(x)
|
538 |
+
return x, None
|
539 |
+
|
540 |
+
|
541 |
+
class WindowAttention(nn.Module):
|
542 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
543 |
+
It supports both of shifted and non-shifted window.
|
544 |
+
Args:
|
545 |
+
dim (int): Number of input channels.
|
546 |
+
window_size (tuple[int]): The height and width of the window.
|
547 |
+
num_heads (int): Number of attention heads.
|
548 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
549 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
550 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
551 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
552 |
+
"""
|
553 |
+
|
554 |
+
def __init__(self, dim, window_size, num_heads, down_ratio=1, qkv_bias=True, qk_scale=None, attn_drop=0.,
|
555 |
+
proj_drop=0.):
|
556 |
+
|
557 |
+
super().__init__()
|
558 |
+
self.dim = dim
|
559 |
+
self.window_size = window_size # Wh, Ww
|
560 |
+
self.num_heads = num_heads
|
561 |
+
head_dim = dim // num_heads
|
562 |
+
self.scale = qk_scale or head_dim ** -0.5
|
563 |
+
|
564 |
+
self.q = FullyConnectedLayer(in_features=dim, out_features=dim)
|
565 |
+
self.k = FullyConnectedLayer(in_features=dim, out_features=dim)
|
566 |
+
self.v = FullyConnectedLayer(in_features=dim, out_features=dim)
|
567 |
+
self.proj = FullyConnectedLayer(in_features=dim, out_features=dim)
|
568 |
+
|
569 |
+
self.softmax = nn.Softmax(dim=-1)
|
570 |
+
|
571 |
+
def forward(self, x, mask_windows=None, mask=None):
|
572 |
+
"""
|
573 |
+
Args:
|
574 |
+
x: input features with shape of (num_windows*B, N, C)
|
575 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
576 |
+
"""
|
577 |
+
B_, N, C = x.shape
|
578 |
+
norm_x = F.normalize(x, p=2.0, dim=-1)
|
579 |
+
q = self.q(norm_x).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
580 |
+
k = self.k(norm_x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 3, 1)
|
581 |
+
v = self.v(x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
582 |
+
|
583 |
+
attn = (q @ k) * self.scale
|
584 |
+
|
585 |
+
if mask is not None:
|
586 |
+
nW = mask.shape[0]
|
587 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
588 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
589 |
+
|
590 |
+
if mask_windows is not None:
|
591 |
+
attn_mask_windows = mask_windows.squeeze(-1).unsqueeze(1).unsqueeze(1)
|
592 |
+
attn = attn + attn_mask_windows.masked_fill(attn_mask_windows == 0, float(-100.0)).masked_fill(
|
593 |
+
attn_mask_windows == 1, float(0.0))
|
594 |
+
with torch.no_grad():
|
595 |
+
mask_windows = torch.clamp(torch.sum(mask_windows, dim=1, keepdim=True), 0, 1).repeat(1, N, 1)
|
596 |
+
|
597 |
+
attn = self.softmax(attn)
|
598 |
+
|
599 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
600 |
+
x = self.proj(x)
|
601 |
+
return x, mask_windows
|
602 |
+
|
603 |
+
|
604 |
+
class SwinTransformerBlock(nn.Module):
|
605 |
+
r""" Swin Transformer Block.
|
606 |
+
Args:
|
607 |
+
dim (int): Number of input channels.
|
608 |
+
input_resolution (tuple[int]): Input resulotion.
|
609 |
+
num_heads (int): Number of attention heads.
|
610 |
+
window_size (int): Window size.
|
611 |
+
shift_size (int): Shift size for SW-MSA.
|
612 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
613 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
614 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
615 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
616 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
617 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
618 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
619 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
620 |
+
"""
|
621 |
+
|
622 |
+
def __init__(self, dim, input_resolution, num_heads, down_ratio=1, window_size=7, shift_size=0,
|
623 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
624 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
625 |
+
super().__init__()
|
626 |
+
self.dim = dim
|
627 |
+
self.input_resolution = input_resolution
|
628 |
+
self.num_heads = num_heads
|
629 |
+
self.window_size = window_size
|
630 |
+
self.shift_size = shift_size
|
631 |
+
self.mlp_ratio = mlp_ratio
|
632 |
+
if min(self.input_resolution) <= self.window_size:
|
633 |
+
# if window size is larger than input resolution, we don't partition windows
|
634 |
+
self.shift_size = 0
|
635 |
+
self.window_size = min(self.input_resolution)
|
636 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
637 |
+
|
638 |
+
if self.shift_size > 0:
|
639 |
+
down_ratio = 1
|
640 |
+
self.attn = WindowAttention(dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
641 |
+
down_ratio=down_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
|
642 |
+
proj_drop=drop)
|
643 |
+
|
644 |
+
self.fuse = FullyConnectedLayer(in_features=dim * 2, out_features=dim, activation='lrelu')
|
645 |
+
|
646 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
647 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
648 |
+
|
649 |
+
if self.shift_size > 0:
|
650 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
651 |
+
else:
|
652 |
+
attn_mask = None
|
653 |
+
|
654 |
+
self.register_buffer("attn_mask", attn_mask)
|
655 |
+
|
656 |
+
def calculate_mask(self, x_size):
|
657 |
+
# calculate attention mask for SW-MSA
|
658 |
+
H, W = x_size
|
659 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
660 |
+
h_slices = (slice(0, -self.window_size),
|
661 |
+
slice(-self.window_size, -self.shift_size),
|
662 |
+
slice(-self.shift_size, None))
|
663 |
+
w_slices = (slice(0, -self.window_size),
|
664 |
+
slice(-self.window_size, -self.shift_size),
|
665 |
+
slice(-self.shift_size, None))
|
666 |
+
cnt = 0
|
667 |
+
for h in h_slices:
|
668 |
+
for w in w_slices:
|
669 |
+
img_mask[:, h, w, :] = cnt
|
670 |
+
cnt += 1
|
671 |
+
|
672 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
673 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
674 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
675 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
676 |
+
|
677 |
+
return attn_mask
|
678 |
+
|
679 |
+
def forward(self, x, x_size, mask=None):
|
680 |
+
# H, W = self.input_resolution
|
681 |
+
H, W = x_size
|
682 |
+
B, L, C = x.shape
|
683 |
+
# assert L == H * W, "input feature has wrong size"
|
684 |
+
|
685 |
+
shortcut = x
|
686 |
+
x = x.view(B, H, W, C)
|
687 |
+
if mask is not None:
|
688 |
+
mask = mask.view(B, H, W, 1)
|
689 |
+
|
690 |
+
# cyclic shift
|
691 |
+
if self.shift_size > 0:
|
692 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
693 |
+
if mask is not None:
|
694 |
+
shifted_mask = torch.roll(mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
695 |
+
else:
|
696 |
+
shifted_x = x
|
697 |
+
if mask is not None:
|
698 |
+
shifted_mask = mask
|
699 |
+
|
700 |
+
# partition windows
|
701 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
702 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
703 |
+
if mask is not None:
|
704 |
+
mask_windows = window_partition(shifted_mask, self.window_size)
|
705 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size, 1)
|
706 |
+
else:
|
707 |
+
mask_windows = None
|
708 |
+
|
709 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
710 |
+
if self.input_resolution == x_size:
|
711 |
+
attn_windows, mask_windows = self.attn(x_windows, mask_windows,
|
712 |
+
mask=self.attn_mask) # nW*B, window_size*window_size, C
|
713 |
+
else:
|
714 |
+
attn_windows, mask_windows = self.attn(x_windows, mask_windows, mask=self.calculate_mask(x_size).to(
|
715 |
+
x.device)) # nW*B, window_size*window_size, C
|
716 |
+
|
717 |
+
# merge windows
|
718 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
719 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
720 |
+
if mask is not None:
|
721 |
+
mask_windows = mask_windows.view(-1, self.window_size, self.window_size, 1)
|
722 |
+
shifted_mask = window_reverse(mask_windows, self.window_size, H, W)
|
723 |
+
|
724 |
+
# reverse cyclic shift
|
725 |
+
if self.shift_size > 0:
|
726 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
727 |
+
if mask is not None:
|
728 |
+
mask = torch.roll(shifted_mask, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
729 |
+
else:
|
730 |
+
x = shifted_x
|
731 |
+
if mask is not None:
|
732 |
+
mask = shifted_mask
|
733 |
+
x = x.view(B, H * W, C)
|
734 |
+
if mask is not None:
|
735 |
+
mask = mask.view(B, H * W, 1)
|
736 |
+
|
737 |
+
# FFN
|
738 |
+
x = self.fuse(torch.cat([shortcut, x], dim=-1))
|
739 |
+
x = self.mlp(x)
|
740 |
+
|
741 |
+
return x, mask
|
742 |
+
|
743 |
+
|
744 |
+
class PatchMerging(nn.Module):
|
745 |
+
def __init__(self, in_channels, out_channels, down=2):
|
746 |
+
super().__init__()
|
747 |
+
self.conv = Conv2dLayerPartial(in_channels=in_channels,
|
748 |
+
out_channels=out_channels,
|
749 |
+
kernel_size=3,
|
750 |
+
activation='lrelu',
|
751 |
+
down=down,
|
752 |
+
)
|
753 |
+
self.down = down
|
754 |
+
|
755 |
+
def forward(self, x, x_size, mask=None):
|
756 |
+
x = token2feature(x, x_size)
|
757 |
+
if mask is not None:
|
758 |
+
mask = token2feature(mask, x_size)
|
759 |
+
x, mask = self.conv(x, mask)
|
760 |
+
if self.down != 1:
|
761 |
+
ratio = 1 / self.down
|
762 |
+
x_size = (int(x_size[0] * ratio), int(x_size[1] * ratio))
|
763 |
+
x = feature2token(x)
|
764 |
+
if mask is not None:
|
765 |
+
mask = feature2token(mask)
|
766 |
+
return x, x_size, mask
|
767 |
+
|
768 |
+
|
769 |
+
class PatchUpsampling(nn.Module):
|
770 |
+
def __init__(self, in_channels, out_channels, up=2):
|
771 |
+
super().__init__()
|
772 |
+
self.conv = Conv2dLayerPartial(in_channels=in_channels,
|
773 |
+
out_channels=out_channels,
|
774 |
+
kernel_size=3,
|
775 |
+
activation='lrelu',
|
776 |
+
up=up,
|
777 |
+
)
|
778 |
+
self.up = up
|
779 |
+
|
780 |
+
def forward(self, x, x_size, mask=None):
|
781 |
+
x = token2feature(x, x_size)
|
782 |
+
if mask is not None:
|
783 |
+
mask = token2feature(mask, x_size)
|
784 |
+
x, mask = self.conv(x, mask)
|
785 |
+
if self.up != 1:
|
786 |
+
x_size = (int(x_size[0] * self.up), int(x_size[1] * self.up))
|
787 |
+
x = feature2token(x)
|
788 |
+
if mask is not None:
|
789 |
+
mask = feature2token(mask)
|
790 |
+
return x, x_size, mask
|
791 |
+
|
792 |
+
|
793 |
+
class BasicLayer(nn.Module):
|
794 |
+
""" A basic Swin Transformer layer for one stage.
|
795 |
+
Args:
|
796 |
+
dim (int): Number of input channels.
|
797 |
+
input_resolution (tuple[int]): Input resolution.
|
798 |
+
depth (int): Number of blocks.
|
799 |
+
num_heads (int): Number of attention heads.
|
800 |
+
window_size (int): Local window size.
|
801 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
802 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
803 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
804 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
805 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
806 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
807 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
808 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
809 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
810 |
+
"""
|
811 |
+
|
812 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size, down_ratio=1,
|
813 |
+
mlp_ratio=2., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
814 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
815 |
+
|
816 |
+
super().__init__()
|
817 |
+
self.dim = dim
|
818 |
+
self.input_resolution = input_resolution
|
819 |
+
self.depth = depth
|
820 |
+
self.use_checkpoint = use_checkpoint
|
821 |
+
|
822 |
+
# patch merging layer
|
823 |
+
if downsample is not None:
|
824 |
+
# self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
825 |
+
self.downsample = downsample
|
826 |
+
else:
|
827 |
+
self.downsample = None
|
828 |
+
|
829 |
+
# build blocks
|
830 |
+
self.blocks = nn.ModuleList([
|
831 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
832 |
+
num_heads=num_heads, down_ratio=down_ratio, window_size=window_size,
|
833 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
834 |
+
mlp_ratio=mlp_ratio,
|
835 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
836 |
+
drop=drop, attn_drop=attn_drop,
|
837 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
838 |
+
norm_layer=norm_layer)
|
839 |
+
for i in range(depth)])
|
840 |
+
|
841 |
+
self.conv = Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, activation='lrelu')
|
842 |
+
|
843 |
+
def forward(self, x, x_size, mask=None):
|
844 |
+
if self.downsample is not None:
|
845 |
+
x, x_size, mask = self.downsample(x, x_size, mask)
|
846 |
+
identity = x
|
847 |
+
for blk in self.blocks:
|
848 |
+
if self.use_checkpoint:
|
849 |
+
x, mask = checkpoint.checkpoint(blk, x, x_size, mask)
|
850 |
+
else:
|
851 |
+
x, mask = blk(x, x_size, mask)
|
852 |
+
if mask is not None:
|
853 |
+
mask = token2feature(mask, x_size)
|
854 |
+
x, mask = self.conv(token2feature(x, x_size), mask)
|
855 |
+
x = feature2token(x) + identity
|
856 |
+
if mask is not None:
|
857 |
+
mask = feature2token(mask)
|
858 |
+
return x, x_size, mask
|
859 |
+
|
860 |
+
|
861 |
+
class ToToken(nn.Module):
|
862 |
+
def __init__(self, in_channels=3, dim=128, kernel_size=5, stride=1):
|
863 |
+
super().__init__()
|
864 |
+
|
865 |
+
self.proj = Conv2dLayerPartial(in_channels=in_channels, out_channels=dim, kernel_size=kernel_size,
|
866 |
+
activation='lrelu')
|
867 |
+
|
868 |
+
def forward(self, x, mask):
|
869 |
+
x, mask = self.proj(x, mask)
|
870 |
+
|
871 |
+
return x, mask
|
872 |
+
|
873 |
+
|
874 |
+
class EncFromRGB(nn.Module):
|
875 |
+
def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
|
876 |
+
super().__init__()
|
877 |
+
self.conv0 = Conv2dLayer(in_channels=in_channels,
|
878 |
+
out_channels=out_channels,
|
879 |
+
kernel_size=1,
|
880 |
+
activation=activation,
|
881 |
+
)
|
882 |
+
self.conv1 = Conv2dLayer(in_channels=out_channels,
|
883 |
+
out_channels=out_channels,
|
884 |
+
kernel_size=3,
|
885 |
+
activation=activation,
|
886 |
+
)
|
887 |
+
|
888 |
+
def forward(self, x):
|
889 |
+
x = self.conv0(x)
|
890 |
+
x = self.conv1(x)
|
891 |
+
|
892 |
+
return x
|
893 |
+
|
894 |
+
|
895 |
+
class ConvBlockDown(nn.Module):
|
896 |
+
def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log
|
897 |
+
super().__init__()
|
898 |
+
|
899 |
+
self.conv0 = Conv2dLayer(in_channels=in_channels,
|
900 |
+
out_channels=out_channels,
|
901 |
+
kernel_size=3,
|
902 |
+
activation=activation,
|
903 |
+
down=2,
|
904 |
+
)
|
905 |
+
self.conv1 = Conv2dLayer(in_channels=out_channels,
|
906 |
+
out_channels=out_channels,
|
907 |
+
kernel_size=3,
|
908 |
+
activation=activation,
|
909 |
+
)
|
910 |
+
|
911 |
+
def forward(self, x):
|
912 |
+
x = self.conv0(x)
|
913 |
+
x = self.conv1(x)
|
914 |
+
|
915 |
+
return x
|
916 |
+
|
917 |
+
|
918 |
+
def token2feature(x, x_size):
|
919 |
+
B, N, C = x.shape
|
920 |
+
h, w = x_size
|
921 |
+
x = x.permute(0, 2, 1).reshape(B, C, h, w)
|
922 |
+
return x
|
923 |
+
|
924 |
+
|
925 |
+
def feature2token(x):
|
926 |
+
B, C, H, W = x.shape
|
927 |
+
x = x.view(B, C, -1).transpose(1, 2)
|
928 |
+
return x
|
929 |
+
|
930 |
+
|
931 |
+
class Encoder(nn.Module):
|
932 |
+
def __init__(self, res_log2, img_channels, activation, patch_size=5, channels=16, drop_path_rate=0.1):
|
933 |
+
super().__init__()
|
934 |
+
|
935 |
+
self.resolution = []
|
936 |
+
|
937 |
+
for idx, i in enumerate(range(res_log2, 3, -1)): # from input size to 16x16
|
938 |
+
res = 2 ** i
|
939 |
+
self.resolution.append(res)
|
940 |
+
if i == res_log2:
|
941 |
+
block = EncFromRGB(img_channels * 2 + 1, nf(i), activation)
|
942 |
+
else:
|
943 |
+
block = ConvBlockDown(nf(i + 1), nf(i), activation)
|
944 |
+
setattr(self, 'EncConv_Block_%dx%d' % (res, res), block)
|
945 |
+
|
946 |
+
def forward(self, x):
|
947 |
+
out = {}
|
948 |
+
for res in self.resolution:
|
949 |
+
res_log2 = int(np.log2(res))
|
950 |
+
x = getattr(self, 'EncConv_Block_%dx%d' % (res, res))(x)
|
951 |
+
out[res_log2] = x
|
952 |
+
|
953 |
+
return out
|
954 |
+
|
955 |
+
|
956 |
+
class ToStyle(nn.Module):
|
957 |
+
def __init__(self, in_channels, out_channels, activation, drop_rate):
|
958 |
+
super().__init__()
|
959 |
+
self.conv = nn.Sequential(
|
960 |
+
Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation,
|
961 |
+
down=2),
|
962 |
+
Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation,
|
963 |
+
down=2),
|
964 |
+
Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation,
|
965 |
+
down=2),
|
966 |
+
)
|
967 |
+
|
968 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
969 |
+
self.fc = FullyConnectedLayer(in_features=in_channels,
|
970 |
+
out_features=out_channels,
|
971 |
+
activation=activation)
|
972 |
+
# self.dropout = nn.Dropout(drop_rate)
|
973 |
+
|
974 |
+
def forward(self, x):
|
975 |
+
x = self.conv(x)
|
976 |
+
x = self.pool(x)
|
977 |
+
x = self.fc(x.flatten(start_dim=1))
|
978 |
+
# x = self.dropout(x)
|
979 |
+
|
980 |
+
return x
|
981 |
+
|
982 |
+
|
983 |
+
class DecBlockFirstV2(nn.Module):
|
984 |
+
def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
|
985 |
+
super().__init__()
|
986 |
+
self.res = res
|
987 |
+
|
988 |
+
self.conv0 = Conv2dLayer(in_channels=in_channels,
|
989 |
+
out_channels=in_channels,
|
990 |
+
kernel_size=3,
|
991 |
+
activation=activation,
|
992 |
+
)
|
993 |
+
self.conv1 = StyleConv(in_channels=in_channels,
|
994 |
+
out_channels=out_channels,
|
995 |
+
style_dim=style_dim,
|
996 |
+
resolution=2 ** res,
|
997 |
+
kernel_size=3,
|
998 |
+
use_noise=use_noise,
|
999 |
+
activation=activation,
|
1000 |
+
demodulate=demodulate,
|
1001 |
+
)
|
1002 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
1003 |
+
out_channels=img_channels,
|
1004 |
+
style_dim=style_dim,
|
1005 |
+
kernel_size=1,
|
1006 |
+
demodulate=False,
|
1007 |
+
)
|
1008 |
+
|
1009 |
+
def forward(self, x, ws, gs, E_features, noise_mode='random'):
|
1010 |
+
# x = self.fc(x).view(x.shape[0], -1, 4, 4)
|
1011 |
+
x = self.conv0(x)
|
1012 |
+
x = x + E_features[self.res]
|
1013 |
+
style = get_style_code(ws[:, 0], gs)
|
1014 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
1015 |
+
style = get_style_code(ws[:, 1], gs)
|
1016 |
+
img = self.toRGB(x, style, skip=None)
|
1017 |
+
|
1018 |
+
return x, img
|
1019 |
+
|
1020 |
+
|
1021 |
+
class DecBlock(nn.Module):
|
1022 |
+
def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate,
|
1023 |
+
img_channels): # res = 4, ..., resolution_log2
|
1024 |
+
super().__init__()
|
1025 |
+
self.res = res
|
1026 |
+
|
1027 |
+
self.conv0 = StyleConv(in_channels=in_channels,
|
1028 |
+
out_channels=out_channels,
|
1029 |
+
style_dim=style_dim,
|
1030 |
+
resolution=2 ** res,
|
1031 |
+
kernel_size=3,
|
1032 |
+
up=2,
|
1033 |
+
use_noise=use_noise,
|
1034 |
+
activation=activation,
|
1035 |
+
demodulate=demodulate,
|
1036 |
+
)
|
1037 |
+
self.conv1 = StyleConv(in_channels=out_channels,
|
1038 |
+
out_channels=out_channels,
|
1039 |
+
style_dim=style_dim,
|
1040 |
+
resolution=2 ** res,
|
1041 |
+
kernel_size=3,
|
1042 |
+
use_noise=use_noise,
|
1043 |
+
activation=activation,
|
1044 |
+
demodulate=demodulate,
|
1045 |
+
)
|
1046 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
1047 |
+
out_channels=img_channels,
|
1048 |
+
style_dim=style_dim,
|
1049 |
+
kernel_size=1,
|
1050 |
+
demodulate=False,
|
1051 |
+
)
|
1052 |
+
|
1053 |
+
def forward(self, x, img, ws, gs, E_features, noise_mode='random'):
|
1054 |
+
style = get_style_code(ws[:, self.res * 2 - 9], gs)
|
1055 |
+
x = self.conv0(x, style, noise_mode=noise_mode)
|
1056 |
+
x = x + E_features[self.res]
|
1057 |
+
style = get_style_code(ws[:, self.res * 2 - 8], gs)
|
1058 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
1059 |
+
style = get_style_code(ws[:, self.res * 2 - 7], gs)
|
1060 |
+
img = self.toRGB(x, style, skip=img)
|
1061 |
+
|
1062 |
+
return x, img
|
1063 |
+
|
1064 |
+
|
1065 |
+
class Decoder(nn.Module):
|
1066 |
+
def __init__(self, res_log2, activation, style_dim, use_noise, demodulate, img_channels):
|
1067 |
+
super().__init__()
|
1068 |
+
self.Dec_16x16 = DecBlockFirstV2(4, nf(4), nf(4), activation, style_dim, use_noise, demodulate, img_channels)
|
1069 |
+
for res in range(5, res_log2 + 1):
|
1070 |
+
setattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res),
|
1071 |
+
DecBlock(res, nf(res - 1), nf(res), activation, style_dim, use_noise, demodulate, img_channels))
|
1072 |
+
self.res_log2 = res_log2
|
1073 |
+
|
1074 |
+
def forward(self, x, ws, gs, E_features, noise_mode='random'):
|
1075 |
+
x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode)
|
1076 |
+
for res in range(5, self.res_log2 + 1):
|
1077 |
+
block = getattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res))
|
1078 |
+
x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode)
|
1079 |
+
|
1080 |
+
return img
|
1081 |
+
|
1082 |
+
|
1083 |
+
class DecStyleBlock(nn.Module):
|
1084 |
+
def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
|
1085 |
+
super().__init__()
|
1086 |
+
self.res = res
|
1087 |
+
|
1088 |
+
self.conv0 = StyleConv(in_channels=in_channels,
|
1089 |
+
out_channels=out_channels,
|
1090 |
+
style_dim=style_dim,
|
1091 |
+
resolution=2 ** res,
|
1092 |
+
kernel_size=3,
|
1093 |
+
up=2,
|
1094 |
+
use_noise=use_noise,
|
1095 |
+
activation=activation,
|
1096 |
+
demodulate=demodulate,
|
1097 |
+
)
|
1098 |
+
self.conv1 = StyleConv(in_channels=out_channels,
|
1099 |
+
out_channels=out_channels,
|
1100 |
+
style_dim=style_dim,
|
1101 |
+
resolution=2 ** res,
|
1102 |
+
kernel_size=3,
|
1103 |
+
use_noise=use_noise,
|
1104 |
+
activation=activation,
|
1105 |
+
demodulate=demodulate,
|
1106 |
+
)
|
1107 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
1108 |
+
out_channels=img_channels,
|
1109 |
+
style_dim=style_dim,
|
1110 |
+
kernel_size=1,
|
1111 |
+
demodulate=False,
|
1112 |
+
)
|
1113 |
+
|
1114 |
+
def forward(self, x, img, style, skip, noise_mode='random'):
|
1115 |
+
x = self.conv0(x, style, noise_mode=noise_mode)
|
1116 |
+
x = x + skip
|
1117 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
1118 |
+
img = self.toRGB(x, style, skip=img)
|
1119 |
+
|
1120 |
+
return x, img
|
1121 |
+
|
1122 |
+
|
1123 |
+
class FirstStage(nn.Module):
|
1124 |
+
def __init__(self, img_channels, img_resolution=256, dim=180, w_dim=512, use_noise=False, demodulate=True,
|
1125 |
+
activation='lrelu'):
|
1126 |
+
super().__init__()
|
1127 |
+
res = 64
|
1128 |
+
|
1129 |
+
self.conv_first = Conv2dLayerPartial(in_channels=img_channels + 1, out_channels=dim, kernel_size=3,
|
1130 |
+
activation=activation)
|
1131 |
+
self.enc_conv = nn.ModuleList()
|
1132 |
+
down_time = int(np.log2(img_resolution // res))
|
1133 |
+
# 根据图片尺寸构建 swim transformer 的层数
|
1134 |
+
for i in range(down_time): # from input size to 64
|
1135 |
+
self.enc_conv.append(
|
1136 |
+
Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation)
|
1137 |
+
)
|
1138 |
+
|
1139 |
+
# from 64 -> 16 -> 64
|
1140 |
+
depths = [2, 3, 4, 3, 2]
|
1141 |
+
ratios = [1, 1 / 2, 1 / 2, 2, 2]
|
1142 |
+
num_heads = 6
|
1143 |
+
window_sizes = [8, 16, 16, 16, 8]
|
1144 |
+
drop_path_rate = 0.1
|
1145 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
1146 |
+
|
1147 |
+
self.tran = nn.ModuleList()
|
1148 |
+
for i, depth in enumerate(depths):
|
1149 |
+
res = int(res * ratios[i])
|
1150 |
+
if ratios[i] < 1:
|
1151 |
+
merge = PatchMerging(dim, dim, down=int(1 / ratios[i]))
|
1152 |
+
elif ratios[i] > 1:
|
1153 |
+
merge = PatchUpsampling(dim, dim, up=ratios[i])
|
1154 |
+
else:
|
1155 |
+
merge = None
|
1156 |
+
self.tran.append(
|
1157 |
+
BasicLayer(dim=dim, input_resolution=[res, res], depth=depth, num_heads=num_heads,
|
1158 |
+
window_size=window_sizes[i], drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
|
1159 |
+
downsample=merge)
|
1160 |
+
)
|
1161 |
+
|
1162 |
+
# global style
|
1163 |
+
down_conv = []
|
1164 |
+
for i in range(int(np.log2(16))):
|
1165 |
+
down_conv.append(
|
1166 |
+
Conv2dLayer(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation))
|
1167 |
+
down_conv.append(nn.AdaptiveAvgPool2d((1, 1)))
|
1168 |
+
self.down_conv = nn.Sequential(*down_conv)
|
1169 |
+
self.to_style = FullyConnectedLayer(in_features=dim, out_features=dim * 2, activation=activation)
|
1170 |
+
self.ws_style = FullyConnectedLayer(in_features=w_dim, out_features=dim, activation=activation)
|
1171 |
+
self.to_square = FullyConnectedLayer(in_features=dim, out_features=16 * 16, activation=activation)
|
1172 |
+
|
1173 |
+
style_dim = dim * 3
|
1174 |
+
self.dec_conv = nn.ModuleList()
|
1175 |
+
for i in range(down_time): # from 64 to input size
|
1176 |
+
res = res * 2
|
1177 |
+
self.dec_conv.append(
|
1178 |
+
DecStyleBlock(res, dim, dim, activation, style_dim, use_noise, demodulate, img_channels))
|
1179 |
+
|
1180 |
+
def forward(self, images_in, masks_in, ws, noise_mode='random'):
|
1181 |
+
x = torch.cat([masks_in - 0.5, images_in * masks_in], dim=1)
|
1182 |
+
|
1183 |
+
skips = []
|
1184 |
+
x, mask = self.conv_first(x, masks_in) # input size
|
1185 |
+
skips.append(x)
|
1186 |
+
for i, block in enumerate(self.enc_conv): # input size to 64
|
1187 |
+
x, mask = block(x, mask)
|
1188 |
+
if i != len(self.enc_conv) - 1:
|
1189 |
+
skips.append(x)
|
1190 |
+
|
1191 |
+
x_size = x.size()[-2:]
|
1192 |
+
x = feature2token(x)
|
1193 |
+
mask = feature2token(mask)
|
1194 |
+
mid = len(self.tran) // 2
|
1195 |
+
for i, block in enumerate(self.tran): # 64 to 16
|
1196 |
+
if i < mid:
|
1197 |
+
x, x_size, mask = block(x, x_size, mask)
|
1198 |
+
skips.append(x)
|
1199 |
+
elif i > mid:
|
1200 |
+
x, x_size, mask = block(x, x_size, None)
|
1201 |
+
x = x + skips[mid - i]
|
1202 |
+
else:
|
1203 |
+
x, x_size, mask = block(x, x_size, None)
|
1204 |
+
|
1205 |
+
mul_map = torch.ones_like(x) * 0.5
|
1206 |
+
mul_map = F.dropout(mul_map, training=True)
|
1207 |
+
ws = self.ws_style(ws[:, -1])
|
1208 |
+
add_n = self.to_square(ws).unsqueeze(1)
|
1209 |
+
add_n = F.interpolate(add_n, size=x.size(1), mode='linear', align_corners=False).squeeze(1).unsqueeze(
|
1210 |
+
-1)
|
1211 |
+
x = x * mul_map + add_n * (1 - mul_map)
|
1212 |
+
gs = self.to_style(self.down_conv(token2feature(x, x_size)).flatten(start_dim=1))
|
1213 |
+
style = torch.cat([gs, ws], dim=1)
|
1214 |
+
|
1215 |
+
x = token2feature(x, x_size).contiguous()
|
1216 |
+
img = None
|
1217 |
+
for i, block in enumerate(self.dec_conv):
|
1218 |
+
x, img = block(x, img, style, skips[len(self.dec_conv) - i - 1], noise_mode=noise_mode)
|
1219 |
+
|
1220 |
+
# ensemble
|
1221 |
+
img = img * (1 - masks_in) + images_in * masks_in
|
1222 |
+
|
1223 |
+
return img
|
1224 |
+
|
1225 |
+
|
1226 |
+
class SynthesisNet(nn.Module):
|
1227 |
+
def __init__(self,
|
1228 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
1229 |
+
img_resolution, # Output image resolution.
|
1230 |
+
img_channels=3, # Number of color channels.
|
1231 |
+
channel_base=32768, # Overall multiplier for the number of channels.
|
1232 |
+
channel_decay=1.0,
|
1233 |
+
channel_max=512, # Maximum number of channels in any layer.
|
1234 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
1235 |
+
drop_rate=0.5,
|
1236 |
+
use_noise=False,
|
1237 |
+
demodulate=True,
|
1238 |
+
):
|
1239 |
+
super().__init__()
|
1240 |
+
resolution_log2 = int(np.log2(img_resolution))
|
1241 |
+
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
|
1242 |
+
|
1243 |
+
self.num_layers = resolution_log2 * 2 - 3 * 2
|
1244 |
+
self.img_resolution = img_resolution
|
1245 |
+
self.resolution_log2 = resolution_log2
|
1246 |
+
|
1247 |
+
# first stage
|
1248 |
+
self.first_stage = FirstStage(img_channels, img_resolution=img_resolution, w_dim=w_dim, use_noise=False,
|
1249 |
+
demodulate=demodulate)
|
1250 |
+
|
1251 |
+
# second stage
|
1252 |
+
self.enc = Encoder(resolution_log2, img_channels, activation, patch_size=5, channels=16)
|
1253 |
+
self.to_square = FullyConnectedLayer(in_features=w_dim, out_features=16 * 16, activation=activation)
|
1254 |
+
self.to_style = ToStyle(in_channels=nf(4), out_channels=nf(2) * 2, activation=activation, drop_rate=drop_rate)
|
1255 |
+
style_dim = w_dim + nf(2) * 2
|
1256 |
+
self.dec = Decoder(resolution_log2, activation, style_dim, use_noise, demodulate, img_channels)
|
1257 |
+
|
1258 |
+
def forward(self, images_in, masks_in, ws, noise_mode='random', return_stg1=False):
|
1259 |
+
out_stg1 = self.first_stage(images_in, masks_in, ws, noise_mode=noise_mode)
|
1260 |
+
|
1261 |
+
# encoder
|
1262 |
+
x = images_in * masks_in + out_stg1 * (1 - masks_in)
|
1263 |
+
x = torch.cat([masks_in - 0.5, x, images_in * masks_in], dim=1)
|
1264 |
+
E_features = self.enc(x)
|
1265 |
+
|
1266 |
+
fea_16 = E_features[4]
|
1267 |
+
mul_map = torch.ones_like(fea_16) * 0.5
|
1268 |
+
mul_map = F.dropout(mul_map, training=True)
|
1269 |
+
add_n = self.to_square(ws[:, 0]).view(-1, 16, 16).unsqueeze(1)
|
1270 |
+
add_n = F.interpolate(add_n, size=fea_16.size()[-2:], mode='bilinear', align_corners=False)
|
1271 |
+
fea_16 = fea_16 * mul_map + add_n * (1 - mul_map)
|
1272 |
+
E_features[4] = fea_16
|
1273 |
+
|
1274 |
+
# style
|
1275 |
+
gs = self.to_style(fea_16)
|
1276 |
+
|
1277 |
+
# decoder
|
1278 |
+
img = self.dec(fea_16, ws, gs, E_features, noise_mode=noise_mode)
|
1279 |
+
|
1280 |
+
# ensemble
|
1281 |
+
img = img * (1 - masks_in) + images_in * masks_in
|
1282 |
+
|
1283 |
+
if not return_stg1:
|
1284 |
+
return img
|
1285 |
+
else:
|
1286 |
+
return img, out_stg1
|
1287 |
+
|
1288 |
+
|
1289 |
+
class Generator(nn.Module):
|
1290 |
+
def __init__(self,
|
1291 |
+
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
|
1292 |
+
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
|
1293 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
1294 |
+
img_resolution, # resolution of generated image
|
1295 |
+
img_channels, # Number of input color channels.
|
1296 |
+
synthesis_kwargs={}, # Arguments for SynthesisNetwork.
|
1297 |
+
mapping_kwargs={}, # Arguments for MappingNetwork.
|
1298 |
+
):
|
1299 |
+
super().__init__()
|
1300 |
+
self.z_dim = z_dim
|
1301 |
+
self.c_dim = c_dim
|
1302 |
+
self.w_dim = w_dim
|
1303 |
+
self.img_resolution = img_resolution
|
1304 |
+
self.img_channels = img_channels
|
1305 |
+
|
1306 |
+
self.synthesis = SynthesisNet(w_dim=w_dim,
|
1307 |
+
img_resolution=img_resolution,
|
1308 |
+
img_channels=img_channels,
|
1309 |
+
**synthesis_kwargs)
|
1310 |
+
self.mapping = MappingNet(z_dim=z_dim,
|
1311 |
+
c_dim=c_dim,
|
1312 |
+
w_dim=w_dim,
|
1313 |
+
num_ws=self.synthesis.num_layers,
|
1314 |
+
**mapping_kwargs)
|
1315 |
+
|
1316 |
+
def forward(self, images_in, masks_in, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False,
|
1317 |
+
noise_mode='none', return_stg1=False):
|
1318 |
+
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff,
|
1319 |
+
skip_w_avg_update=skip_w_avg_update)
|
1320 |
+
img = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode)
|
1321 |
+
return img
|
1322 |
+
|
1323 |
+
|
1324 |
+
class Discriminator(torch.nn.Module):
|
1325 |
+
def __init__(self,
|
1326 |
+
c_dim, # Conditioning label (C) dimensionality.
|
1327 |
+
img_resolution, # Input resolution.
|
1328 |
+
img_channels, # Number of input color channels.
|
1329 |
+
channel_base=32768, # Overall multiplier for the number of channels.
|
1330 |
+
channel_max=512, # Maximum number of channels in any layer.
|
1331 |
+
channel_decay=1,
|
1332 |
+
cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
|
1333 |
+
activation='lrelu',
|
1334 |
+
mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
|
1335 |
+
mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
|
1336 |
+
):
|
1337 |
+
super().__init__()
|
1338 |
+
self.c_dim = c_dim
|
1339 |
+
self.img_resolution = img_resolution
|
1340 |
+
self.img_channels = img_channels
|
1341 |
+
|
1342 |
+
resolution_log2 = int(np.log2(img_resolution))
|
1343 |
+
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
|
1344 |
+
self.resolution_log2 = resolution_log2
|
1345 |
+
|
1346 |
+
if cmap_dim == None:
|
1347 |
+
cmap_dim = nf(2)
|
1348 |
+
if c_dim == 0:
|
1349 |
+
cmap_dim = 0
|
1350 |
+
self.cmap_dim = cmap_dim
|
1351 |
+
|
1352 |
+
if c_dim > 0:
|
1353 |
+
self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None)
|
1354 |
+
|
1355 |
+
Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)]
|
1356 |
+
for res in range(resolution_log2, 2, -1):
|
1357 |
+
Dis.append(DisBlock(nf(res), nf(res - 1), activation))
|
1358 |
+
|
1359 |
+
if mbstd_num_channels > 0:
|
1360 |
+
Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
|
1361 |
+
Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation))
|
1362 |
+
self.Dis = nn.Sequential(*Dis)
|
1363 |
+
|
1364 |
+
self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
|
1365 |
+
self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
|
1366 |
+
|
1367 |
+
# for 64x64
|
1368 |
+
Dis_stg1 = [DisFromRGB(img_channels + 1, nf(resolution_log2) // 2, activation)]
|
1369 |
+
for res in range(resolution_log2, 2, -1):
|
1370 |
+
Dis_stg1.append(DisBlock(nf(res) // 2, nf(res - 1) // 2, activation))
|
1371 |
+
|
1372 |
+
if mbstd_num_channels > 0:
|
1373 |
+
Dis_stg1.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
|
1374 |
+
Dis_stg1.append(Conv2dLayer(nf(2) // 2 + mbstd_num_channels, nf(2) // 2, kernel_size=3, activation=activation))
|
1375 |
+
self.Dis_stg1 = nn.Sequential(*Dis_stg1)
|
1376 |
+
|
1377 |
+
self.fc0_stg1 = FullyConnectedLayer(nf(2) // 2 * 4 ** 2, nf(2) // 2, activation=activation)
|
1378 |
+
self.fc1_stg1 = FullyConnectedLayer(nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim)
|
1379 |
+
|
1380 |
+
def forward(self, images_in, masks_in, images_stg1, c):
|
1381 |
+
x = self.Dis(torch.cat([masks_in - 0.5, images_in], dim=1))
|
1382 |
+
x = self.fc1(self.fc0(x.flatten(start_dim=1)))
|
1383 |
+
|
1384 |
+
x_stg1 = self.Dis_stg1(torch.cat([masks_in - 0.5, images_stg1], dim=1))
|
1385 |
+
x_stg1 = self.fc1_stg1(self.fc0_stg1(x_stg1.flatten(start_dim=1)))
|
1386 |
+
|
1387 |
+
if self.c_dim > 0:
|
1388 |
+
cmap = self.mapping(None, c)
|
1389 |
+
|
1390 |
+
if self.cmap_dim > 0:
|
1391 |
+
x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
1392 |
+
x_stg1 = (x_stg1 * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
1393 |
+
|
1394 |
+
return x, x_stg1
|
1395 |
+
|
1396 |
+
|
1397 |
+
MAT_MODEL_URL = os.environ.get(
|
1398 |
+
"MAT_MODEL_URL",
|
1399 |
+
"https://github.com/Sanster/models/releases/download/add_mat/Places_512_FullData_G.pth",
|
1400 |
+
)
|
1401 |
+
|
1402 |
+
|
1403 |
+
class MAT(InpaintModel):
|
1404 |
+
min_size = 512
|
1405 |
+
pad_mod = 512
|
1406 |
+
pad_to_square = True
|
1407 |
+
|
1408 |
+
def init_model(self, device, **kwargs):
|
1409 |
+
seed = 240 # pick up a random number
|
1410 |
+
random.seed(seed)
|
1411 |
+
np.random.seed(seed)
|
1412 |
+
torch.manual_seed(seed)
|
1413 |
+
|
1414 |
+
G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3)
|
1415 |
+
self.model = load_model(G, MAT_MODEL_URL, device)
|
1416 |
+
self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(device) # [1., 512]
|
1417 |
+
self.label = torch.zeros([1, self.model.c_dim], device=device)
|
1418 |
+
|
1419 |
+
@staticmethod
|
1420 |
+
def is_downloaded() -> bool:
|
1421 |
+
return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL))
|
1422 |
+
|
1423 |
+
def forward(self, image, mask, config: Config):
|
1424 |
+
"""Input images and output images have same size
|
1425 |
+
images: [H, W, C] RGB
|
1426 |
+
masks: [H, W] mask area == 255
|
1427 |
+
return: BGR IMAGE
|
1428 |
+
"""
|
1429 |
+
|
1430 |
+
image = norm_img(image) # [0, 1]
|
1431 |
+
image = image * 2 - 1 # [0, 1] -> [-1, 1]
|
1432 |
+
|
1433 |
+
mask = (mask > 127) * 255
|
1434 |
+
mask = 255 - mask
|
1435 |
+
mask = norm_img(mask)
|
1436 |
+
|
1437 |
+
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
1438 |
+
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
1439 |
+
|
1440 |
+
output = self.model(image, mask, self.z, self.label, truncation_psi=1, noise_mode='none')
|
1441 |
+
output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8)
|
1442 |
+
output = output[0].cpu().numpy()
|
1443 |
+
cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
1444 |
+
return cur_res
|
lama_cleaner/model/opencv2.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
|
3 |
+
from lama_cleaner.model.base import InpaintModel
|
4 |
+
from lama_cleaner.schema import Config
|
5 |
+
|
6 |
+
flag_map = {
|
7 |
+
"INPAINT_NS": cv2.INPAINT_NS,
|
8 |
+
"INPAINT_TELEA": cv2.INPAINT_TELEA
|
9 |
+
}
|
10 |
+
|
11 |
+
class OpenCV2(InpaintModel):
|
12 |
+
pad_mod = 1
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
def is_downloaded() -> bool:
|
16 |
+
return True
|
17 |
+
|
18 |
+
def forward(self, image, mask, config: Config):
|
19 |
+
"""Input image and output image have same size
|
20 |
+
image: [H, W, C] RGB
|
21 |
+
mask: [H, W, 1]
|
22 |
+
return: BGR IMAGE
|
23 |
+
"""
|
24 |
+
cur_res = cv2.inpaint(image[:,:,::-1], mask, inpaintRadius=config.cv2_radius, flags=flag_map[config.cv2_flag])
|
25 |
+
return cur_res
|