|
import asyncio |
|
import io |
|
import json |
|
import os |
|
import sys |
|
from typing import IO |
|
|
|
import click |
|
from PIL import Image |
|
|
|
from ..bg import remove |
|
from ..session_factory import new_session |
|
from ..sessions import sessions_names |
|
|
|
|
|
@click.command( |
|
name="b", |
|
help="for a byte stream as input", |
|
) |
|
@click.option( |
|
"-m", |
|
"--model", |
|
default="u2net", |
|
type=click.Choice(sessions_names), |
|
show_default=True, |
|
show_choices=True, |
|
help="model name", |
|
) |
|
@click.option( |
|
"-a", |
|
"--alpha-matting", |
|
is_flag=True, |
|
show_default=True, |
|
help="use alpha matting", |
|
) |
|
@click.option( |
|
"-af", |
|
"--alpha-matting-foreground-threshold", |
|
default=240, |
|
type=int, |
|
show_default=True, |
|
help="trimap fg threshold", |
|
) |
|
@click.option( |
|
"-ab", |
|
"--alpha-matting-background-threshold", |
|
default=10, |
|
type=int, |
|
show_default=True, |
|
help="trimap bg threshold", |
|
) |
|
@click.option( |
|
"-ae", |
|
"--alpha-matting-erode-size", |
|
default=10, |
|
type=int, |
|
show_default=True, |
|
help="erode size", |
|
) |
|
@click.option( |
|
"-om", |
|
"--only-mask", |
|
is_flag=True, |
|
show_default=True, |
|
help="output only the mask", |
|
) |
|
@click.option( |
|
"-ppm", |
|
"--post-process-mask", |
|
is_flag=True, |
|
show_default=True, |
|
help="post process the mask", |
|
) |
|
@click.option( |
|
"-bgc", |
|
"--bgcolor", |
|
default=None, |
|
type=(int, int, int, int), |
|
nargs=4, |
|
help="Background color (R G B A) to replace the removed background with", |
|
) |
|
@click.option("-x", "--extras", type=str) |
|
@click.option( |
|
"-o", |
|
"--output_specifier", |
|
type=str, |
|
help="printf-style specifier for output filenames (e.g. 'output-%d.png'))", |
|
) |
|
@click.argument( |
|
"image_width", |
|
type=int, |
|
) |
|
@click.argument( |
|
"image_height", |
|
type=int, |
|
) |
|
def rs_command( |
|
model: str, |
|
extras: str, |
|
image_width: int, |
|
image_height: int, |
|
output_specifier: str, |
|
**kwargs |
|
) -> None: |
|
try: |
|
kwargs.update(json.loads(extras)) |
|
except Exception: |
|
pass |
|
|
|
session = new_session(model) |
|
bytes_per_img = image_width * image_height * 3 |
|
|
|
if output_specifier: |
|
output_dir = os.path.dirname( |
|
os.path.abspath(os.path.expanduser(output_specifier)) |
|
) |
|
|
|
if not os.path.isdir(output_dir): |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
def img_to_byte_array(img: Image) -> bytes: |
|
buff = io.BytesIO() |
|
img.save(buff, format="PNG") |
|
return buff.getvalue() |
|
|
|
async def connect_stdin_stdout(): |
|
loop = asyncio.get_event_loop() |
|
reader = asyncio.StreamReader() |
|
protocol = asyncio.StreamReaderProtocol(reader) |
|
|
|
await loop.connect_read_pipe(lambda: protocol, sys.stdin) |
|
w_transport, w_protocol = await loop.connect_write_pipe( |
|
asyncio.streams.FlowControlMixin, sys.stdout |
|
) |
|
|
|
writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop) |
|
return reader, writer |
|
|
|
async def main(): |
|
reader, writer = await connect_stdin_stdout() |
|
|
|
idx = 0 |
|
while True: |
|
try: |
|
img_bytes = await reader.readexactly(bytes_per_img) |
|
if not img_bytes: |
|
break |
|
|
|
img = Image.frombytes("RGB", (image_width, image_height), img_bytes) |
|
output = remove(img, session=session, **kwargs) |
|
|
|
if output_specifier: |
|
output.save((output_specifier % idx), format="PNG") |
|
else: |
|
writer.write(img_to_byte_array(output)) |
|
|
|
idx += 1 |
|
except asyncio.IncompleteReadError: |
|
break |
|
|
|
asyncio.run(main()) |
|
|