Misdi KenjieDec commited on
Commit
cc38472
·
0 Parent(s):

Duplicate from KenjieDec/RemBG

Browse files

Co-authored-by: Kenjie Dec <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.onnx filter=lfs diff=lfs merge=lfs -text
13
+ *.ot filter=lfs diff=lfs merge=lfs -text
14
+ *.parquet filter=lfs diff=lfs merge=lfs -text
15
+ *.pb filter=lfs diff=lfs merge=lfs -text
16
+ *.pt filter=lfs diff=lfs merge=lfs -text
17
+ *.pth filter=lfs diff=lfs merge=lfs -text
18
+ *.rar filter=lfs diff=lfs merge=lfs -text
19
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
21
+ *.tflite filter=lfs diff=lfs merge=lfs -text
22
+ *.tgz filter=lfs diff=lfs merge=lfs -text
23
+ *.wasm filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Daniel Gatis
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Rembg
3
+ emoji: 👀
4
+ colorFrom: pink
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.0.20
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: KenjieDec/RemBG
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
anime-girl.jpg ADDED
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Modified from Akhaliq Hugging Face Demo
2
+ ## https://huggingface.co/akhaliq
3
+
4
+ import gradio as gr
5
+ import os
6
+ import cv2
7
+
8
+ def inference(file, mask, model):
9
+ im = cv2.imread(file, cv2.IMREAD_COLOR)
10
+ cv2.imwrite(os.path.join("input.png"), im)
11
+
12
+ from rembg import new_session, remove
13
+
14
+ input_path = 'input.png'
15
+ output_path = 'output.png'
16
+
17
+ with open(input_path, 'rb') as i:
18
+ with open(output_path, 'wb') as o:
19
+ input = i.read()
20
+ output = remove(
21
+ input,
22
+ session = new_session(model),
23
+ only_mask = (True if mask == "Mask only" else False)
24
+ )
25
+
26
+
27
+
28
+ o.write(output)
29
+ return os.path.join("output.png")
30
+
31
+ title = "RemBG"
32
+ description = "Gradio demo for RemBG. To use it, simply upload your image and wait. Read more at the link below."
33
+ article = "<p style='text-align: center;'><a href='https://github.com/danielgatis/rembg' target='_blank'>Github Repo</a></p>"
34
+
35
+
36
+ gr.Interface(
37
+ inference,
38
+ [
39
+ gr.inputs.Image(type="filepath", label="Input"),
40
+ gr.inputs.Radio(
41
+ [
42
+ "Default",
43
+ "Mask only"
44
+ ],
45
+ type="value",
46
+ default="Default",
47
+ label="Choices"
48
+ ),
49
+ gr.inputs.Dropdown([
50
+ "u2net",
51
+ "u2netp",
52
+ "u2net_human_seg",
53
+ "u2net_cloth_seg",
54
+ "silueta",
55
+ "isnet-general-use",
56
+ "isnet-anime",
57
+ "sam",
58
+ ],
59
+ type="value",
60
+ default="isnet-general-use",
61
+ label="Models"
62
+ ),
63
+ ],
64
+ gr.outputs.Image(type="filepath", label="Output"),
65
+ title=title,
66
+ description=description,
67
+ article=article,
68
+ examples=[["lion.png", "Default", "u2net"], ["girl.jpg", "Default", "u2net"], ["anime-girl.jpg", "Default", "isnet-anime"]],
69
+ enable_queue=True
70
+ ).launch()
girl.jpg ADDED
lion.png ADDED
rembg/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from . import _version
2
+
3
+ __version__ = _version.get_versions()["version"]
4
+
5
+ from .bg import remove
6
+ from .session_factory import new_session
rembg/_version.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file helps to compute a version number in source trees obtained from
2
+ # git-archive tarball (such as those provided by githubs download-from-tag
3
+ # feature). Distribution tarballs (built by setup.py sdist) and build
4
+ # directories (produced by setup.py build) will contain a much shorter file
5
+ # that just contains the computed version number.
6
+
7
+ # This file is released into the public domain. Generated by
8
+ # versioneer-0.21 (https://github.com/python-versioneer/python-versioneer)
9
+
10
+ """Git implementation of _version.py."""
11
+
12
+ import errno
13
+ import os
14
+ import re
15
+ import subprocess
16
+ import sys
17
+ from typing import Callable, Dict
18
+
19
+
20
+ def get_keywords():
21
+ """Get the keywords needed to look up the version information."""
22
+ # these strings will be replaced by git during git-archive.
23
+ # setup.py/versioneer.py will grep for the variable names, so they must
24
+ # each be defined on a line of their own. _version.py will just call
25
+ # get_keywords().
26
+ git_refnames = " (HEAD -> main, tag: v2.0.43)"
27
+ git_full = "848a38e4cc5cf41522974dea00848596105b1dfa"
28
+ git_date = "2023-06-02 09:20:57 -0300"
29
+ keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
30
+ return keywords
31
+
32
+
33
+ class VersioneerConfig:
34
+ """Container for Versioneer configuration parameters."""
35
+
36
+
37
+ def get_config():
38
+ """Create, populate and return the VersioneerConfig() object."""
39
+ # these strings are filled in when 'setup.py versioneer' creates
40
+ # _version.py
41
+ cfg = VersioneerConfig()
42
+ cfg.VCS = "git"
43
+ cfg.style = "pep440"
44
+ cfg.tag_prefix = "v"
45
+ cfg.parentdir_prefix = "rembg-"
46
+ cfg.versionfile_source = "rembg/_version.py"
47
+ cfg.verbose = False
48
+ return cfg
49
+
50
+
51
+ class NotThisMethod(Exception):
52
+ """Exception raised if a method is not valid for the current scenario."""
53
+
54
+
55
+ LONG_VERSION_PY: Dict[str, str] = {}
56
+ HANDLERS: Dict[str, Dict[str, Callable]] = {}
57
+
58
+
59
+ def register_vcs_handler(vcs, method): # decorator
60
+ """Create decorator to mark a method as the handler of a VCS."""
61
+
62
+ def decorate(f):
63
+ """Store f in HANDLERS[vcs][method]."""
64
+ if vcs not in HANDLERS:
65
+ HANDLERS[vcs] = {}
66
+ HANDLERS[vcs][method] = f
67
+ return f
68
+
69
+ return decorate
70
+
71
+
72
+ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None):
73
+ """Call the given command(s)."""
74
+ assert isinstance(commands, list)
75
+ process = None
76
+ for command in commands:
77
+ try:
78
+ dispcmd = str([command] + args)
79
+ # remember shell=False, so use git.cmd on windows, not just git
80
+ process = subprocess.Popen(
81
+ [command] + args,
82
+ cwd=cwd,
83
+ env=env,
84
+ stdout=subprocess.PIPE,
85
+ stderr=(subprocess.PIPE if hide_stderr else None),
86
+ )
87
+ break
88
+ except OSError:
89
+ e = sys.exc_info()[1]
90
+ if e.errno == errno.ENOENT:
91
+ continue
92
+ if verbose:
93
+ print("unable to run %s" % dispcmd)
94
+ print(e)
95
+ return None, None
96
+ else:
97
+ if verbose:
98
+ print("unable to find command, tried %s" % (commands,))
99
+ return None, None
100
+ stdout = process.communicate()[0].strip().decode()
101
+ if process.returncode != 0:
102
+ if verbose:
103
+ print("unable to run %s (error)" % dispcmd)
104
+ print("stdout was %s" % stdout)
105
+ return None, process.returncode
106
+ return stdout, process.returncode
107
+
108
+
109
+ def versions_from_parentdir(parentdir_prefix, root, verbose):
110
+ """Try to determine the version from the parent directory name.
111
+
112
+ Source tarballs conventionally unpack into a directory that includes both
113
+ the project name and a version string. We will also support searching up
114
+ two directory levels for an appropriately named parent directory
115
+ """
116
+ rootdirs = []
117
+
118
+ for _ in range(3):
119
+ dirname = os.path.basename(root)
120
+ if dirname.startswith(parentdir_prefix):
121
+ return {
122
+ "version": dirname[len(parentdir_prefix) :],
123
+ "full-revisionid": None,
124
+ "dirty": False,
125
+ "error": None,
126
+ "date": None,
127
+ }
128
+ rootdirs.append(root)
129
+ root = os.path.dirname(root) # up a level
130
+
131
+ if verbose:
132
+ print(
133
+ "Tried directories %s but none started with prefix %s"
134
+ % (str(rootdirs), parentdir_prefix)
135
+ )
136
+ raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
137
+
138
+
139
+ @register_vcs_handler("git", "get_keywords")
140
+ def git_get_keywords(versionfile_abs):
141
+ """Extract version information from the given file."""
142
+ # the code embedded in _version.py can just fetch the value of these
143
+ # keywords. When used from setup.py, we don't want to import _version.py,
144
+ # so we do it with a regexp instead. This function is not used from
145
+ # _version.py.
146
+ keywords = {}
147
+ try:
148
+ with open(versionfile_abs, "r") as fobj:
149
+ for line in fobj:
150
+ if line.strip().startswith("git_refnames ="):
151
+ mo = re.search(r'=\s*"(.*)"', line)
152
+ if mo:
153
+ keywords["refnames"] = mo.group(1)
154
+ if line.strip().startswith("git_full ="):
155
+ mo = re.search(r'=\s*"(.*)"', line)
156
+ if mo:
157
+ keywords["full"] = mo.group(1)
158
+ if line.strip().startswith("git_date ="):
159
+ mo = re.search(r'=\s*"(.*)"', line)
160
+ if mo:
161
+ keywords["date"] = mo.group(1)
162
+ except OSError:
163
+ pass
164
+ return keywords
165
+
166
+
167
+ @register_vcs_handler("git", "keywords")
168
+ def git_versions_from_keywords(keywords, tag_prefix, verbose):
169
+ """Get version information from git keywords."""
170
+ if "refnames" not in keywords:
171
+ raise NotThisMethod("Short version file found")
172
+ date = keywords.get("date")
173
+ if date is not None:
174
+ # Use only the last line. Previous lines may contain GPG signature
175
+ # information.
176
+ date = date.splitlines()[-1]
177
+
178
+ # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant
179
+ # datestamp. However we prefer "%ci" (which expands to an "ISO-8601
180
+ # -like" string, which we must then edit to make compliant), because
181
+ # it's been around since git-1.5.3, and it's too difficult to
182
+ # discover which version we're using, or to work around using an
183
+ # older one.
184
+ date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
185
+ refnames = keywords["refnames"].strip()
186
+ if refnames.startswith("$Format"):
187
+ if verbose:
188
+ print("keywords are unexpanded, not using")
189
+ raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
190
+ refs = {r.strip() for r in refnames.strip("()").split(",")}
191
+ # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
192
+ # just "foo-1.0". If we see a "tag: " prefix, prefer those.
193
+ TAG = "tag: "
194
+ tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)}
195
+ if not tags:
196
+ # Either we're using git < 1.8.3, or there really are no tags. We use
197
+ # a heuristic: assume all version tags have a digit. The old git %d
198
+ # expansion behaves like git log --decorate=short and strips out the
199
+ # refs/heads/ and refs/tags/ prefixes that would let us distinguish
200
+ # between branches and tags. By ignoring refnames without digits, we
201
+ # filter out many common branch names like "release" and
202
+ # "stabilization", as well as "HEAD" and "master".
203
+ tags = {r for r in refs if re.search(r"\d", r)}
204
+ if verbose:
205
+ print("discarding '%s', no digits" % ",".join(refs - tags))
206
+ if verbose:
207
+ print("likely tags: %s" % ",".join(sorted(tags)))
208
+ for ref in sorted(tags):
209
+ # sorting will prefer e.g. "2.0" over "2.0rc1"
210
+ if ref.startswith(tag_prefix):
211
+ r = ref[len(tag_prefix) :]
212
+ # Filter out refs that exactly match prefix or that don't start
213
+ # with a number once the prefix is stripped (mostly a concern
214
+ # when prefix is '')
215
+ if not re.match(r"\d", r):
216
+ continue
217
+ if verbose:
218
+ print("picking %s" % r)
219
+ return {
220
+ "version": r,
221
+ "full-revisionid": keywords["full"].strip(),
222
+ "dirty": False,
223
+ "error": None,
224
+ "date": date,
225
+ }
226
+ # no suitable tags, so version is "0+unknown", but full hex is still there
227
+ if verbose:
228
+ print("no suitable tags, using unknown + full revision id")
229
+ return {
230
+ "version": "0+unknown",
231
+ "full-revisionid": keywords["full"].strip(),
232
+ "dirty": False,
233
+ "error": "no suitable tags",
234
+ "date": None,
235
+ }
236
+
237
+
238
+ @register_vcs_handler("git", "pieces_from_vcs")
239
+ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
240
+ """Get version from 'git describe' in the root of the source tree.
241
+
242
+ This only gets called if the git-archive 'subst' keywords were *not*
243
+ expanded, and _version.py hasn't already been rewritten with a short
244
+ version string, meaning we're inside a checked out source tree.
245
+ """
246
+ GITS = ["git"]
247
+ TAG_PREFIX_REGEX = "*"
248
+ if sys.platform == "win32":
249
+ GITS = ["git.cmd", "git.exe"]
250
+ TAG_PREFIX_REGEX = r"\*"
251
+
252
+ _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True)
253
+ if rc != 0:
254
+ if verbose:
255
+ print("Directory %s not under git control" % root)
256
+ raise NotThisMethod("'git rev-parse --git-dir' returned error")
257
+
258
+ # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
259
+ # if there isn't one, this yields HEX[-dirty] (no NUM)
260
+ describe_out, rc = runner(
261
+ GITS,
262
+ [
263
+ "describe",
264
+ "--tags",
265
+ "--dirty",
266
+ "--always",
267
+ "--long",
268
+ "--match",
269
+ "%s%s" % (tag_prefix, TAG_PREFIX_REGEX),
270
+ ],
271
+ cwd=root,
272
+ )
273
+ # --long was added in git-1.5.5
274
+ if describe_out is None:
275
+ raise NotThisMethod("'git describe' failed")
276
+ describe_out = describe_out.strip()
277
+ full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
278
+ if full_out is None:
279
+ raise NotThisMethod("'git rev-parse' failed")
280
+ full_out = full_out.strip()
281
+
282
+ pieces = {}
283
+ pieces["long"] = full_out
284
+ pieces["short"] = full_out[:7] # maybe improved later
285
+ pieces["error"] = None
286
+
287
+ branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root)
288
+ # --abbrev-ref was added in git-1.6.3
289
+ if rc != 0 or branch_name is None:
290
+ raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
291
+ branch_name = branch_name.strip()
292
+
293
+ if branch_name == "HEAD":
294
+ # If we aren't exactly on a branch, pick a branch which represents
295
+ # the current commit. If all else fails, we are on a branchless
296
+ # commit.
297
+ branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
298
+ # --contains was added in git-1.5.4
299
+ if rc != 0 or branches is None:
300
+ raise NotThisMethod("'git branch --contains' returned error")
301
+ branches = branches.split("\n")
302
+
303
+ # Remove the first line if we're running detached
304
+ if "(" in branches[0]:
305
+ branches.pop(0)
306
+
307
+ # Strip off the leading "* " from the list of branches.
308
+ branches = [branch[2:] for branch in branches]
309
+ if "master" in branches:
310
+ branch_name = "master"
311
+ elif not branches:
312
+ branch_name = None
313
+ else:
314
+ # Pick the first branch that is returned. Good or bad.
315
+ branch_name = branches[0]
316
+
317
+ pieces["branch"] = branch_name
318
+
319
+ # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
320
+ # TAG might have hyphens.
321
+ git_describe = describe_out
322
+
323
+ # look for -dirty suffix
324
+ dirty = git_describe.endswith("-dirty")
325
+ pieces["dirty"] = dirty
326
+ if dirty:
327
+ git_describe = git_describe[: git_describe.rindex("-dirty")]
328
+
329
+ # now we have TAG-NUM-gHEX or HEX
330
+
331
+ if "-" in git_describe:
332
+ # TAG-NUM-gHEX
333
+ mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe)
334
+ if not mo:
335
+ # unparsable. Maybe git-describe is misbehaving?
336
+ pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out
337
+ return pieces
338
+
339
+ # tag
340
+ full_tag = mo.group(1)
341
+ if not full_tag.startswith(tag_prefix):
342
+ if verbose:
343
+ fmt = "tag '%s' doesn't start with prefix '%s'"
344
+ print(fmt % (full_tag, tag_prefix))
345
+ pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % (
346
+ full_tag,
347
+ tag_prefix,
348
+ )
349
+ return pieces
350
+ pieces["closest-tag"] = full_tag[len(tag_prefix) :]
351
+
352
+ # distance: number of commits since tag
353
+ pieces["distance"] = int(mo.group(2))
354
+
355
+ # commit: short hex revision ID
356
+ pieces["short"] = mo.group(3)
357
+
358
+ else:
359
+ # HEX: no tags
360
+ pieces["closest-tag"] = None
361
+ count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root)
362
+ pieces["distance"] = int(count_out) # total number of commits
363
+
364
+ # commit date: see ISO-8601 comment in git_versions_from_keywords()
365
+ date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
366
+ # Use only the last line. Previous lines may contain GPG signature
367
+ # information.
368
+ date = date.splitlines()[-1]
369
+ pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
370
+
371
+ return pieces
372
+
373
+
374
+ def plus_or_dot(pieces):
375
+ """Return a + if we don't already have one, else return a ."""
376
+ if "+" in pieces.get("closest-tag", ""):
377
+ return "."
378
+ return "+"
379
+
380
+
381
+ def render_pep440(pieces):
382
+ """Build up version string, with post-release "local version identifier".
383
+
384
+ Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
385
+ get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
386
+
387
+ Exceptions:
388
+ 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
389
+ """
390
+ if pieces["closest-tag"]:
391
+ rendered = pieces["closest-tag"]
392
+ if pieces["distance"] or pieces["dirty"]:
393
+ rendered += plus_or_dot(pieces)
394
+ rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
395
+ if pieces["dirty"]:
396
+ rendered += ".dirty"
397
+ else:
398
+ # exception #1
399
+ rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
400
+ if pieces["dirty"]:
401
+ rendered += ".dirty"
402
+ return rendered
403
+
404
+
405
+ def render_pep440_branch(pieces):
406
+ """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
407
+
408
+ The ".dev0" means not master branch. Note that .dev0 sorts backwards
409
+ (a feature branch will appear "older" than the master branch).
410
+
411
+ Exceptions:
412
+ 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
413
+ """
414
+ if pieces["closest-tag"]:
415
+ rendered = pieces["closest-tag"]
416
+ if pieces["distance"] or pieces["dirty"]:
417
+ if pieces["branch"] != "master":
418
+ rendered += ".dev0"
419
+ rendered += plus_or_dot(pieces)
420
+ rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
421
+ if pieces["dirty"]:
422
+ rendered += ".dirty"
423
+ else:
424
+ # exception #1
425
+ rendered = "0"
426
+ if pieces["branch"] != "master":
427
+ rendered += ".dev0"
428
+ rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
429
+ if pieces["dirty"]:
430
+ rendered += ".dirty"
431
+ return rendered
432
+
433
+
434
+ def pep440_split_post(ver):
435
+ """Split pep440 version string at the post-release segment.
436
+
437
+ Returns the release segments before the post-release and the
438
+ post-release version number (or -1 if no post-release segment is present).
439
+ """
440
+ vc = str.split(ver, ".post")
441
+ return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
442
+
443
+
444
+ def render_pep440_pre(pieces):
445
+ """TAG[.postN.devDISTANCE] -- No -dirty.
446
+
447
+ Exceptions:
448
+ 1: no tags. 0.post0.devDISTANCE
449
+ """
450
+ if pieces["closest-tag"]:
451
+ if pieces["distance"]:
452
+ # update the post release segment
453
+ tag_version, post_version = pep440_split_post(pieces["closest-tag"])
454
+ rendered = tag_version
455
+ if post_version is not None:
456
+ rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"])
457
+ else:
458
+ rendered += ".post0.dev%d" % (pieces["distance"])
459
+ else:
460
+ # no commits, use the tag as the version
461
+ rendered = pieces["closest-tag"]
462
+ else:
463
+ # exception #1
464
+ rendered = "0.post0.dev%d" % pieces["distance"]
465
+ return rendered
466
+
467
+
468
+ def render_pep440_post(pieces):
469
+ """TAG[.postDISTANCE[.dev0]+gHEX] .
470
+
471
+ The ".dev0" means dirty. Note that .dev0 sorts backwards
472
+ (a dirty tree will appear "older" than the corresponding clean one),
473
+ but you shouldn't be releasing software with -dirty anyways.
474
+
475
+ Exceptions:
476
+ 1: no tags. 0.postDISTANCE[.dev0]
477
+ """
478
+ if pieces["closest-tag"]:
479
+ rendered = pieces["closest-tag"]
480
+ if pieces["distance"] or pieces["dirty"]:
481
+ rendered += ".post%d" % pieces["distance"]
482
+ if pieces["dirty"]:
483
+ rendered += ".dev0"
484
+ rendered += plus_or_dot(pieces)
485
+ rendered += "g%s" % pieces["short"]
486
+ else:
487
+ # exception #1
488
+ rendered = "0.post%d" % pieces["distance"]
489
+ if pieces["dirty"]:
490
+ rendered += ".dev0"
491
+ rendered += "+g%s" % pieces["short"]
492
+ return rendered
493
+
494
+
495
+ def render_pep440_post_branch(pieces):
496
+ """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
497
+
498
+ The ".dev0" means not master branch.
499
+
500
+ Exceptions:
501
+ 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
502
+ """
503
+ if pieces["closest-tag"]:
504
+ rendered = pieces["closest-tag"]
505
+ if pieces["distance"] or pieces["dirty"]:
506
+ rendered += ".post%d" % pieces["distance"]
507
+ if pieces["branch"] != "master":
508
+ rendered += ".dev0"
509
+ rendered += plus_or_dot(pieces)
510
+ rendered += "g%s" % pieces["short"]
511
+ if pieces["dirty"]:
512
+ rendered += ".dirty"
513
+ else:
514
+ # exception #1
515
+ rendered = "0.post%d" % pieces["distance"]
516
+ if pieces["branch"] != "master":
517
+ rendered += ".dev0"
518
+ rendered += "+g%s" % pieces["short"]
519
+ if pieces["dirty"]:
520
+ rendered += ".dirty"
521
+ return rendered
522
+
523
+
524
+ def render_pep440_old(pieces):
525
+ """TAG[.postDISTANCE[.dev0]] .
526
+
527
+ The ".dev0" means dirty.
528
+
529
+ Exceptions:
530
+ 1: no tags. 0.postDISTANCE[.dev0]
531
+ """
532
+ if pieces["closest-tag"]:
533
+ rendered = pieces["closest-tag"]
534
+ if pieces["distance"] or pieces["dirty"]:
535
+ rendered += ".post%d" % pieces["distance"]
536
+ if pieces["dirty"]:
537
+ rendered += ".dev0"
538
+ else:
539
+ # exception #1
540
+ rendered = "0.post%d" % pieces["distance"]
541
+ if pieces["dirty"]:
542
+ rendered += ".dev0"
543
+ return rendered
544
+
545
+
546
+ def render_git_describe(pieces):
547
+ """TAG[-DISTANCE-gHEX][-dirty].
548
+
549
+ Like 'git describe --tags --dirty --always'.
550
+
551
+ Exceptions:
552
+ 1: no tags. HEX[-dirty] (note: no 'g' prefix)
553
+ """
554
+ if pieces["closest-tag"]:
555
+ rendered = pieces["closest-tag"]
556
+ if pieces["distance"]:
557
+ rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
558
+ else:
559
+ # exception #1
560
+ rendered = pieces["short"]
561
+ if pieces["dirty"]:
562
+ rendered += "-dirty"
563
+ return rendered
564
+
565
+
566
+ def render_git_describe_long(pieces):
567
+ """TAG-DISTANCE-gHEX[-dirty].
568
+
569
+ Like 'git describe --tags --dirty --always -long'.
570
+ The distance/hash is unconditional.
571
+
572
+ Exceptions:
573
+ 1: no tags. HEX[-dirty] (note: no 'g' prefix)
574
+ """
575
+ if pieces["closest-tag"]:
576
+ rendered = pieces["closest-tag"]
577
+ rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
578
+ else:
579
+ # exception #1
580
+ rendered = pieces["short"]
581
+ if pieces["dirty"]:
582
+ rendered += "-dirty"
583
+ return rendered
584
+
585
+
586
+ def render(pieces, style):
587
+ """Render the given version pieces into the requested style."""
588
+ if pieces["error"]:
589
+ return {
590
+ "version": "unknown",
591
+ "full-revisionid": pieces.get("long"),
592
+ "dirty": None,
593
+ "error": pieces["error"],
594
+ "date": None,
595
+ }
596
+
597
+ if not style or style == "default":
598
+ style = "pep440" # the default
599
+
600
+ if style == "pep440":
601
+ rendered = render_pep440(pieces)
602
+ elif style == "pep440-branch":
603
+ rendered = render_pep440_branch(pieces)
604
+ elif style == "pep440-pre":
605
+ rendered = render_pep440_pre(pieces)
606
+ elif style == "pep440-post":
607
+ rendered = render_pep440_post(pieces)
608
+ elif style == "pep440-post-branch":
609
+ rendered = render_pep440_post_branch(pieces)
610
+ elif style == "pep440-old":
611
+ rendered = render_pep440_old(pieces)
612
+ elif style == "git-describe":
613
+ rendered = render_git_describe(pieces)
614
+ elif style == "git-describe-long":
615
+ rendered = render_git_describe_long(pieces)
616
+ else:
617
+ raise ValueError("unknown style '%s'" % style)
618
+
619
+ return {
620
+ "version": rendered,
621
+ "full-revisionid": pieces["long"],
622
+ "dirty": pieces["dirty"],
623
+ "error": None,
624
+ "date": pieces.get("date"),
625
+ }
626
+
627
+
628
+ def get_versions():
629
+ """Get version information or return default if unable to do so."""
630
+ # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
631
+ # __file__, we can work backwards from there to the root. Some
632
+ # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
633
+ # case we can only use expanded keywords.
634
+
635
+ cfg = get_config()
636
+ verbose = cfg.verbose
637
+
638
+ try:
639
+ return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose)
640
+ except NotThisMethod:
641
+ pass
642
+
643
+ try:
644
+ root = os.path.realpath(__file__)
645
+ # versionfile_source is the relative path from the top of the source
646
+ # tree (where the .git directory might live) to this file. Invert
647
+ # this to find the root from __file__.
648
+ for _ in cfg.versionfile_source.split("/"):
649
+ root = os.path.dirname(root)
650
+ except NameError:
651
+ return {
652
+ "version": "0+unknown",
653
+ "full-revisionid": None,
654
+ "dirty": None,
655
+ "error": "unable to find root of source tree",
656
+ "date": None,
657
+ }
658
+
659
+ try:
660
+ pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
661
+ return render(pieces, cfg.style)
662
+ except NotThisMethod:
663
+ pass
664
+
665
+ try:
666
+ if cfg.parentdir_prefix:
667
+ return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
668
+ except NotThisMethod:
669
+ pass
670
+
671
+ return {
672
+ "version": "0+unknown",
673
+ "full-revisionid": None,
674
+ "dirty": None,
675
+ "error": "unable to compute version",
676
+ "date": None,
677
+ }
rembg/bg.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from enum import Enum
3
+ from typing import Any, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ from cv2 import (
7
+ BORDER_DEFAULT,
8
+ MORPH_ELLIPSE,
9
+ MORPH_OPEN,
10
+ GaussianBlur,
11
+ getStructuringElement,
12
+ morphologyEx,
13
+ )
14
+ from PIL import Image, ImageOps
15
+ from PIL.Image import Image as PILImage
16
+ from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
17
+ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
18
+ from pymatting.util.util import stack_images
19
+ from scipy.ndimage import binary_erosion
20
+
21
+ from .session_factory import new_session
22
+ from .sessions import sessions_class
23
+ from .sessions.base import BaseSession
24
+
25
+ kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
26
+
27
+
28
+ class ReturnType(Enum):
29
+ BYTES = 0
30
+ PILLOW = 1
31
+ NDARRAY = 2
32
+
33
+
34
+ def alpha_matting_cutout(
35
+ img: PILImage,
36
+ mask: PILImage,
37
+ foreground_threshold: int,
38
+ background_threshold: int,
39
+ erode_structure_size: int,
40
+ ) -> PILImage:
41
+ if img.mode == "RGBA" or img.mode == "CMYK":
42
+ img = img.convert("RGB")
43
+
44
+ img = np.asarray(img)
45
+ mask = np.asarray(mask)
46
+
47
+ is_foreground = mask > foreground_threshold
48
+ is_background = mask < background_threshold
49
+
50
+ structure = None
51
+ if erode_structure_size > 0:
52
+ structure = np.ones(
53
+ (erode_structure_size, erode_structure_size), dtype=np.uint8
54
+ )
55
+
56
+ is_foreground = binary_erosion(is_foreground, structure=structure)
57
+ is_background = binary_erosion(is_background, structure=structure, border_value=1)
58
+
59
+ trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
60
+ trimap[is_foreground] = 255
61
+ trimap[is_background] = 0
62
+
63
+ img_normalized = img / 255.0
64
+ trimap_normalized = trimap / 255.0
65
+
66
+ alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
67
+ foreground = estimate_foreground_ml(img_normalized, alpha)
68
+ cutout = stack_images(foreground, alpha)
69
+
70
+ cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
71
+ cutout = Image.fromarray(cutout)
72
+
73
+ return cutout
74
+
75
+
76
+ def naive_cutout(img: PILImage, mask: PILImage) -> PILImage:
77
+ empty = Image.new("RGBA", (img.size), 0)
78
+ cutout = Image.composite(img, empty, mask)
79
+ return cutout
80
+
81
+
82
+ def get_concat_v_multi(imgs: List[PILImage]) -> PILImage:
83
+ pivot = imgs.pop(0)
84
+ for im in imgs:
85
+ pivot = get_concat_v(pivot, im)
86
+ return pivot
87
+
88
+
89
+ def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
90
+ dst = Image.new("RGBA", (img1.width, img1.height + img2.height))
91
+ dst.paste(img1, (0, 0))
92
+ dst.paste(img2, (0, img1.height))
93
+ return dst
94
+
95
+
96
+ def post_process(mask: np.ndarray) -> np.ndarray:
97
+ """
98
+ Post Process the mask for a smooth boundary by applying Morphological Operations
99
+ Research based on paper: https://www.sciencedirect.com/science/article/pii/S2352914821000757
100
+ args:
101
+ mask: Binary Numpy Mask
102
+ """
103
+ mask = morphologyEx(mask, MORPH_OPEN, kernel)
104
+ mask = GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=BORDER_DEFAULT)
105
+ mask = np.where(mask < 127, 0, 255).astype(np.uint8) # convert again to binary
106
+ return mask
107
+
108
+
109
+ def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage:
110
+ r, g, b, a = color
111
+ colored_image = Image.new("RGBA", img.size, (r, g, b, a))
112
+ colored_image.paste(img, mask=img)
113
+
114
+ return colored_image
115
+
116
+
117
+ def fix_image_orientation(img: PILImage) -> PILImage:
118
+ return ImageOps.exif_transpose(img)
119
+
120
+
121
+ def download_models() -> None:
122
+ for session in sessions_class:
123
+ session.download_models()
124
+
125
+
126
+ def remove(
127
+ data: Union[bytes, PILImage, np.ndarray],
128
+ alpha_matting: bool = False,
129
+ alpha_matting_foreground_threshold: int = 240,
130
+ alpha_matting_background_threshold: int = 10,
131
+ alpha_matting_erode_size: int = 10,
132
+ session: Optional[BaseSession] = None,
133
+ only_mask: bool = False,
134
+ post_process_mask: bool = False,
135
+ bgcolor: Optional[Tuple[int, int, int, int]] = None,
136
+ *args: Optional[Any],
137
+ **kwargs: Optional[Any]
138
+ ) -> Union[bytes, PILImage, np.ndarray]:
139
+ if isinstance(data, PILImage):
140
+ return_type = ReturnType.PILLOW
141
+ img = data
142
+ elif isinstance(data, bytes):
143
+ return_type = ReturnType.BYTES
144
+ img = Image.open(io.BytesIO(data))
145
+ elif isinstance(data, np.ndarray):
146
+ return_type = ReturnType.NDARRAY
147
+ img = Image.fromarray(data)
148
+ else:
149
+ raise ValueError("Input type {} is not supported.".format(type(data)))
150
+
151
+ # Fix image orientation
152
+ img = fix_image_orientation(img)
153
+
154
+ if session is None:
155
+ session = new_session("u2net", *args, **kwargs)
156
+
157
+ masks = session.predict(img, *args, **kwargs)
158
+ cutouts = []
159
+
160
+ for mask in masks:
161
+ if post_process_mask:
162
+ mask = Image.fromarray(post_process(np.array(mask)))
163
+
164
+ if only_mask:
165
+ cutout = mask
166
+
167
+ elif alpha_matting:
168
+ try:
169
+ cutout = alpha_matting_cutout(
170
+ img,
171
+ mask,
172
+ alpha_matting_foreground_threshold,
173
+ alpha_matting_background_threshold,
174
+ alpha_matting_erode_size,
175
+ )
176
+ except ValueError:
177
+ cutout = naive_cutout(img, mask)
178
+
179
+ else:
180
+ cutout = naive_cutout(img, mask)
181
+
182
+ cutouts.append(cutout)
183
+
184
+ cutout = img
185
+ if len(cutouts) > 0:
186
+ cutout = get_concat_v_multi(cutouts)
187
+
188
+ if bgcolor is not None and not only_mask:
189
+ cutout = apply_background_color(cutout, bgcolor)
190
+
191
+ if ReturnType.PILLOW == return_type:
192
+ return cutout
193
+
194
+ if ReturnType.NDARRAY == return_type:
195
+ return np.asarray(cutout)
196
+
197
+ bio = io.BytesIO()
198
+ cutout.save(bio, "PNG")
199
+ bio.seek(0)
200
+
201
+ return bio.read()
rembg/cli.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+
3
+ from . import _version
4
+ from .commands import command_functions
5
+
6
+
7
+ @click.group()
8
+ @click.version_option(version=_version.get_versions()["version"])
9
+ def main() -> None:
10
+ pass
11
+
12
+
13
+ for command in command_functions:
14
+ main.add_command(command)
rembg/commands/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ from pathlib import Path
3
+ from pkgutil import iter_modules
4
+
5
+ command_functions = []
6
+
7
+ package_dir = Path(__file__).resolve().parent
8
+ for _b, module_name, _p in iter_modules([str(package_dir)]):
9
+ module = import_module(f"{__name__}.{module_name}")
10
+ for attribute_name in dir(module):
11
+ attribute = getattr(module, attribute_name)
12
+ if attribute_name.endswith("_command"):
13
+ command_functions.append(attribute)
rembg/commands/b_command.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import io
3
+ import json
4
+ import os
5
+ import sys
6
+ from typing import IO
7
+
8
+ import click
9
+ from PIL import Image
10
+
11
+ from ..bg import remove
12
+ from ..session_factory import new_session
13
+ from ..sessions import sessions_names
14
+
15
+
16
+ @click.command(
17
+ name="b",
18
+ help="for a byte stream as input",
19
+ )
20
+ @click.option(
21
+ "-m",
22
+ "--model",
23
+ default="u2net",
24
+ type=click.Choice(sessions_names),
25
+ show_default=True,
26
+ show_choices=True,
27
+ help="model name",
28
+ )
29
+ @click.option(
30
+ "-a",
31
+ "--alpha-matting",
32
+ is_flag=True,
33
+ show_default=True,
34
+ help="use alpha matting",
35
+ )
36
+ @click.option(
37
+ "-af",
38
+ "--alpha-matting-foreground-threshold",
39
+ default=240,
40
+ type=int,
41
+ show_default=True,
42
+ help="trimap fg threshold",
43
+ )
44
+ @click.option(
45
+ "-ab",
46
+ "--alpha-matting-background-threshold",
47
+ default=10,
48
+ type=int,
49
+ show_default=True,
50
+ help="trimap bg threshold",
51
+ )
52
+ @click.option(
53
+ "-ae",
54
+ "--alpha-matting-erode-size",
55
+ default=10,
56
+ type=int,
57
+ show_default=True,
58
+ help="erode size",
59
+ )
60
+ @click.option(
61
+ "-om",
62
+ "--only-mask",
63
+ is_flag=True,
64
+ show_default=True,
65
+ help="output only the mask",
66
+ )
67
+ @click.option(
68
+ "-ppm",
69
+ "--post-process-mask",
70
+ is_flag=True,
71
+ show_default=True,
72
+ help="post process the mask",
73
+ )
74
+ @click.option(
75
+ "-bgc",
76
+ "--bgcolor",
77
+ default=None,
78
+ type=(int, int, int, int),
79
+ nargs=4,
80
+ help="Background color (R G B A) to replace the removed background with",
81
+ )
82
+ @click.option("-x", "--extras", type=str)
83
+ @click.option(
84
+ "-o",
85
+ "--output_specifier",
86
+ type=str,
87
+ help="printf-style specifier for output filenames (e.g. 'output-%d.png'))",
88
+ )
89
+ @click.argument(
90
+ "image_width",
91
+ type=int,
92
+ )
93
+ @click.argument(
94
+ "image_height",
95
+ type=int,
96
+ )
97
+ def rs_command(
98
+ model: str,
99
+ extras: str,
100
+ image_width: int,
101
+ image_height: int,
102
+ output_specifier: str,
103
+ **kwargs
104
+ ) -> None:
105
+ try:
106
+ kwargs.update(json.loads(extras))
107
+ except Exception:
108
+ pass
109
+
110
+ session = new_session(model)
111
+ bytes_per_img = image_width * image_height * 3
112
+
113
+ if output_specifier:
114
+ output_dir = os.path.dirname(
115
+ os.path.abspath(os.path.expanduser(output_specifier))
116
+ )
117
+
118
+ if not os.path.isdir(output_dir):
119
+ os.makedirs(output_dir, exist_ok=True)
120
+
121
+ def img_to_byte_array(img: Image) -> bytes:
122
+ buff = io.BytesIO()
123
+ img.save(buff, format="PNG")
124
+ return buff.getvalue()
125
+
126
+ async def connect_stdin_stdout():
127
+ loop = asyncio.get_event_loop()
128
+ reader = asyncio.StreamReader()
129
+ protocol = asyncio.StreamReaderProtocol(reader)
130
+
131
+ await loop.connect_read_pipe(lambda: protocol, sys.stdin)
132
+ w_transport, w_protocol = await loop.connect_write_pipe(
133
+ asyncio.streams.FlowControlMixin, sys.stdout
134
+ )
135
+
136
+ writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop)
137
+ return reader, writer
138
+
139
+ async def main():
140
+ reader, writer = await connect_stdin_stdout()
141
+
142
+ idx = 0
143
+ while True:
144
+ try:
145
+ img_bytes = await reader.readexactly(bytes_per_img)
146
+ if not img_bytes:
147
+ break
148
+
149
+ img = Image.frombytes("RGB", (image_width, image_height), img_bytes)
150
+ output = remove(img, session=session, **kwargs)
151
+
152
+ if output_specifier:
153
+ output.save((output_specifier % idx), format="PNG")
154
+ else:
155
+ writer.write(img_to_byte_array(output))
156
+
157
+ idx += 1
158
+ except asyncio.IncompleteReadError:
159
+ break
160
+
161
+ asyncio.run(main())
rembg/commands/i_command.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+ from typing import IO
4
+
5
+ import click
6
+
7
+ from ..bg import remove
8
+ from ..session_factory import new_session
9
+ from ..sessions import sessions_names
10
+
11
+
12
+ @click.command(
13
+ name="i",
14
+ help="for a file as input",
15
+ )
16
+ @click.option(
17
+ "-m",
18
+ "--model",
19
+ default="u2net",
20
+ type=click.Choice(sessions_names),
21
+ show_default=True,
22
+ show_choices=True,
23
+ help="model name",
24
+ )
25
+ @click.option(
26
+ "-a",
27
+ "--alpha-matting",
28
+ is_flag=True,
29
+ show_default=True,
30
+ help="use alpha matting",
31
+ )
32
+ @click.option(
33
+ "-af",
34
+ "--alpha-matting-foreground-threshold",
35
+ default=240,
36
+ type=int,
37
+ show_default=True,
38
+ help="trimap fg threshold",
39
+ )
40
+ @click.option(
41
+ "-ab",
42
+ "--alpha-matting-background-threshold",
43
+ default=10,
44
+ type=int,
45
+ show_default=True,
46
+ help="trimap bg threshold",
47
+ )
48
+ @click.option(
49
+ "-ae",
50
+ "--alpha-matting-erode-size",
51
+ default=10,
52
+ type=int,
53
+ show_default=True,
54
+ help="erode size",
55
+ )
56
+ @click.option(
57
+ "-om",
58
+ "--only-mask",
59
+ is_flag=True,
60
+ show_default=True,
61
+ help="output only the mask",
62
+ )
63
+ @click.option(
64
+ "-ppm",
65
+ "--post-process-mask",
66
+ is_flag=True,
67
+ show_default=True,
68
+ help="post process the mask",
69
+ )
70
+ @click.option(
71
+ "-bgc",
72
+ "--bgcolor",
73
+ default=None,
74
+ type=(int, int, int, int),
75
+ nargs=4,
76
+ help="Background color (R G B A) to replace the removed background with",
77
+ )
78
+ @click.option("-x", "--extras", type=str)
79
+ @click.argument(
80
+ "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
81
+ )
82
+ @click.argument(
83
+ "output",
84
+ default=(None if sys.stdin.isatty() else "-"),
85
+ type=click.File("wb", lazy=True),
86
+ )
87
+ def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:
88
+ try:
89
+ kwargs.update(json.loads(extras))
90
+ except Exception:
91
+ pass
92
+
93
+ output.write(remove(input.read(), session=new_session(model), **kwargs))
rembg/commands/p_command.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pathlib
3
+ import time
4
+ from typing import cast
5
+
6
+ import click
7
+ import filetype
8
+ from tqdm import tqdm
9
+ from watchdog.events import FileSystemEvent, FileSystemEventHandler
10
+ from watchdog.observers import Observer
11
+
12
+ from ..bg import remove
13
+ from ..session_factory import new_session
14
+ from ..sessions import sessions_names
15
+
16
+
17
+ @click.command(
18
+ name="p",
19
+ help="for a folder as input",
20
+ )
21
+ @click.option(
22
+ "-m",
23
+ "--model",
24
+ default="u2net",
25
+ type=click.Choice(sessions_names),
26
+ show_default=True,
27
+ show_choices=True,
28
+ help="model name",
29
+ )
30
+ @click.option(
31
+ "-a",
32
+ "--alpha-matting",
33
+ is_flag=True,
34
+ show_default=True,
35
+ help="use alpha matting",
36
+ )
37
+ @click.option(
38
+ "-af",
39
+ "--alpha-matting-foreground-threshold",
40
+ default=240,
41
+ type=int,
42
+ show_default=True,
43
+ help="trimap fg threshold",
44
+ )
45
+ @click.option(
46
+ "-ab",
47
+ "--alpha-matting-background-threshold",
48
+ default=10,
49
+ type=int,
50
+ show_default=True,
51
+ help="trimap bg threshold",
52
+ )
53
+ @click.option(
54
+ "-ae",
55
+ "--alpha-matting-erode-size",
56
+ default=10,
57
+ type=int,
58
+ show_default=True,
59
+ help="erode size",
60
+ )
61
+ @click.option(
62
+ "-om",
63
+ "--only-mask",
64
+ is_flag=True,
65
+ show_default=True,
66
+ help="output only the mask",
67
+ )
68
+ @click.option(
69
+ "-ppm",
70
+ "--post-process-mask",
71
+ is_flag=True,
72
+ show_default=True,
73
+ help="post process the mask",
74
+ )
75
+ @click.option(
76
+ "-w",
77
+ "--watch",
78
+ default=False,
79
+ is_flag=True,
80
+ show_default=True,
81
+ help="watches a folder for changes",
82
+ )
83
+ @click.option(
84
+ "-bgc",
85
+ "--bgcolor",
86
+ default=None,
87
+ type=(int, int, int, int),
88
+ nargs=4,
89
+ help="Background color (R G B A) to replace the removed background with",
90
+ )
91
+ @click.option("-x", "--extras", type=str)
92
+ @click.argument(
93
+ "input",
94
+ type=click.Path(
95
+ exists=True,
96
+ path_type=pathlib.Path,
97
+ file_okay=False,
98
+ dir_okay=True,
99
+ readable=True,
100
+ ),
101
+ )
102
+ @click.argument(
103
+ "output",
104
+ type=click.Path(
105
+ exists=False,
106
+ path_type=pathlib.Path,
107
+ file_okay=False,
108
+ dir_okay=True,
109
+ writable=True,
110
+ ),
111
+ )
112
+ def p_command(
113
+ model: str,
114
+ extras: str,
115
+ input: pathlib.Path,
116
+ output: pathlib.Path,
117
+ watch: bool,
118
+ **kwargs,
119
+ ) -> None:
120
+ try:
121
+ kwargs.update(json.loads(extras))
122
+ except Exception:
123
+ pass
124
+
125
+ session = new_session(model)
126
+
127
+ def process(each_input: pathlib.Path) -> None:
128
+ try:
129
+ mimetype = filetype.guess(each_input)
130
+ if mimetype is None:
131
+ return
132
+ if mimetype.mime.find("image") < 0:
133
+ return
134
+
135
+ each_output = (output / each_input.name).with_suffix(".png")
136
+ each_output.parents[0].mkdir(parents=True, exist_ok=True)
137
+
138
+ if not each_output.exists():
139
+ each_output.write_bytes(
140
+ cast(
141
+ bytes,
142
+ remove(each_input.read_bytes(), session=session, **kwargs),
143
+ )
144
+ )
145
+
146
+ if watch:
147
+ print(
148
+ f"processed: {each_input.absolute()} -> {each_output.absolute()}"
149
+ )
150
+ except Exception as e:
151
+ print(e)
152
+
153
+ inputs = list(input.glob("**/*"))
154
+ if not watch:
155
+ inputs = tqdm(inputs)
156
+
157
+ for each_input in inputs:
158
+ if not each_input.is_dir():
159
+ process(each_input)
160
+
161
+ if watch:
162
+ observer = Observer()
163
+
164
+ class EventHandler(FileSystemEventHandler):
165
+ def on_any_event(self, event: FileSystemEvent) -> None:
166
+ if not (
167
+ event.is_directory or event.event_type in ["deleted", "closed"]
168
+ ):
169
+ process(pathlib.Path(event.src_path))
170
+
171
+ event_handler = EventHandler()
172
+ observer.schedule(event_handler, input, recursive=False)
173
+ observer.start()
174
+
175
+ try:
176
+ while True:
177
+ time.sleep(1)
178
+
179
+ finally:
180
+ observer.stop()
181
+ observer.join()
rembg/commands/s_command.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import webbrowser
4
+ from typing import Optional, Tuple, cast
5
+
6
+ import aiohttp
7
+ import click
8
+ import gradio as gr
9
+ import uvicorn
10
+ from asyncer import asyncify
11
+ from fastapi import Depends, FastAPI, File, Form, Query
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+ from starlette.responses import Response
14
+
15
+ from .._version import get_versions
16
+ from ..bg import remove
17
+ from ..session_factory import new_session
18
+ from ..sessions import sessions_names
19
+ from ..sessions.base import BaseSession
20
+
21
+
22
+ @click.command(
23
+ name="s",
24
+ help="for a http server",
25
+ )
26
+ @click.option(
27
+ "-p",
28
+ "--port",
29
+ default=5000,
30
+ type=int,
31
+ show_default=True,
32
+ help="port",
33
+ )
34
+ @click.option(
35
+ "-l",
36
+ "--log_level",
37
+ default="info",
38
+ type=str,
39
+ show_default=True,
40
+ help="log level",
41
+ )
42
+ @click.option(
43
+ "-t",
44
+ "--threads",
45
+ default=None,
46
+ type=int,
47
+ show_default=True,
48
+ help="number of worker threads",
49
+ )
50
+ def s_command(port: int, log_level: str, threads: int) -> None:
51
+ sessions: dict[str, BaseSession] = {}
52
+ tags_metadata = [
53
+ {
54
+ "name": "Background Removal",
55
+ "description": "Endpoints that perform background removal with different image sources.",
56
+ "externalDocs": {
57
+ "description": "GitHub Source",
58
+ "url": "https://github.com/danielgatis/rembg",
59
+ },
60
+ },
61
+ ]
62
+ app = FastAPI(
63
+ title="Rembg",
64
+ description="Rembg is a tool to remove images background. That is it.",
65
+ version=get_versions()["version"],
66
+ contact={
67
+ "name": "Daniel Gatis",
68
+ "url": "https://github.com/danielgatis",
69
+ "email": "[email protected]",
70
+ },
71
+ license_info={
72
+ "name": "MIT License",
73
+ "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
74
+ },
75
+ openapi_tags=tags_metadata,
76
+ docs_url="/api",
77
+ )
78
+
79
+ app.add_middleware(
80
+ CORSMiddleware,
81
+ allow_credentials=True,
82
+ allow_origins=["*"],
83
+ allow_methods=["*"],
84
+ allow_headers=["*"],
85
+ )
86
+
87
+ class CommonQueryParams:
88
+ def __init__(
89
+ self,
90
+ model: str = Query(
91
+ description="Model to use when processing image",
92
+ regex=r"(" + "|".join(sessions_names) + ")",
93
+ default="u2net",
94
+ ),
95
+ a: bool = Query(default=False, description="Enable Alpha Matting"),
96
+ af: int = Query(
97
+ default=240,
98
+ ge=0,
99
+ le=255,
100
+ description="Alpha Matting (Foreground Threshold)",
101
+ ),
102
+ ab: int = Query(
103
+ default=10,
104
+ ge=0,
105
+ le=255,
106
+ description="Alpha Matting (Background Threshold)",
107
+ ),
108
+ ae: int = Query(
109
+ default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
110
+ ),
111
+ om: bool = Query(default=False, description="Only Mask"),
112
+ ppm: bool = Query(default=False, description="Post Process Mask"),
113
+ bgc: Optional[str] = Query(default=None, description="Background Color"),
114
+ extras: Optional[str] = Query(
115
+ default=None, description="Extra parameters as JSON"
116
+ ),
117
+ ):
118
+ self.model = model
119
+ self.a = a
120
+ self.af = af
121
+ self.ab = ab
122
+ self.ae = ae
123
+ self.om = om
124
+ self.ppm = ppm
125
+ self.extras = extras
126
+ self.bgc = (
127
+ cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
128
+ if bgc
129
+ else None
130
+ )
131
+
132
+ class CommonQueryPostParams:
133
+ def __init__(
134
+ self,
135
+ model: str = Form(
136
+ description="Model to use when processing image",
137
+ regex=r"(" + "|".join(sessions_names) + ")",
138
+ default="u2net",
139
+ ),
140
+ a: bool = Form(default=False, description="Enable Alpha Matting"),
141
+ af: int = Form(
142
+ default=240,
143
+ ge=0,
144
+ le=255,
145
+ description="Alpha Matting (Foreground Threshold)",
146
+ ),
147
+ ab: int = Form(
148
+ default=10,
149
+ ge=0,
150
+ le=255,
151
+ description="Alpha Matting (Background Threshold)",
152
+ ),
153
+ ae: int = Form(
154
+ default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
155
+ ),
156
+ om: bool = Form(default=False, description="Only Mask"),
157
+ ppm: bool = Form(default=False, description="Post Process Mask"),
158
+ bgc: Optional[str] = Query(default=None, description="Background Color"),
159
+ extras: Optional[str] = Query(
160
+ default=None, description="Extra parameters as JSON"
161
+ ),
162
+ ):
163
+ self.model = model
164
+ self.a = a
165
+ self.af = af
166
+ self.ab = ab
167
+ self.ae = ae
168
+ self.om = om
169
+ self.ppm = ppm
170
+ self.extras = extras
171
+ self.bgc = (
172
+ cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
173
+ if bgc
174
+ else None
175
+ )
176
+
177
+ def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
178
+ kwargs = {}
179
+
180
+ if commons.extras:
181
+ try:
182
+ kwargs.update(json.loads(commons.extras))
183
+ except Exception:
184
+ pass
185
+
186
+ return Response(
187
+ remove(
188
+ content,
189
+ session=sessions.setdefault(commons.model, new_session(commons.model)),
190
+ alpha_matting=commons.a,
191
+ alpha_matting_foreground_threshold=commons.af,
192
+ alpha_matting_background_threshold=commons.ab,
193
+ alpha_matting_erode_size=commons.ae,
194
+ only_mask=commons.om,
195
+ post_process_mask=commons.ppm,
196
+ bgcolor=commons.bgc,
197
+ **kwargs,
198
+ ),
199
+ media_type="image/png",
200
+ )
201
+
202
+ @app.on_event("startup")
203
+ def startup():
204
+ try:
205
+ webbrowser.open(f"http://localhost:{port}")
206
+ except Exception:
207
+ pass
208
+
209
+ if threads is not None:
210
+ from anyio import CapacityLimiter
211
+ from anyio.lowlevel import RunVar
212
+
213
+ RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
214
+
215
+ @app.get(
216
+ path="/api/remove",
217
+ tags=["Background Removal"],
218
+ summary="Remove from URL",
219
+ description="Removes the background from an image obtained by retrieving an URL.",
220
+ )
221
+ async def get_index(
222
+ url: str = Query(
223
+ default=..., description="URL of the image that has to be processed."
224
+ ),
225
+ commons: CommonQueryParams = Depends(),
226
+ ):
227
+ async with aiohttp.ClientSession() as session:
228
+ async with session.get(url) as response:
229
+ file = await response.read()
230
+ return await asyncify(im_without_bg)(file, commons)
231
+
232
+ @app.post(
233
+ path="/api/remove",
234
+ tags=["Background Removal"],
235
+ summary="Remove from Stream",
236
+ description="Removes the background from an image sent within the request itself.",
237
+ )
238
+ async def post_index(
239
+ file: bytes = File(
240
+ default=...,
241
+ description="Image file (byte stream) that has to be processed.",
242
+ ),
243
+ commons: CommonQueryPostParams = Depends(),
244
+ ):
245
+ return await asyncify(im_without_bg)(file, commons) # type: ignore
246
+
247
+ def gr_app(app):
248
+ def inference(input_path, model):
249
+ output_path = "output.png"
250
+ with open(input_path, "rb") as i:
251
+ with open(output_path, "wb") as o:
252
+ input = i.read()
253
+ output = remove(input, session=new_session(model))
254
+ o.write(output)
255
+ return os.path.join(output_path)
256
+
257
+ interface = gr.Interface(
258
+ inference,
259
+ [
260
+ gr.components.Image(type="filepath", label="Input"),
261
+ gr.components.Dropdown(
262
+ [
263
+ "u2net",
264
+ "u2netp",
265
+ "u2net_human_seg",
266
+ "u2net_cloth_seg",
267
+ "silueta",
268
+ "isnet-general-use",
269
+ "isnet-anime",
270
+ ],
271
+ value="u2net",
272
+ label="Models",
273
+ ),
274
+ ],
275
+ gr.components.Image(type="filepath", label="Output"),
276
+ )
277
+
278
+ interface.queue(concurrency_count=3)
279
+ app = gr.mount_gradio_app(app, interface, path="/")
280
+ return app
281
+
282
+ print(f"To access the API documentation, go to http://localhost:{port}/api")
283
+ print(f"To access the UI, go to http://localhost:{port}")
284
+
285
+ uvicorn.run(gr_app(app), host="0.0.0.0", port=port, log_level=log_level)
rembg/session_base.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple
2
+
3
+ import numpy as np
4
+ import onnxruntime as ort
5
+ from PIL import Image
6
+ from PIL.Image import Image as PILImage
7
+
8
+
9
+ class BaseSession:
10
+ def __init__(self, model_name: str, inner_session: ort.InferenceSession):
11
+ self.model_name = model_name
12
+ self.inner_session = inner_session
13
+
14
+ def normalize(
15
+ self,
16
+ img: PILImage,
17
+ mean: Tuple[float, float, float],
18
+ std: Tuple[float, float, float],
19
+ size: Tuple[int, int],
20
+ ) -> Dict[str, np.ndarray]:
21
+ im = img.convert("RGB").resize(size, Image.LANCZOS)
22
+
23
+ im_ary = np.array(im)
24
+ im_ary = im_ary / np.max(im_ary)
25
+
26
+ tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
27
+ tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
28
+ tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1]
29
+ tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2]
30
+
31
+ tmpImg = tmpImg.transpose((2, 0, 1))
32
+
33
+ return {
34
+ self.inner_session.get_inputs()[0]
35
+ .name: np.expand_dims(tmpImg, 0)
36
+ .astype(np.float32)
37
+ }
38
+
39
+ def predict(self, img: PILImage) -> List[PILImage]:
40
+ raise NotImplementedError
rembg/session_cloth.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ from PIL.Image import Image as PILImage
6
+ from scipy.special import log_softmax
7
+
8
+ from .session_base import BaseSession
9
+
10
+ pallete1 = [
11
+ 0,
12
+ 0,
13
+ 0,
14
+ 255,
15
+ 255,
16
+ 255,
17
+ 0,
18
+ 0,
19
+ 0,
20
+ 0,
21
+ 0,
22
+ 0,
23
+ ]
24
+
25
+ pallete2 = [
26
+ 0,
27
+ 0,
28
+ 0,
29
+ 0,
30
+ 0,
31
+ 0,
32
+ 255,
33
+ 255,
34
+ 255,
35
+ 0,
36
+ 0,
37
+ 0,
38
+ ]
39
+
40
+ pallete3 = [
41
+ 0,
42
+ 0,
43
+ 0,
44
+ 0,
45
+ 0,
46
+ 0,
47
+ 0,
48
+ 0,
49
+ 0,
50
+ 255,
51
+ 255,
52
+ 255,
53
+ ]
54
+
55
+
56
+ class ClothSession(BaseSession):
57
+ def predict(self, img: PILImage) -> List[PILImage]:
58
+ ort_outs = self.inner_session.run(
59
+ None, self.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), (768, 768))
60
+ )
61
+
62
+ pred = ort_outs
63
+ pred = log_softmax(pred[0], 1)
64
+ pred = np.argmax(pred, axis=1, keepdims=True)
65
+ pred = np.squeeze(pred, 0)
66
+ pred = np.squeeze(pred, 0)
67
+
68
+ mask = Image.fromarray(pred.astype("uint8"), mode="L")
69
+ mask = mask.resize(img.size, Image.LANCZOS)
70
+
71
+ masks = []
72
+
73
+ mask1 = mask.copy()
74
+ mask1.putpalette(pallete1)
75
+ mask1 = mask1.convert("RGB").convert("L")
76
+ masks.append(mask1)
77
+
78
+ mask2 = mask.copy()
79
+ mask2.putpalette(pallete2)
80
+ mask2 = mask2.convert("RGB").convert("L")
81
+ masks.append(mask2)
82
+
83
+ mask3 = mask.copy()
84
+ mask3.putpalette(pallete3)
85
+ mask3 = mask3.convert("RGB").convert("L")
86
+ masks.append(mask3)
87
+
88
+ return masks
rembg/session_factory.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Type
3
+
4
+ import onnxruntime as ort
5
+
6
+ from .sessions import sessions_class
7
+ from .sessions.base import BaseSession
8
+ from .sessions.u2net import U2netSession
9
+
10
+
11
+ def new_session(
12
+ model_name: str = "u2net", providers=None, *args, **kwargs
13
+ ) -> BaseSession:
14
+ session_class: Type[BaseSession] = U2netSession
15
+
16
+ for sc in sessions_class:
17
+ if sc.name() == model_name:
18
+ session_class = sc
19
+ break
20
+
21
+ sess_opts = ort.SessionOptions()
22
+
23
+ if "OMP_NUM_THREADS" in os.environ:
24
+ sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
25
+
26
+ return session_class(model_name, sess_opts, providers, *args, **kwargs)
rembg/session_simple.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ from PIL.Image import Image as PILImage
6
+
7
+ from .session_base import BaseSession
8
+
9
+
10
+ class SimpleSession(BaseSession):
11
+ def predict(self, img: PILImage) -> List[PILImage]:
12
+ ort_outs = self.inner_session.run(
13
+ None,
14
+ self.normalize(
15
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
16
+ ),
17
+ )
18
+
19
+ pred = ort_outs[0][:, 0, :, :]
20
+
21
+ ma = np.max(pred)
22
+ mi = np.min(pred)
23
+
24
+ pred = (pred - mi) / (ma - mi)
25
+ pred = np.squeeze(pred)
26
+
27
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
28
+ mask = mask.resize(img.size, Image.LANCZOS)
29
+
30
+ return [mask]
rembg/sessions/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ from inspect import isclass
3
+ from pathlib import Path
4
+ from pkgutil import iter_modules
5
+
6
+ from .base import BaseSession
7
+
8
+ sessions_class = []
9
+ sessions_names = []
10
+
11
+ package_dir = Path(__file__).resolve().parent
12
+ for _b, module_name, _p in iter_modules([str(package_dir)]):
13
+ module = import_module(f"{__name__}.{module_name}")
14
+ for attribute_name in dir(module):
15
+ attribute = getattr(module, attribute_name)
16
+ if (
17
+ isclass(attribute)
18
+ and issubclass(attribute, BaseSession)
19
+ and attribute != BaseSession
20
+ ):
21
+ sessions_class.append(attribute)
22
+ sessions_names.append(attribute.name())
rembg/sessions/base.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Tuple
3
+
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+
10
+ class BaseSession:
11
+ def __init__(
12
+ self,
13
+ model_name: str,
14
+ sess_opts: ort.SessionOptions,
15
+ providers=None,
16
+ *args,
17
+ **kwargs
18
+ ):
19
+ self.model_name = model_name
20
+
21
+ self.providers = []
22
+
23
+ _providers = ort.get_available_providers()
24
+ if providers:
25
+ for provider in providers:
26
+ if provider in _providers:
27
+ self.providers.append(provider)
28
+ else:
29
+ self.providers.extend(_providers)
30
+
31
+ self.inner_session = ort.InferenceSession(
32
+ str(self.__class__.download_models()),
33
+ providers=self.providers,
34
+ sess_options=sess_opts,
35
+ )
36
+
37
+ def normalize(
38
+ self,
39
+ img: PILImage,
40
+ mean: Tuple[float, float, float],
41
+ std: Tuple[float, float, float],
42
+ size: Tuple[int, int],
43
+ *args,
44
+ **kwargs
45
+ ) -> Dict[str, np.ndarray]:
46
+ im = img.convert("RGB").resize(size, Image.LANCZOS)
47
+
48
+ im_ary = np.array(im)
49
+ im_ary = im_ary / np.max(im_ary)
50
+
51
+ tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
52
+ tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
53
+ tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1]
54
+ tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2]
55
+
56
+ tmpImg = tmpImg.transpose((2, 0, 1))
57
+
58
+ return {
59
+ self.inner_session.get_inputs()[0]
60
+ .name: np.expand_dims(tmpImg, 0)
61
+ .astype(np.float32)
62
+ }
63
+
64
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
65
+ raise NotImplementedError
66
+
67
+ @classmethod
68
+ def checksum_disabled(cls, *args, **kwargs):
69
+ return os.getenv("MODEL_CHECKSUM_DISABLED", None) is not None
70
+
71
+ @classmethod
72
+ def u2net_home(cls, *args, **kwargs):
73
+ return os.path.expanduser(
74
+ os.getenv(
75
+ "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
76
+ )
77
+ )
78
+
79
+ @classmethod
80
+ def download_models(cls, *args, **kwargs):
81
+ raise NotImplementedError
82
+
83
+ @classmethod
84
+ def name(cls, *args, **kwargs):
85
+ raise NotImplementedError
rembg/sessions/dis.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class DisSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
17
+ )
18
+
19
+ pred = ort_outs[0][:, 0, :, :]
20
+
21
+ ma = np.max(pred)
22
+ mi = np.min(pred)
23
+
24
+ pred = (pred - mi) / (ma - mi)
25
+ pred = np.squeeze(pred)
26
+
27
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
28
+ mask = mask.resize(img.size, Image.LANCZOS)
29
+
30
+ return [mask]
31
+
32
+ @classmethod
33
+ def download_models(cls, *args, **kwargs):
34
+ fname = f"{cls.name()}.onnx"
35
+ pooch.retrieve(
36
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
37
+ "md5:fc16ebd8b0c10d971d3513d564d01e29",
38
+ fname=fname,
39
+ path=cls.u2net_home(),
40
+ progressbar=True,
41
+ )
42
+
43
+ return os.path.join(cls.u2net_home(), fname)
44
+
45
+ @classmethod
46
+ def name(cls, *args, **kwargs):
47
+ return "isnet-general-use"
rembg/sessions/dis_anime.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class DisSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
17
+ )
18
+
19
+ pred = ort_outs[0][:, 0, :, :]
20
+
21
+ ma = np.max(pred)
22
+ mi = np.min(pred)
23
+
24
+ pred = (pred - mi) / (ma - mi)
25
+ pred = np.squeeze(pred)
26
+
27
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
28
+ mask = mask.resize(img.size, Image.LANCZOS)
29
+
30
+ return [mask]
31
+
32
+ @classmethod
33
+ def download_models(cls, *args, **kwargs):
34
+ fname = f"{cls.name()}.onnx"
35
+ pooch.retrieve(
36
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx",
37
+ None
38
+ if cls.checksum_disabled(*args, **kwargs)
39
+ else "md5:6f184e756bb3bd901c8849220a83e38e",
40
+ fname=fname,
41
+ path=cls.u2net_home(*args, **kwargs),
42
+ progressbar=True,
43
+ )
44
+
45
+ return os.path.join(cls.u2net_home(), fname)
46
+
47
+ @classmethod
48
+ def name(cls, *args, **kwargs):
49
+ return "isnet-anime"
rembg/sessions/dis_general_use.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class DisSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
17
+ )
18
+
19
+ pred = ort_outs[0][:, 0, :, :]
20
+
21
+ ma = np.max(pred)
22
+ mi = np.min(pred)
23
+
24
+ pred = (pred - mi) / (ma - mi)
25
+ pred = np.squeeze(pred)
26
+
27
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
28
+ mask = mask.resize(img.size, Image.LANCZOS)
29
+
30
+ return [mask]
31
+
32
+ @classmethod
33
+ def download_models(cls, *args, **kwargs):
34
+ fname = f"{cls.name()}.onnx"
35
+ pooch.retrieve(
36
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
37
+ None
38
+ if cls.checksum_disabled(*args, **kwargs)
39
+ else "md5:fc16ebd8b0c10d971d3513d564d01e29",
40
+ fname=fname,
41
+ path=cls.u2net_home(*args, **kwargs),
42
+ progressbar=True,
43
+ )
44
+
45
+ return os.path.join(cls.u2net_home(), fname)
46
+
47
+ @classmethod
48
+ def name(cls, *args, **kwargs):
49
+ return "isnet-general-use"
rembg/sessions/sam.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ import pooch
7
+ from PIL import Image
8
+ from PIL.Image import Image as PILImage
9
+
10
+ from .base import BaseSession
11
+
12
+
13
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
14
+ scale = long_side_length * 1.0 / max(oldh, oldw)
15
+ newh, neww = oldh * scale, oldw * scale
16
+ neww = int(neww + 0.5)
17
+ newh = int(newh + 0.5)
18
+ return (newh, neww)
19
+
20
+
21
+ def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray:
22
+ old_h, old_w = original_size
23
+ new_h, new_w = get_preprocess_shape(
24
+ original_size[0], original_size[1], target_length
25
+ )
26
+ coords = coords.copy().astype(float)
27
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
28
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
29
+ return coords
30
+
31
+
32
+ def resize_longes_side(img: PILImage, size=1024):
33
+ w, h = img.size
34
+ if h > w:
35
+ new_h, new_w = size, int(w * size / h)
36
+ else:
37
+ new_h, new_w = int(h * size / w), size
38
+
39
+ return img.resize((new_w, new_h))
40
+
41
+
42
+ def pad_to_square(img: np.ndarray, size=1024):
43
+ h, w = img.shape[:2]
44
+ padh = size - h
45
+ padw = size - w
46
+ img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant")
47
+ img = img.astype(np.float32)
48
+ return img
49
+
50
+
51
+ class SamSession(BaseSession):
52
+ def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
53
+ self.model_name = model_name
54
+ paths = self.__class__.download_models()
55
+ self.encoder = ort.InferenceSession(
56
+ str(paths[0]),
57
+ providers=ort.get_available_providers(),
58
+ sess_options=sess_opts,
59
+ )
60
+ self.decoder = ort.InferenceSession(
61
+ str(paths[1]),
62
+ providers=ort.get_available_providers(),
63
+ sess_options=sess_opts,
64
+ )
65
+
66
+ def normalize(
67
+ self,
68
+ img: np.ndarray,
69
+ mean=(123.675, 116.28, 103.53),
70
+ std=(58.395, 57.12, 57.375),
71
+ size=(1024, 1024),
72
+ *args,
73
+ **kwargs,
74
+ ):
75
+ pixel_mean = np.array([*mean]).reshape(1, 1, -1)
76
+ pixel_std = np.array([*std]).reshape(1, 1, -1)
77
+ x = (img - pixel_mean) / pixel_std
78
+ return x
79
+
80
+ def predict(
81
+ self,
82
+ img: PILImage,
83
+ *args,
84
+ **kwargs,
85
+ ) -> List[PILImage]:
86
+ # Preprocess image
87
+ image = resize_longes_side(img)
88
+ image = np.array(image)
89
+ image = self.normalize(image)
90
+ image = pad_to_square(image)
91
+
92
+ input_labels = kwargs.get("input_labels")
93
+ input_points = kwargs.get("input_points")
94
+
95
+ if input_labels is None:
96
+ raise ValueError("input_labels is required")
97
+ if input_points is None:
98
+ raise ValueError("input_points is required")
99
+
100
+ # Transpose
101
+ image = image.transpose(2, 0, 1)[None, :, :, :]
102
+ # Run encoder (Image embedding)
103
+ encoded = self.encoder.run(None, {"x": image})
104
+ image_embedding = encoded[0]
105
+
106
+ # Add a batch index, concatenate a padding point, and transform.
107
+ onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
108
+ None, :, :
109
+ ]
110
+ onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
111
+ None, :
112
+ ].astype(np.float32)
113
+ onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32)
114
+
115
+ # Create an empty mask input and an indicator for no mask.
116
+ onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
117
+ onnx_has_mask_input = np.zeros(1, dtype=np.float32)
118
+
119
+ decoder_inputs = {
120
+ "image_embeddings": image_embedding,
121
+ "point_coords": onnx_coord,
122
+ "point_labels": onnx_label,
123
+ "mask_input": onnx_mask_input,
124
+ "has_mask_input": onnx_has_mask_input,
125
+ "orig_im_size": np.array(img.size[::-1], dtype=np.float32),
126
+ }
127
+
128
+ masks, _, low_res_logits = self.decoder.run(None, decoder_inputs)
129
+ masks = masks > 0.0
130
+ masks = [
131
+ Image.fromarray((masks[i, 0] * 255).astype(np.uint8))
132
+ for i in range(masks.shape[0])
133
+ ]
134
+
135
+ return masks
136
+
137
+ @classmethod
138
+ def download_models(cls, *args, **kwargs):
139
+ fname_encoder = f"{cls.name()}_encoder.onnx"
140
+ fname_decoder = f"{cls.name()}_decoder.onnx"
141
+
142
+ pooch.retrieve(
143
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
144
+ None
145
+ if cls.checksum_disabled(*args, **kwargs)
146
+ else "md5:13d97c5c79ab13ef86d67cbde5f1b250",
147
+ fname=fname_encoder,
148
+ path=cls.u2net_home(*args, **kwargs),
149
+ progressbar=True,
150
+ )
151
+
152
+ pooch.retrieve(
153
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
154
+ None
155
+ if cls.checksum_disabled(*args, **kwargs)
156
+ else "md5:fa3d1c36a3187d3de1c8deebf33dd127",
157
+ fname=fname_decoder,
158
+ path=cls.u2net_home(*args, **kwargs),
159
+ progressbar=True,
160
+ )
161
+
162
+ return (
163
+ os.path.join(cls.u2net_home(), fname_encoder),
164
+ os.path.join(cls.u2net_home(), fname_decoder),
165
+ )
166
+
167
+ @classmethod
168
+ def name(cls, *args, **kwargs):
169
+ return "sam"
rembg/sessions/silueta.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class SiluetaSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(
17
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
18
+ ),
19
+ )
20
+
21
+ pred = ort_outs[0][:, 0, :, :]
22
+
23
+ ma = np.max(pred)
24
+ mi = np.min(pred)
25
+
26
+ pred = (pred - mi) / (ma - mi)
27
+ pred = np.squeeze(pred)
28
+
29
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
30
+ mask = mask.resize(img.size, Image.LANCZOS)
31
+
32
+ return [mask]
33
+
34
+ @classmethod
35
+ def download_models(cls, *args, **kwargs):
36
+ fname = f"{cls.name()}.onnx"
37
+ pooch.retrieve(
38
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
39
+ None
40
+ if cls.checksum_disabled(*args, **kwargs)
41
+ else "md5:55e59e0d8062d2f5d013f4725ee84782",
42
+ fname=fname,
43
+ path=cls.u2net_home(*args, **kwargs),
44
+ progressbar=True,
45
+ )
46
+
47
+ return os.path.join(cls.u2net_home(), fname)
48
+
49
+ @classmethod
50
+ def name(cls, *args, **kwargs):
51
+ return "silueta"
rembg/sessions/u2net.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class U2netSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(
17
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
18
+ ),
19
+ )
20
+
21
+ pred = ort_outs[0][:, 0, :, :]
22
+
23
+ ma = np.max(pred)
24
+ mi = np.min(pred)
25
+
26
+ pred = (pred - mi) / (ma - mi)
27
+ pred = np.squeeze(pred)
28
+
29
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
30
+ mask = mask.resize(img.size, Image.LANCZOS)
31
+
32
+ return [mask]
33
+
34
+ @classmethod
35
+ def download_models(cls, *args, **kwargs):
36
+ fname = f"{cls.name()}.onnx"
37
+ pooch.retrieve(
38
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
39
+ None
40
+ if cls.checksum_disabled(*args, **kwargs)
41
+ else "md5:60024c5c889badc19c04ad937298a77b",
42
+ fname=fname,
43
+ path=cls.u2net_home(*args, **kwargs),
44
+ progressbar=True,
45
+ )
46
+
47
+ return os.path.join(cls.u2net_home(), fname)
48
+
49
+ @classmethod
50
+ def name(cls, *args, **kwargs):
51
+ return "u2net"
rembg/sessions/u2net_cloth_seg.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+ from scipy.special import log_softmax
9
+
10
+ from .base import BaseSession
11
+
12
+ pallete1 = [
13
+ 0,
14
+ 0,
15
+ 0,
16
+ 255,
17
+ 255,
18
+ 255,
19
+ 0,
20
+ 0,
21
+ 0,
22
+ 0,
23
+ 0,
24
+ 0,
25
+ ]
26
+
27
+ pallete2 = [
28
+ 0,
29
+ 0,
30
+ 0,
31
+ 0,
32
+ 0,
33
+ 0,
34
+ 255,
35
+ 255,
36
+ 255,
37
+ 0,
38
+ 0,
39
+ 0,
40
+ ]
41
+
42
+ pallete3 = [
43
+ 0,
44
+ 0,
45
+ 0,
46
+ 0,
47
+ 0,
48
+ 0,
49
+ 0,
50
+ 0,
51
+ 0,
52
+ 255,
53
+ 255,
54
+ 255,
55
+ ]
56
+
57
+
58
+ class Unet2ClothSession(BaseSession):
59
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
60
+ ort_outs = self.inner_session.run(
61
+ None,
62
+ self.normalize(
63
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (768, 768)
64
+ ),
65
+ )
66
+
67
+ pred = ort_outs
68
+ pred = log_softmax(pred[0], 1)
69
+ pred = np.argmax(pred, axis=1, keepdims=True)
70
+ pred = np.squeeze(pred, 0)
71
+ pred = np.squeeze(pred, 0)
72
+
73
+ mask = Image.fromarray(pred.astype("uint8"), mode="L")
74
+ mask = mask.resize(img.size, Image.LANCZOS)
75
+
76
+ masks = []
77
+
78
+ mask1 = mask.copy()
79
+ mask1.putpalette(pallete1)
80
+ mask1 = mask1.convert("RGB").convert("L")
81
+ masks.append(mask1)
82
+
83
+ mask2 = mask.copy()
84
+ mask2.putpalette(pallete2)
85
+ mask2 = mask2.convert("RGB").convert("L")
86
+ masks.append(mask2)
87
+
88
+ mask3 = mask.copy()
89
+ mask3.putpalette(pallete3)
90
+ mask3 = mask3.convert("RGB").convert("L")
91
+ masks.append(mask3)
92
+
93
+ return masks
94
+
95
+ @classmethod
96
+ def download_models(cls, *args, **kwargs):
97
+ fname = f"{cls.name()}.onnx"
98
+ pooch.retrieve(
99
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
100
+ None
101
+ if cls.checksum_disabled(*args, **kwargs)
102
+ else "md5:2434d1f3cb744e0e49386c906e5a08bb",
103
+ fname=fname,
104
+ path=cls.u2net_home(*args, **kwargs),
105
+ progressbar=True,
106
+ )
107
+
108
+ return os.path.join(cls.u2net_home(), fname)
109
+
110
+ @classmethod
111
+ def name(cls, *args, **kwargs):
112
+ return "u2net_cloth_seg"
rembg/sessions/u2net_human_seg.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class U2netHumanSegSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(
17
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
18
+ ),
19
+ )
20
+
21
+ pred = ort_outs[0][:, 0, :, :]
22
+
23
+ ma = np.max(pred)
24
+ mi = np.min(pred)
25
+
26
+ pred = (pred - mi) / (ma - mi)
27
+ pred = np.squeeze(pred)
28
+
29
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
30
+ mask = mask.resize(img.size, Image.LANCZOS)
31
+
32
+ return [mask]
33
+
34
+ @classmethod
35
+ def download_models(cls, *args, **kwargs):
36
+ fname = f"{cls.name()}.onnx"
37
+ pooch.retrieve(
38
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
39
+ None
40
+ if cls.checksum_disabled(*args, **kwargs)
41
+ else "md5:c09ddc2e0104f800e3e1bb4652583d1f",
42
+ fname=fname,
43
+ path=cls.u2net_home(*args, **kwargs),
44
+ progressbar=True,
45
+ )
46
+
47
+ return os.path.join(cls.u2net_home(), fname)
48
+
49
+ @classmethod
50
+ def name(cls, *args, **kwargs):
51
+ return "u2net_human_seg"
rembg/sessions/u2netp.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class U2netpSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(
17
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
18
+ ),
19
+ )
20
+
21
+ pred = ort_outs[0][:, 0, :, :]
22
+
23
+ ma = np.max(pred)
24
+ mi = np.min(pred)
25
+
26
+ pred = (pred - mi) / (ma - mi)
27
+ pred = np.squeeze(pred)
28
+
29
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
30
+ mask = mask.resize(img.size, Image.LANCZOS)
31
+
32
+ return [mask]
33
+
34
+ @classmethod
35
+ def download_models(cls, *args, **kwargs):
36
+ fname = f"{cls.name()}.onnx"
37
+ pooch.retrieve(
38
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
39
+ None
40
+ if cls.checksum_disabled(*args, **kwargs)
41
+ else "md5:8e83ca70e441ab06c318d82300c84806",
42
+ fname=fname,
43
+ path=cls.u2net_home(*args, **kwargs),
44
+ progressbar=True,
45
+ )
46
+
47
+ return os.path.join(cls.u2net_home(), fname)
48
+
49
+ @classmethod
50
+ def name(cls, *args, **kwargs):
51
+ return "u2netp"
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.8.1
2
+ asyncer==0.0.2
3
+ click==8.1.3
4
+ fastapi==0.87.0
5
+ filetype==1.2.0
6
+ pooch==1.6.0
7
+ imagehash==4.3.1
8
+ numpy==1.23.5
9
+ onnxruntime==1.13.1
10
+ opencv-python-headless==4.6.0.66
11
+ pillow==9.3.0
12
+ pymatting==1.1.8
13
+ python-multipart==0.0.5
14
+ scikit-image==0.19.3
15
+ scipy==1.9.3
16
+ tqdm==4.64.1
17
+ uvicorn==0.20.0
18
+ watchdog==2.1.9