thera / app.py
Alexander Becker
Limit upload image size
d3c591a
import pickle
import json
import os
import gradio as gr
from PIL import Image
import numpy as np
import jax
from gradio_dualvision import DualVisionApp
from gradio_dualvision.gradio_patches.radio import Radio
from huggingface_hub import hf_hub_download
from model import build_thera
from super_resolve import process
REPO_ID_EDSR = "prs-eth/thera-edsr-pro"
REPO_ID_RDN = "prs-eth/thera-rdn-pro"
MAX_SIZE = 600
print(f"JAX devices: {jax.devices()}")
print(f"JAX device type: {jax.devices()[0].device_kind}")
model_path = hf_hub_download(repo_id=REPO_ID_EDSR, filename="model.pkl")
with open(model_path, 'rb') as fh:
check = pickle.load(fh)
params_edsr, backbone, size = check['model'], check['backbone'], check['size']
model_edsr = build_thera(3, backbone, size)
model_path = hf_hub_download(repo_id=REPO_ID_RDN, filename="model.pkl")
with open(model_path, 'rb') as fh:
check = pickle.load(fh)
params_rdn, backbone, size = check['model'], check['backbone'], check['size']
model_rdn = build_thera(3, backbone, size)
class TheraApp(DualVisionApp):
DEFAULT_SCALE = 3.92
DEFAULT_DO_ENSEMBLE = False
DEFAULT_MODEL = 'edsr'
def make_header(self):
gr.Markdown(
"""
## Thera: Aliasing-Free Arbitrary-Scale Super-Resolution with Neural Heat Fields
<p align="center">
<a title="Website" href="https://therasr.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/badge/%E2%99%A5%20Project%20-Website-blue">
</a>
<a title="arXiv" href="https://arxiv.org/pdf/2311.17643" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/badge/%F0%9F%93%84%20Read%20-Paper-AF3436">
</a>
<a title="Github" href="https://github.com/prs-eth/thera" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/github/stars/prs-eth/thera?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
</a>
</p>
<p align="center" style="margin-top: 0px;">
<strong>Upload a photo or select an example below to do arbitrary-scale super-resolution in real time!</strong>
</p>
"""
)
def build_user_components(self):
with gr.Row():
scale = gr.Slider(
label="Scaling factor",
minimum=1,
maximum=6,
step=0.01,
value=self.DEFAULT_SCALE,
)
model = gr.Radio(
[
("EDSR", 'edsr'),
("RDN", 'rdn'),
],
label="Backbone",
value=self.DEFAULT_MODEL,
)
do_ensemble = gr.Radio(
[
("No", False),
("Yes", True),
],
label="Do Ensemble",
value=self.DEFAULT_DO_ENSEMBLE,
)
return {
"scale": scale,
"model": model,
"do_ensemble": do_ensemble,
}
def process(self, image_in: Image.Image, **kwargs):
scale = kwargs.get("scale", self.DEFAULT_SCALE)
do_ensemble = kwargs.get("do_ensemble", self.DEFAULT_DO_ENSEMBLE)
model = kwargs.get("model", self.DEFAULT_MODEL)
if max(*image_in.size) > MAX_SIZE:
raise gr.Error(f"We have currently limited the size of uploaded images to {MAX_SIZE}x{MAX_SIZE}"
f" pixels, to enable a smooth experience to all users.")
source = np.asarray(image_in) / 255.
# determine target shape
target_shape = (
round(source.shape[0] * scale),
round(source.shape[1] * scale),
)
if model == 'edsr':
m, p = model_edsr, params_edsr
elif model == 'rdn':
m, p = model_rdn, params_rdn
else:
raise NotImplementedError('model:', model)
out = process(source, m, p, target_shape, do_ensemble=do_ensemble)
out = Image.fromarray(np.asarray(out))
nearest = image_in.resize(out.size, Image.NEAREST)
out_modalities = {
"nearest": nearest,
"out": out,
}
out_settings = {
'scale': scale,
'model': model,
'do_ensemble': do_ensemble,
}
return out_modalities, out_settings
def process_components(
self, image_in, modality_selector_left, modality_selector_right, **kwargs
):
if image_in is None:
raise gr.Error("Input image is required")
image_settings = {}
if isinstance(image_in, str):
image_settings_path = image_in + ".settings.json"
if os.path.isfile(image_settings_path):
with open(image_settings_path, "r") as f:
image_settings = json.load(f)
image_in = Image.open(image_in).convert("RGB")
else:
if not isinstance(image_in, Image.Image):
raise gr.Error(f"Input must be a PIL image, got {type(image_in)}")
image_in = image_in.convert("RGB")
image_settings.update(kwargs)
results_dict, results_settings = self.process(image_in, **image_settings)
if not isinstance(results_dict, dict):
raise gr.Error(
f"`process` must return a dict[str, PIL.Image]. Got type: {type(results_dict)}"
)
if len(results_dict) == 0:
raise gr.Error("`process` did not return any modalities")
for k, v in results_dict.items():
if not isinstance(k, str):
raise gr.Error(
f"Output dict must have string keys. Found key of type {type(k)}: {repr(k)}"
)
if k == self.key_original_image:
raise gr.Error(
f"Output dict must not have an '{self.key_original_image}' key; it is reserved for the input"
)
if not isinstance(v, Image.Image):
raise gr.Error(
f"Value for key '{k}' must be a PIL Image, got type {type(v)}"
)
if len(results_settings) != len(self.input_keys):
raise gr.Error(
f"Expected number of settings ({len(self.input_keys)}), returned ({len(results_settings)})"
)
if any(k not in results_settings for k in self.input_keys):
raise gr.Error(f"Mismatching setgings keys")
results_settings = {
k: cls(**ctor_args, value=results_settings[k])
for k, cls, ctor_args in zip(
self.input_keys, self.input_cls, self.input_kwargs
)
}
results_dict = {
**results_dict,
self.key_original_image: image_in,
}
results_state = [[v, k] for k, v in results_dict.items()]
modalities = list(results_dict.keys())
modality_left = (
modality_selector_left
if modality_selector_left in modalities
else modalities[0]
)
modality_right = (
modality_selector_right
if modality_selector_right in modalities
else modalities[1]
)
return [
results_state, # goes to a gr.Gallery
[
results_dict[modality_left],
results_dict[modality_right],
], # ImageSliderPlus
Radio(
choices=modalities,
value=modality_left,
label="Left",
key="Left",
),
Radio(
choices=modalities if self.left_selector_visible else modalities[1:],
value=modality_right,
label="Right",
key="Right",
),
*results_settings.values(),
]
with TheraApp(
title="Thera Arbitrary-Scale Super-Resolution",
examples_path="files",
examples_per_page=12,
squeeze_canvas=True,
advanced_settings_can_be_half_width=False,
#spaces_zero_gpu_enabled=True,
) as demo:
demo.queue(
api_open=False,
).launch(
server_name="0.0.0.0",
server_port=7860,
)