krunakuamar commited on
Commit
252e766
1 Parent(s): 762cf51

Upload 75 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +44 -0
  2. app.py +316 -0
  3. app/big-lama.pt +3 -0
  4. app/u2net.onnx +3 -0
  5. app/yolov8x-seg.pt +3 -0
  6. lama_cleaner/__init__.py +11 -0
  7. lama_cleaner/__pycache__/__init__.cpython-38.pyc +0 -0
  8. lama_cleaner/__pycache__/const.cpython-38.pyc +0 -0
  9. lama_cleaner/__pycache__/helper.cpython-38.pyc +0 -0
  10. lama_cleaner/__pycache__/interactive_seg.cpython-38.pyc +0 -0
  11. lama_cleaner/__pycache__/model_manager.cpython-38.pyc +0 -0
  12. lama_cleaner/__pycache__/parse_args.cpython-38.pyc +0 -0
  13. lama_cleaner/__pycache__/runtime.cpython-38.pyc +0 -0
  14. lama_cleaner/__pycache__/schema.cpython-38.pyc +0 -0
  15. lama_cleaner/__pycache__/server2.cpython-38.pyc +0 -0
  16. lama_cleaner/benchmark.py +109 -0
  17. lama_cleaner/const.py +68 -0
  18. lama_cleaner/file_manager/__init__.py +1 -0
  19. lama_cleaner/file_manager/__pycache__/__init__.cpython-38.pyc +0 -0
  20. lama_cleaner/file_manager/__pycache__/file_manager.cpython-38.pyc +0 -0
  21. lama_cleaner/file_manager/__pycache__/storage_backends.cpython-38.pyc +0 -0
  22. lama_cleaner/file_manager/__pycache__/utils.cpython-38.pyc +0 -0
  23. lama_cleaner/file_manager/file_manager.py +252 -0
  24. lama_cleaner/file_manager/storage_backends.py +46 -0
  25. lama_cleaner/file_manager/utils.py +66 -0
  26. lama_cleaner/helper.py +218 -0
  27. lama_cleaner/interactive_seg.py +202 -0
  28. lama_cleaner/model/__init__.py +0 -0
  29. lama_cleaner/model/__pycache__/__init__.cpython-38.pyc +0 -0
  30. lama_cleaner/model/__pycache__/base.cpython-38.pyc +0 -0
  31. lama_cleaner/model/__pycache__/ddim_sampler.cpython-38.pyc +0 -0
  32. lama_cleaner/model/__pycache__/fcf.cpython-38.pyc +0 -0
  33. lama_cleaner/model/__pycache__/lama.cpython-38.pyc +0 -0
  34. lama_cleaner/model/__pycache__/ldm.cpython-38.pyc +0 -0
  35. lama_cleaner/model/__pycache__/manga.cpython-38.pyc +0 -0
  36. lama_cleaner/model/__pycache__/mat.cpython-38.pyc +0 -0
  37. lama_cleaner/model/__pycache__/opencv2.cpython-38.pyc +0 -0
  38. lama_cleaner/model/__pycache__/paint_by_example.cpython-38.pyc +0 -0
  39. lama_cleaner/model/__pycache__/plms_sampler.cpython-38.pyc +0 -0
  40. lama_cleaner/model/__pycache__/sd.cpython-38.pyc +0 -0
  41. lama_cleaner/model/__pycache__/utils.cpython-38.pyc +0 -0
  42. lama_cleaner/model/__pycache__/zits.cpython-38.pyc +0 -0
  43. lama_cleaner/model/base.py +247 -0
  44. lama_cleaner/model/ddim_sampler.py +192 -0
  45. lama_cleaner/model/fcf.py +1212 -0
  46. lama_cleaner/model/lama.py +61 -0
  47. lama_cleaner/model/ldm.py +310 -0
  48. lama_cleaner/model/manga.py +130 -0
  49. lama_cleaner/model/mat.py +1444 -0
  50. lama_cleaner/model/opencv2.py +25 -0
Dockerfile ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8
2
+
3
+ RUN mkdir /app
4
+ RUN mkdir /.cache/
5
+ RUN mkdir /.cache/matplotlib
6
+ RUN mkdir /.cache/huggingface
7
+ RUN mkdir /.cache/huggingface/hub/
8
+ RUN mkdir /.cache/torch/
9
+ RUN mkdir /.config
10
+ RUN mkdir /.config/matplotlib/
11
+
12
+ RUN chmod -R 777 /.cache
13
+ RUN chmod -R 777 /.cache/matplotlib
14
+ RUN chmod -R 777 /.cache/huggingface/hub
15
+ RUN chmod -R 777 /.cache/torch
16
+ RUN chmod -R 777 /.config/
17
+ RUN chmod -R 777 /.config/matplotlib
18
+ RUN chmod -R 777 /app
19
+
20
+
21
+ COPY lama_cleaner ./lama_cleaner
22
+ COPY ./app.py ./app.py
23
+
24
+
25
+ COPY app/yolov8x-seg.pt /app
26
+ COPY big-lama.pt /app
27
+ # COPY clickseg_pplnet.pt /app
28
+ COPY u2net.onnx /app
29
+ COPY u2net.onnx /tmp
30
+
31
+ RUN chmod -R a+r /app/yolov8x-seg.pt
32
+ RUN chmod -R a+r /app/big-lama.pt
33
+ #RUN chmod -R a+r /app/clickseg_pplnet.pt
34
+ RUN chmod -R a+r /app/u2net.onnx
35
+ RUN chmod -R a+r /tmp/u2net.onnx
36
+
37
+
38
+ COPY ./requirements.txt ./requirements.txt
39
+ RUN pip install -r ./requirements.txt
40
+
41
+ RUN --mount=type=secret,id=SECRET,mode=0444,required=true \
42
+ git clone $(cat /run/secrets/SECRET)
43
+
44
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import imghdr
3
+ import os
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ from ultralytics import YOLO
9
+ from ultralytics.yolo.utils.ops import scale_image
10
+ import asyncio
11
+ from fastapi import FastAPI, File, UploadFile, Request, Response
12
+ from fastapi.responses import JSONResponse
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ import uvicorn
15
+ # from mangum import Mangum
16
+ from argparse import ArgumentParser
17
+
18
+ import lama_cleaner.server2 as server
19
+ from lama_cleaner.helper import (
20
+ load_img,
21
+ )
22
+
23
+ # os.environ["TRANSFORMERS_CACHE"] = "/path/to/writable/directory"
24
+
25
+ app = FastAPI()
26
+
27
+ # handler = Mangum(app)
28
+ origins = ["*"]
29
+
30
+ app.add_middleware(
31
+ CORSMiddleware,
32
+ allow_origins=origins,
33
+ allow_credentials=True,
34
+ allow_methods=["*"],
35
+ allow_headers=["*"],
36
+ )
37
+
38
+
39
+ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
40
+ """
41
+ Args:
42
+ image_numpy: numpy image
43
+ ext: image extension
44
+ Returns:
45
+ image bytes
46
+ """
47
+ data = cv2.imencode(
48
+ f".{ext}",
49
+ image_numpy,
50
+ [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
51
+ )[1].tobytes()
52
+ return data
53
+
54
+
55
+ def get_image_ext(img_bytes):
56
+ """
57
+ Args:
58
+ img_bytes: image bytes
59
+ Returns:
60
+ image extension
61
+ """
62
+ if not img_bytes:
63
+ raise ValueError("Empty input")
64
+ header = img_bytes[:32]
65
+ w = imghdr.what("", header)
66
+ if w is None:
67
+ w = "jpeg"
68
+ return w
69
+
70
+
71
+ def predict_on_image(model, img, conf, retina_masks):
72
+ """
73
+ Args:
74
+ model: YOLOv8 model
75
+ img: image (C, H, W)
76
+ conf: confidence threshold
77
+ retina_masks: use retina masks or not
78
+ Returns:
79
+ boxes: box with xyxy format, (N, 4)
80
+ masks: masks, (N, H, W)
81
+ cls: class of masks, (N, )
82
+ probs: confidence score, (N, 1)
83
+ """
84
+ with torch.no_grad():
85
+ result = model(img, conf=conf, retina_masks=retina_masks, scale=1)[0]
86
+
87
+ boxes, masks, cls, probs = None, None, None, None
88
+
89
+ if result.boxes.cls.size(0) > 0:
90
+ # detection
91
+ cls = result.boxes.cls.cpu().numpy().astype(np.int32)
92
+ probs = result.boxes.conf.cpu().numpy() # confidence score, (N, 1)
93
+ boxes = result.boxes.xyxy.cpu().numpy() # box with xyxy format, (N, 4)
94
+
95
+ # segmentation
96
+ masks = result.masks.masks.cpu().numpy() # masks, (N, H, W)
97
+ masks = np.transpose(masks, (1, 2, 0)) # masks, (H, W, N)
98
+ # rescale masks to original image
99
+ masks = scale_image(masks.shape[:2], masks, result.masks.orig_shape)
100
+ masks = np.transpose(masks, (2, 0, 1)) # masks, (N, H, W)
101
+
102
+ return boxes, masks, cls, probs
103
+
104
+
105
+ def overlay(image, mask, color, alpha, id, resize=None):
106
+ """Overlays a binary mask on an image.
107
+
108
+ Args:
109
+ image: Image to be overlayed on.
110
+ mask: Binary mask to overlay.
111
+ color: Color to use for the mask.
112
+ alpha: Opacity of the mask.
113
+ id: id of the mask
114
+ resize: Resize the image to this size. If None, no resizing is performed.
115
+
116
+ Returns:
117
+ The overlayed image.
118
+ """
119
+ color = color[::-1]
120
+ colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0)
121
+ colored_mask = np.moveaxis(colored_mask, 0, -1)
122
+ masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
123
+ image_overlay = masked.filled()
124
+
125
+ imgray = cv2.cvtColor(image_overlay, cv2.COLOR_BGR2GRAY)
126
+
127
+ contour_thickness = 8
128
+ _, thresh = cv2.threshold(imgray, 255, 255, 255)
129
+ contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
130
+ imgray = cv2.cvtColor(imgray, cv2.COLOR_GRAY2BGR)
131
+ imgray = cv2.drawContours(imgray, contours, -1, (255, 255, 255), contour_thickness)
132
+
133
+ imgray = np.where(imgray.any(-1, keepdims=True), (46, 36, 225), 0)
134
+
135
+ if resize is not None:
136
+ image = cv2.resize(image.transpose(1, 2, 0), resize)
137
+ image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize)
138
+
139
+ return imgray
140
+
141
+
142
+ async def process_mask(idx, mask_i, boxes, probs, yolo_model, blank_image, cls):
143
+ """Process the mask of the image.
144
+
145
+ Args:
146
+ idx: index of the mask
147
+ mask_i: mask of the image
148
+ boxes: box with xyxy format, (N, 4)
149
+ probs: confidence score, (N, 1)
150
+ yolo_model: YOLOv8 model
151
+ blank_image: blank image
152
+ cls: class of masks, (N, )
153
+
154
+ Returns:
155
+ dictionary_seg: dictionary of the mask of the image
156
+ """
157
+ dictionary_seg = {}
158
+ maskwith_back = overlay(blank_image, mask_i, color=(255, 155, 155), alpha=0.5, id=idx)
159
+
160
+ alpha = np.sum(maskwith_back, axis=-1) > 0
161
+ alpha = np.uint8(alpha * 255)
162
+ maskwith_back = np.dstack((maskwith_back, alpha))
163
+
164
+ imgencode = await asyncio.get_running_loop().run_in_executor(None, cv2.imencode, '.png', maskwith_back)
165
+ mask = base64.b64encode(imgencode[1]).decode('utf-8')
166
+
167
+ dictionary_seg["confi"] = f'{probs[idx] * 100:.2f}'
168
+ dictionary_seg["boxe"] = [int(item) for item in list(boxes[idx])]
169
+ dictionary_seg["mask"] = mask
170
+ dictionary_seg["cls"] = str(yolo_model.names[cls[idx]])
171
+
172
+ return dictionary_seg
173
+
174
+
175
+ @app.middleware("http")
176
+ async def check_auth_header(request: Request, call_next):
177
+ token = request.headers.get('Authorization')
178
+ if token != os.environ.get("SECRET"):
179
+ return JSONResponse(content={'error': 'Authorization header missing or incorrect.'}, status_code=403)
180
+ else:
181
+ response = await call_next(request)
182
+ return response
183
+
184
+
185
+ @app.post("/api/mask")
186
+ async def detect_mask(file: UploadFile = File()):
187
+ """
188
+ Detects masks in an image uploaded via a POST request and returns a JSON response containing the details of the detected masks.
189
+
190
+ Args:
191
+ None
192
+
193
+ Parameters:
194
+ - file: a file object containing the input image
195
+
196
+ Returns:
197
+ A JSON response containing the details of the detected masks:
198
+ - code: 200 if objects were detected, 500 if no objects were detected
199
+ - msg: a message indicating whether objects were detected or not
200
+ - data: a list of dictionaries, where each dictionary contains the following keys:
201
+ - confi: the confidence level of the detected object
202
+ - boxe: a list containing the coordinates of the bounding box of the detected object
203
+ - mask: the mask of the detected object encoded in base64
204
+ - cls: the class of the detected object
205
+
206
+ Raises:
207
+ 500: No objects detected
208
+ """
209
+ file = await file.read()
210
+
211
+ img, _ = load_img(file)
212
+
213
+ # predict by YOLOv8
214
+ boxes, masks, cls, probs = predict_on_image(yolo_model, img, conf=0.55, retina_masks=True)
215
+
216
+ if boxes is None:
217
+ return {'code': 500, 'msg': 'No objects detected'}
218
+
219
+ # overlay masks on original image
220
+ blank_image = np.zeros(img.shape, dtype=np.uint8)
221
+
222
+ data = []
223
+
224
+ coroutines = [process_mask(idx, mask_i, boxes, probs, yolo_model, blank_image, cls) for idx, mask_i in
225
+ enumerate(masks)]
226
+ results = await asyncio.gather(*coroutines)
227
+
228
+ for result in results:
229
+ data.append(result)
230
+
231
+ return {'code': 200, 'msg': "object detected", 'data': data}
232
+
233
+
234
+ @app.post("/api/lama/paint")
235
+ async def paint(img: UploadFile = File(), mask: UploadFile = File()):
236
+ """
237
+ Endpoint to process an image with a given mask using the server's process function.
238
+
239
+ Route: '/api/lama/paint'
240
+ Method: POST
241
+
242
+ Parameters:
243
+ img: The input image file (JPEG or PNG format).
244
+ mask: The mask file (JPEG or PNG format).
245
+ Returns:
246
+ A JSON object containing the processed image in base64 format under the "image" key.
247
+ """
248
+ img = await img.read()
249
+ mask = await mask.read()
250
+ return {"image": server.process(img, mask)}
251
+
252
+
253
+ @app.post("/api/remove")
254
+ async def remove(img: UploadFile = File()):
255
+ x = await img.read()
256
+ return {"image": server.remove(x)}
257
+
258
+ @app.post("/api/lama/model")
259
+ def switch_model(new_name: str):
260
+ return server.switch_model(new_name)
261
+
262
+
263
+ @app.get("/api/lama/model")
264
+ def current_model():
265
+ return server.current_model()
266
+
267
+
268
+ @app.get("/api/lama/switchmode")
269
+ def get_is_disable_model_switch():
270
+ return server.get_is_disable_model_switch()
271
+
272
+
273
+ @app.on_event("startup")
274
+ def init_data():
275
+ model_device = "cpu"
276
+ global yolo_model
277
+ # TODO Update for local development
278
+ yolo_model = YOLO('yolov8x-seg.pt')
279
+ # yolo_model = YOLO('/app/yolov8x-seg.pt')
280
+ yolo_model.to(model_device)
281
+ print(f"YOLO model yolov8x-seg.pt loaded.")
282
+ server.initModel()
283
+
284
+
285
+ def create_app(args):
286
+ """
287
+ Creates the FastAPI app and adds the endpoints.
288
+
289
+ Args:
290
+ args: The arguments.
291
+ """
292
+ uvicorn.run("app:app", host=args.host, port=args.port, reload=args.reload)
293
+
294
+
295
+ if __name__ == "__main__":
296
+ parser = ArgumentParser()
297
+ parser.add_argument('--model_name', type=str, default='lama', help='Model name')
298
+ parser.add_argument('--host', type=str, default="0.0.0.0")
299
+ parser.add_argument('--port', type=int, default=5000)
300
+ parser.add_argument('--reload', type=bool, default=True)
301
+ parser.add_argument('--model_device', type=str, default='cpu', help='Model device')
302
+ parser.add_argument('--disable_model_switch', type=bool, default=False, help='Disable model switch')
303
+ parser.add_argument('--gui', type=bool, default=False, help='Enable GUI')
304
+ parser.add_argument('--cpu_offload', type=bool, default=False, help='Enable CPU offload')
305
+ parser.add_argument('--disable_nsfw', type=bool, default=False, help='Disable NSFW')
306
+ parser.add_argument('--enable_xformers', type=bool, default=False, help='Enable xformers')
307
+ parser.add_argument('--hf_access_token', type=str, default='', help='Hugging Face access token')
308
+ parser.add_argument('--local_files_only', type=bool, default=False, help='Enable local files only')
309
+ parser.add_argument('--no_half', type=bool, default=False, help='Disable half')
310
+ parser.add_argument('--sd_cpu_textencoder', type=bool, default=False, help='Enable CPU text encoder')
311
+ parser.add_argument('--sd_disable_nsfw', type=bool, default=False, help='Disable NSFW')
312
+ parser.add_argument('--sd_enable_xformers', type=bool, default=False, help='Enable xformers')
313
+ parser.add_argument('--sd_run_local', type=bool, default=False, help='Enable local files only')
314
+
315
+ args = parser.parse_args()
316
+ create_app(args)
app/big-lama.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:344c77bbcb158f17dd143070d1e789f38a66c04202311ae3a258ef66667a9ea9
3
+ size 205669692
app/u2net.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d10d2f3bb75ae3b6d527c77944fc5e7dcd94b29809d47a739a7a728a912b491
3
+ size 175997641
app/yolov8x-seg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d63cbfa5764867c0066bedfa43cf2dcd90a412a1de44b2e238c43978a9d28ea6
3
+ size 144076467
lama_cleaner/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.simplefilter("ignore", UserWarning)
3
+
4
+ from lama_cleaner.parse_args import parse_args
5
+
6
+ def entry_point():
7
+ args = parse_args()
8
+ # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
9
+ # https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
10
+ from lama_cleaner.server import main
11
+ main(args)
lama_cleaner/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (478 Bytes). View file
 
lama_cleaner/__pycache__/const.cpython-38.pyc ADDED
Binary file (1.79 kB). View file
 
lama_cleaner/__pycache__/helper.cpython-38.pyc ADDED
Binary file (5.43 kB). View file
 
lama_cleaner/__pycache__/interactive_seg.cpython-38.pyc ADDED
Binary file (6.77 kB). View file
 
lama_cleaner/__pycache__/model_manager.cpython-38.pyc ADDED
Binary file (2.27 kB). View file
 
lama_cleaner/__pycache__/parse_args.cpython-38.pyc ADDED
Binary file (4.28 kB). View file
 
lama_cleaner/__pycache__/runtime.cpython-38.pyc ADDED
Binary file (1.35 kB). View file
 
lama_cleaner/__pycache__/schema.cpython-38.pyc ADDED
Binary file (2.42 kB). View file
 
lama_cleaner/__pycache__/server2.cpython-38.pyc ADDED
Binary file (6.31 kB). View file
 
lama_cleaner/benchmark.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import argparse
4
+ import os
5
+ import time
6
+
7
+ import numpy as np
8
+ import nvidia_smi
9
+ import psutil
10
+ import torch
11
+
12
+ from lama_cleaner.model_manager import ModelManager
13
+ from lama_cleaner.schema import Config, HDStrategy, SDSampler
14
+
15
+ try:
16
+ torch._C._jit_override_can_fuse_on_cpu(False)
17
+ torch._C._jit_override_can_fuse_on_gpu(False)
18
+ torch._C._jit_set_texpr_fuser_enabled(False)
19
+ torch._C._jit_set_nvfuser_enabled(False)
20
+ except:
21
+ pass
22
+
23
+ NUM_THREADS = str(4)
24
+
25
+ os.environ["OMP_NUM_THREADS"] = NUM_THREADS
26
+ os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
27
+ os.environ["MKL_NUM_THREADS"] = NUM_THREADS
28
+ os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
29
+ os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
30
+ if os.environ.get("CACHE_DIR"):
31
+ os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
32
+
33
+
34
+ def run_model(model, size):
35
+ # RGB
36
+ image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
37
+ mask = np.random.randint(0, 255, size).astype(np.uint8)
38
+
39
+ config = Config(
40
+ ldm_steps=2,
41
+ hd_strategy=HDStrategy.ORIGINAL,
42
+ hd_strategy_crop_margin=128,
43
+ hd_strategy_crop_trigger_size=128,
44
+ hd_strategy_resize_limit=128,
45
+ prompt="a fox is sitting on a bench",
46
+ sd_steps=5,
47
+ sd_sampler=SDSampler.ddim
48
+ )
49
+ model(image, mask, config)
50
+
51
+
52
+ def benchmark(model, times: int, empty_cache: bool):
53
+ sizes = [(512, 512)]
54
+
55
+ nvidia_smi.nvmlInit()
56
+ device_id = 0
57
+ handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id)
58
+
59
+ def format(metrics):
60
+ return f"{np.mean(metrics):.2f} ± {np.std(metrics):.2f}"
61
+
62
+ process = psutil.Process(os.getpid())
63
+ # 每个 size 给出显存和内存占用的指标
64
+ for size in sizes:
65
+ torch.cuda.empty_cache()
66
+ time_metrics = []
67
+ cpu_metrics = []
68
+ memory_metrics = []
69
+ gpu_memory_metrics = []
70
+ for _ in range(times):
71
+ start = time.time()
72
+ run_model(model, size)
73
+ torch.cuda.synchronize()
74
+
75
+ # cpu_metrics.append(process.cpu_percent())
76
+ time_metrics.append((time.time() - start) * 1000)
77
+ memory_metrics.append(process.memory_info().rss / 1024 / 1024)
78
+ gpu_memory_metrics.append(nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024)
79
+
80
+ print(f"size: {size}".center(80, "-"))
81
+ # print(f"cpu: {format(cpu_metrics)}")
82
+ print(f"latency: {format(time_metrics)}ms")
83
+ print(f"memory: {format(memory_metrics)} MB")
84
+ print(f"gpu memory: {format(gpu_memory_metrics)} MB")
85
+
86
+ nvidia_smi.nvmlShutdown()
87
+
88
+
89
+ def get_args_parser():
90
+ parser = argparse.ArgumentParser()
91
+ parser.add_argument("--name")
92
+ parser.add_argument("--device", default="cuda", type=str)
93
+ parser.add_argument("--times", default=10, type=int)
94
+ parser.add_argument("--empty-cache", action="store_true")
95
+ return parser.parse_args()
96
+
97
+
98
+ if __name__ == "__main__":
99
+ args = get_args_parser()
100
+ device = torch.device(args.device)
101
+ model = ModelManager(
102
+ name=args.name,
103
+ device=device,
104
+ sd_run_local=True,
105
+ disable_nsfw=True,
106
+ sd_cpu_textencoder=True,
107
+ hf_access_token="123"
108
+ )
109
+ benchmark(model, args.times, args.empty_cache)
lama_cleaner/const.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ DEFAULT_MODEL = "lama"
4
+ AVAILABLE_MODELS = [
5
+ "lama",
6
+ "ldm",
7
+ "zits",
8
+ "mat",
9
+ "fcf",
10
+ "sd1.5",
11
+ "cv2",
12
+ "manga",
13
+ "sd2",
14
+ "paint_by_example"
15
+ ]
16
+
17
+ AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
18
+ DEFAULT_DEVICE = 'cuda'
19
+
20
+ NO_HALF_HELP = """
21
+ Using full precision model.
22
+ If your generate result is always black or green, use this argument. (sd/paint_by_exmaple)
23
+ """
24
+
25
+ CPU_OFFLOAD_HELP = """
26
+ Offloads all models to CPU, significantly reducing vRAM usage. (sd/paint_by_example)
27
+ """
28
+
29
+ DISABLE_NSFW_HELP = """
30
+ Disable NSFW checker. (sd/paint_by_example)
31
+ """
32
+
33
+ SD_CPU_TEXTENCODER_HELP = """
34
+ Run Stable Diffusion text encoder model on CPU to save GPU memory.
35
+ """
36
+
37
+ LOCAL_FILES_ONLY_HELP = """
38
+ Use local files only, not connect to Hugging Face server. (sd/paint_by_example)
39
+ """
40
+
41
+ ENABLE_XFORMERS_HELP = """
42
+ Enable xFormers optimizations. Requires xformers package has been installed. See: https://github.com/facebookresearch/xformers (sd/paint_by_example)
43
+ """
44
+
45
+ DEFAULT_MODEL_DIR = os.getenv(
46
+ "XDG_CACHE_HOME",
47
+ os.path.join(os.path.expanduser("~"), ".cache")
48
+ )
49
+ MODEL_DIR_HELP = """
50
+ Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache
51
+ """
52
+
53
+ OUTPUT_DIR_HELP = """
54
+ Only required when --input is directory. Result images will be saved to output directory automatically.
55
+ """
56
+
57
+ INPUT_HELP = """
58
+ If input is image, it will be loaded by default.
59
+ If input is directory, you can browse and select image in file manager.
60
+ """
61
+
62
+ GUI_HELP = """
63
+ Launch Lama Cleaner as desktop app
64
+ """
65
+
66
+ NO_GUI_AUTO_CLOSE_HELP = """
67
+ Prevent backend auto close after the GUI window closed.
68
+ """
lama_cleaner/file_manager/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .file_manager import FileManager
lama_cleaner/file_manager/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (227 Bytes). View file
 
lama_cleaner/file_manager/__pycache__/file_manager.cpython-38.pyc ADDED
Binary file (7.68 kB). View file
 
lama_cleaner/file_manager/__pycache__/storage_backends.cpython-38.pyc ADDED
Binary file (2.01 kB). View file
 
lama_cleaner/file_manager/__pycache__/utils.cpython-38.pyc ADDED
Binary file (1.64 kB). View file
 
lama_cleaner/file_manager/file_manager.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/thumbnail.py
2
+ import os
3
+ import time
4
+ from datetime import datetime
5
+ from io import BytesIO
6
+ from pathlib import Path
7
+
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image, ImageOps, PngImagePlugin
11
+ from loguru import logger
12
+ from watchdog.events import FileSystemEventHandler
13
+ from watchdog.observers import Observer
14
+
15
+ LARGE_ENOUGH_NUMBER = 100
16
+ PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024 ** 2)
17
+ from .storage_backends import FilesystemStorageBackend
18
+ from .utils import aspect_to_string, generate_filename, glob_img
19
+
20
+
21
+ class FileManager(FileSystemEventHandler):
22
+ def __init__(self, app=None):
23
+ self.app = app
24
+ self._default_root_directory = "media"
25
+ self._default_thumbnail_directory = "media"
26
+ self._default_root_url = "/"
27
+ self._default_thumbnail_root_url = "/"
28
+ self._default_format = "JPEG"
29
+ self.output_dir: Path = None
30
+
31
+ if app is not None:
32
+ self.init_app(app)
33
+
34
+ self.image_dir_filenames = []
35
+ self.output_dir_filenames = []
36
+
37
+ self.image_dir_observer = None
38
+ self.output_dir_observer = None
39
+
40
+ self.modified_time = {
41
+ "image": datetime.utcnow(),
42
+ "output": datetime.utcnow(),
43
+ }
44
+
45
+ def start(self):
46
+ self.image_dir_filenames = self._media_names(self.root_directory)
47
+ self.output_dir_filenames = self._media_names(self.output_dir)
48
+
49
+ logger.info(f"Start watching image directory: {self.root_directory}")
50
+ self.image_dir_observer = Observer()
51
+ self.image_dir_observer.schedule(self, self.root_directory, recursive=False)
52
+ self.image_dir_observer.start()
53
+
54
+ logger.info(f"Start watching output directory: {self.output_dir}")
55
+ self.output_dir_observer = Observer()
56
+ self.output_dir_observer.schedule(self, self.output_dir, recursive=False)
57
+ self.output_dir_observer.start()
58
+
59
+ def on_modified(self, event):
60
+ if not os.path.isdir(event.src_path):
61
+ return
62
+ if event.src_path == str(self.root_directory):
63
+ logger.info(f"Image directory {event.src_path} modified")
64
+ self.image_dir_filenames = self._media_names(self.root_directory)
65
+ self.modified_time['image'] = datetime.utcnow()
66
+ elif event.src_path == str(self.output_dir):
67
+ logger.info(f"Output directory {event.src_path} modified")
68
+ self.output_dir_filenames = self._media_names(self.output_dir)
69
+ self.modified_time['output'] = datetime.utcnow()
70
+
71
+ def init_app(self, app):
72
+ if self.app is None:
73
+ self.app = app
74
+ app.thumbnail_instance = self
75
+
76
+ if not hasattr(app, "extensions"):
77
+ app.extensions = {}
78
+
79
+ if "thumbnail" in app.extensions:
80
+ raise RuntimeError("Flask-thumbnail extension already initialized")
81
+
82
+ app.extensions["thumbnail"] = self
83
+
84
+ app.config.setdefault("THUMBNAIL_MEDIA_ROOT", self._default_root_directory)
85
+ app.config.setdefault("THUMBNAIL_MEDIA_THUMBNAIL_ROOT", self._default_thumbnail_directory)
86
+ app.config.setdefault("THUMBNAIL_MEDIA_URL", self._default_root_url)
87
+ app.config.setdefault("THUMBNAIL_MEDIA_THUMBNAIL_URL", self._default_thumbnail_root_url)
88
+ app.config.setdefault("THUMBNAIL_DEFAULT_FORMAT", self._default_format)
89
+
90
+ def save_to_output_directory(self, image: np.ndarray, filename: str):
91
+ fp = Path(filename)
92
+ new_name = fp.stem + f"_{int(time.time())}" + fp.suffix
93
+ if image.shape[2] == 3:
94
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
95
+ elif image.shape[2] == 4:
96
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA)
97
+
98
+ cv2.imwrite(str(self.output_dir / new_name), image)
99
+
100
+ @property
101
+ def root_directory(self):
102
+ path = self.app.config["THUMBNAIL_MEDIA_ROOT"]
103
+
104
+ if os.path.isabs(path):
105
+ return path
106
+ else:
107
+ return os.path.join(self.app.root_path, path)
108
+
109
+ @property
110
+ def thumbnail_directory(self):
111
+ path = self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"]
112
+
113
+ if os.path.isabs(path):
114
+ return path
115
+ else:
116
+ return os.path.join(self.app.root_path, path)
117
+
118
+ @property
119
+ def root_url(self):
120
+ return self.app.config["THUMBNAIL_MEDIA_URL"]
121
+
122
+ @property
123
+ def media_names(self):
124
+ # return self.image_dir_filenames
125
+ return self._media_names(self.root_directory)
126
+
127
+ @property
128
+ def output_media_names(self):
129
+ return self._media_names(self.output_dir)
130
+ # return self.output_dir_filenames
131
+
132
+ @staticmethod
133
+ def _media_names(directory: Path):
134
+ names = sorted([it.name for it in glob_img(directory)])
135
+ res = []
136
+ for name in names:
137
+ path = os.path.join(directory, name)
138
+ img = Image.open(path)
139
+ res.append({"name": name, "height": img.height, "width": img.width, "ctime": os.path.getctime(path)})
140
+ return res
141
+
142
+ @property
143
+ def thumbnail_url(self):
144
+ return self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_URL"]
145
+
146
+ def get_thumbnail(self, directory: Path, original_filename: str, width, height, **options):
147
+ storage = FilesystemStorageBackend(self.app)
148
+ crop = options.get("crop", "fit")
149
+ background = options.get("background")
150
+ quality = options.get("quality", 90)
151
+
152
+ original_path, original_filename = os.path.split(original_filename)
153
+ original_filepath = os.path.join(directory, original_path, original_filename)
154
+ image = Image.open(BytesIO(storage.read(original_filepath)))
155
+
156
+ # keep ratio resize
157
+ if width is not None:
158
+ height = int(image.height * width / image.width)
159
+ else:
160
+ width = int(image.width * height / image.height)
161
+
162
+ thumbnail_size = (width, height)
163
+
164
+ thumbnail_filename = generate_filename(
165
+ original_filename, aspect_to_string(thumbnail_size), crop, background, quality
166
+ )
167
+
168
+ thumbnail_filepath = os.path.join(
169
+ self.thumbnail_directory, original_path, thumbnail_filename
170
+ )
171
+ thumbnail_url = os.path.join(self.thumbnail_url, original_path, thumbnail_filename)
172
+
173
+ if storage.exists(thumbnail_filepath):
174
+ return thumbnail_url, (width, height)
175
+
176
+ try:
177
+ image.load()
178
+ except (IOError, OSError):
179
+ self.app.logger.warning("Thumbnail not load image: %s", original_filepath)
180
+ return thumbnail_url, (width, height)
181
+
182
+ # get original image format
183
+ options["format"] = options.get("format", image.format)
184
+
185
+ image = self._create_thumbnail(image, thumbnail_size, crop, background=background)
186
+
187
+ raw_data = self.get_raw_data(image, **options)
188
+ storage.save(thumbnail_filepath, raw_data)
189
+
190
+ return thumbnail_url, (width, height)
191
+
192
+ def get_raw_data(self, image, **options):
193
+ data = {
194
+ "format": self._get_format(image, **options),
195
+ "quality": options.get("quality", 90),
196
+ }
197
+
198
+ _file = BytesIO()
199
+ image.save(_file, **data)
200
+ return _file.getvalue()
201
+
202
+ @staticmethod
203
+ def colormode(image, colormode="RGB"):
204
+ if colormode == "RGB" or colormode == "RGBA":
205
+ if image.mode == "RGBA":
206
+ return image
207
+ if image.mode == "LA":
208
+ return image.convert("RGBA")
209
+ return image.convert(colormode)
210
+
211
+ if colormode == "GRAY":
212
+ return image.convert("L")
213
+
214
+ return image.convert(colormode)
215
+
216
+ @staticmethod
217
+ def background(original_image, color=0xFF):
218
+ size = (max(original_image.size),) * 2
219
+ image = Image.new("L", size, color)
220
+ image.paste(
221
+ original_image,
222
+ tuple(map(lambda x: (x[0] - x[1]) / 2, zip(size, original_image.size))),
223
+ )
224
+
225
+ return image
226
+
227
+ def _get_format(self, image, **options):
228
+ if options.get("format"):
229
+ return options.get("format")
230
+ if image.format:
231
+ return image.format
232
+
233
+ return self.app.config["THUMBNAIL_DEFAULT_FORMAT"]
234
+
235
+ def _create_thumbnail(self, image, size, crop="fit", background=None):
236
+ try:
237
+ resample = Image.Resampling.LANCZOS
238
+ except AttributeError: # pylint: disable=raise-missing-from
239
+ resample = Image.ANTIALIAS
240
+
241
+ if crop == "fit":
242
+ image = ImageOps.fit(image, size, resample)
243
+ else:
244
+ image = image.copy()
245
+ image.thumbnail(size, resample=resample)
246
+
247
+ if background is not None:
248
+ image = self.background(image)
249
+
250
+ image = self.colormode(image)
251
+
252
+ return image
lama_cleaner/file_manager/storage_backends.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/storage_backends.py
2
+ import errno
3
+ import os
4
+ from abc import ABC, abstractmethod
5
+
6
+
7
+ class BaseStorageBackend(ABC):
8
+ def __init__(self, app=None):
9
+ self.app = app
10
+
11
+ @abstractmethod
12
+ def read(self, filepath, mode="rb", **kwargs):
13
+ raise NotImplementedError
14
+
15
+ @abstractmethod
16
+ def exists(self, filepath):
17
+ raise NotImplementedError
18
+
19
+ @abstractmethod
20
+ def save(self, filepath, data):
21
+ raise NotImplementedError
22
+
23
+
24
+ class FilesystemStorageBackend(BaseStorageBackend):
25
+ def read(self, filepath, mode="rb", **kwargs):
26
+ with open(filepath, mode) as f: # pylint: disable=unspecified-encoding
27
+ return f.read()
28
+
29
+ def exists(self, filepath):
30
+ return os.path.exists(filepath)
31
+
32
+ def save(self, filepath, data):
33
+ directory = os.path.dirname(filepath)
34
+
35
+ if not os.path.exists(directory):
36
+ try:
37
+ os.makedirs(directory)
38
+ except OSError as e:
39
+ if e.errno != errno.EEXIST:
40
+ raise
41
+
42
+ if not os.path.isdir(directory):
43
+ raise IOError("{} is not a directory".format(directory))
44
+
45
+ with open(filepath, "wb") as f:
46
+ f.write(data)
lama_cleaner/file_manager/utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py
2
+ import os
3
+ from pathlib import Path
4
+
5
+ from typing import Union
6
+
7
+
8
+ def generate_filename(original_filename, *options):
9
+ name, ext = os.path.splitext(original_filename)
10
+ for v in options:
11
+ if v:
12
+ name += "_%s" % v
13
+ name += ext
14
+
15
+ return name
16
+
17
+
18
+ def parse_size(size):
19
+ if isinstance(size, int):
20
+ # If the size parameter is a single number, assume square aspect.
21
+ return [size, size]
22
+
23
+ if isinstance(size, (tuple, list)):
24
+ if len(size) == 1:
25
+ # If single value tuple/list is provided, exand it to two elements
26
+ return size + type(size)(size)
27
+ return size
28
+
29
+ try:
30
+ thumbnail_size = [int(x) for x in size.lower().split("x", 1)]
31
+ except ValueError:
32
+ raise ValueError( # pylint: disable=raise-missing-from
33
+ "Bad thumbnail size format. Valid format is INTxINT."
34
+ )
35
+
36
+ if len(thumbnail_size) == 1:
37
+ # If the size parameter only contains a single integer, assume square aspect.
38
+ thumbnail_size.append(thumbnail_size[0])
39
+
40
+ return thumbnail_size
41
+
42
+
43
+ def aspect_to_string(size):
44
+ if isinstance(size, str):
45
+ return size
46
+
47
+ return "x".join(map(str, size))
48
+
49
+
50
+ IMG_SUFFIX = {'.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG'}
51
+
52
+
53
+ def glob_img(p: Union[Path, str], recursive: bool = False):
54
+ p = Path(p)
55
+ if p.is_file() and p.suffix in IMG_SUFFIX:
56
+ yield p
57
+ else:
58
+ if recursive:
59
+ files = Path(p).glob("**/*.*")
60
+ else:
61
+ files = Path(p).glob("*.*")
62
+
63
+ for it in files:
64
+ if it.suffix not in IMG_SUFFIX:
65
+ continue
66
+ yield it
lama_cleaner/helper.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import sys
4
+ from typing import List, Optional
5
+ from urllib.parse import urlparse
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image, ImageOps
11
+ from loguru import logger
12
+ from torch.hub import download_url_to_file, get_dir
13
+
14
+
15
+ def get_cache_path_by_url(url):
16
+ parts = urlparse(url)
17
+ hub_dir = get_dir()
18
+ model_dir = os.path.join(hub_dir, "checkpoints")
19
+ if not os.path.isdir(model_dir):
20
+ os.makedirs(model_dir)
21
+ filename = os.path.basename(parts.path)
22
+ cached_file = os.path.join(model_dir, filename)
23
+ return cached_file
24
+
25
+
26
+ def download_model(url):
27
+ cached_file = get_cache_path_by_url(url)
28
+ if not os.path.exists(cached_file):
29
+ sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
30
+ hash_prefix = None
31
+ download_url_to_file(url, cached_file, hash_prefix, progress=True)
32
+ return cached_file
33
+
34
+
35
+ def ceil_modulo(x, mod):
36
+ if x % mod == 0:
37
+ return x
38
+ return (x // mod + 1) * mod
39
+
40
+
41
+ def load_jit_model(url_or_path, device):
42
+ # if os.path.exists(url_or_path):
43
+ # model_path = url_or_path
44
+ # else:
45
+ # model_path = download_model(url_or_path)
46
+ model_path = os.getcwd()
47
+ logger.info(f"Load model from: {model_path}")
48
+ try:
49
+ model = torch.jit.load(model_path).to(device)
50
+ except:
51
+ logger.error(
52
+ f"Failed to load {model_path}, delete model and restart lama-cleaner"
53
+ )
54
+ exit(-1)
55
+ model.eval()
56
+ return model
57
+
58
+
59
+ def load_model(model: torch.nn.Module, url_or_path, device):
60
+ if os.path.exists(url_or_path):
61
+ model_path = url_or_path
62
+ else:
63
+ model_path = download_model(url_or_path)
64
+
65
+ try:
66
+ state_dict = torch.load(model_path, map_location='cpu')
67
+ model.load_state_dict(state_dict, strict=True)
68
+ model.to(device)
69
+ logger.info(f"Load model from: {model_path}")
70
+ except:
71
+ logger.error(
72
+ f"Failed to load {model_path}, delete model and restart lama-cleaner"
73
+ )
74
+ exit(-1)
75
+ model.eval()
76
+ return model
77
+
78
+
79
+ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
80
+ data = cv2.imencode(
81
+ f".{ext}",
82
+ image_numpy,
83
+ [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
84
+ )[1]
85
+ image_bytes = data.tobytes()
86
+ return image_bytes
87
+
88
+
89
+ def load_img(img_bytes, gray: bool = False):
90
+ alpha_channel = None
91
+ image = Image.open(io.BytesIO(img_bytes))
92
+ try:
93
+ image = ImageOps.exif_transpose(image)
94
+ except:
95
+ pass
96
+
97
+ if gray:
98
+ image = image.convert('L')
99
+ np_img = np.array(image)
100
+ else:
101
+ if image.mode == 'RGBA':
102
+ np_img = np.array(image)
103
+ alpha_channel = np_img[:, :, -1]
104
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
105
+ else:
106
+ image = image.convert('RGB')
107
+ np_img = np.array(image)
108
+
109
+ return np_img, alpha_channel
110
+
111
+
112
+ def norm_img(np_img):
113
+ if len(np_img.shape) == 2:
114
+ np_img = np_img[:, :, np.newaxis]
115
+ np_img = np.transpose(np_img, (2, 0, 1))
116
+ np_img = np_img.astype("float32") / 255
117
+ return np_img
118
+
119
+
120
+ def resize_max_size(
121
+ np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
122
+ ) -> np.ndarray:
123
+ # Resize image's longer size to size_limit if longer size larger than size_limit
124
+ h, w = np_img.shape[:2]
125
+ if max(h, w) > size_limit:
126
+ ratio = size_limit / max(h, w)
127
+ new_w = int(w * ratio + 0.5)
128
+ new_h = int(h * ratio + 0.5)
129
+ return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
130
+ else:
131
+ return np_img
132
+
133
+
134
+ def pad_img_to_modulo(
135
+ img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
136
+ ):
137
+ """
138
+
139
+ Args:
140
+ img: [H, W, C]
141
+ mod:
142
+ square: 是否为正方形
143
+ min_size:
144
+
145
+ Returns:
146
+
147
+ """
148
+ if len(img.shape) == 2:
149
+ img = img[:, :, np.newaxis]
150
+ height, width = img.shape[:2]
151
+ out_height = ceil_modulo(height, mod)
152
+ out_width = ceil_modulo(width, mod)
153
+
154
+ if min_size is not None:
155
+ assert min_size % mod == 0
156
+ out_width = max(min_size, out_width)
157
+ out_height = max(min_size, out_height)
158
+
159
+ if square:
160
+ max_size = max(out_height, out_width)
161
+ out_height = max_size
162
+ out_width = max_size
163
+
164
+ return np.pad(
165
+ img,
166
+ ((0, out_height - height), (0, out_width - width), (0, 0)),
167
+ mode="symmetric",
168
+ )
169
+
170
+
171
+ def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
172
+ """
173
+ Args:
174
+ mask: (h, w, 1) 0~255
175
+
176
+ Returns:
177
+
178
+ """
179
+ height, width = mask.shape[:2]
180
+ _, thresh = cv2.threshold(mask, 127, 255, 0)
181
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
182
+
183
+ boxes = []
184
+ for cnt in contours:
185
+ x, y, w, h = cv2.boundingRect(cnt)
186
+ box = np.array([x, y, x + w, y + h]).astype(int)
187
+
188
+ box[::2] = np.clip(box[::2], 0, width)
189
+ box[1::2] = np.clip(box[1::2], 0, height)
190
+ boxes.append(box)
191
+
192
+ return boxes
193
+
194
+
195
+ def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
196
+ """
197
+ Args:
198
+ mask: (h, w) 0~255
199
+
200
+ Returns:
201
+
202
+ """
203
+ _, thresh = cv2.threshold(mask, 127, 255, 0)
204
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
205
+
206
+ max_area = 0
207
+ max_index = -1
208
+ for i, cnt in enumerate(contours):
209
+ area = cv2.contourArea(cnt)
210
+ if area > max_area:
211
+ max_area = area
212
+ max_index = i
213
+
214
+ if max_index != -1:
215
+ new_mask = np.zeros_like(mask)
216
+ return cv2.drawContours(new_mask, contours, max_index, 255, -1)
217
+ else:
218
+ return mask
lama_cleaner/interactive_seg.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple, List
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from loguru import logger
9
+ from pydantic import BaseModel
10
+
11
+ from lama_cleaner.helper import load_jit_model
12
+
13
+
14
+ class Click(BaseModel):
15
+ # [y, x]
16
+ coords: Tuple[float, float]
17
+ is_positive: bool
18
+ indx: int
19
+
20
+ @property
21
+ def coords_and_indx(self):
22
+ return (*self.coords, self.indx)
23
+
24
+ def scale(self, x_ratio: float, y_ratio: float) -> 'Click':
25
+ return Click(
26
+ coords=(self.coords[0] * x_ratio, self.coords[1] * y_ratio),
27
+ is_positive=self.is_positive,
28
+ indx=self.indx
29
+ )
30
+
31
+
32
+ class ResizeTrans:
33
+ def __init__(self, size=480):
34
+ super().__init__()
35
+ self.crop_height = size
36
+ self.crop_width = size
37
+
38
+ def transform(self, image_nd, clicks_lists):
39
+ assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
40
+ image_height, image_width = image_nd.shape[2:4]
41
+ self.image_height = image_height
42
+ self.image_width = image_width
43
+ image_nd_r = F.interpolate(image_nd, (self.crop_height, self.crop_width), mode='bilinear', align_corners=True)
44
+
45
+ y_ratio = self.crop_height / image_height
46
+ x_ratio = self.crop_width / image_width
47
+
48
+ clicks_lists_resized = []
49
+ for clicks_list in clicks_lists:
50
+ clicks_list_resized = [click.scale(y_ratio, x_ratio) for click in clicks_list]
51
+ clicks_lists_resized.append(clicks_list_resized)
52
+
53
+ return image_nd_r, clicks_lists_resized
54
+
55
+ def inv_transform(self, prob_map):
56
+ new_prob_map = F.interpolate(prob_map, (self.image_height, self.image_width), mode='bilinear',
57
+ align_corners=True)
58
+
59
+ return new_prob_map
60
+
61
+
62
+ class ISPredictor(object):
63
+ def __init__(
64
+ self,
65
+ model,
66
+ device,
67
+ open_kernel_size: int,
68
+ dilate_kernel_size: int,
69
+ net_clicks_limit=None,
70
+ zoom_in=None,
71
+ infer_size=384,
72
+ ):
73
+ self.model = model
74
+ self.open_kernel_size = open_kernel_size
75
+ self.dilate_kernel_size = dilate_kernel_size
76
+ self.net_clicks_limit = net_clicks_limit
77
+ self.device = device
78
+ self.zoom_in = zoom_in
79
+ self.infer_size = infer_size
80
+
81
+ # self.transforms = [zoom_in] if zoom_in is not None else []
82
+
83
+ def __call__(self, input_image: torch.Tensor, clicks: List[Click], prev_mask):
84
+ """
85
+
86
+ Args:
87
+ input_image: [1, 3, H, W] [0~1]
88
+ clicks: List[Click]
89
+ prev_mask: [1, 1, H, W]
90
+
91
+ Returns:
92
+
93
+ """
94
+ transforms = [ResizeTrans(self.infer_size)]
95
+ input_image = torch.cat((input_image, prev_mask), dim=1)
96
+
97
+ # image_nd resized to infer_size
98
+ for t in transforms:
99
+ image_nd, clicks_lists = t.transform(input_image, [clicks])
100
+
101
+ # image_nd.shape = [1, 4, 256, 256]
102
+ # points_nd.sha[e = [1, 2, 3]
103
+ # clicks_lists[0][0] Click 类
104
+ points_nd = self.get_points_nd(clicks_lists)
105
+ pred_logits = self.model(image_nd, points_nd)
106
+ pred = torch.sigmoid(pred_logits)
107
+ pred = self.post_process(pred)
108
+
109
+ prediction = F.interpolate(pred, mode='bilinear', align_corners=True,
110
+ size=image_nd.size()[2:])
111
+
112
+ for t in reversed(transforms):
113
+ prediction = t.inv_transform(prediction)
114
+
115
+ # if self.zoom_in is not None and self.zoom_in.check_possible_recalculation():
116
+ # return self.get_prediction(clicker)
117
+
118
+ return prediction.cpu().numpy()[0, 0]
119
+
120
+ def post_process(self, pred: torch.Tensor) -> torch.Tensor:
121
+ pred_mask = pred.cpu().numpy()[0][0]
122
+ # morph_open to remove small noise
123
+ kernel_size = self.open_kernel_size
124
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
125
+ pred_mask = cv2.morphologyEx(pred_mask, cv2.MORPH_OPEN, kernel, iterations=1)
126
+
127
+ # Why dilate: make region slightly larger to avoid missing some pixels, this generally works better
128
+ dilate_kernel_size = self.dilate_kernel_size
129
+ if dilate_kernel_size > 1:
130
+ kernel = cv2.getStructuringElement(cv2.MORPH_DILATE, (dilate_kernel_size, dilate_kernel_size))
131
+ pred_mask = cv2.dilate(pred_mask, kernel, 1)
132
+ return torch.from_numpy(pred_mask).unsqueeze(0).unsqueeze(0)
133
+
134
+ def get_points_nd(self, clicks_lists):
135
+ total_clicks = []
136
+ num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
137
+ num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
138
+ num_max_points = max(num_pos_clicks + num_neg_clicks)
139
+ if self.net_clicks_limit is not None:
140
+ num_max_points = min(self.net_clicks_limit, num_max_points)
141
+ num_max_points = max(1, num_max_points)
142
+
143
+ for clicks_list in clicks_lists:
144
+ clicks_list = clicks_list[:self.net_clicks_limit]
145
+ pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive]
146
+ pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)]
147
+
148
+ neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive]
149
+ neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)]
150
+ total_clicks.append(pos_clicks + neg_clicks)
151
+
152
+ return torch.tensor(total_clicks, device=self.device)
153
+
154
+
155
+ INTERACTIVE_SEG_MODEL_URL = os.environ.get(
156
+ "INTERACTIVE_SEG_MODEL_URL",
157
+ "https://github.com/Sanster/models/releases/download/clickseg_pplnet/clickseg_pplnet.pt",
158
+ )
159
+
160
+
161
+ class InteractiveSeg:
162
+ def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3):
163
+ device = torch.device('cpu')
164
+ model = load_jit_model(INTERACTIVE_SEG_MODEL_URL, device).eval()
165
+ self.predictor = ISPredictor(model, device,
166
+ infer_size=infer_size,
167
+ open_kernel_size=open_kernel_size,
168
+ dilate_kernel_size=dilate_kernel_size)
169
+
170
+ def __call__(self, image, clicks, prev_mask=None):
171
+ """
172
+
173
+ Args:
174
+ image: [H,W,C] RGB
175
+ clicks:
176
+
177
+ Returns:
178
+
179
+ """
180
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
181
+ image = torch.from_numpy((image / 255).transpose(2, 0, 1)).unsqueeze(0).float()
182
+ if prev_mask is None:
183
+ mask = torch.zeros_like(image[:, :1, :, :])
184
+ else:
185
+ logger.info('InteractiveSeg run with prev_mask')
186
+ mask = torch.from_numpy(prev_mask / 255).unsqueeze(0).unsqueeze(0).float()
187
+
188
+ pred_probs = self.predictor(image, clicks, mask)
189
+ pred_mask = pred_probs > 0.5
190
+ pred_mask = (pred_mask * 255).astype(np.uint8)
191
+
192
+ # Find largest contour
193
+ # pred_mask = only_keep_largest_contour(pred_mask)
194
+ # To simplify frontend process, add mask brush color here
195
+ fg = pred_mask == 255
196
+ bg = pred_mask != 255
197
+ pred_mask = cv2.cvtColor(pred_mask, cv2.COLOR_GRAY2BGRA)
198
+ # frontend brush color "ffcc00bb"
199
+ pred_mask[bg] = 0
200
+ pred_mask[fg] = [255, 203, 0, int(255 * 0.73)]
201
+ pred_mask = cv2.cvtColor(pred_mask, cv2.COLOR_BGRA2RGBA)
202
+ return pred_mask
lama_cleaner/model/__init__.py ADDED
File without changes
lama_cleaner/model/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (172 Bytes). View file
 
lama_cleaner/model/__pycache__/base.cpython-38.pyc ADDED
Binary file (6.7 kB). View file
 
lama_cleaner/model/__pycache__/ddim_sampler.cpython-38.pyc ADDED
Binary file (4.74 kB). View file
 
lama_cleaner/model/__pycache__/fcf.cpython-38.pyc ADDED
Binary file (33.4 kB). View file
 
lama_cleaner/model/__pycache__/lama.cpython-38.pyc ADDED
Binary file (2.12 kB). View file
 
lama_cleaner/model/__pycache__/ldm.cpython-38.pyc ADDED
Binary file (7.79 kB). View file
 
lama_cleaner/model/__pycache__/manga.cpython-38.pyc ADDED
Binary file (2.72 kB). View file
 
lama_cleaner/model/__pycache__/mat.cpython-38.pyc ADDED
Binary file (38.8 kB). View file
 
lama_cleaner/model/__pycache__/opencv2.cpython-38.pyc ADDED
Binary file (1.13 kB). View file
 
lama_cleaner/model/__pycache__/paint_by_example.cpython-38.pyc ADDED
Binary file (4.25 kB). View file
 
lama_cleaner/model/__pycache__/plms_sampler.cpython-38.pyc ADDED
Binary file (7.09 kB). View file
 
lama_cleaner/model/__pycache__/sd.cpython-38.pyc ADDED
Binary file (6.26 kB). View file
 
lama_cleaner/model/__pycache__/utils.cpython-38.pyc ADDED
Binary file (26.3 kB). View file
 
lama_cleaner/model/__pycache__/zits.cpython-38.pyc ADDED
Binary file (10.5 kB). View file
 
lama_cleaner/model/base.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Optional
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from loguru import logger
8
+
9
+ from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo
10
+ from lama_cleaner.schema import Config, HDStrategy
11
+
12
+
13
+ class InpaintModel:
14
+ min_size: Optional[int] = None
15
+ pad_mod = 8
16
+ pad_to_square = False
17
+
18
+ def __init__(self, device, **kwargs):
19
+ """
20
+
21
+ Args:
22
+ device:
23
+ """
24
+ self.device = device
25
+ self.init_model(device, **kwargs)
26
+
27
+ @abc.abstractmethod
28
+ def init_model(self, device, **kwargs):
29
+ ...
30
+
31
+ @staticmethod
32
+ @abc.abstractmethod
33
+ def is_downloaded() -> bool:
34
+ ...
35
+
36
+ @abc.abstractmethod
37
+ def forward(self, image, mask, config: Config):
38
+ """Input images and output images have same size
39
+ images: [H, W, C] RGB
40
+ masks: [H, W, 1] 255 为 masks 区域
41
+ return: BGR IMAGE
42
+ """
43
+ ...
44
+
45
+ def _pad_forward(self, image, mask, config: Config):
46
+ origin_height, origin_width = image.shape[:2]
47
+ pad_image = pad_img_to_modulo(
48
+ image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
49
+ )
50
+ pad_mask = pad_img_to_modulo(
51
+ mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
52
+ )
53
+
54
+ logger.info(f"final forward pad size: {pad_image.shape}")
55
+
56
+ result = self.forward(pad_image, pad_mask, config)
57
+ result = result[0:origin_height, 0:origin_width, :]
58
+
59
+ result, image, mask = self.forward_post_process(result, image, mask, config)
60
+
61
+ mask = mask[:, :, np.newaxis]
62
+ result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
63
+ return result
64
+
65
+ def forward_post_process(self, result, image, mask, config):
66
+ return result, image, mask
67
+
68
+ @torch.no_grad()
69
+ def __call__(self, image, mask, config: Config):
70
+ """
71
+ images: [H, W, C] RGB, not normalized
72
+ masks: [H, W]
73
+ return: BGR IMAGE
74
+ """
75
+ inpaint_result = None
76
+ logger.info(f"hd_strategy: {config.hd_strategy}")
77
+ if config.hd_strategy == HDStrategy.CROP:
78
+ if max(image.shape) > config.hd_strategy_crop_trigger_size:
79
+ logger.info(f"Run crop strategy")
80
+ boxes = boxes_from_mask(mask)
81
+ crop_result = []
82
+ for box in boxes:
83
+ crop_image, crop_box = self._run_box(image, mask, box, config)
84
+ crop_result.append((crop_image, crop_box))
85
+
86
+ inpaint_result = image[:, :, ::-1]
87
+ for crop_image, crop_box in crop_result:
88
+ x1, y1, x2, y2 = crop_box
89
+ inpaint_result[y1:y2, x1:x2, :] = crop_image
90
+
91
+ elif config.hd_strategy == HDStrategy.RESIZE:
92
+ if max(image.shape) > config.hd_strategy_resize_limit:
93
+ origin_size = image.shape[:2]
94
+ downsize_image = resize_max_size(
95
+ image, size_limit=config.hd_strategy_resize_limit
96
+ )
97
+ downsize_mask = resize_max_size(
98
+ mask, size_limit=config.hd_strategy_resize_limit
99
+ )
100
+
101
+ logger.info(
102
+ f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}"
103
+ )
104
+ inpaint_result = self._pad_forward(
105
+ downsize_image, downsize_mask, config
106
+ )
107
+
108
+ # only paste masked area result
109
+ inpaint_result = cv2.resize(
110
+ inpaint_result,
111
+ (origin_size[1], origin_size[0]),
112
+ interpolation=cv2.INTER_CUBIC,
113
+ )
114
+ original_pixel_indices = mask < 127
115
+ inpaint_result[original_pixel_indices] = image[:, :, ::-1][
116
+ original_pixel_indices
117
+ ]
118
+
119
+ if inpaint_result is None:
120
+ inpaint_result = self._pad_forward(image, mask, config)
121
+
122
+ return inpaint_result
123
+
124
+ def _crop_box(self, image, mask, box, config: Config):
125
+ """
126
+
127
+ Args:
128
+ image: [H, W, C] RGB
129
+ mask: [H, W, 1]
130
+ box: [left,top,right,bottom]
131
+
132
+ Returns:
133
+ BGR IMAGE, (l, r, r, b)
134
+ """
135
+ box_h = box[3] - box[1]
136
+ box_w = box[2] - box[0]
137
+ cx = (box[0] + box[2]) // 2
138
+ cy = (box[1] + box[3]) // 2
139
+ img_h, img_w = image.shape[:2]
140
+
141
+ w = box_w + config.hd_strategy_crop_margin * 2
142
+ h = box_h + config.hd_strategy_crop_margin * 2
143
+
144
+ _l = cx - w // 2
145
+ _r = cx + w // 2
146
+ _t = cy - h // 2
147
+ _b = cy + h // 2
148
+
149
+ l = max(_l, 0)
150
+ r = min(_r, img_w)
151
+ t = max(_t, 0)
152
+ b = min(_b, img_h)
153
+
154
+ # try to get more context when crop around image edge
155
+ if _l < 0:
156
+ r += abs(_l)
157
+ if _r > img_w:
158
+ l -= _r - img_w
159
+ if _t < 0:
160
+ b += abs(_t)
161
+ if _b > img_h:
162
+ t -= _b - img_h
163
+
164
+ l = max(l, 0)
165
+ r = min(r, img_w)
166
+ t = max(t, 0)
167
+ b = min(b, img_h)
168
+
169
+ crop_img = image[t:b, l:r, :]
170
+ crop_mask = mask[t:b, l:r]
171
+
172
+ logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
173
+
174
+ return crop_img, crop_mask, [l, t, r, b]
175
+
176
+ def _calculate_cdf(self, histogram):
177
+ cdf = histogram.cumsum()
178
+ normalized_cdf = cdf / float(cdf.max())
179
+ return normalized_cdf
180
+
181
+ def _calculate_lookup(self, source_cdf, reference_cdf):
182
+ lookup_table = np.zeros(256)
183
+ lookup_val = 0
184
+ for source_index, source_val in enumerate(source_cdf):
185
+ for reference_index, reference_val in enumerate(reference_cdf):
186
+ if reference_val >= source_val:
187
+ lookup_val = reference_index
188
+ break
189
+ lookup_table[source_index] = lookup_val
190
+ return lookup_table
191
+
192
+ def _match_histograms(self, source, reference, mask):
193
+ transformed_channels = []
194
+ for channel in range(source.shape[-1]):
195
+ source_channel = source[:, :, channel]
196
+ reference_channel = reference[:, :, channel]
197
+
198
+ # only calculate histograms for non-masked parts
199
+ source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256])
200
+ reference_histogram, _ = np.histogram(reference_channel[mask == 0], 256, [0, 256])
201
+
202
+ source_cdf = self._calculate_cdf(source_histogram)
203
+ reference_cdf = self._calculate_cdf(reference_histogram)
204
+
205
+ lookup = self._calculate_lookup(source_cdf, reference_cdf)
206
+
207
+ transformed_channels.append(cv2.LUT(source_channel, lookup))
208
+
209
+ result = cv2.merge(transformed_channels)
210
+ result = cv2.convertScaleAbs(result)
211
+
212
+ return result
213
+
214
+ def _apply_cropper(self, image, mask, config: Config):
215
+ img_h, img_w = image.shape[:2]
216
+ l, t, w, h = (
217
+ config.croper_x,
218
+ config.croper_y,
219
+ config.croper_width,
220
+ config.croper_height,
221
+ )
222
+ r = l + w
223
+ b = t + h
224
+
225
+ l = max(l, 0)
226
+ r = min(r, img_w)
227
+ t = max(t, 0)
228
+ b = min(b, img_h)
229
+
230
+ crop_img = image[t:b, l:r, :]
231
+ crop_mask = mask[t:b, l:r]
232
+ return crop_img, crop_mask, (l, t, r, b)
233
+
234
+ def _run_box(self, image, mask, box, config: Config):
235
+ """
236
+
237
+ Args:
238
+ image: [H, W, C] RGB
239
+ mask: [H, W, 1]
240
+ box: [left,top,right,bottom]
241
+
242
+ Returns:
243
+ BGR IMAGE
244
+ """
245
+ crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
246
+
247
+ return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
lama_cleaner/model/ddim_sampler.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from loguru import logger
4
+ from tqdm import tqdm
5
+
6
+ from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
7
+
8
+
9
+ class DDIMSampler(object):
10
+ def __init__(self, model, schedule="linear"):
11
+ super().__init__()
12
+ self.model = model
13
+ self.ddpm_num_timesteps = model.num_timesteps
14
+ self.schedule = schedule
15
+
16
+ def register_buffer(self, name, attr):
17
+ setattr(self, name, attr)
18
+
19
+ def make_schedule(
20
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
21
+ ):
22
+ self.ddim_timesteps = make_ddim_timesteps(
23
+ ddim_discr_method=ddim_discretize,
24
+ num_ddim_timesteps=ddim_num_steps,
25
+ # array([1])
26
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
27
+ verbose=verbose,
28
+ )
29
+ alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
30
+ assert (
31
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
32
+ ), "alphas have to be defined for each timestep"
33
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
34
+
35
+ self.register_buffer("betas", to_torch(self.model.betas))
36
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
37
+ self.register_buffer(
38
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
39
+ )
40
+
41
+ # calculations for diffusion q(x_t | x_{t-1}) and others
42
+ self.register_buffer(
43
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
44
+ )
45
+ self.register_buffer(
46
+ "sqrt_one_minus_alphas_cumprod",
47
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
48
+ )
49
+ self.register_buffer(
50
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
51
+ )
52
+ self.register_buffer(
53
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
54
+ )
55
+ self.register_buffer(
56
+ "sqrt_recipm1_alphas_cumprod",
57
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
58
+ )
59
+
60
+ # ddim sampling parameters
61
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
62
+ alphacums=alphas_cumprod.cpu(),
63
+ ddim_timesteps=self.ddim_timesteps,
64
+ eta=ddim_eta,
65
+ verbose=verbose,
66
+ )
67
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
68
+ self.register_buffer("ddim_alphas", ddim_alphas)
69
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
70
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
71
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
72
+ (1 - self.alphas_cumprod_prev)
73
+ / (1 - self.alphas_cumprod)
74
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
75
+ )
76
+ self.register_buffer(
77
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
78
+ )
79
+
80
+ @torch.no_grad()
81
+ def sample(self, steps, conditioning, batch_size, shape):
82
+ self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
83
+ # sampling
84
+ C, H, W = shape
85
+ size = (batch_size, C, H, W)
86
+
87
+ # samples: 1,3,128,128
88
+ return self.ddim_sampling(
89
+ conditioning,
90
+ size,
91
+ quantize_denoised=False,
92
+ ddim_use_original_steps=False,
93
+ noise_dropout=0,
94
+ temperature=1.0,
95
+ )
96
+
97
+ @torch.no_grad()
98
+ def ddim_sampling(
99
+ self,
100
+ cond,
101
+ shape,
102
+ ddim_use_original_steps=False,
103
+ quantize_denoised=False,
104
+ temperature=1.0,
105
+ noise_dropout=0.0,
106
+ ):
107
+ device = self.model.betas.device
108
+ b = shape[0]
109
+ img = torch.randn(shape, device=device, dtype=cond.dtype)
110
+ timesteps = (
111
+ self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
112
+ )
113
+
114
+ time_range = (
115
+ reversed(range(0, timesteps))
116
+ if ddim_use_original_steps
117
+ else np.flip(timesteps)
118
+ )
119
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
120
+ logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
121
+
122
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
123
+
124
+ for i, step in enumerate(iterator):
125
+ index = total_steps - i - 1
126
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
127
+
128
+ outs = self.p_sample_ddim(
129
+ img,
130
+ cond,
131
+ ts,
132
+ index=index,
133
+ use_original_steps=ddim_use_original_steps,
134
+ quantize_denoised=quantize_denoised,
135
+ temperature=temperature,
136
+ noise_dropout=noise_dropout,
137
+ )
138
+ img, _ = outs
139
+
140
+ return img
141
+
142
+ @torch.no_grad()
143
+ def p_sample_ddim(
144
+ self,
145
+ x,
146
+ c,
147
+ t,
148
+ index,
149
+ repeat_noise=False,
150
+ use_original_steps=False,
151
+ quantize_denoised=False,
152
+ temperature=1.0,
153
+ noise_dropout=0.0,
154
+ ):
155
+ b, *_, device = *x.shape, x.device
156
+ e_t = self.model.apply_model(x, t, c)
157
+
158
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
159
+ alphas_prev = (
160
+ self.model.alphas_cumprod_prev
161
+ if use_original_steps
162
+ else self.ddim_alphas_prev
163
+ )
164
+ sqrt_one_minus_alphas = (
165
+ self.model.sqrt_one_minus_alphas_cumprod
166
+ if use_original_steps
167
+ else self.ddim_sqrt_one_minus_alphas
168
+ )
169
+ sigmas = (
170
+ self.model.ddim_sigmas_for_original_num_steps
171
+ if use_original_steps
172
+ else self.ddim_sigmas
173
+ )
174
+ # select parameters corresponding to the currently considered timestep
175
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
176
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
177
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
178
+ sqrt_one_minus_at = torch.full(
179
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
180
+ )
181
+
182
+ # current prediction for x_0
183
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
184
+ if quantize_denoised: # 没用
185
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
186
+ # direction pointing to x_t
187
+ dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
188
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
189
+ if noise_dropout > 0.0: # 没用
190
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
191
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
192
+ return x_prev, pred_x0
lama_cleaner/model/fcf.py ADDED
@@ -0,0 +1,1212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import torch.fft as fft
8
+ import torch.nn.functional as F
9
+ from torch import conv2d, nn
10
+
11
+ from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img, boxes_from_mask, resize_max_size
12
+ from lama_cleaner.model.base import InpaintModel
13
+ from lama_cleaner.model.utils import setup_filter, _parse_scaling, _parse_padding, Conv2dLayer, FullyConnectedLayer, \
14
+ MinibatchStdLayer, activation_funcs, conv2d_resample, bias_act, upsample2d, normalize_2nd_moment, downsample2d
15
+ from lama_cleaner.schema import Config
16
+
17
+
18
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
19
+ assert isinstance(x, torch.Tensor)
20
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
21
+
22
+
23
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
24
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
25
+ """
26
+ # Validate arguments.
27
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
28
+ if f is None:
29
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
30
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
31
+ assert f.dtype == torch.float32 and not f.requires_grad
32
+ batch_size, num_channels, in_height, in_width = x.shape
33
+ upx, upy = _parse_scaling(up)
34
+ downx, downy = _parse_scaling(down)
35
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
36
+
37
+ # Upsample by inserting zeros.
38
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
39
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
40
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
41
+
42
+ # Pad or crop.
43
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
44
+ x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)]
45
+
46
+ # Setup filter.
47
+ f = f * (gain ** (f.ndim / 2))
48
+ f = f.to(x.dtype)
49
+ if not flip_filter:
50
+ f = f.flip(list(range(f.ndim)))
51
+
52
+ # Convolve with the filter.
53
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
54
+ if f.ndim == 4:
55
+ x = conv2d(input=x, weight=f, groups=num_channels)
56
+ else:
57
+ x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
58
+ x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
59
+
60
+ # Downsample by throwing away pixels.
61
+ x = x[:, :, ::downy, ::downx]
62
+ return x
63
+
64
+
65
+ class EncoderEpilogue(torch.nn.Module):
66
+ def __init__(self,
67
+ in_channels, # Number of input channels.
68
+ cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
69
+ z_dim, # Output Latent (Z) dimensionality.
70
+ resolution, # Resolution of this block.
71
+ img_channels, # Number of input color channels.
72
+ architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'.
73
+ mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
74
+ mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
75
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
76
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
77
+ ):
78
+ assert architecture in ['orig', 'skip', 'resnet']
79
+ super().__init__()
80
+ self.in_channels = in_channels
81
+ self.cmap_dim = cmap_dim
82
+ self.resolution = resolution
83
+ self.img_channels = img_channels
84
+ self.architecture = architecture
85
+
86
+ if architecture == 'skip':
87
+ self.fromrgb = Conv2dLayer(self.img_channels, in_channels, kernel_size=1, activation=activation)
88
+ self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size,
89
+ num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
90
+ self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation,
91
+ conv_clamp=conv_clamp)
92
+ self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), z_dim, activation=activation)
93
+ self.dropout = torch.nn.Dropout(p=0.5)
94
+
95
+ def forward(self, x, cmap, force_fp32=False):
96
+ _ = force_fp32 # unused
97
+ dtype = torch.float32
98
+ memory_format = torch.contiguous_format
99
+
100
+ # FromRGB.
101
+ x = x.to(dtype=dtype, memory_format=memory_format)
102
+
103
+ # Main layers.
104
+ if self.mbstd is not None:
105
+ x = self.mbstd(x)
106
+ const_e = self.conv(x)
107
+ x = self.fc(const_e.flatten(1))
108
+ x = self.dropout(x)
109
+
110
+ # Conditioning.
111
+ if self.cmap_dim > 0:
112
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
113
+
114
+ assert x.dtype == dtype
115
+ return x, const_e
116
+
117
+
118
+ class EncoderBlock(torch.nn.Module):
119
+ def __init__(self,
120
+ in_channels, # Number of input channels, 0 = first block.
121
+ tmp_channels, # Number of intermediate channels.
122
+ out_channels, # Number of output channels.
123
+ resolution, # Resolution of this block.
124
+ img_channels, # Number of input color channels.
125
+ first_layer_idx, # Index of the first layer.
126
+ architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
127
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
128
+ resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
129
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
130
+ use_fp16=False, # Use FP16 for this block?
131
+ fp16_channels_last=False, # Use channels-last memory format with FP16?
132
+ freeze_layers=0, # Freeze-D: Number of layers to freeze.
133
+ ):
134
+ assert in_channels in [0, tmp_channels]
135
+ assert architecture in ['orig', 'skip', 'resnet']
136
+ super().__init__()
137
+ self.in_channels = in_channels
138
+ self.resolution = resolution
139
+ self.img_channels = img_channels + 1
140
+ self.first_layer_idx = first_layer_idx
141
+ self.architecture = architecture
142
+ self.use_fp16 = use_fp16
143
+ self.channels_last = (use_fp16 and fp16_channels_last)
144
+ self.register_buffer('resample_filter', setup_filter(resample_filter))
145
+
146
+ self.num_layers = 0
147
+
148
+ def trainable_gen():
149
+ while True:
150
+ layer_idx = self.first_layer_idx + self.num_layers
151
+ trainable = (layer_idx >= freeze_layers)
152
+ self.num_layers += 1
153
+ yield trainable
154
+
155
+ trainable_iter = trainable_gen()
156
+
157
+ if in_channels == 0:
158
+ self.fromrgb = Conv2dLayer(self.img_channels, tmp_channels, kernel_size=1, activation=activation,
159
+ trainable=next(trainable_iter), conv_clamp=conv_clamp,
160
+ channels_last=self.channels_last)
161
+
162
+ self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
163
+ trainable=next(trainable_iter), conv_clamp=conv_clamp,
164
+ channels_last=self.channels_last)
165
+
166
+ self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
167
+ trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp,
168
+ channels_last=self.channels_last)
169
+
170
+ if architecture == 'resnet':
171
+ self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
172
+ trainable=next(trainable_iter), resample_filter=resample_filter,
173
+ channels_last=self.channels_last)
174
+
175
+ def forward(self, x, img, force_fp32=False):
176
+ # dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
177
+ dtype = torch.float32
178
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
179
+
180
+ # Input.
181
+ if x is not None:
182
+ x = x.to(dtype=dtype, memory_format=memory_format)
183
+
184
+ # FromRGB.
185
+ if self.in_channels == 0:
186
+ img = img.to(dtype=dtype, memory_format=memory_format)
187
+ y = self.fromrgb(img)
188
+ x = x + y if x is not None else y
189
+ img = downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
190
+
191
+ # Main layers.
192
+ if self.architecture == 'resnet':
193
+ y = self.skip(x, gain=np.sqrt(0.5))
194
+ x = self.conv0(x)
195
+ feat = x.clone()
196
+ x = self.conv1(x, gain=np.sqrt(0.5))
197
+ x = y.add_(x)
198
+ else:
199
+ x = self.conv0(x)
200
+ feat = x.clone()
201
+ x = self.conv1(x)
202
+
203
+ assert x.dtype == dtype
204
+ return x, img, feat
205
+
206
+
207
+ class EncoderNetwork(torch.nn.Module):
208
+ def __init__(self,
209
+ c_dim, # Conditioning label (C) dimensionality.
210
+ z_dim, # Input latent (Z) dimensionality.
211
+ img_resolution, # Input resolution.
212
+ img_channels, # Number of input color channels.
213
+ architecture='orig', # Architecture: 'orig', 'skip', 'resnet'.
214
+ channel_base=16384, # Overall multiplier for the number of channels.
215
+ channel_max=512, # Maximum number of channels in any layer.
216
+ num_fp16_res=0, # Use FP16 for the N highest resolutions.
217
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
218
+ cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
219
+ block_kwargs={}, # Arguments for DiscriminatorBlock.
220
+ mapping_kwargs={}, # Arguments for MappingNetwork.
221
+ epilogue_kwargs={}, # Arguments for EncoderEpilogue.
222
+ ):
223
+ super().__init__()
224
+ self.c_dim = c_dim
225
+ self.z_dim = z_dim
226
+ self.img_resolution = img_resolution
227
+ self.img_resolution_log2 = int(np.log2(img_resolution))
228
+ self.img_channels = img_channels
229
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
230
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
231
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
232
+
233
+ if cmap_dim is None:
234
+ cmap_dim = channels_dict[4]
235
+ if c_dim == 0:
236
+ cmap_dim = 0
237
+
238
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
239
+ cur_layer_idx = 0
240
+ for res in self.block_resolutions:
241
+ in_channels = channels_dict[res] if res < img_resolution else 0
242
+ tmp_channels = channels_dict[res]
243
+ out_channels = channels_dict[res // 2]
244
+ use_fp16 = (res >= fp16_resolution)
245
+ use_fp16 = False
246
+ block = EncoderBlock(in_channels, tmp_channels, out_channels, resolution=res,
247
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
248
+ setattr(self, f'b{res}', block)
249
+ cur_layer_idx += block.num_layers
250
+ if c_dim > 0:
251
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None,
252
+ **mapping_kwargs)
253
+ self.b4 = EncoderEpilogue(channels_dict[4], cmap_dim=cmap_dim, z_dim=z_dim * 2, resolution=4, **epilogue_kwargs,
254
+ **common_kwargs)
255
+
256
+ def forward(self, img, c, **block_kwargs):
257
+ x = None
258
+ feats = {}
259
+ for res in self.block_resolutions:
260
+ block = getattr(self, f'b{res}')
261
+ x, img, feat = block(x, img, **block_kwargs)
262
+ feats[res] = feat
263
+
264
+ cmap = None
265
+ if self.c_dim > 0:
266
+ cmap = self.mapping(None, c)
267
+ x, const_e = self.b4(x, cmap)
268
+ feats[4] = const_e
269
+
270
+ B, _ = x.shape
271
+ z = torch.zeros((B, self.z_dim), requires_grad=False, dtype=x.dtype,
272
+ device=x.device) ## Noise for Co-Modulation
273
+ return x, z, feats
274
+
275
+
276
+ def fma(a, b, c): # => a * b + c
277
+ return _FusedMultiplyAdd.apply(a, b, c)
278
+
279
+
280
+ class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
281
+ @staticmethod
282
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
283
+ out = torch.addcmul(c, a, b)
284
+ ctx.save_for_backward(a, b)
285
+ ctx.c_shape = c.shape
286
+ return out
287
+
288
+ @staticmethod
289
+ def backward(ctx, dout): # pylint: disable=arguments-differ
290
+ a, b = ctx.saved_tensors
291
+ c_shape = ctx.c_shape
292
+ da = None
293
+ db = None
294
+ dc = None
295
+
296
+ if ctx.needs_input_grad[0]:
297
+ da = _unbroadcast(dout * b, a.shape)
298
+
299
+ if ctx.needs_input_grad[1]:
300
+ db = _unbroadcast(dout * a, b.shape)
301
+
302
+ if ctx.needs_input_grad[2]:
303
+ dc = _unbroadcast(dout, c_shape)
304
+
305
+ return da, db, dc
306
+
307
+
308
+ def _unbroadcast(x, shape):
309
+ extra_dims = x.ndim - len(shape)
310
+ assert extra_dims >= 0
311
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
312
+ if len(dim):
313
+ x = x.sum(dim=dim, keepdim=True)
314
+ if extra_dims:
315
+ x = x.reshape(-1, *x.shape[extra_dims + 1:])
316
+ assert x.shape == shape
317
+ return x
318
+
319
+
320
+ def modulated_conv2d(
321
+ x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
322
+ weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
323
+ styles, # Modulation coefficients of shape [batch_size, in_channels].
324
+ noise=None, # Optional noise tensor to add to the output activations.
325
+ up=1, # Integer upsampling factor.
326
+ down=1, # Integer downsampling factor.
327
+ padding=0, # Padding with respect to the upsampled image.
328
+ resample_filter=None,
329
+ # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
330
+ demodulate=True, # Apply weight demodulation?
331
+ flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
332
+ fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation?
333
+ ):
334
+ batch_size = x.shape[0]
335
+ out_channels, in_channels, kh, kw = weight.shape
336
+
337
+ # Pre-normalize inputs to avoid FP16 overflow.
338
+ if x.dtype == torch.float16 and demodulate:
339
+ weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1, 2, 3],
340
+ keepdim=True)) # max_Ikk
341
+ styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
342
+
343
+ # Calculate per-sample weights and demodulation coefficients.
344
+ w = None
345
+ dcoefs = None
346
+ if demodulate or fused_modconv:
347
+ w = weight.unsqueeze(0) # [NOIkk]
348
+ w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
349
+ if demodulate:
350
+ dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO]
351
+ if demodulate and fused_modconv:
352
+ w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
353
+ # Execute by scaling the activations before and after the convolution.
354
+ if not fused_modconv:
355
+ x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
356
+ x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down,
357
+ padding=padding, flip_weight=flip_weight)
358
+ if demodulate and noise is not None:
359
+ x = fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
360
+ elif demodulate:
361
+ x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
362
+ elif noise is not None:
363
+ x = x.add_(noise.to(x.dtype))
364
+ return x
365
+
366
+ # Execute as one fused op using grouped convolution.
367
+ batch_size = int(batch_size)
368
+ x = x.reshape(1, -1, *x.shape[2:])
369
+ w = w.reshape(-1, in_channels, kh, kw)
370
+ x = conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding,
371
+ groups=batch_size, flip_weight=flip_weight)
372
+ x = x.reshape(batch_size, -1, *x.shape[2:])
373
+ if noise is not None:
374
+ x = x.add_(noise)
375
+ return x
376
+
377
+
378
+ class SynthesisLayer(torch.nn.Module):
379
+ def __init__(self,
380
+ in_channels, # Number of input channels.
381
+ out_channels, # Number of output channels.
382
+ w_dim, # Intermediate latent (W) dimensionality.
383
+ resolution, # Resolution of this layer.
384
+ kernel_size=3, # Convolution kernel size.
385
+ up=1, # Integer upsampling factor.
386
+ use_noise=True, # Enable noise input?
387
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
388
+ resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
389
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
390
+ channels_last=False, # Use channels_last format for the weights?
391
+ ):
392
+ super().__init__()
393
+ self.resolution = resolution
394
+ self.up = up
395
+ self.use_noise = use_noise
396
+ self.activation = activation
397
+ self.conv_clamp = conv_clamp
398
+ self.register_buffer('resample_filter', setup_filter(resample_filter))
399
+ self.padding = kernel_size // 2
400
+ self.act_gain = activation_funcs[activation].def_gain
401
+
402
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
403
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
404
+ self.weight = torch.nn.Parameter(
405
+ torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
406
+ if use_noise:
407
+ self.register_buffer('noise_const', torch.randn([resolution, resolution]))
408
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
409
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
410
+
411
+ def forward(self, x, w, noise_mode='none', fused_modconv=True, gain=1):
412
+ assert noise_mode in ['random', 'const', 'none']
413
+ in_resolution = self.resolution // self.up
414
+ styles = self.affine(w)
415
+
416
+ noise = None
417
+ if self.use_noise and noise_mode == 'random':
418
+ noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution],
419
+ device=x.device) * self.noise_strength
420
+ if self.use_noise and noise_mode == 'const':
421
+ noise = self.noise_const * self.noise_strength
422
+
423
+ flip_weight = (self.up == 1) # slightly faster
424
+ x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
425
+ padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight,
426
+ fused_modconv=fused_modconv)
427
+
428
+ act_gain = self.act_gain * gain
429
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
430
+ x = F.leaky_relu(x, negative_slope=0.2, inplace=False)
431
+ if act_gain != 1:
432
+ x = x * act_gain
433
+ if act_clamp is not None:
434
+ x = x.clamp(-act_clamp, act_clamp)
435
+ return x
436
+
437
+
438
+ class ToRGBLayer(torch.nn.Module):
439
+ def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
440
+ super().__init__()
441
+ self.conv_clamp = conv_clamp
442
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
443
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
444
+ self.weight = torch.nn.Parameter(
445
+ torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
446
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
447
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
448
+
449
+ def forward(self, x, w, fused_modconv=True):
450
+ styles = self.affine(w) * self.weight_gain
451
+ x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
452
+ x = bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
453
+ return x
454
+
455
+
456
+ class SynthesisForeword(torch.nn.Module):
457
+ def __init__(self,
458
+ z_dim, # Output Latent (Z) dimensionality.
459
+ resolution, # Resolution of this block.
460
+ in_channels,
461
+ img_channels, # Number of input color channels.
462
+ architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
463
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
464
+
465
+ ):
466
+ super().__init__()
467
+ self.in_channels = in_channels
468
+ self.z_dim = z_dim
469
+ self.resolution = resolution
470
+ self.img_channels = img_channels
471
+ self.architecture = architecture
472
+
473
+ self.fc = FullyConnectedLayer(self.z_dim, (self.z_dim // 2) * 4 * 4, activation=activation)
474
+ self.conv = SynthesisLayer(self.in_channels, self.in_channels, w_dim=(z_dim // 2) * 3, resolution=4)
475
+
476
+ if architecture == 'skip':
477
+ self.torgb = ToRGBLayer(self.in_channels, self.img_channels, kernel_size=1, w_dim=(z_dim // 2) * 3)
478
+
479
+ def forward(self, x, ws, feats, img, force_fp32=False):
480
+ _ = force_fp32 # unused
481
+ dtype = torch.float32
482
+ memory_format = torch.contiguous_format
483
+
484
+ x_global = x.clone()
485
+ # ToRGB.
486
+ x = self.fc(x)
487
+ x = x.view(-1, self.z_dim // 2, 4, 4)
488
+ x = x.to(dtype=dtype, memory_format=memory_format)
489
+
490
+ # Main layers.
491
+ x_skip = feats[4].clone()
492
+ x = x + x_skip
493
+
494
+ mod_vector = []
495
+ mod_vector.append(ws[:, 0])
496
+ mod_vector.append(x_global.clone())
497
+ mod_vector = torch.cat(mod_vector, dim=1)
498
+
499
+ x = self.conv(x, mod_vector)
500
+
501
+ mod_vector = []
502
+ mod_vector.append(ws[:, 2 * 2 - 3])
503
+ mod_vector.append(x_global.clone())
504
+ mod_vector = torch.cat(mod_vector, dim=1)
505
+
506
+ if self.architecture == 'skip':
507
+ img = self.torgb(x, mod_vector)
508
+ img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format)
509
+
510
+ assert x.dtype == dtype
511
+ return x, img
512
+
513
+
514
+ class SELayer(nn.Module):
515
+ def __init__(self, channel, reduction=16):
516
+ super(SELayer, self).__init__()
517
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
518
+ self.fc = nn.Sequential(
519
+ nn.Linear(channel, channel // reduction, bias=False),
520
+ nn.ReLU(inplace=False),
521
+ nn.Linear(channel // reduction, channel, bias=False),
522
+ nn.Sigmoid()
523
+ )
524
+
525
+ def forward(self, x):
526
+ b, c, _, _ = x.size()
527
+ y = self.avg_pool(x).view(b, c)
528
+ y = self.fc(y).view(b, c, 1, 1)
529
+ res = x * y.expand_as(x)
530
+ return res
531
+
532
+
533
+ class FourierUnit(nn.Module):
534
+
535
+ def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
536
+ spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
537
+ # bn_layer not used
538
+ super(FourierUnit, self).__init__()
539
+ self.groups = groups
540
+
541
+ self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
542
+ out_channels=out_channels * 2,
543
+ kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
544
+ self.relu = torch.nn.ReLU(inplace=False)
545
+
546
+ # squeeze and excitation block
547
+ self.use_se = use_se
548
+ if use_se:
549
+ if se_kwargs is None:
550
+ se_kwargs = {}
551
+ self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
552
+
553
+ self.spatial_scale_factor = spatial_scale_factor
554
+ self.spatial_scale_mode = spatial_scale_mode
555
+ self.spectral_pos_encoding = spectral_pos_encoding
556
+ self.ffc3d = ffc3d
557
+ self.fft_norm = fft_norm
558
+
559
+ def forward(self, x):
560
+ batch = x.shape[0]
561
+
562
+ if self.spatial_scale_factor is not None:
563
+ orig_size = x.shape[-2:]
564
+ x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode,
565
+ align_corners=False)
566
+
567
+ r_size = x.size()
568
+ # (batch, c, h, w/2+1, 2)
569
+ fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
570
+ ffted = fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
571
+ ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
572
+ ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
573
+ ffted = ffted.view((batch, -1,) + ffted.size()[3:])
574
+
575
+ if self.spectral_pos_encoding:
576
+ height, width = ffted.shape[-2:]
577
+ coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
578
+ coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
579
+ ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
580
+
581
+ if self.use_se:
582
+ ffted = self.se(ffted)
583
+
584
+ ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
585
+ ffted = self.relu(ffted)
586
+
587
+ ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
588
+ 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
589
+ ffted = torch.complex(ffted[..., 0], ffted[..., 1])
590
+
591
+ ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
592
+ output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
593
+
594
+ if self.spatial_scale_factor is not None:
595
+ output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
596
+
597
+ return output
598
+
599
+
600
+ class SpectralTransform(nn.Module):
601
+
602
+ def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs):
603
+ # bn_layer not used
604
+ super(SpectralTransform, self).__init__()
605
+ self.enable_lfu = enable_lfu
606
+ if stride == 2:
607
+ self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
608
+ else:
609
+ self.downsample = nn.Identity()
610
+
611
+ self.stride = stride
612
+ self.conv1 = nn.Sequential(
613
+ nn.Conv2d(in_channels, out_channels //
614
+ 2, kernel_size=1, groups=groups, bias=False),
615
+ # nn.BatchNorm2d(out_channels // 2),
616
+ nn.ReLU(inplace=True)
617
+ )
618
+ self.fu = FourierUnit(
619
+ out_channels // 2, out_channels // 2, groups, **fu_kwargs)
620
+ if self.enable_lfu:
621
+ self.lfu = FourierUnit(
622
+ out_channels // 2, out_channels // 2, groups)
623
+ self.conv2 = torch.nn.Conv2d(
624
+ out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
625
+
626
+ def forward(self, x):
627
+
628
+ x = self.downsample(x)
629
+ x = self.conv1(x)
630
+ output = self.fu(x)
631
+
632
+ if self.enable_lfu:
633
+ n, c, h, w = x.shape
634
+ split_no = 2
635
+ split_s = h // split_no
636
+ xs = torch.cat(torch.split(
637
+ x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
638
+ xs = torch.cat(torch.split(xs, split_s, dim=-1),
639
+ dim=1).contiguous()
640
+ xs = self.lfu(xs)
641
+ xs = xs.repeat(1, 1, split_no, split_no).contiguous()
642
+ else:
643
+ xs = 0
644
+
645
+ output = self.conv2(x + output + xs)
646
+
647
+ return output
648
+
649
+
650
+ class FFC(nn.Module):
651
+
652
+ def __init__(self, in_channels, out_channels, kernel_size,
653
+ ratio_gin, ratio_gout, stride=1, padding=0,
654
+ dilation=1, groups=1, bias=False, enable_lfu=True,
655
+ padding_type='reflect', gated=False, **spectral_kwargs):
656
+ super(FFC, self).__init__()
657
+
658
+ assert stride == 1 or stride == 2, "Stride should be 1 or 2."
659
+ self.stride = stride
660
+
661
+ in_cg = int(in_channels * ratio_gin)
662
+ in_cl = in_channels - in_cg
663
+ out_cg = int(out_channels * ratio_gout)
664
+ out_cl = out_channels - out_cg
665
+ # groups_g = 1 if groups == 1 else int(groups * ratio_gout)
666
+ # groups_l = 1 if groups == 1 else groups - groups_g
667
+
668
+ self.ratio_gin = ratio_gin
669
+ self.ratio_gout = ratio_gout
670
+ self.global_in_num = in_cg
671
+
672
+ module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
673
+ self.convl2l = module(in_cl, out_cl, kernel_size,
674
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
675
+ module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
676
+ self.convl2g = module(in_cl, out_cg, kernel_size,
677
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
678
+ module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
679
+ self.convg2l = module(in_cg, out_cl, kernel_size,
680
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
681
+ module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
682
+ self.convg2g = module(
683
+ in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
684
+
685
+ self.gated = gated
686
+ module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
687
+ self.gate = module(in_channels, 2, 1)
688
+
689
+ def forward(self, x, fname=None):
690
+ x_l, x_g = x if type(x) is tuple else (x, 0)
691
+ out_xl, out_xg = 0, 0
692
+
693
+ if self.gated:
694
+ total_input_parts = [x_l]
695
+ if torch.is_tensor(x_g):
696
+ total_input_parts.append(x_g)
697
+ total_input = torch.cat(total_input_parts, dim=1)
698
+
699
+ gates = torch.sigmoid(self.gate(total_input))
700
+ g2l_gate, l2g_gate = gates.chunk(2, dim=1)
701
+ else:
702
+ g2l_gate, l2g_gate = 1, 1
703
+
704
+ spec_x = self.convg2g(x_g)
705
+
706
+ if self.ratio_gout != 1:
707
+ out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
708
+ if self.ratio_gout != 0:
709
+ out_xg = self.convl2g(x_l) * l2g_gate + spec_x
710
+
711
+ return out_xl, out_xg
712
+
713
+
714
+ class FFC_BN_ACT(nn.Module):
715
+
716
+ def __init__(self, in_channels, out_channels,
717
+ kernel_size, ratio_gin, ratio_gout,
718
+ stride=1, padding=0, dilation=1, groups=1, bias=False,
719
+ norm_layer=nn.SyncBatchNorm, activation_layer=nn.Identity,
720
+ padding_type='reflect',
721
+ enable_lfu=True, **kwargs):
722
+ super(FFC_BN_ACT, self).__init__()
723
+ self.ffc = FFC(in_channels, out_channels, kernel_size,
724
+ ratio_gin, ratio_gout, stride, padding, dilation,
725
+ groups, bias, enable_lfu, padding_type=padding_type, **kwargs)
726
+ lnorm = nn.Identity if ratio_gout == 1 else norm_layer
727
+ gnorm = nn.Identity if ratio_gout == 0 else norm_layer
728
+ global_channels = int(out_channels * ratio_gout)
729
+ # self.bn_l = lnorm(out_channels - global_channels)
730
+ # self.bn_g = gnorm(global_channels)
731
+
732
+ lact = nn.Identity if ratio_gout == 1 else activation_layer
733
+ gact = nn.Identity if ratio_gout == 0 else activation_layer
734
+ self.act_l = lact(inplace=True)
735
+ self.act_g = gact(inplace=True)
736
+
737
+ def forward(self, x, fname=None):
738
+ x_l, x_g = self.ffc(x, fname=fname, )
739
+ x_l = self.act_l(x_l)
740
+ x_g = self.act_g(x_g)
741
+ return x_l, x_g
742
+
743
+
744
+ class FFCResnetBlock(nn.Module):
745
+ def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1,
746
+ spatial_transform_kwargs=None, inline=False, ratio_gin=0.75, ratio_gout=0.75):
747
+ super().__init__()
748
+ self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
749
+ norm_layer=norm_layer,
750
+ activation_layer=activation_layer,
751
+ padding_type=padding_type,
752
+ ratio_gin=ratio_gin, ratio_gout=ratio_gout)
753
+ self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
754
+ norm_layer=norm_layer,
755
+ activation_layer=activation_layer,
756
+ padding_type=padding_type,
757
+ ratio_gin=ratio_gin, ratio_gout=ratio_gout)
758
+ self.inline = inline
759
+
760
+ def forward(self, x, fname=None):
761
+ if self.inline:
762
+ x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
763
+ else:
764
+ x_l, x_g = x if type(x) is tuple else (x, 0)
765
+
766
+ id_l, id_g = x_l, x_g
767
+
768
+ x_l, x_g = self.conv1((x_l, x_g), fname=fname)
769
+ x_l, x_g = self.conv2((x_l, x_g), fname=fname)
770
+
771
+ x_l, x_g = id_l + x_l, id_g + x_g
772
+ out = x_l, x_g
773
+ if self.inline:
774
+ out = torch.cat(out, dim=1)
775
+ return out
776
+
777
+
778
+ class ConcatTupleLayer(nn.Module):
779
+ def forward(self, x):
780
+ assert isinstance(x, tuple)
781
+ x_l, x_g = x
782
+ assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
783
+ if not torch.is_tensor(x_g):
784
+ return x_l
785
+ return torch.cat(x, dim=1)
786
+
787
+
788
+ class FFCBlock(torch.nn.Module):
789
+ def __init__(self,
790
+ dim, # Number of output/input channels.
791
+ kernel_size, # Width and height of the convolution kernel.
792
+ padding,
793
+ ratio_gin=0.75,
794
+ ratio_gout=0.75,
795
+ activation='linear', # Activation function: 'relu', 'lrelu', etc.
796
+ ):
797
+ super().__init__()
798
+ if activation == 'linear':
799
+ self.activation = nn.Identity
800
+ else:
801
+ self.activation = nn.ReLU
802
+ self.padding = padding
803
+ self.kernel_size = kernel_size
804
+ self.ffc_block = FFCResnetBlock(dim=dim,
805
+ padding_type='reflect',
806
+ norm_layer=nn.SyncBatchNorm,
807
+ activation_layer=self.activation,
808
+ dilation=1,
809
+ ratio_gin=ratio_gin,
810
+ ratio_gout=ratio_gout)
811
+
812
+ self.concat_layer = ConcatTupleLayer()
813
+
814
+ def forward(self, gen_ft, mask, fname=None):
815
+ x = gen_ft.float()
816
+
817
+ x_l, x_g = x[:, :-self.ffc_block.conv1.ffc.global_in_num], x[:, -self.ffc_block.conv1.ffc.global_in_num:]
818
+ id_l, id_g = x_l, x_g
819
+
820
+ x_l, x_g = self.ffc_block((x_l, x_g), fname=fname)
821
+ x_l, x_g = id_l + x_l, id_g + x_g
822
+ x = self.concat_layer((x_l, x_g))
823
+
824
+ return x + gen_ft.float()
825
+
826
+
827
+ class FFCSkipLayer(torch.nn.Module):
828
+ def __init__(self,
829
+ dim, # Number of input/output channels.
830
+ kernel_size=3, # Convolution kernel size.
831
+ ratio_gin=0.75,
832
+ ratio_gout=0.75,
833
+ ):
834
+ super().__init__()
835
+ self.padding = kernel_size // 2
836
+
837
+ self.ffc_act = FFCBlock(dim=dim, kernel_size=kernel_size, activation=nn.ReLU,
838
+ padding=self.padding, ratio_gin=ratio_gin, ratio_gout=ratio_gout)
839
+
840
+ def forward(self, gen_ft, mask, fname=None):
841
+ x = self.ffc_act(gen_ft, mask, fname=fname)
842
+ return x
843
+
844
+
845
+ class SynthesisBlock(torch.nn.Module):
846
+ def __init__(self,
847
+ in_channels, # Number of input channels, 0 = first block.
848
+ out_channels, # Number of output channels.
849
+ w_dim, # Intermediate latent (W) dimensionality.
850
+ resolution, # Resolution of this block.
851
+ img_channels, # Number of output color channels.
852
+ is_last, # Is this the last block?
853
+ architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
854
+ resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
855
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
856
+ use_fp16=False, # Use FP16 for this block?
857
+ fp16_channels_last=False, # Use channels-last memory format with FP16?
858
+ **layer_kwargs, # Arguments for SynthesisLayer.
859
+ ):
860
+ assert architecture in ['orig', 'skip', 'resnet']
861
+ super().__init__()
862
+ self.in_channels = in_channels
863
+ self.w_dim = w_dim
864
+ self.resolution = resolution
865
+ self.img_channels = img_channels
866
+ self.is_last = is_last
867
+ self.architecture = architecture
868
+ self.use_fp16 = use_fp16
869
+ self.channels_last = (use_fp16 and fp16_channels_last)
870
+ self.register_buffer('resample_filter', setup_filter(resample_filter))
871
+ self.num_conv = 0
872
+ self.num_torgb = 0
873
+ self.res_ffc = {4: 0, 8: 0, 16: 0, 32: 1, 64: 1, 128: 1, 256: 1, 512: 1}
874
+
875
+ if in_channels != 0 and resolution >= 8:
876
+ self.ffc_skip = nn.ModuleList()
877
+ for _ in range(self.res_ffc[resolution]):
878
+ self.ffc_skip.append(FFCSkipLayer(dim=out_channels))
879
+
880
+ if in_channels == 0:
881
+ self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
882
+
883
+ if in_channels != 0:
884
+ self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim * 3, resolution=resolution, up=2,
885
+ resample_filter=resample_filter, conv_clamp=conv_clamp,
886
+ channels_last=self.channels_last, **layer_kwargs)
887
+ self.num_conv += 1
888
+
889
+ self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim * 3, resolution=resolution,
890
+ conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
891
+ self.num_conv += 1
892
+
893
+ if is_last or architecture == 'skip':
894
+ self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim * 3,
895
+ conv_clamp=conv_clamp, channels_last=self.channels_last)
896
+ self.num_torgb += 1
897
+
898
+ if in_channels != 0 and architecture == 'resnet':
899
+ self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
900
+ resample_filter=resample_filter, channels_last=self.channels_last)
901
+
902
+ def forward(self, x, mask, feats, img, ws, fname=None, force_fp32=False, fused_modconv=None, **layer_kwargs):
903
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
904
+ dtype = torch.float32
905
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
906
+ if fused_modconv is None:
907
+ fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
908
+
909
+ x = x.to(dtype=dtype, memory_format=memory_format)
910
+ x_skip = feats[self.resolution].clone().to(dtype=dtype, memory_format=memory_format)
911
+
912
+ # Main layers.
913
+ if self.in_channels == 0:
914
+ x = self.conv1(x, ws[1], fused_modconv=fused_modconv, **layer_kwargs)
915
+ elif self.architecture == 'resnet':
916
+ y = self.skip(x, gain=np.sqrt(0.5))
917
+ x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs)
918
+ if len(self.ffc_skip) > 0:
919
+ mask = F.interpolate(mask, size=x_skip.shape[2:], )
920
+ z = x + x_skip
921
+ for fres in self.ffc_skip:
922
+ z = fres(z, mask)
923
+ x = x + z
924
+ else:
925
+ x = x + x_skip
926
+ x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
927
+ x = y.add_(x)
928
+ else:
929
+ x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs)
930
+ if len(self.ffc_skip) > 0:
931
+ mask = F.interpolate(mask, size=x_skip.shape[2:], )
932
+ z = x + x_skip
933
+ for fres in self.ffc_skip:
934
+ z = fres(z, mask)
935
+ x = x + z
936
+ else:
937
+ x = x + x_skip
938
+ x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, **layer_kwargs)
939
+ # ToRGB.
940
+ if img is not None:
941
+ img = upsample2d(img, self.resample_filter)
942
+ if self.is_last or self.architecture == 'skip':
943
+ y = self.torgb(x, ws[2].clone(), fused_modconv=fused_modconv)
944
+ y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
945
+ img = img.add_(y) if img is not None else y
946
+
947
+ x = x.to(dtype=dtype)
948
+ assert x.dtype == dtype
949
+ assert img is None or img.dtype == torch.float32
950
+ return x, img
951
+
952
+
953
+ class SynthesisNetwork(torch.nn.Module):
954
+ def __init__(self,
955
+ w_dim, # Intermediate latent (W) dimensionality.
956
+ z_dim, # Output Latent (Z) dimensionality.
957
+ img_resolution, # Output image resolution.
958
+ img_channels, # Number of color channels.
959
+ channel_base=16384, # Overall multiplier for the number of channels.
960
+ channel_max=512, # Maximum number of channels in any layer.
961
+ num_fp16_res=0, # Use FP16 for the N highest resolutions.
962
+ **block_kwargs, # Arguments for SynthesisBlock.
963
+ ):
964
+ assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
965
+ super().__init__()
966
+ self.w_dim = w_dim
967
+ self.img_resolution = img_resolution
968
+ self.img_resolution_log2 = int(np.log2(img_resolution))
969
+ self.img_channels = img_channels
970
+ self.block_resolutions = [2 ** i for i in range(3, self.img_resolution_log2 + 1)]
971
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
972
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
973
+
974
+ self.foreword = SynthesisForeword(img_channels=img_channels, in_channels=min(channel_base // 4, channel_max),
975
+ z_dim=z_dim * 2, resolution=4)
976
+
977
+ self.num_ws = self.img_resolution_log2 * 2 - 2
978
+ for res in self.block_resolutions:
979
+ if res // 2 in channels_dict.keys():
980
+ in_channels = channels_dict[res // 2] if res > 4 else 0
981
+ else:
982
+ in_channels = min(channel_base // (res // 2), channel_max)
983
+ out_channels = channels_dict[res]
984
+ use_fp16 = (res >= fp16_resolution)
985
+ use_fp16 = False
986
+ is_last = (res == self.img_resolution)
987
+ block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
988
+ img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
989
+ setattr(self, f'b{res}', block)
990
+
991
+ def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs):
992
+
993
+ img = None
994
+
995
+ x, img = self.foreword(x_global, ws, feats, img)
996
+
997
+ for res in self.block_resolutions:
998
+ block = getattr(self, f'b{res}')
999
+ mod_vector0 = []
1000
+ mod_vector0.append(ws[:, int(np.log2(res)) * 2 - 5])
1001
+ mod_vector0.append(x_global.clone())
1002
+ mod_vector0 = torch.cat(mod_vector0, dim=1)
1003
+
1004
+ mod_vector1 = []
1005
+ mod_vector1.append(ws[:, int(np.log2(res)) * 2 - 4])
1006
+ mod_vector1.append(x_global.clone())
1007
+ mod_vector1 = torch.cat(mod_vector1, dim=1)
1008
+
1009
+ mod_vector_rgb = []
1010
+ mod_vector_rgb.append(ws[:, int(np.log2(res)) * 2 - 3])
1011
+ mod_vector_rgb.append(x_global.clone())
1012
+ mod_vector_rgb = torch.cat(mod_vector_rgb, dim=1)
1013
+ x, img = block(x, mask, feats, img, (mod_vector0, mod_vector1, mod_vector_rgb), fname=fname, **block_kwargs)
1014
+ return img
1015
+
1016
+
1017
+ class MappingNetwork(torch.nn.Module):
1018
+ def __init__(self,
1019
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
1020
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
1021
+ w_dim, # Intermediate latent (W) dimensionality.
1022
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
1023
+ num_layers=8, # Number of mapping layers.
1024
+ embed_features=None, # Label embedding dimensionality, None = same as w_dim.
1025
+ layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
1026
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
1027
+ lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
1028
+ w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
1029
+ ):
1030
+ super().__init__()
1031
+ self.z_dim = z_dim
1032
+ self.c_dim = c_dim
1033
+ self.w_dim = w_dim
1034
+ self.num_ws = num_ws
1035
+ self.num_layers = num_layers
1036
+ self.w_avg_beta = w_avg_beta
1037
+
1038
+ if embed_features is None:
1039
+ embed_features = w_dim
1040
+ if c_dim == 0:
1041
+ embed_features = 0
1042
+ if layer_features is None:
1043
+ layer_features = w_dim
1044
+ features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
1045
+
1046
+ if c_dim > 0:
1047
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
1048
+ for idx in range(num_layers):
1049
+ in_features = features_list[idx]
1050
+ out_features = features_list[idx + 1]
1051
+ layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
1052
+ setattr(self, f'fc{idx}', layer)
1053
+
1054
+ if num_ws is not None and w_avg_beta is not None:
1055
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
1056
+
1057
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
1058
+ # Embed, normalize, and concat inputs.
1059
+ x = None
1060
+ with torch.autograd.profiler.record_function('input'):
1061
+ if self.z_dim > 0:
1062
+ x = normalize_2nd_moment(z.to(torch.float32))
1063
+ if self.c_dim > 0:
1064
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
1065
+ x = torch.cat([x, y], dim=1) if x is not None else y
1066
+
1067
+ # Main layers.
1068
+ for idx in range(self.num_layers):
1069
+ layer = getattr(self, f'fc{idx}')
1070
+ x = layer(x)
1071
+
1072
+ # Update moving average of W.
1073
+ if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
1074
+ with torch.autograd.profiler.record_function('update_w_avg'):
1075
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
1076
+
1077
+ # Broadcast.
1078
+ if self.num_ws is not None:
1079
+ with torch.autograd.profiler.record_function('broadcast'):
1080
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
1081
+
1082
+ # Apply truncation.
1083
+ if truncation_psi != 1:
1084
+ with torch.autograd.profiler.record_function('truncate'):
1085
+ assert self.w_avg_beta is not None
1086
+ if self.num_ws is None or truncation_cutoff is None:
1087
+ x = self.w_avg.lerp(x, truncation_psi)
1088
+ else:
1089
+ x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
1090
+ return x
1091
+
1092
+
1093
+ class Generator(torch.nn.Module):
1094
+ def __init__(self,
1095
+ z_dim, # Input latent (Z) dimensionality.
1096
+ c_dim, # Conditioning label (C) dimensionality.
1097
+ w_dim, # Intermediate latent (W) dimensionality.
1098
+ img_resolution, # Output resolution.
1099
+ img_channels, # Number of output color channels.
1100
+ encoder_kwargs={}, # Arguments for EncoderNetwork.
1101
+ mapping_kwargs={}, # Arguments for MappingNetwork.
1102
+ synthesis_kwargs={}, # Arguments for SynthesisNetwork.
1103
+ ):
1104
+ super().__init__()
1105
+ self.z_dim = z_dim
1106
+ self.c_dim = c_dim
1107
+ self.w_dim = w_dim
1108
+ self.img_resolution = img_resolution
1109
+ self.img_channels = img_channels
1110
+ self.encoder = EncoderNetwork(c_dim=c_dim, z_dim=z_dim, img_resolution=img_resolution,
1111
+ img_channels=img_channels, **encoder_kwargs)
1112
+ self.synthesis = SynthesisNetwork(z_dim=z_dim, w_dim=w_dim, img_resolution=img_resolution,
1113
+ img_channels=img_channels, **synthesis_kwargs)
1114
+ self.num_ws = self.synthesis.num_ws
1115
+ self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
1116
+
1117
+ def forward(self, img, c, fname=None, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs):
1118
+ mask = img[:, -1].unsqueeze(1)
1119
+ x_global, z, feats = self.encoder(img, c)
1120
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
1121
+ img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs)
1122
+ return img
1123
+
1124
+
1125
+ FCF_MODEL_URL = os.environ.get(
1126
+ "FCF_MODEL_URL",
1127
+ "https://github.com/Sanster/models/releases/download/add_fcf/places_512_G.pth",
1128
+ )
1129
+
1130
+
1131
+ class FcF(InpaintModel):
1132
+ min_size = 512
1133
+ pad_mod = 512
1134
+ pad_to_square = True
1135
+
1136
+ def init_model(self, device, **kwargs):
1137
+ seed = 0
1138
+ random.seed(seed)
1139
+ np.random.seed(seed)
1140
+ torch.manual_seed(seed)
1141
+ torch.cuda.manual_seed_all(seed)
1142
+ torch.backends.cudnn.deterministic = True
1143
+ torch.backends.cudnn.benchmark = False
1144
+
1145
+ kwargs = {'channel_base': 1 * 32768, 'channel_max': 512, 'num_fp16_res': 4, 'conv_clamp': 256}
1146
+ G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3,
1147
+ synthesis_kwargs=kwargs, encoder_kwargs=kwargs, mapping_kwargs={'num_layers': 2})
1148
+ self.model = load_model(G, FCF_MODEL_URL, device)
1149
+ self.label = torch.zeros([1, self.model.c_dim], device=device)
1150
+
1151
+ @staticmethod
1152
+ def is_downloaded() -> bool:
1153
+ return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL))
1154
+
1155
+ @torch.no_grad()
1156
+ def __call__(self, image, mask, config: Config):
1157
+ """
1158
+ images: [H, W, C] RGB, not normalized
1159
+ masks: [H, W]
1160
+ return: BGR IMAGE
1161
+ """
1162
+ if image.shape[0] == 512 and image.shape[1] == 512:
1163
+ return self._pad_forward(image, mask, config)
1164
+
1165
+ boxes = boxes_from_mask(mask)
1166
+ crop_result = []
1167
+ config.hd_strategy_crop_margin = 128
1168
+ for box in boxes:
1169
+ crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config)
1170
+ origin_size = crop_image.shape[:2]
1171
+ resize_image = resize_max_size(crop_image, size_limit=512)
1172
+ resize_mask = resize_max_size(crop_mask, size_limit=512)
1173
+ inpaint_result = self._pad_forward(resize_image, resize_mask, config)
1174
+
1175
+ # only paste masked area result
1176
+ inpaint_result = cv2.resize(inpaint_result, (origin_size[1], origin_size[0]), interpolation=cv2.INTER_CUBIC)
1177
+
1178
+ original_pixel_indices = crop_mask < 127
1179
+ inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][original_pixel_indices]
1180
+
1181
+ crop_result.append((inpaint_result, crop_box))
1182
+
1183
+ inpaint_result = image[:, :, ::-1]
1184
+ for crop_image, crop_box in crop_result:
1185
+ x1, y1, x2, y2 = crop_box
1186
+ inpaint_result[y1:y2, x1:x2, :] = crop_image
1187
+
1188
+ return inpaint_result
1189
+
1190
+ def forward(self, image, mask, config: Config):
1191
+ """Input images and output images have same size
1192
+ images: [H, W, C] RGB
1193
+ masks: [H, W] mask area == 255
1194
+ return: BGR IMAGE
1195
+ """
1196
+
1197
+ image = norm_img(image) # [0, 1]
1198
+ image = image * 2 - 1 # [0, 1] -> [-1, 1]
1199
+ mask = (mask > 120) * 255
1200
+ mask = norm_img(mask)
1201
+
1202
+ image = torch.from_numpy(image).unsqueeze(0).to(self.device)
1203
+ mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
1204
+
1205
+ erased_img = image * (1 - mask)
1206
+ input_image = torch.cat([0.5 - mask, erased_img], dim=1)
1207
+
1208
+ output = self.model(input_image, self.label, truncation_psi=0.1, noise_mode='none')
1209
+ output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8)
1210
+ output = output[0].cpu().numpy()
1211
+ cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
1212
+ return cur_res
lama_cleaner/model/lama.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from loguru import logger
7
+
8
+ from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
9
+ from lama_cleaner.model.base import InpaintModel
10
+ from lama_cleaner.schema import Config
11
+
12
+ LAMA_MODEL_URL = os.environ.get(
13
+ "LAMA_MODEL_URL",
14
+ "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
15
+ )
16
+
17
+
18
+ class LaMa(InpaintModel):
19
+ pad_mod = 8
20
+
21
+ def init_model(self, device, **kwargs):
22
+ if os.environ.get("LAMA_MODEL"):
23
+ model_path = os.environ.get("LAMA_MODEL")
24
+ if not os.path.exists(model_path):
25
+ raise FileNotFoundError(
26
+ f"lama torchscript model not found: {model_path}"
27
+ )
28
+ else:
29
+ model_path = download_model(LAMA_MODEL_URL)
30
+ # TODO used to create a lambda docker image
31
+ # model_path = '../app/big-lama.pt'
32
+ logger.info(f"Load LaMa model from: {model_path}")
33
+ model = torch.jit.load(model_path, map_location="cpu")
34
+ model = model.to(device)
35
+ model.eval()
36
+ self.model = model
37
+ self.model_path = model_path
38
+
39
+ @staticmethod
40
+ def is_downloaded() -> bool:
41
+ return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
42
+
43
+ def forward(self, image, mask, config: Config):
44
+ """Input image and output image have same size
45
+ image: [H, W, C] RGB
46
+ mask: [H, W]
47
+ return: BGR IMAGE
48
+ """
49
+ image = norm_img(image)
50
+ mask = norm_img(mask)
51
+
52
+ mask = (mask > 0) * 1
53
+ image = torch.from_numpy(image).unsqueeze(0).to(self.device)
54
+ mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
55
+
56
+ inpainted_image = self.model(image, mask)
57
+
58
+ cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
59
+ cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
60
+ cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
61
+ return cur_res
lama_cleaner/model/ldm.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from lama_cleaner.model.base import InpaintModel
7
+ from lama_cleaner.model.ddim_sampler import DDIMSampler
8
+ from lama_cleaner.model.plms_sampler import PLMSSampler
9
+ from lama_cleaner.schema import Config, LDMSampler
10
+
11
+ torch.manual_seed(42)
12
+ import torch.nn as nn
13
+ from lama_cleaner.helper import (
14
+ norm_img,
15
+ get_cache_path_by_url,
16
+ load_jit_model,
17
+ )
18
+ from lama_cleaner.model.utils import (
19
+ make_beta_schedule,
20
+ timestep_embedding,
21
+ )
22
+
23
+ LDM_ENCODE_MODEL_URL = os.environ.get(
24
+ "LDM_ENCODE_MODEL_URL",
25
+ "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt",
26
+ )
27
+
28
+ LDM_DECODE_MODEL_URL = os.environ.get(
29
+ "LDM_DECODE_MODEL_URL",
30
+ "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt",
31
+ )
32
+
33
+ LDM_DIFFUSION_MODEL_URL = os.environ.get(
34
+ "LDM_DIFFUSION_MODEL_URL",
35
+ "https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt",
36
+ )
37
+
38
+
39
+ class DDPM(nn.Module):
40
+ # classic DDPM with Gaussian diffusion, in image space
41
+ def __init__(
42
+ self,
43
+ device,
44
+ timesteps=1000,
45
+ beta_schedule="linear",
46
+ linear_start=0.0015,
47
+ linear_end=0.0205,
48
+ cosine_s=0.008,
49
+ original_elbo_weight=0.0,
50
+ v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
51
+ l_simple_weight=1.0,
52
+ parameterization="eps", # all assuming fixed variance schedules
53
+ use_positional_encodings=False,
54
+ ):
55
+ super().__init__()
56
+ self.device = device
57
+ self.parameterization = parameterization
58
+ self.use_positional_encodings = use_positional_encodings
59
+
60
+ self.v_posterior = v_posterior
61
+ self.original_elbo_weight = original_elbo_weight
62
+ self.l_simple_weight = l_simple_weight
63
+
64
+ self.register_schedule(
65
+ beta_schedule=beta_schedule,
66
+ timesteps=timesteps,
67
+ linear_start=linear_start,
68
+ linear_end=linear_end,
69
+ cosine_s=cosine_s,
70
+ )
71
+
72
+ def register_schedule(
73
+ self,
74
+ given_betas=None,
75
+ beta_schedule="linear",
76
+ timesteps=1000,
77
+ linear_start=1e-4,
78
+ linear_end=2e-2,
79
+ cosine_s=8e-3,
80
+ ):
81
+ betas = make_beta_schedule(
82
+ self.device,
83
+ beta_schedule,
84
+ timesteps,
85
+ linear_start=linear_start,
86
+ linear_end=linear_end,
87
+ cosine_s=cosine_s,
88
+ )
89
+ alphas = 1.0 - betas
90
+ alphas_cumprod = np.cumprod(alphas, axis=0)
91
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
92
+
93
+ (timesteps,) = betas.shape
94
+ self.num_timesteps = int(timesteps)
95
+ self.linear_start = linear_start
96
+ self.linear_end = linear_end
97
+ assert (
98
+ alphas_cumprod.shape[0] == self.num_timesteps
99
+ ), "alphas have to be defined for each timestep"
100
+
101
+ to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
102
+
103
+ self.register_buffer("betas", to_torch(betas))
104
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
105
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
106
+
107
+ # calculations for diffusion q(x_t | x_{t-1}) and others
108
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
109
+ self.register_buffer(
110
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
111
+ )
112
+ self.register_buffer(
113
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
114
+ )
115
+ self.register_buffer(
116
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
117
+ )
118
+ self.register_buffer(
119
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
120
+ )
121
+
122
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
123
+ posterior_variance = (1 - self.v_posterior) * betas * (
124
+ 1.0 - alphas_cumprod_prev
125
+ ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
126
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
127
+ self.register_buffer("posterior_variance", to_torch(posterior_variance))
128
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
129
+ self.register_buffer(
130
+ "posterior_log_variance_clipped",
131
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
132
+ )
133
+ self.register_buffer(
134
+ "posterior_mean_coef1",
135
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
136
+ )
137
+ self.register_buffer(
138
+ "posterior_mean_coef2",
139
+ to_torch(
140
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
141
+ ),
142
+ )
143
+
144
+ if self.parameterization == "eps":
145
+ lvlb_weights = self.betas**2 / (
146
+ 2
147
+ * self.posterior_variance
148
+ * to_torch(alphas)
149
+ * (1 - self.alphas_cumprod)
150
+ )
151
+ elif self.parameterization == "x0":
152
+ lvlb_weights = (
153
+ 0.5
154
+ * np.sqrt(torch.Tensor(alphas_cumprod))
155
+ / (2.0 * 1 - torch.Tensor(alphas_cumprod))
156
+ )
157
+ else:
158
+ raise NotImplementedError("mu not supported")
159
+ # TODO how to choose this term
160
+ lvlb_weights[0] = lvlb_weights[1]
161
+ self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
162
+ assert not torch.isnan(self.lvlb_weights).all()
163
+
164
+
165
+ class LatentDiffusion(DDPM):
166
+ def __init__(
167
+ self,
168
+ diffusion_model,
169
+ device,
170
+ cond_stage_key="image",
171
+ cond_stage_trainable=False,
172
+ concat_mode=True,
173
+ scale_factor=1.0,
174
+ scale_by_std=False,
175
+ *args,
176
+ **kwargs,
177
+ ):
178
+ self.num_timesteps_cond = 1
179
+ self.scale_by_std = scale_by_std
180
+ super().__init__(device, *args, **kwargs)
181
+ self.diffusion_model = diffusion_model
182
+ self.concat_mode = concat_mode
183
+ self.cond_stage_trainable = cond_stage_trainable
184
+ self.cond_stage_key = cond_stage_key
185
+ self.num_downs = 2
186
+ self.scale_factor = scale_factor
187
+
188
+ def make_cond_schedule(
189
+ self,
190
+ ):
191
+ self.cond_ids = torch.full(
192
+ size=(self.num_timesteps,),
193
+ fill_value=self.num_timesteps - 1,
194
+ dtype=torch.long,
195
+ )
196
+ ids = torch.round(
197
+ torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
198
+ ).long()
199
+ self.cond_ids[: self.num_timesteps_cond] = ids
200
+
201
+ def register_schedule(
202
+ self,
203
+ given_betas=None,
204
+ beta_schedule="linear",
205
+ timesteps=1000,
206
+ linear_start=1e-4,
207
+ linear_end=2e-2,
208
+ cosine_s=8e-3,
209
+ ):
210
+ super().register_schedule(
211
+ given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
212
+ )
213
+
214
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
215
+ if self.shorten_cond_schedule:
216
+ self.make_cond_schedule()
217
+
218
+ def apply_model(self, x_noisy, t, cond):
219
+ # x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128
220
+ t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False)
221
+ x_recon = self.diffusion_model(x_noisy, t_emb, cond)
222
+ return x_recon
223
+
224
+
225
+ class LDM(InpaintModel):
226
+ pad_mod = 32
227
+
228
+ def __init__(self, device, fp16: bool = True, **kwargs):
229
+ self.fp16 = fp16
230
+ super().__init__(device)
231
+ self.device = device
232
+
233
+ def init_model(self, device, **kwargs):
234
+ self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
235
+ self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
236
+ self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
237
+ if self.fp16 and "cuda" in str(device):
238
+ self.diffusion_model = self.diffusion_model.half()
239
+ self.cond_stage_model_decode = self.cond_stage_model_decode.half()
240
+ self.cond_stage_model_encode = self.cond_stage_model_encode.half()
241
+
242
+ self.model = LatentDiffusion(self.diffusion_model, device)
243
+
244
+ @staticmethod
245
+ def is_downloaded() -> bool:
246
+ model_paths = [
247
+ get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
248
+ get_cache_path_by_url(LDM_DECODE_MODEL_URL),
249
+ get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
250
+ ]
251
+ return all([os.path.exists(it) for it in model_paths])
252
+
253
+ @torch.cuda.amp.autocast()
254
+ def forward(self, image, mask, config: Config):
255
+ """
256
+ image: [H, W, C] RGB
257
+ mask: [H, W, 1]
258
+ return: BGR IMAGE
259
+ """
260
+ # image [1,3,512,512] float32
261
+ # mask: [1,1,512,512] float32
262
+ # masked_image: [1,3,512,512] float32
263
+ if config.ldm_sampler == LDMSampler.ddim:
264
+ sampler = DDIMSampler(self.model)
265
+ elif config.ldm_sampler == LDMSampler.plms:
266
+ sampler = PLMSSampler(self.model)
267
+ else:
268
+ raise ValueError()
269
+
270
+ steps = config.ldm_steps
271
+ image = norm_img(image)
272
+ mask = norm_img(mask)
273
+
274
+ mask[mask < 0.5] = 0
275
+ mask[mask >= 0.5] = 1
276
+
277
+ image = torch.from_numpy(image).unsqueeze(0).to(self.device)
278
+ mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
279
+ masked_image = (1 - mask) * image
280
+
281
+ mask = self._norm(mask)
282
+ masked_image = self._norm(masked_image)
283
+
284
+ c = self.cond_stage_model_encode(masked_image)
285
+ torch.cuda.empty_cache()
286
+
287
+ cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128
288
+ c = torch.cat((c, cc), dim=1) # 1,4,128,128
289
+
290
+ shape = (c.shape[1] - 1,) + c.shape[2:]
291
+ samples_ddim = sampler.sample(
292
+ steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
293
+ )
294
+ torch.cuda.empty_cache()
295
+ x_samples_ddim = self.cond_stage_model_decode(
296
+ samples_ddim
297
+ ) # samples_ddim: 1, 3, 128, 128 float32
298
+ torch.cuda.empty_cache()
299
+
300
+ # image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
301
+ # mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
302
+ inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
303
+
304
+ # inpainted = (1 - mask) * image + mask * predicted_image
305
+ inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
306
+ inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
307
+ return inpainted_image
308
+
309
+ def _norm(self, tensor):
310
+ return tensor * 2.0 - 1.0
lama_cleaner/model/manga.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ from loguru import logger
9
+
10
+ from lama_cleaner.helper import get_cache_path_by_url, load_jit_model
11
+ from lama_cleaner.model.base import InpaintModel
12
+ from lama_cleaner.schema import Config
13
+
14
+ # def norm(np_img):
15
+ # return np_img / 255 * 2 - 1.0
16
+ #
17
+ #
18
+ # @torch.no_grad()
19
+ # def run():
20
+ # name = 'manga_1080x740.jpg'
21
+ # img_p = f'/Users/qing/code/github/MangaInpainting/examples/test/imgs/{name}'
22
+ # mask_p = f'/Users/qing/code/github/MangaInpainting/examples/test/masks/mask_{name}'
23
+ # erika_model = torch.jit.load('erika.jit')
24
+ # manga_inpaintor_model = torch.jit.load('manga_inpaintor.jit')
25
+ #
26
+ # img = cv2.imread(img_p)
27
+ # gray_img = cv2.imread(img_p, cv2.IMREAD_GRAYSCALE)
28
+ # mask = cv2.imread(mask_p, cv2.IMREAD_GRAYSCALE)
29
+ #
30
+ # kernel = np.ones((9, 9), dtype=np.uint8)
31
+ # mask = cv2.dilate(mask, kernel, 2)
32
+ # # cv2.imwrite("mask.jpg", mask)
33
+ # # cv2.imshow('dilated_mask', cv2.hconcat([mask, dilated_mask]))
34
+ # # cv2.waitKey(0)
35
+ # # exit()
36
+ #
37
+ # # img = pad(img)
38
+ # gray_img = pad(gray_img).astype(np.float32)
39
+ # mask = pad(mask)
40
+ #
41
+ # # pad_mod = 16
42
+ # import time
43
+ # start = time.time()
44
+ # y = erika_model(torch.from_numpy(gray_img[np.newaxis, np.newaxis, :, :]))
45
+ # y = torch.clamp(y, 0, 255)
46
+ # lines = y.cpu().numpy()
47
+ # print(f"erika_model time: {time.time() - start}")
48
+ #
49
+ # cv2.imwrite('lines.png', lines[0][0])
50
+ #
51
+ # start = time.time()
52
+ # masks = torch.from_numpy(mask[np.newaxis, np.newaxis, :, :])
53
+ # masks = torch.where(masks > 0.5, torch.tensor(1.0), torch.tensor(0.0))
54
+ # noise = torch.randn_like(masks)
55
+ #
56
+ # images = torch.from_numpy(norm(gray_img)[np.newaxis, np.newaxis, :, :])
57
+ # lines = torch.from_numpy(norm(lines))
58
+ #
59
+ # outputs = manga_inpaintor_model(images, lines, masks, noise)
60
+ # print(f"manga_inpaintor_model time: {time.time() - start}")
61
+ #
62
+ # outputs_merged = (outputs * masks) + (images * (1 - masks))
63
+ # outputs_merged = outputs_merged * 127.5 + 127.5
64
+ # outputs_merged = outputs_merged.permute(0, 2, 3, 1)[0].detach().cpu().numpy().astype(np.uint8)
65
+ # cv2.imwrite(f'output_{name}', outputs_merged)
66
+
67
+
68
+ MANGA_INPAINTOR_MODEL_URL = os.environ.get(
69
+ "MANGA_INPAINTOR_MODEL_URL",
70
+ "https://github.com/Sanster/models/releases/download/manga/manga_inpaintor.jit"
71
+ )
72
+ MANGA_LINE_MODEL_URL = os.environ.get(
73
+ "MANGA_LINE_MODEL_URL",
74
+ "https://github.com/Sanster/models/releases/download/manga/erika.jit"
75
+ )
76
+
77
+
78
+ class Manga(InpaintModel):
79
+ pad_mod = 16
80
+
81
+ def init_model(self, device, **kwargs):
82
+ self.inpaintor_model = load_jit_model(MANGA_INPAINTOR_MODEL_URL, device)
83
+ self.line_model = load_jit_model(MANGA_LINE_MODEL_URL, device)
84
+ self.seed = 42
85
+
86
+ @staticmethod
87
+ def is_downloaded() -> bool:
88
+ model_paths = [
89
+ get_cache_path_by_url(MANGA_INPAINTOR_MODEL_URL),
90
+ get_cache_path_by_url(MANGA_LINE_MODEL_URL),
91
+ ]
92
+ return all([os.path.exists(it) for it in model_paths])
93
+
94
+ def forward(self, image, mask, config: Config):
95
+ """
96
+ image: [H, W, C] RGB
97
+ mask: [H, W, 1]
98
+ return: BGR IMAGE
99
+ """
100
+ seed = self.seed
101
+ random.seed(seed)
102
+ np.random.seed(seed)
103
+ torch.manual_seed(seed)
104
+ torch.cuda.manual_seed_all(seed)
105
+
106
+ gray_img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
107
+ gray_img = torch.from_numpy(gray_img[np.newaxis, np.newaxis, :, :].astype(np.float32)).to(self.device)
108
+ start = time.time()
109
+ lines = self.line_model(gray_img)
110
+ torch.cuda.empty_cache()
111
+ lines = torch.clamp(lines, 0, 255)
112
+ logger.info(f"erika_model time: {time.time() - start}")
113
+
114
+ mask = torch.from_numpy(mask[np.newaxis, :, :, :]).to(self.device)
115
+ mask = mask.permute(0, 3, 1, 2)
116
+ mask = torch.where(mask > 0.5, 1.0, 0.0)
117
+ noise = torch.randn_like(mask)
118
+ ones = torch.ones_like(mask)
119
+
120
+ gray_img = gray_img / 255 * 2 - 1.0
121
+ lines = lines / 255 * 2 - 1.0
122
+
123
+ start = time.time()
124
+ inpainted_image = self.inpaintor_model(gray_img, lines, mask, noise, ones)
125
+ logger.info(f"image_inpaintor_model time: {time.time() - start}")
126
+
127
+ cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
128
+ cur_res = (cur_res * 127.5 + 127.5).astype(np.uint8)
129
+ cur_res = cv2.cvtColor(cur_res, cv2.COLOR_GRAY2BGR)
130
+ return cur_res
lama_cleaner/model/mat.py ADDED
@@ -0,0 +1,1444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.utils.checkpoint as checkpoint
10
+
11
+ from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img
12
+ from lama_cleaner.model.base import InpaintModel
13
+ from lama_cleaner.model.utils import setup_filter, Conv2dLayer, FullyConnectedLayer, conv2d_resample, bias_act, \
14
+ upsample2d, activation_funcs, MinibatchStdLayer, to_2tuple, normalize_2nd_moment
15
+ from lama_cleaner.schema import Config
16
+
17
+
18
+ class ModulatedConv2d(nn.Module):
19
+ def __init__(self,
20
+ in_channels, # Number of input channels.
21
+ out_channels, # Number of output channels.
22
+ kernel_size, # Width and height of the convolution kernel.
23
+ style_dim, # dimension of the style code
24
+ demodulate=True, # perfrom demodulation
25
+ up=1, # Integer upsampling factor.
26
+ down=1, # Integer downsampling factor.
27
+ resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
28
+ conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
29
+ ):
30
+ super().__init__()
31
+ self.demodulate = demodulate
32
+
33
+ self.weight = torch.nn.Parameter(torch.randn([1, out_channels, in_channels, kernel_size, kernel_size]))
34
+ self.out_channels = out_channels
35
+ self.kernel_size = kernel_size
36
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
37
+ self.padding = self.kernel_size // 2
38
+ self.up = up
39
+ self.down = down
40
+ self.register_buffer('resample_filter', setup_filter(resample_filter))
41
+ self.conv_clamp = conv_clamp
42
+
43
+ self.affine = FullyConnectedLayer(style_dim, in_channels, bias_init=1)
44
+
45
+ def forward(self, x, style):
46
+ batch, in_channels, height, width = x.shape
47
+ style = self.affine(style).view(batch, 1, in_channels, 1, 1)
48
+ weight = self.weight * self.weight_gain * style
49
+
50
+ if self.demodulate:
51
+ decoefs = (weight.pow(2).sum(dim=[2, 3, 4]) + 1e-8).rsqrt()
52
+ weight = weight * decoefs.view(batch, self.out_channels, 1, 1, 1)
53
+
54
+ weight = weight.view(batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size)
55
+ x = x.view(1, batch * in_channels, height, width)
56
+ x = conv2d_resample(x=x, w=weight, f=self.resample_filter, up=self.up, down=self.down,
57
+ padding=self.padding, groups=batch)
58
+ out = x.view(batch, self.out_channels, *x.shape[2:])
59
+
60
+ return out
61
+
62
+
63
+ class StyleConv(torch.nn.Module):
64
+ def __init__(self,
65
+ in_channels, # Number of input channels.
66
+ out_channels, # Number of output channels.
67
+ style_dim, # Intermediate latent (W) dimensionality.
68
+ resolution, # Resolution of this layer.
69
+ kernel_size=3, # Convolution kernel size.
70
+ up=1, # Integer upsampling factor.
71
+ use_noise=False, # Enable noise input?
72
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
73
+ resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
74
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
75
+ demodulate=True, # perform demodulation
76
+ ):
77
+ super().__init__()
78
+
79
+ self.conv = ModulatedConv2d(in_channels=in_channels,
80
+ out_channels=out_channels,
81
+ kernel_size=kernel_size,
82
+ style_dim=style_dim,
83
+ demodulate=demodulate,
84
+ up=up,
85
+ resample_filter=resample_filter,
86
+ conv_clamp=conv_clamp)
87
+
88
+ self.use_noise = use_noise
89
+ self.resolution = resolution
90
+ if use_noise:
91
+ self.register_buffer('noise_const', torch.randn([resolution, resolution]))
92
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
93
+
94
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
95
+ self.activation = activation
96
+ self.act_gain = activation_funcs[activation].def_gain
97
+ self.conv_clamp = conv_clamp
98
+
99
+ def forward(self, x, style, noise_mode='random', gain=1):
100
+ x = self.conv(x, style)
101
+
102
+ assert noise_mode in ['random', 'const', 'none']
103
+
104
+ if self.use_noise:
105
+ if noise_mode == 'random':
106
+ xh, xw = x.size()[-2:]
107
+ noise = torch.randn([x.shape[0], 1, xh, xw], device=x.device) \
108
+ * self.noise_strength
109
+ if noise_mode == 'const':
110
+ noise = self.noise_const * self.noise_strength
111
+ x = x + noise
112
+
113
+ act_gain = self.act_gain * gain
114
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
115
+ out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp)
116
+
117
+ return out
118
+
119
+
120
+ class ToRGB(torch.nn.Module):
121
+ def __init__(self,
122
+ in_channels,
123
+ out_channels,
124
+ style_dim,
125
+ kernel_size=1,
126
+ resample_filter=[1, 3, 3, 1],
127
+ conv_clamp=None,
128
+ demodulate=False):
129
+ super().__init__()
130
+
131
+ self.conv = ModulatedConv2d(in_channels=in_channels,
132
+ out_channels=out_channels,
133
+ kernel_size=kernel_size,
134
+ style_dim=style_dim,
135
+ demodulate=demodulate,
136
+ resample_filter=resample_filter,
137
+ conv_clamp=conv_clamp)
138
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
139
+ self.register_buffer('resample_filter', setup_filter(resample_filter))
140
+ self.conv_clamp = conv_clamp
141
+
142
+ def forward(self, x, style, skip=None):
143
+ x = self.conv(x, style)
144
+ out = bias_act(x, self.bias, clamp=self.conv_clamp)
145
+
146
+ if skip is not None:
147
+ if skip.shape != out.shape:
148
+ skip = upsample2d(skip, self.resample_filter)
149
+ out = out + skip
150
+
151
+ return out
152
+
153
+
154
+ def get_style_code(a, b):
155
+ return torch.cat([a, b], dim=1)
156
+
157
+
158
+ class DecBlockFirst(nn.Module):
159
+ def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
160
+ super().__init__()
161
+ self.fc = FullyConnectedLayer(in_features=in_channels * 2,
162
+ out_features=in_channels * 4 ** 2,
163
+ activation=activation)
164
+ self.conv = StyleConv(in_channels=in_channels,
165
+ out_channels=out_channels,
166
+ style_dim=style_dim,
167
+ resolution=4,
168
+ kernel_size=3,
169
+ use_noise=use_noise,
170
+ activation=activation,
171
+ demodulate=demodulate,
172
+ )
173
+ self.toRGB = ToRGB(in_channels=out_channels,
174
+ out_channels=img_channels,
175
+ style_dim=style_dim,
176
+ kernel_size=1,
177
+ demodulate=False,
178
+ )
179
+
180
+ def forward(self, x, ws, gs, E_features, noise_mode='random'):
181
+ x = self.fc(x).view(x.shape[0], -1, 4, 4)
182
+ x = x + E_features[2]
183
+ style = get_style_code(ws[:, 0], gs)
184
+ x = self.conv(x, style, noise_mode=noise_mode)
185
+ style = get_style_code(ws[:, 1], gs)
186
+ img = self.toRGB(x, style, skip=None)
187
+
188
+ return x, img
189
+
190
+
191
+ class DecBlockFirstV2(nn.Module):
192
+ def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
193
+ super().__init__()
194
+ self.conv0 = Conv2dLayer(in_channels=in_channels,
195
+ out_channels=in_channels,
196
+ kernel_size=3,
197
+ activation=activation,
198
+ )
199
+ self.conv1 = StyleConv(in_channels=in_channels,
200
+ out_channels=out_channels,
201
+ style_dim=style_dim,
202
+ resolution=4,
203
+ kernel_size=3,
204
+ use_noise=use_noise,
205
+ activation=activation,
206
+ demodulate=demodulate,
207
+ )
208
+ self.toRGB = ToRGB(in_channels=out_channels,
209
+ out_channels=img_channels,
210
+ style_dim=style_dim,
211
+ kernel_size=1,
212
+ demodulate=False,
213
+ )
214
+
215
+ def forward(self, x, ws, gs, E_features, noise_mode='random'):
216
+ # x = self.fc(x).view(x.shape[0], -1, 4, 4)
217
+ x = self.conv0(x)
218
+ x = x + E_features[2]
219
+ style = get_style_code(ws[:, 0], gs)
220
+ x = self.conv1(x, style, noise_mode=noise_mode)
221
+ style = get_style_code(ws[:, 1], gs)
222
+ img = self.toRGB(x, style, skip=None)
223
+
224
+ return x, img
225
+
226
+
227
+ class DecBlock(nn.Module):
228
+ def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate,
229
+ img_channels): # res = 2, ..., resolution_log2
230
+ super().__init__()
231
+ self.res = res
232
+
233
+ self.conv0 = StyleConv(in_channels=in_channels,
234
+ out_channels=out_channels,
235
+ style_dim=style_dim,
236
+ resolution=2 ** res,
237
+ kernel_size=3,
238
+ up=2,
239
+ use_noise=use_noise,
240
+ activation=activation,
241
+ demodulate=demodulate,
242
+ )
243
+ self.conv1 = StyleConv(in_channels=out_channels,
244
+ out_channels=out_channels,
245
+ style_dim=style_dim,
246
+ resolution=2 ** res,
247
+ kernel_size=3,
248
+ use_noise=use_noise,
249
+ activation=activation,
250
+ demodulate=demodulate,
251
+ )
252
+ self.toRGB = ToRGB(in_channels=out_channels,
253
+ out_channels=img_channels,
254
+ style_dim=style_dim,
255
+ kernel_size=1,
256
+ demodulate=False,
257
+ )
258
+
259
+ def forward(self, x, img, ws, gs, E_features, noise_mode='random'):
260
+ style = get_style_code(ws[:, self.res * 2 - 5], gs)
261
+ x = self.conv0(x, style, noise_mode=noise_mode)
262
+ x = x + E_features[self.res]
263
+ style = get_style_code(ws[:, self.res * 2 - 4], gs)
264
+ x = self.conv1(x, style, noise_mode=noise_mode)
265
+ style = get_style_code(ws[:, self.res * 2 - 3], gs)
266
+ img = self.toRGB(x, style, skip=img)
267
+
268
+ return x, img
269
+
270
+
271
+ class MappingNet(torch.nn.Module):
272
+ def __init__(self,
273
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
274
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
275
+ w_dim, # Intermediate latent (W) dimensionality.
276
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
277
+ num_layers=8, # Number of mapping layers.
278
+ embed_features=None, # Label embedding dimensionality, None = same as w_dim.
279
+ layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
280
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
281
+ lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
282
+ w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
283
+ ):
284
+ super().__init__()
285
+ self.z_dim = z_dim
286
+ self.c_dim = c_dim
287
+ self.w_dim = w_dim
288
+ self.num_ws = num_ws
289
+ self.num_layers = num_layers
290
+ self.w_avg_beta = w_avg_beta
291
+
292
+ if embed_features is None:
293
+ embed_features = w_dim
294
+ if c_dim == 0:
295
+ embed_features = 0
296
+ if layer_features is None:
297
+ layer_features = w_dim
298
+ features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
299
+
300
+ if c_dim > 0:
301
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
302
+ for idx in range(num_layers):
303
+ in_features = features_list[idx]
304
+ out_features = features_list[idx + 1]
305
+ layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
306
+ setattr(self, f'fc{idx}', layer)
307
+
308
+ if num_ws is not None and w_avg_beta is not None:
309
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
310
+
311
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
312
+ # Embed, normalize, and concat inputs.
313
+ x = None
314
+ with torch.autograd.profiler.record_function('input'):
315
+ if self.z_dim > 0:
316
+ x = normalize_2nd_moment(z.to(torch.float32))
317
+ if self.c_dim > 0:
318
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
319
+ x = torch.cat([x, y], dim=1) if x is not None else y
320
+
321
+ # Main layers.
322
+ for idx in range(self.num_layers):
323
+ layer = getattr(self, f'fc{idx}')
324
+ x = layer(x)
325
+
326
+ # Update moving average of W.
327
+ if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
328
+ with torch.autograd.profiler.record_function('update_w_avg'):
329
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
330
+
331
+ # Broadcast.
332
+ if self.num_ws is not None:
333
+ with torch.autograd.profiler.record_function('broadcast'):
334
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
335
+
336
+ # Apply truncation.
337
+ if truncation_psi != 1:
338
+ with torch.autograd.profiler.record_function('truncate'):
339
+ assert self.w_avg_beta is not None
340
+ if self.num_ws is None or truncation_cutoff is None:
341
+ x = self.w_avg.lerp(x, truncation_psi)
342
+ else:
343
+ x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
344
+
345
+ return x
346
+
347
+
348
+ class DisFromRGB(nn.Module):
349
+ def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
350
+ super().__init__()
351
+ self.conv = Conv2dLayer(in_channels=in_channels,
352
+ out_channels=out_channels,
353
+ kernel_size=1,
354
+ activation=activation,
355
+ )
356
+
357
+ def forward(self, x):
358
+ return self.conv(x)
359
+
360
+
361
+ class DisBlock(nn.Module):
362
+ def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
363
+ super().__init__()
364
+ self.conv0 = Conv2dLayer(in_channels=in_channels,
365
+ out_channels=in_channels,
366
+ kernel_size=3,
367
+ activation=activation,
368
+ )
369
+ self.conv1 = Conv2dLayer(in_channels=in_channels,
370
+ out_channels=out_channels,
371
+ kernel_size=3,
372
+ down=2,
373
+ activation=activation,
374
+ )
375
+ self.skip = Conv2dLayer(in_channels=in_channels,
376
+ out_channels=out_channels,
377
+ kernel_size=1,
378
+ down=2,
379
+ bias=False,
380
+ )
381
+
382
+ def forward(self, x):
383
+ skip = self.skip(x, gain=np.sqrt(0.5))
384
+ x = self.conv0(x)
385
+ x = self.conv1(x, gain=np.sqrt(0.5))
386
+ out = skip + x
387
+
388
+ return out
389
+
390
+
391
+ class Discriminator(torch.nn.Module):
392
+ def __init__(self,
393
+ c_dim, # Conditioning label (C) dimensionality.
394
+ img_resolution, # Input resolution.
395
+ img_channels, # Number of input color channels.
396
+ channel_base=32768, # Overall multiplier for the number of channels.
397
+ channel_max=512, # Maximum number of channels in any layer.
398
+ channel_decay=1,
399
+ cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
400
+ activation='lrelu',
401
+ mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
402
+ mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
403
+ ):
404
+ super().__init__()
405
+ self.c_dim = c_dim
406
+ self.img_resolution = img_resolution
407
+ self.img_channels = img_channels
408
+
409
+ resolution_log2 = int(np.log2(img_resolution))
410
+ assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
411
+ self.resolution_log2 = resolution_log2
412
+
413
+ def nf(stage):
414
+ return np.clip(int(channel_base / 2 ** (stage * channel_decay)), 1, channel_max)
415
+
416
+ if cmap_dim == None:
417
+ cmap_dim = nf(2)
418
+ if c_dim == 0:
419
+ cmap_dim = 0
420
+ self.cmap_dim = cmap_dim
421
+
422
+ if c_dim > 0:
423
+ self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None)
424
+
425
+ Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)]
426
+ for res in range(resolution_log2, 2, -1):
427
+ Dis.append(DisBlock(nf(res), nf(res - 1), activation))
428
+
429
+ if mbstd_num_channels > 0:
430
+ Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
431
+ Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation))
432
+ self.Dis = nn.Sequential(*Dis)
433
+
434
+ self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
435
+ self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
436
+
437
+ def forward(self, images_in, masks_in, c):
438
+ x = torch.cat([masks_in - 0.5, images_in], dim=1)
439
+ x = self.Dis(x)
440
+ x = self.fc1(self.fc0(x.flatten(start_dim=1)))
441
+
442
+ if self.c_dim > 0:
443
+ cmap = self.mapping(None, c)
444
+
445
+ if self.cmap_dim > 0:
446
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
447
+
448
+ return x
449
+
450
+
451
+ def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512):
452
+ NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512}
453
+ return NF[2 ** stage]
454
+
455
+
456
+ class Mlp(nn.Module):
457
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
458
+ super().__init__()
459
+ out_features = out_features or in_features
460
+ hidden_features = hidden_features or in_features
461
+ self.fc1 = FullyConnectedLayer(in_features=in_features, out_features=hidden_features, activation='lrelu')
462
+ self.fc2 = FullyConnectedLayer(in_features=hidden_features, out_features=out_features)
463
+
464
+ def forward(self, x):
465
+ x = self.fc1(x)
466
+ x = self.fc2(x)
467
+ return x
468
+
469
+
470
+ def window_partition(x, window_size):
471
+ """
472
+ Args:
473
+ x: (B, H, W, C)
474
+ window_size (int): window size
475
+ Returns:
476
+ windows: (num_windows*B, window_size, window_size, C)
477
+ """
478
+ B, H, W, C = x.shape
479
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
480
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
481
+ return windows
482
+
483
+
484
+ def window_reverse(windows, window_size: int, H: int, W: int):
485
+ """
486
+ Args:
487
+ windows: (num_windows*B, window_size, window_size, C)
488
+ window_size (int): Window size
489
+ H (int): Height of image
490
+ W (int): Width of image
491
+ Returns:
492
+ x: (B, H, W, C)
493
+ """
494
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
495
+ # B = windows.shape[0] / (H * W / window_size / window_size)
496
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
497
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
498
+ return x
499
+
500
+
501
+ class Conv2dLayerPartial(nn.Module):
502
+ def __init__(self,
503
+ in_channels, # Number of input channels.
504
+ out_channels, # Number of output channels.
505
+ kernel_size, # Width and height of the convolution kernel.
506
+ bias=True, # Apply additive bias before the activation function?
507
+ activation='linear', # Activation function: 'relu', 'lrelu', etc.
508
+ up=1, # Integer upsampling factor.
509
+ down=1, # Integer downsampling factor.
510
+ resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
511
+ conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
512
+ trainable=True, # Update the weights of this layer during training?
513
+ ):
514
+ super().__init__()
515
+ self.conv = Conv2dLayer(in_channels, out_channels, kernel_size, bias, activation, up, down, resample_filter,
516
+ conv_clamp, trainable)
517
+
518
+ self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size)
519
+ self.slide_winsize = kernel_size ** 2
520
+ self.stride = down
521
+ self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0
522
+
523
+ def forward(self, x, mask=None):
524
+ if mask is not None:
525
+ with torch.no_grad():
526
+ if self.weight_maskUpdater.type() != x.type():
527
+ self.weight_maskUpdater = self.weight_maskUpdater.to(x)
528
+ update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride,
529
+ padding=self.padding)
530
+ mask_ratio = self.slide_winsize / (update_mask + 1e-8)
531
+ update_mask = torch.clamp(update_mask, 0, 1) # 0 or 1
532
+ mask_ratio = torch.mul(mask_ratio, update_mask)
533
+ x = self.conv(x)
534
+ x = torch.mul(x, mask_ratio)
535
+ return x, update_mask
536
+ else:
537
+ x = self.conv(x)
538
+ return x, None
539
+
540
+
541
+ class WindowAttention(nn.Module):
542
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
543
+ It supports both of shifted and non-shifted window.
544
+ Args:
545
+ dim (int): Number of input channels.
546
+ window_size (tuple[int]): The height and width of the window.
547
+ num_heads (int): Number of attention heads.
548
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
549
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
550
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
551
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
552
+ """
553
+
554
+ def __init__(self, dim, window_size, num_heads, down_ratio=1, qkv_bias=True, qk_scale=None, attn_drop=0.,
555
+ proj_drop=0.):
556
+
557
+ super().__init__()
558
+ self.dim = dim
559
+ self.window_size = window_size # Wh, Ww
560
+ self.num_heads = num_heads
561
+ head_dim = dim // num_heads
562
+ self.scale = qk_scale or head_dim ** -0.5
563
+
564
+ self.q = FullyConnectedLayer(in_features=dim, out_features=dim)
565
+ self.k = FullyConnectedLayer(in_features=dim, out_features=dim)
566
+ self.v = FullyConnectedLayer(in_features=dim, out_features=dim)
567
+ self.proj = FullyConnectedLayer(in_features=dim, out_features=dim)
568
+
569
+ self.softmax = nn.Softmax(dim=-1)
570
+
571
+ def forward(self, x, mask_windows=None, mask=None):
572
+ """
573
+ Args:
574
+ x: input features with shape of (num_windows*B, N, C)
575
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
576
+ """
577
+ B_, N, C = x.shape
578
+ norm_x = F.normalize(x, p=2.0, dim=-1)
579
+ q = self.q(norm_x).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
580
+ k = self.k(norm_x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 3, 1)
581
+ v = self.v(x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
582
+
583
+ attn = (q @ k) * self.scale
584
+
585
+ if mask is not None:
586
+ nW = mask.shape[0]
587
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
588
+ attn = attn.view(-1, self.num_heads, N, N)
589
+
590
+ if mask_windows is not None:
591
+ attn_mask_windows = mask_windows.squeeze(-1).unsqueeze(1).unsqueeze(1)
592
+ attn = attn + attn_mask_windows.masked_fill(attn_mask_windows == 0, float(-100.0)).masked_fill(
593
+ attn_mask_windows == 1, float(0.0))
594
+ with torch.no_grad():
595
+ mask_windows = torch.clamp(torch.sum(mask_windows, dim=1, keepdim=True), 0, 1).repeat(1, N, 1)
596
+
597
+ attn = self.softmax(attn)
598
+
599
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
600
+ x = self.proj(x)
601
+ return x, mask_windows
602
+
603
+
604
+ class SwinTransformerBlock(nn.Module):
605
+ r""" Swin Transformer Block.
606
+ Args:
607
+ dim (int): Number of input channels.
608
+ input_resolution (tuple[int]): Input resulotion.
609
+ num_heads (int): Number of attention heads.
610
+ window_size (int): Window size.
611
+ shift_size (int): Shift size for SW-MSA.
612
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
613
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
614
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
615
+ drop (float, optional): Dropout rate. Default: 0.0
616
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
617
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
618
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
619
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
620
+ """
621
+
622
+ def __init__(self, dim, input_resolution, num_heads, down_ratio=1, window_size=7, shift_size=0,
623
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
624
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
625
+ super().__init__()
626
+ self.dim = dim
627
+ self.input_resolution = input_resolution
628
+ self.num_heads = num_heads
629
+ self.window_size = window_size
630
+ self.shift_size = shift_size
631
+ self.mlp_ratio = mlp_ratio
632
+ if min(self.input_resolution) <= self.window_size:
633
+ # if window size is larger than input resolution, we don't partition windows
634
+ self.shift_size = 0
635
+ self.window_size = min(self.input_resolution)
636
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
637
+
638
+ if self.shift_size > 0:
639
+ down_ratio = 1
640
+ self.attn = WindowAttention(dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
641
+ down_ratio=down_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
642
+ proj_drop=drop)
643
+
644
+ self.fuse = FullyConnectedLayer(in_features=dim * 2, out_features=dim, activation='lrelu')
645
+
646
+ mlp_hidden_dim = int(dim * mlp_ratio)
647
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
648
+
649
+ if self.shift_size > 0:
650
+ attn_mask = self.calculate_mask(self.input_resolution)
651
+ else:
652
+ attn_mask = None
653
+
654
+ self.register_buffer("attn_mask", attn_mask)
655
+
656
+ def calculate_mask(self, x_size):
657
+ # calculate attention mask for SW-MSA
658
+ H, W = x_size
659
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
660
+ h_slices = (slice(0, -self.window_size),
661
+ slice(-self.window_size, -self.shift_size),
662
+ slice(-self.shift_size, None))
663
+ w_slices = (slice(0, -self.window_size),
664
+ slice(-self.window_size, -self.shift_size),
665
+ slice(-self.shift_size, None))
666
+ cnt = 0
667
+ for h in h_slices:
668
+ for w in w_slices:
669
+ img_mask[:, h, w, :] = cnt
670
+ cnt += 1
671
+
672
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
673
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
674
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
675
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
676
+
677
+ return attn_mask
678
+
679
+ def forward(self, x, x_size, mask=None):
680
+ # H, W = self.input_resolution
681
+ H, W = x_size
682
+ B, L, C = x.shape
683
+ # assert L == H * W, "input feature has wrong size"
684
+
685
+ shortcut = x
686
+ x = x.view(B, H, W, C)
687
+ if mask is not None:
688
+ mask = mask.view(B, H, W, 1)
689
+
690
+ # cyclic shift
691
+ if self.shift_size > 0:
692
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
693
+ if mask is not None:
694
+ shifted_mask = torch.roll(mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
695
+ else:
696
+ shifted_x = x
697
+ if mask is not None:
698
+ shifted_mask = mask
699
+
700
+ # partition windows
701
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
702
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
703
+ if mask is not None:
704
+ mask_windows = window_partition(shifted_mask, self.window_size)
705
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size, 1)
706
+ else:
707
+ mask_windows = None
708
+
709
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
710
+ if self.input_resolution == x_size:
711
+ attn_windows, mask_windows = self.attn(x_windows, mask_windows,
712
+ mask=self.attn_mask) # nW*B, window_size*window_size, C
713
+ else:
714
+ attn_windows, mask_windows = self.attn(x_windows, mask_windows, mask=self.calculate_mask(x_size).to(
715
+ x.device)) # nW*B, window_size*window_size, C
716
+
717
+ # merge windows
718
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
719
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
720
+ if mask is not None:
721
+ mask_windows = mask_windows.view(-1, self.window_size, self.window_size, 1)
722
+ shifted_mask = window_reverse(mask_windows, self.window_size, H, W)
723
+
724
+ # reverse cyclic shift
725
+ if self.shift_size > 0:
726
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
727
+ if mask is not None:
728
+ mask = torch.roll(shifted_mask, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
729
+ else:
730
+ x = shifted_x
731
+ if mask is not None:
732
+ mask = shifted_mask
733
+ x = x.view(B, H * W, C)
734
+ if mask is not None:
735
+ mask = mask.view(B, H * W, 1)
736
+
737
+ # FFN
738
+ x = self.fuse(torch.cat([shortcut, x], dim=-1))
739
+ x = self.mlp(x)
740
+
741
+ return x, mask
742
+
743
+
744
+ class PatchMerging(nn.Module):
745
+ def __init__(self, in_channels, out_channels, down=2):
746
+ super().__init__()
747
+ self.conv = Conv2dLayerPartial(in_channels=in_channels,
748
+ out_channels=out_channels,
749
+ kernel_size=3,
750
+ activation='lrelu',
751
+ down=down,
752
+ )
753
+ self.down = down
754
+
755
+ def forward(self, x, x_size, mask=None):
756
+ x = token2feature(x, x_size)
757
+ if mask is not None:
758
+ mask = token2feature(mask, x_size)
759
+ x, mask = self.conv(x, mask)
760
+ if self.down != 1:
761
+ ratio = 1 / self.down
762
+ x_size = (int(x_size[0] * ratio), int(x_size[1] * ratio))
763
+ x = feature2token(x)
764
+ if mask is not None:
765
+ mask = feature2token(mask)
766
+ return x, x_size, mask
767
+
768
+
769
+ class PatchUpsampling(nn.Module):
770
+ def __init__(self, in_channels, out_channels, up=2):
771
+ super().__init__()
772
+ self.conv = Conv2dLayerPartial(in_channels=in_channels,
773
+ out_channels=out_channels,
774
+ kernel_size=3,
775
+ activation='lrelu',
776
+ up=up,
777
+ )
778
+ self.up = up
779
+
780
+ def forward(self, x, x_size, mask=None):
781
+ x = token2feature(x, x_size)
782
+ if mask is not None:
783
+ mask = token2feature(mask, x_size)
784
+ x, mask = self.conv(x, mask)
785
+ if self.up != 1:
786
+ x_size = (int(x_size[0] * self.up), int(x_size[1] * self.up))
787
+ x = feature2token(x)
788
+ if mask is not None:
789
+ mask = feature2token(mask)
790
+ return x, x_size, mask
791
+
792
+
793
+ class BasicLayer(nn.Module):
794
+ """ A basic Swin Transformer layer for one stage.
795
+ Args:
796
+ dim (int): Number of input channels.
797
+ input_resolution (tuple[int]): Input resolution.
798
+ depth (int): Number of blocks.
799
+ num_heads (int): Number of attention heads.
800
+ window_size (int): Local window size.
801
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
802
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
803
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
804
+ drop (float, optional): Dropout rate. Default: 0.0
805
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
806
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
807
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
808
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
809
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
810
+ """
811
+
812
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size, down_ratio=1,
813
+ mlp_ratio=2., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
814
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
815
+
816
+ super().__init__()
817
+ self.dim = dim
818
+ self.input_resolution = input_resolution
819
+ self.depth = depth
820
+ self.use_checkpoint = use_checkpoint
821
+
822
+ # patch merging layer
823
+ if downsample is not None:
824
+ # self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
825
+ self.downsample = downsample
826
+ else:
827
+ self.downsample = None
828
+
829
+ # build blocks
830
+ self.blocks = nn.ModuleList([
831
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
832
+ num_heads=num_heads, down_ratio=down_ratio, window_size=window_size,
833
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
834
+ mlp_ratio=mlp_ratio,
835
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
836
+ drop=drop, attn_drop=attn_drop,
837
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
838
+ norm_layer=norm_layer)
839
+ for i in range(depth)])
840
+
841
+ self.conv = Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, activation='lrelu')
842
+
843
+ def forward(self, x, x_size, mask=None):
844
+ if self.downsample is not None:
845
+ x, x_size, mask = self.downsample(x, x_size, mask)
846
+ identity = x
847
+ for blk in self.blocks:
848
+ if self.use_checkpoint:
849
+ x, mask = checkpoint.checkpoint(blk, x, x_size, mask)
850
+ else:
851
+ x, mask = blk(x, x_size, mask)
852
+ if mask is not None:
853
+ mask = token2feature(mask, x_size)
854
+ x, mask = self.conv(token2feature(x, x_size), mask)
855
+ x = feature2token(x) + identity
856
+ if mask is not None:
857
+ mask = feature2token(mask)
858
+ return x, x_size, mask
859
+
860
+
861
+ class ToToken(nn.Module):
862
+ def __init__(self, in_channels=3, dim=128, kernel_size=5, stride=1):
863
+ super().__init__()
864
+
865
+ self.proj = Conv2dLayerPartial(in_channels=in_channels, out_channels=dim, kernel_size=kernel_size,
866
+ activation='lrelu')
867
+
868
+ def forward(self, x, mask):
869
+ x, mask = self.proj(x, mask)
870
+
871
+ return x, mask
872
+
873
+
874
+ class EncFromRGB(nn.Module):
875
+ def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
876
+ super().__init__()
877
+ self.conv0 = Conv2dLayer(in_channels=in_channels,
878
+ out_channels=out_channels,
879
+ kernel_size=1,
880
+ activation=activation,
881
+ )
882
+ self.conv1 = Conv2dLayer(in_channels=out_channels,
883
+ out_channels=out_channels,
884
+ kernel_size=3,
885
+ activation=activation,
886
+ )
887
+
888
+ def forward(self, x):
889
+ x = self.conv0(x)
890
+ x = self.conv1(x)
891
+
892
+ return x
893
+
894
+
895
+ class ConvBlockDown(nn.Module):
896
+ def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log
897
+ super().__init__()
898
+
899
+ self.conv0 = Conv2dLayer(in_channels=in_channels,
900
+ out_channels=out_channels,
901
+ kernel_size=3,
902
+ activation=activation,
903
+ down=2,
904
+ )
905
+ self.conv1 = Conv2dLayer(in_channels=out_channels,
906
+ out_channels=out_channels,
907
+ kernel_size=3,
908
+ activation=activation,
909
+ )
910
+
911
+ def forward(self, x):
912
+ x = self.conv0(x)
913
+ x = self.conv1(x)
914
+
915
+ return x
916
+
917
+
918
+ def token2feature(x, x_size):
919
+ B, N, C = x.shape
920
+ h, w = x_size
921
+ x = x.permute(0, 2, 1).reshape(B, C, h, w)
922
+ return x
923
+
924
+
925
+ def feature2token(x):
926
+ B, C, H, W = x.shape
927
+ x = x.view(B, C, -1).transpose(1, 2)
928
+ return x
929
+
930
+
931
+ class Encoder(nn.Module):
932
+ def __init__(self, res_log2, img_channels, activation, patch_size=5, channels=16, drop_path_rate=0.1):
933
+ super().__init__()
934
+
935
+ self.resolution = []
936
+
937
+ for idx, i in enumerate(range(res_log2, 3, -1)): # from input size to 16x16
938
+ res = 2 ** i
939
+ self.resolution.append(res)
940
+ if i == res_log2:
941
+ block = EncFromRGB(img_channels * 2 + 1, nf(i), activation)
942
+ else:
943
+ block = ConvBlockDown(nf(i + 1), nf(i), activation)
944
+ setattr(self, 'EncConv_Block_%dx%d' % (res, res), block)
945
+
946
+ def forward(self, x):
947
+ out = {}
948
+ for res in self.resolution:
949
+ res_log2 = int(np.log2(res))
950
+ x = getattr(self, 'EncConv_Block_%dx%d' % (res, res))(x)
951
+ out[res_log2] = x
952
+
953
+ return out
954
+
955
+
956
+ class ToStyle(nn.Module):
957
+ def __init__(self, in_channels, out_channels, activation, drop_rate):
958
+ super().__init__()
959
+ self.conv = nn.Sequential(
960
+ Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation,
961
+ down=2),
962
+ Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation,
963
+ down=2),
964
+ Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation,
965
+ down=2),
966
+ )
967
+
968
+ self.pool = nn.AdaptiveAvgPool2d(1)
969
+ self.fc = FullyConnectedLayer(in_features=in_channels,
970
+ out_features=out_channels,
971
+ activation=activation)
972
+ # self.dropout = nn.Dropout(drop_rate)
973
+
974
+ def forward(self, x):
975
+ x = self.conv(x)
976
+ x = self.pool(x)
977
+ x = self.fc(x.flatten(start_dim=1))
978
+ # x = self.dropout(x)
979
+
980
+ return x
981
+
982
+
983
+ class DecBlockFirstV2(nn.Module):
984
+ def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
985
+ super().__init__()
986
+ self.res = res
987
+
988
+ self.conv0 = Conv2dLayer(in_channels=in_channels,
989
+ out_channels=in_channels,
990
+ kernel_size=3,
991
+ activation=activation,
992
+ )
993
+ self.conv1 = StyleConv(in_channels=in_channels,
994
+ out_channels=out_channels,
995
+ style_dim=style_dim,
996
+ resolution=2 ** res,
997
+ kernel_size=3,
998
+ use_noise=use_noise,
999
+ activation=activation,
1000
+ demodulate=demodulate,
1001
+ )
1002
+ self.toRGB = ToRGB(in_channels=out_channels,
1003
+ out_channels=img_channels,
1004
+ style_dim=style_dim,
1005
+ kernel_size=1,
1006
+ demodulate=False,
1007
+ )
1008
+
1009
+ def forward(self, x, ws, gs, E_features, noise_mode='random'):
1010
+ # x = self.fc(x).view(x.shape[0], -1, 4, 4)
1011
+ x = self.conv0(x)
1012
+ x = x + E_features[self.res]
1013
+ style = get_style_code(ws[:, 0], gs)
1014
+ x = self.conv1(x, style, noise_mode=noise_mode)
1015
+ style = get_style_code(ws[:, 1], gs)
1016
+ img = self.toRGB(x, style, skip=None)
1017
+
1018
+ return x, img
1019
+
1020
+
1021
+ class DecBlock(nn.Module):
1022
+ def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate,
1023
+ img_channels): # res = 4, ..., resolution_log2
1024
+ super().__init__()
1025
+ self.res = res
1026
+
1027
+ self.conv0 = StyleConv(in_channels=in_channels,
1028
+ out_channels=out_channels,
1029
+ style_dim=style_dim,
1030
+ resolution=2 ** res,
1031
+ kernel_size=3,
1032
+ up=2,
1033
+ use_noise=use_noise,
1034
+ activation=activation,
1035
+ demodulate=demodulate,
1036
+ )
1037
+ self.conv1 = StyleConv(in_channels=out_channels,
1038
+ out_channels=out_channels,
1039
+ style_dim=style_dim,
1040
+ resolution=2 ** res,
1041
+ kernel_size=3,
1042
+ use_noise=use_noise,
1043
+ activation=activation,
1044
+ demodulate=demodulate,
1045
+ )
1046
+ self.toRGB = ToRGB(in_channels=out_channels,
1047
+ out_channels=img_channels,
1048
+ style_dim=style_dim,
1049
+ kernel_size=1,
1050
+ demodulate=False,
1051
+ )
1052
+
1053
+ def forward(self, x, img, ws, gs, E_features, noise_mode='random'):
1054
+ style = get_style_code(ws[:, self.res * 2 - 9], gs)
1055
+ x = self.conv0(x, style, noise_mode=noise_mode)
1056
+ x = x + E_features[self.res]
1057
+ style = get_style_code(ws[:, self.res * 2 - 8], gs)
1058
+ x = self.conv1(x, style, noise_mode=noise_mode)
1059
+ style = get_style_code(ws[:, self.res * 2 - 7], gs)
1060
+ img = self.toRGB(x, style, skip=img)
1061
+
1062
+ return x, img
1063
+
1064
+
1065
+ class Decoder(nn.Module):
1066
+ def __init__(self, res_log2, activation, style_dim, use_noise, demodulate, img_channels):
1067
+ super().__init__()
1068
+ self.Dec_16x16 = DecBlockFirstV2(4, nf(4), nf(4), activation, style_dim, use_noise, demodulate, img_channels)
1069
+ for res in range(5, res_log2 + 1):
1070
+ setattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res),
1071
+ DecBlock(res, nf(res - 1), nf(res), activation, style_dim, use_noise, demodulate, img_channels))
1072
+ self.res_log2 = res_log2
1073
+
1074
+ def forward(self, x, ws, gs, E_features, noise_mode='random'):
1075
+ x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode)
1076
+ for res in range(5, self.res_log2 + 1):
1077
+ block = getattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res))
1078
+ x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode)
1079
+
1080
+ return img
1081
+
1082
+
1083
+ class DecStyleBlock(nn.Module):
1084
+ def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
1085
+ super().__init__()
1086
+ self.res = res
1087
+
1088
+ self.conv0 = StyleConv(in_channels=in_channels,
1089
+ out_channels=out_channels,
1090
+ style_dim=style_dim,
1091
+ resolution=2 ** res,
1092
+ kernel_size=3,
1093
+ up=2,
1094
+ use_noise=use_noise,
1095
+ activation=activation,
1096
+ demodulate=demodulate,
1097
+ )
1098
+ self.conv1 = StyleConv(in_channels=out_channels,
1099
+ out_channels=out_channels,
1100
+ style_dim=style_dim,
1101
+ resolution=2 ** res,
1102
+ kernel_size=3,
1103
+ use_noise=use_noise,
1104
+ activation=activation,
1105
+ demodulate=demodulate,
1106
+ )
1107
+ self.toRGB = ToRGB(in_channels=out_channels,
1108
+ out_channels=img_channels,
1109
+ style_dim=style_dim,
1110
+ kernel_size=1,
1111
+ demodulate=False,
1112
+ )
1113
+
1114
+ def forward(self, x, img, style, skip, noise_mode='random'):
1115
+ x = self.conv0(x, style, noise_mode=noise_mode)
1116
+ x = x + skip
1117
+ x = self.conv1(x, style, noise_mode=noise_mode)
1118
+ img = self.toRGB(x, style, skip=img)
1119
+
1120
+ return x, img
1121
+
1122
+
1123
+ class FirstStage(nn.Module):
1124
+ def __init__(self, img_channels, img_resolution=256, dim=180, w_dim=512, use_noise=False, demodulate=True,
1125
+ activation='lrelu'):
1126
+ super().__init__()
1127
+ res = 64
1128
+
1129
+ self.conv_first = Conv2dLayerPartial(in_channels=img_channels + 1, out_channels=dim, kernel_size=3,
1130
+ activation=activation)
1131
+ self.enc_conv = nn.ModuleList()
1132
+ down_time = int(np.log2(img_resolution // res))
1133
+ # 根据图片尺寸构建 swim transformer 的层数
1134
+ for i in range(down_time): # from input size to 64
1135
+ self.enc_conv.append(
1136
+ Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation)
1137
+ )
1138
+
1139
+ # from 64 -> 16 -> 64
1140
+ depths = [2, 3, 4, 3, 2]
1141
+ ratios = [1, 1 / 2, 1 / 2, 2, 2]
1142
+ num_heads = 6
1143
+ window_sizes = [8, 16, 16, 16, 8]
1144
+ drop_path_rate = 0.1
1145
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
1146
+
1147
+ self.tran = nn.ModuleList()
1148
+ for i, depth in enumerate(depths):
1149
+ res = int(res * ratios[i])
1150
+ if ratios[i] < 1:
1151
+ merge = PatchMerging(dim, dim, down=int(1 / ratios[i]))
1152
+ elif ratios[i] > 1:
1153
+ merge = PatchUpsampling(dim, dim, up=ratios[i])
1154
+ else:
1155
+ merge = None
1156
+ self.tran.append(
1157
+ BasicLayer(dim=dim, input_resolution=[res, res], depth=depth, num_heads=num_heads,
1158
+ window_size=window_sizes[i], drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
1159
+ downsample=merge)
1160
+ )
1161
+
1162
+ # global style
1163
+ down_conv = []
1164
+ for i in range(int(np.log2(16))):
1165
+ down_conv.append(
1166
+ Conv2dLayer(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation))
1167
+ down_conv.append(nn.AdaptiveAvgPool2d((1, 1)))
1168
+ self.down_conv = nn.Sequential(*down_conv)
1169
+ self.to_style = FullyConnectedLayer(in_features=dim, out_features=dim * 2, activation=activation)
1170
+ self.ws_style = FullyConnectedLayer(in_features=w_dim, out_features=dim, activation=activation)
1171
+ self.to_square = FullyConnectedLayer(in_features=dim, out_features=16 * 16, activation=activation)
1172
+
1173
+ style_dim = dim * 3
1174
+ self.dec_conv = nn.ModuleList()
1175
+ for i in range(down_time): # from 64 to input size
1176
+ res = res * 2
1177
+ self.dec_conv.append(
1178
+ DecStyleBlock(res, dim, dim, activation, style_dim, use_noise, demodulate, img_channels))
1179
+
1180
+ def forward(self, images_in, masks_in, ws, noise_mode='random'):
1181
+ x = torch.cat([masks_in - 0.5, images_in * masks_in], dim=1)
1182
+
1183
+ skips = []
1184
+ x, mask = self.conv_first(x, masks_in) # input size
1185
+ skips.append(x)
1186
+ for i, block in enumerate(self.enc_conv): # input size to 64
1187
+ x, mask = block(x, mask)
1188
+ if i != len(self.enc_conv) - 1:
1189
+ skips.append(x)
1190
+
1191
+ x_size = x.size()[-2:]
1192
+ x = feature2token(x)
1193
+ mask = feature2token(mask)
1194
+ mid = len(self.tran) // 2
1195
+ for i, block in enumerate(self.tran): # 64 to 16
1196
+ if i < mid:
1197
+ x, x_size, mask = block(x, x_size, mask)
1198
+ skips.append(x)
1199
+ elif i > mid:
1200
+ x, x_size, mask = block(x, x_size, None)
1201
+ x = x + skips[mid - i]
1202
+ else:
1203
+ x, x_size, mask = block(x, x_size, None)
1204
+
1205
+ mul_map = torch.ones_like(x) * 0.5
1206
+ mul_map = F.dropout(mul_map, training=True)
1207
+ ws = self.ws_style(ws[:, -1])
1208
+ add_n = self.to_square(ws).unsqueeze(1)
1209
+ add_n = F.interpolate(add_n, size=x.size(1), mode='linear', align_corners=False).squeeze(1).unsqueeze(
1210
+ -1)
1211
+ x = x * mul_map + add_n * (1 - mul_map)
1212
+ gs = self.to_style(self.down_conv(token2feature(x, x_size)).flatten(start_dim=1))
1213
+ style = torch.cat([gs, ws], dim=1)
1214
+
1215
+ x = token2feature(x, x_size).contiguous()
1216
+ img = None
1217
+ for i, block in enumerate(self.dec_conv):
1218
+ x, img = block(x, img, style, skips[len(self.dec_conv) - i - 1], noise_mode=noise_mode)
1219
+
1220
+ # ensemble
1221
+ img = img * (1 - masks_in) + images_in * masks_in
1222
+
1223
+ return img
1224
+
1225
+
1226
+ class SynthesisNet(nn.Module):
1227
+ def __init__(self,
1228
+ w_dim, # Intermediate latent (W) dimensionality.
1229
+ img_resolution, # Output image resolution.
1230
+ img_channels=3, # Number of color channels.
1231
+ channel_base=32768, # Overall multiplier for the number of channels.
1232
+ channel_decay=1.0,
1233
+ channel_max=512, # Maximum number of channels in any layer.
1234
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
1235
+ drop_rate=0.5,
1236
+ use_noise=False,
1237
+ demodulate=True,
1238
+ ):
1239
+ super().__init__()
1240
+ resolution_log2 = int(np.log2(img_resolution))
1241
+ assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
1242
+
1243
+ self.num_layers = resolution_log2 * 2 - 3 * 2
1244
+ self.img_resolution = img_resolution
1245
+ self.resolution_log2 = resolution_log2
1246
+
1247
+ # first stage
1248
+ self.first_stage = FirstStage(img_channels, img_resolution=img_resolution, w_dim=w_dim, use_noise=False,
1249
+ demodulate=demodulate)
1250
+
1251
+ # second stage
1252
+ self.enc = Encoder(resolution_log2, img_channels, activation, patch_size=5, channels=16)
1253
+ self.to_square = FullyConnectedLayer(in_features=w_dim, out_features=16 * 16, activation=activation)
1254
+ self.to_style = ToStyle(in_channels=nf(4), out_channels=nf(2) * 2, activation=activation, drop_rate=drop_rate)
1255
+ style_dim = w_dim + nf(2) * 2
1256
+ self.dec = Decoder(resolution_log2, activation, style_dim, use_noise, demodulate, img_channels)
1257
+
1258
+ def forward(self, images_in, masks_in, ws, noise_mode='random', return_stg1=False):
1259
+ out_stg1 = self.first_stage(images_in, masks_in, ws, noise_mode=noise_mode)
1260
+
1261
+ # encoder
1262
+ x = images_in * masks_in + out_stg1 * (1 - masks_in)
1263
+ x = torch.cat([masks_in - 0.5, x, images_in * masks_in], dim=1)
1264
+ E_features = self.enc(x)
1265
+
1266
+ fea_16 = E_features[4]
1267
+ mul_map = torch.ones_like(fea_16) * 0.5
1268
+ mul_map = F.dropout(mul_map, training=True)
1269
+ add_n = self.to_square(ws[:, 0]).view(-1, 16, 16).unsqueeze(1)
1270
+ add_n = F.interpolate(add_n, size=fea_16.size()[-2:], mode='bilinear', align_corners=False)
1271
+ fea_16 = fea_16 * mul_map + add_n * (1 - mul_map)
1272
+ E_features[4] = fea_16
1273
+
1274
+ # style
1275
+ gs = self.to_style(fea_16)
1276
+
1277
+ # decoder
1278
+ img = self.dec(fea_16, ws, gs, E_features, noise_mode=noise_mode)
1279
+
1280
+ # ensemble
1281
+ img = img * (1 - masks_in) + images_in * masks_in
1282
+
1283
+ if not return_stg1:
1284
+ return img
1285
+ else:
1286
+ return img, out_stg1
1287
+
1288
+
1289
+ class Generator(nn.Module):
1290
+ def __init__(self,
1291
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
1292
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
1293
+ w_dim, # Intermediate latent (W) dimensionality.
1294
+ img_resolution, # resolution of generated image
1295
+ img_channels, # Number of input color channels.
1296
+ synthesis_kwargs={}, # Arguments for SynthesisNetwork.
1297
+ mapping_kwargs={}, # Arguments for MappingNetwork.
1298
+ ):
1299
+ super().__init__()
1300
+ self.z_dim = z_dim
1301
+ self.c_dim = c_dim
1302
+ self.w_dim = w_dim
1303
+ self.img_resolution = img_resolution
1304
+ self.img_channels = img_channels
1305
+
1306
+ self.synthesis = SynthesisNet(w_dim=w_dim,
1307
+ img_resolution=img_resolution,
1308
+ img_channels=img_channels,
1309
+ **synthesis_kwargs)
1310
+ self.mapping = MappingNet(z_dim=z_dim,
1311
+ c_dim=c_dim,
1312
+ w_dim=w_dim,
1313
+ num_ws=self.synthesis.num_layers,
1314
+ **mapping_kwargs)
1315
+
1316
+ def forward(self, images_in, masks_in, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False,
1317
+ noise_mode='none', return_stg1=False):
1318
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff,
1319
+ skip_w_avg_update=skip_w_avg_update)
1320
+ img = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode)
1321
+ return img
1322
+
1323
+
1324
+ class Discriminator(torch.nn.Module):
1325
+ def __init__(self,
1326
+ c_dim, # Conditioning label (C) dimensionality.
1327
+ img_resolution, # Input resolution.
1328
+ img_channels, # Number of input color channels.
1329
+ channel_base=32768, # Overall multiplier for the number of channels.
1330
+ channel_max=512, # Maximum number of channels in any layer.
1331
+ channel_decay=1,
1332
+ cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
1333
+ activation='lrelu',
1334
+ mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
1335
+ mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
1336
+ ):
1337
+ super().__init__()
1338
+ self.c_dim = c_dim
1339
+ self.img_resolution = img_resolution
1340
+ self.img_channels = img_channels
1341
+
1342
+ resolution_log2 = int(np.log2(img_resolution))
1343
+ assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
1344
+ self.resolution_log2 = resolution_log2
1345
+
1346
+ if cmap_dim == None:
1347
+ cmap_dim = nf(2)
1348
+ if c_dim == 0:
1349
+ cmap_dim = 0
1350
+ self.cmap_dim = cmap_dim
1351
+
1352
+ if c_dim > 0:
1353
+ self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None)
1354
+
1355
+ Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)]
1356
+ for res in range(resolution_log2, 2, -1):
1357
+ Dis.append(DisBlock(nf(res), nf(res - 1), activation))
1358
+
1359
+ if mbstd_num_channels > 0:
1360
+ Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
1361
+ Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation))
1362
+ self.Dis = nn.Sequential(*Dis)
1363
+
1364
+ self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
1365
+ self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
1366
+
1367
+ # for 64x64
1368
+ Dis_stg1 = [DisFromRGB(img_channels + 1, nf(resolution_log2) // 2, activation)]
1369
+ for res in range(resolution_log2, 2, -1):
1370
+ Dis_stg1.append(DisBlock(nf(res) // 2, nf(res - 1) // 2, activation))
1371
+
1372
+ if mbstd_num_channels > 0:
1373
+ Dis_stg1.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
1374
+ Dis_stg1.append(Conv2dLayer(nf(2) // 2 + mbstd_num_channels, nf(2) // 2, kernel_size=3, activation=activation))
1375
+ self.Dis_stg1 = nn.Sequential(*Dis_stg1)
1376
+
1377
+ self.fc0_stg1 = FullyConnectedLayer(nf(2) // 2 * 4 ** 2, nf(2) // 2, activation=activation)
1378
+ self.fc1_stg1 = FullyConnectedLayer(nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim)
1379
+
1380
+ def forward(self, images_in, masks_in, images_stg1, c):
1381
+ x = self.Dis(torch.cat([masks_in - 0.5, images_in], dim=1))
1382
+ x = self.fc1(self.fc0(x.flatten(start_dim=1)))
1383
+
1384
+ x_stg1 = self.Dis_stg1(torch.cat([masks_in - 0.5, images_stg1], dim=1))
1385
+ x_stg1 = self.fc1_stg1(self.fc0_stg1(x_stg1.flatten(start_dim=1)))
1386
+
1387
+ if self.c_dim > 0:
1388
+ cmap = self.mapping(None, c)
1389
+
1390
+ if self.cmap_dim > 0:
1391
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
1392
+ x_stg1 = (x_stg1 * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
1393
+
1394
+ return x, x_stg1
1395
+
1396
+
1397
+ MAT_MODEL_URL = os.environ.get(
1398
+ "MAT_MODEL_URL",
1399
+ "https://github.com/Sanster/models/releases/download/add_mat/Places_512_FullData_G.pth",
1400
+ )
1401
+
1402
+
1403
+ class MAT(InpaintModel):
1404
+ min_size = 512
1405
+ pad_mod = 512
1406
+ pad_to_square = True
1407
+
1408
+ def init_model(self, device, **kwargs):
1409
+ seed = 240 # pick up a random number
1410
+ random.seed(seed)
1411
+ np.random.seed(seed)
1412
+ torch.manual_seed(seed)
1413
+
1414
+ G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3)
1415
+ self.model = load_model(G, MAT_MODEL_URL, device)
1416
+ self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(device) # [1., 512]
1417
+ self.label = torch.zeros([1, self.model.c_dim], device=device)
1418
+
1419
+ @staticmethod
1420
+ def is_downloaded() -> bool:
1421
+ return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL))
1422
+
1423
+ def forward(self, image, mask, config: Config):
1424
+ """Input images and output images have same size
1425
+ images: [H, W, C] RGB
1426
+ masks: [H, W] mask area == 255
1427
+ return: BGR IMAGE
1428
+ """
1429
+
1430
+ image = norm_img(image) # [0, 1]
1431
+ image = image * 2 - 1 # [0, 1] -> [-1, 1]
1432
+
1433
+ mask = (mask > 127) * 255
1434
+ mask = 255 - mask
1435
+ mask = norm_img(mask)
1436
+
1437
+ image = torch.from_numpy(image).unsqueeze(0).to(self.device)
1438
+ mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
1439
+
1440
+ output = self.model(image, mask, self.z, self.label, truncation_psi=1, noise_mode='none')
1441
+ output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8)
1442
+ output = output[0].cpu().numpy()
1443
+ cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
1444
+ return cur_res
lama_cleaner/model/opencv2.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+ from lama_cleaner.model.base import InpaintModel
4
+ from lama_cleaner.schema import Config
5
+
6
+ flag_map = {
7
+ "INPAINT_NS": cv2.INPAINT_NS,
8
+ "INPAINT_TELEA": cv2.INPAINT_TELEA
9
+ }
10
+
11
+ class OpenCV2(InpaintModel):
12
+ pad_mod = 1
13
+
14
+ @staticmethod
15
+ def is_downloaded() -> bool:
16
+ return True
17
+
18
+ def forward(self, image, mask, config: Config):
19
+ """Input image and output image have same size
20
+ image: [H, W, C] RGB
21
+ mask: [H, W, 1]
22
+ return: BGR IMAGE
23
+ """
24
+ cur_res = cv2.inpaint(image[:,:,::-1], mask, inpaintRadius=config.cv2_radius, flags=flag_map[config.cv2_flag])
25
+ return cur_res