|
import pickle |
|
import json |
|
import os |
|
|
|
import gradio as gr |
|
from PIL import Image |
|
import numpy as np |
|
import jax |
|
|
|
from gradio_dualvision import DualVisionApp |
|
from gradio_dualvision.gradio_patches.radio import Radio |
|
from huggingface_hub import hf_hub_download |
|
from model import build_thera |
|
from super_resolve import process |
|
|
|
REPO_ID_EDSR = "prs-eth/thera-edsr-pro" |
|
REPO_ID_RDN = "prs-eth/thera-rdn-pro" |
|
MAX_SIZE = 600 |
|
|
|
print(f"JAX devices: {jax.devices()}") |
|
print(f"JAX device type: {jax.devices()[0].device_kind}") |
|
|
|
model_path = hf_hub_download(repo_id=REPO_ID_EDSR, filename="model.pkl") |
|
with open(model_path, 'rb') as fh: |
|
check = pickle.load(fh) |
|
params_edsr, backbone, size = check['model'], check['backbone'], check['size'] |
|
model_edsr = build_thera(3, backbone, size) |
|
|
|
model_path = hf_hub_download(repo_id=REPO_ID_RDN, filename="model.pkl") |
|
with open(model_path, 'rb') as fh: |
|
check = pickle.load(fh) |
|
params_rdn, backbone, size = check['model'], check['backbone'], check['size'] |
|
model_rdn = build_thera(3, backbone, size) |
|
|
|
|
|
class TheraApp(DualVisionApp): |
|
DEFAULT_SCALE = 3.92 |
|
DEFAULT_DO_ENSEMBLE = False |
|
DEFAULT_MODEL = 'edsr' |
|
|
|
def make_header(self): |
|
gr.Markdown( |
|
""" |
|
## Thera: Aliasing-Free Arbitrary-Scale Super-Resolution with Neural Heat Fields |
|
<p align="center"> |
|
<a title="Website" href="https://therasr.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://img.shields.io/badge/%E2%99%A5%20Project%20-Website-blue"> |
|
</a> |
|
<a title="arXiv" href="https://arxiv.org/pdf/2311.17643" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://img.shields.io/badge/%F0%9F%93%84%20Read%20-Paper-AF3436"> |
|
</a> |
|
<a title="Github" href="https://github.com/prs-eth/thera" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://img.shields.io/github/stars/prs-eth/thera?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> |
|
</a> |
|
</p> |
|
<p align="center" style="margin-top: 0px;"> |
|
<strong>Upload a photo or select an example below to do arbitrary-scale super-resolution in real time!</strong> |
|
</p> |
|
""" |
|
) |
|
|
|
def build_user_components(self): |
|
with gr.Row(): |
|
scale = gr.Slider( |
|
label="Scaling factor", |
|
minimum=1, |
|
maximum=6, |
|
step=0.01, |
|
value=self.DEFAULT_SCALE, |
|
) |
|
model = gr.Radio( |
|
[ |
|
("EDSR", 'edsr'), |
|
("RDN", 'rdn'), |
|
], |
|
label="Backbone", |
|
value=self.DEFAULT_MODEL, |
|
) |
|
do_ensemble = gr.Radio( |
|
[ |
|
("No", False), |
|
("Yes", True), |
|
], |
|
label="Do Ensemble", |
|
value=self.DEFAULT_DO_ENSEMBLE, |
|
) |
|
return { |
|
"scale": scale, |
|
"model": model, |
|
"do_ensemble": do_ensemble, |
|
} |
|
|
|
def process(self, image_in: Image.Image, **kwargs): |
|
scale = kwargs.get("scale", self.DEFAULT_SCALE) |
|
do_ensemble = kwargs.get("do_ensemble", self.DEFAULT_DO_ENSEMBLE) |
|
model = kwargs.get("model", self.DEFAULT_MODEL) |
|
|
|
if max(*image_in.size) > MAX_SIZE: |
|
raise gr.Error(f"We have currently limited the size of uploaded images to {MAX_SIZE}x{MAX_SIZE}" |
|
f" pixels, to enable a smooth experience to all users.") |
|
|
|
source = np.asarray(image_in) / 255. |
|
|
|
|
|
target_shape = ( |
|
round(source.shape[0] * scale), |
|
round(source.shape[1] * scale), |
|
) |
|
|
|
if model == 'edsr': |
|
m, p = model_edsr, params_edsr |
|
elif model == 'rdn': |
|
m, p = model_rdn, params_rdn |
|
else: |
|
raise NotImplementedError('model:', model) |
|
|
|
out = process(source, m, p, target_shape, do_ensemble=do_ensemble) |
|
out = Image.fromarray(np.asarray(out)) |
|
|
|
nearest = image_in.resize(out.size, Image.NEAREST) |
|
|
|
out_modalities = { |
|
"nearest": nearest, |
|
"out": out, |
|
} |
|
out_settings = { |
|
'scale': scale, |
|
'model': model, |
|
'do_ensemble': do_ensemble, |
|
} |
|
return out_modalities, out_settings |
|
|
|
def process_components( |
|
self, image_in, modality_selector_left, modality_selector_right, **kwargs |
|
): |
|
if image_in is None: |
|
raise gr.Error("Input image is required") |
|
|
|
image_settings = {} |
|
if isinstance(image_in, str): |
|
image_settings_path = image_in + ".settings.json" |
|
if os.path.isfile(image_settings_path): |
|
with open(image_settings_path, "r") as f: |
|
image_settings = json.load(f) |
|
image_in = Image.open(image_in).convert("RGB") |
|
else: |
|
if not isinstance(image_in, Image.Image): |
|
raise gr.Error(f"Input must be a PIL image, got {type(image_in)}") |
|
image_in = image_in.convert("RGB") |
|
image_settings.update(kwargs) |
|
|
|
results_dict, results_settings = self.process(image_in, **image_settings) |
|
|
|
if not isinstance(results_dict, dict): |
|
raise gr.Error( |
|
f"`process` must return a dict[str, PIL.Image]. Got type: {type(results_dict)}" |
|
) |
|
if len(results_dict) == 0: |
|
raise gr.Error("`process` did not return any modalities") |
|
for k, v in results_dict.items(): |
|
if not isinstance(k, str): |
|
raise gr.Error( |
|
f"Output dict must have string keys. Found key of type {type(k)}: {repr(k)}" |
|
) |
|
if k == self.key_original_image: |
|
raise gr.Error( |
|
f"Output dict must not have an '{self.key_original_image}' key; it is reserved for the input" |
|
) |
|
if not isinstance(v, Image.Image): |
|
raise gr.Error( |
|
f"Value for key '{k}' must be a PIL Image, got type {type(v)}" |
|
) |
|
if len(results_settings) != len(self.input_keys): |
|
raise gr.Error( |
|
f"Expected number of settings ({len(self.input_keys)}), returned ({len(results_settings)})" |
|
) |
|
if any(k not in results_settings for k in self.input_keys): |
|
raise gr.Error(f"Mismatching setgings keys") |
|
|
|
results_settings = { |
|
k: cls(**ctor_args, value=results_settings[k]) |
|
for k, cls, ctor_args in zip( |
|
self.input_keys, self.input_cls, self.input_kwargs |
|
) |
|
} |
|
|
|
results_dict = { |
|
**results_dict, |
|
self.key_original_image: image_in, |
|
} |
|
|
|
results_state = [[v, k] for k, v in results_dict.items()] |
|
modalities = list(results_dict.keys()) |
|
|
|
modality_left = ( |
|
modality_selector_left |
|
if modality_selector_left in modalities |
|
else modalities[0] |
|
) |
|
modality_right = ( |
|
modality_selector_right |
|
if modality_selector_right in modalities |
|
else modalities[1] |
|
) |
|
|
|
return [ |
|
results_state, |
|
[ |
|
results_dict[modality_left], |
|
results_dict[modality_right], |
|
], |
|
Radio( |
|
choices=modalities, |
|
value=modality_left, |
|
label="Left", |
|
key="Left", |
|
), |
|
Radio( |
|
choices=modalities if self.left_selector_visible else modalities[1:], |
|
value=modality_right, |
|
label="Right", |
|
key="Right", |
|
), |
|
*results_settings.values(), |
|
] |
|
|
|
|
|
with TheraApp( |
|
title="Thera Arbitrary-Scale Super-Resolution", |
|
examples_path="files", |
|
examples_per_page=12, |
|
squeeze_canvas=True, |
|
advanced_settings_can_be_half_width=False, |
|
|
|
) as demo: |
|
demo.queue( |
|
api_open=False, |
|
).launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
) |
|
|