import pickle import json import os import gradio as gr from PIL import Image import numpy as np import jax from gradio_dualvision import DualVisionApp from gradio_dualvision.gradio_patches.radio import Radio from huggingface_hub import hf_hub_download from model import build_thera from super_resolve import process REPO_ID_EDSR = "prs-eth/thera-edsr-pro" REPO_ID_RDN = "prs-eth/thera-rdn-pro" MAX_SIZE = 600 print(f"JAX devices: {jax.devices()}") print(f"JAX device type: {jax.devices()[0].device_kind}") model_path = hf_hub_download(repo_id=REPO_ID_EDSR, filename="model.pkl") with open(model_path, 'rb') as fh: check = pickle.load(fh) params_edsr, backbone, size = check['model'], check['backbone'], check['size'] model_edsr = build_thera(3, backbone, size) model_path = hf_hub_download(repo_id=REPO_ID_RDN, filename="model.pkl") with open(model_path, 'rb') as fh: check = pickle.load(fh) params_rdn, backbone, size = check['model'], check['backbone'], check['size'] model_rdn = build_thera(3, backbone, size) class TheraApp(DualVisionApp): DEFAULT_SCALE = 3.92 DEFAULT_DO_ENSEMBLE = False DEFAULT_MODEL = 'edsr' def make_header(self): gr.Markdown( """ ## Thera: Aliasing-Free Arbitrary-Scale Super-Resolution with Neural Heat Fields
Upload a photo or select an example below to do arbitrary-scale super-resolution in real time!
""" ) def build_user_components(self): with gr.Row(): scale = gr.Slider( label="Scaling factor", minimum=1, maximum=6, step=0.01, value=self.DEFAULT_SCALE, ) model = gr.Radio( [ ("EDSR", 'edsr'), ("RDN", 'rdn'), ], label="Backbone", value=self.DEFAULT_MODEL, ) do_ensemble = gr.Radio( [ ("No", False), ("Yes", True), ], label="Do Ensemble", value=self.DEFAULT_DO_ENSEMBLE, ) return { "scale": scale, "model": model, "do_ensemble": do_ensemble, } def process(self, image_in: Image.Image, **kwargs): scale = kwargs.get("scale", self.DEFAULT_SCALE) do_ensemble = kwargs.get("do_ensemble", self.DEFAULT_DO_ENSEMBLE) model = kwargs.get("model", self.DEFAULT_MODEL) if max(*image_in.size) > MAX_SIZE: raise gr.Error(f"We have currently limited the size of uploaded images to {MAX_SIZE}x{MAX_SIZE}" f" pixels, to enable a smooth experience to all users.") source = np.asarray(image_in) / 255. # determine target shape target_shape = ( round(source.shape[0] * scale), round(source.shape[1] * scale), ) if model == 'edsr': m, p = model_edsr, params_edsr elif model == 'rdn': m, p = model_rdn, params_rdn else: raise NotImplementedError('model:', model) out = process(source, m, p, target_shape, do_ensemble=do_ensemble) out = Image.fromarray(np.asarray(out)) nearest = image_in.resize(out.size, Image.NEAREST) out_modalities = { "nearest": nearest, "out": out, } out_settings = { 'scale': scale, 'model': model, 'do_ensemble': do_ensemble, } return out_modalities, out_settings def process_components( self, image_in, modality_selector_left, modality_selector_right, **kwargs ): if image_in is None: raise gr.Error("Input image is required") image_settings = {} if isinstance(image_in, str): image_settings_path = image_in + ".settings.json" if os.path.isfile(image_settings_path): with open(image_settings_path, "r") as f: image_settings = json.load(f) image_in = Image.open(image_in).convert("RGB") else: if not isinstance(image_in, Image.Image): raise gr.Error(f"Input must be a PIL image, got {type(image_in)}") image_in = image_in.convert("RGB") image_settings.update(kwargs) results_dict, results_settings = self.process(image_in, **image_settings) if not isinstance(results_dict, dict): raise gr.Error( f"`process` must return a dict[str, PIL.Image]. Got type: {type(results_dict)}" ) if len(results_dict) == 0: raise gr.Error("`process` did not return any modalities") for k, v in results_dict.items(): if not isinstance(k, str): raise gr.Error( f"Output dict must have string keys. Found key of type {type(k)}: {repr(k)}" ) if k == self.key_original_image: raise gr.Error( f"Output dict must not have an '{self.key_original_image}' key; it is reserved for the input" ) if not isinstance(v, Image.Image): raise gr.Error( f"Value for key '{k}' must be a PIL Image, got type {type(v)}" ) if len(results_settings) != len(self.input_keys): raise gr.Error( f"Expected number of settings ({len(self.input_keys)}), returned ({len(results_settings)})" ) if any(k not in results_settings for k in self.input_keys): raise gr.Error(f"Mismatching setgings keys") results_settings = { k: cls(**ctor_args, value=results_settings[k]) for k, cls, ctor_args in zip( self.input_keys, self.input_cls, self.input_kwargs ) } results_dict = { **results_dict, self.key_original_image: image_in, } results_state = [[v, k] for k, v in results_dict.items()] modalities = list(results_dict.keys()) modality_left = ( modality_selector_left if modality_selector_left in modalities else modalities[0] ) modality_right = ( modality_selector_right if modality_selector_right in modalities else modalities[1] ) return [ results_state, # goes to a gr.Gallery [ results_dict[modality_left], results_dict[modality_right], ], # ImageSliderPlus Radio( choices=modalities, value=modality_left, label="Left", key="Left", ), Radio( choices=modalities if self.left_selector_visible else modalities[1:], value=modality_right, label="Right", key="Right", ), *results_settings.values(), ] with TheraApp( title="Thera Arbitrary-Scale Super-Resolution", examples_path="files", examples_per_page=12, squeeze_canvas=True, advanced_settings_can_be_half_width=False, #spaces_zero_gpu_enabled=True, ) as demo: demo.queue( api_open=False, ).launch( server_name="0.0.0.0", server_port=7860, )