balthou's picture
update slider names, add md
7aaaf1f
import sys
sys.path.append("src")
from interactive_pipe import interactive, KeyboardControl
from rstor.analyzis.interactive.metrics import plug_configure_metrics
from rstor.analyzis.interactive.crop import plug_crop_selector
from rstor.analyzis.interactive.images import image_selector
from interactive_pipe.data_objects.image import Image
from batch_processing import Batch
import argparse
from rstor.analyzis.parser import get_parser
from pathlib import Path
from rstor.analyzis.interactive.model_selection import get_default_models
from rstor.analyzis.interactive.pipelines import natural_inference_pipeline, morph_canvas, CANVAS
from interactive_pipe import interactive_pipeline
def plug_morph_canvas():
interactive(
canvas=KeyboardControl(CANVAS[0], CANVAS, name="canvas", keyup="p", modulo=True)
)(morph_canvas)
def image_loading_batch(input: Path, args: argparse.Namespace) -> dict:
"""Wrapper to load images files from a directory using batch_processing
"""
if not args.disable_preload:
img = Image.from_file(input).data
return {"name": input.name, "path": input, "buffer": img}
else:
return {"name": input.name, "path": input, "buffer": None}
def main(argv, low_resource_demo=True):
batch = Batch(argv)
batch.set_io_description(
input_help='input image files',
output_help=argparse.SUPPRESS
)
parser = get_parser()
parser.add_argument("-nop", "--disable-preload", action="store_true", help="Disable images preload")
args = batch.parse_args(parser)
backend = args.backend
if backend == "qt":
batch.set_multiprocessing_enabled(False)
img_list = batch.run(image_loading_batch)
if args.keyboard:
image_control = KeyboardControl(0, [0, len(img_list)-1], keydown="3", keyup="9", modulo=True)
else:
image_control = (0, [0, len(img_list)-1], "input image selector")
interactive(image_index=image_control)(image_selector)
plug_crop_selector(num_pad=args.keyboard, low_resources=low_resource_demo)
if not low_resource_demo:
plug_configure_metrics(key_shortcut="a") # "a" if args.keyboard else None)
if args.backend != "gradio":
plug_morph_canvas()
model_dict = get_default_models(args.experiments, Path(args.models_storage), keyboard_control=args.keyboard)
markdown_description = "# πŸ” Blind image deblurring - READ MORE HERE \n"
markdown_description += open("description.md", 'r').read()
interactive_pipeline(
gui=backend,
cache=True,
safe_input_buffer_deepcopy=False,
sliders_layout="smart",
sliders_per_row_layout=3,
markdown_description=markdown_description,
)(natural_inference_pipeline)(
img_list,
model_dict
)
if __name__ == "__main__":
# main(sys.argv[1:])
main(["-e", "6002", "5002", "-i", "__dataset/sample/*.*g", "-b", "gradio"], low_resource_demo=True)