New logic
Browse files- app.py +6 -6
- super_resolve.py +87 -118
app.py
CHANGED
@@ -25,15 +25,15 @@ def load_thera_model(repo_id, filename):
|
|
25 |
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
26 |
with open(model_path, 'rb') as fh:
|
27 |
check = pickle.load(fh)
|
28 |
-
#
|
29 |
-
|
30 |
backbone, size = check['backbone'], check['size']
|
31 |
model = build_thera(3, backbone, size)
|
32 |
-
return model,
|
33 |
|
34 |
|
35 |
print("Carregando Thera EDSR...")
|
36 |
-
model_edsr,
|
37 |
|
38 |
# 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
|
39 |
print("Carregando SDXL + LoRA...")
|
@@ -60,9 +60,9 @@ def full_pipeline(image, prompt, scale_factor=2.0):
|
|
60 |
source_jax = jax.device_put(source, JAX_DEVICE)
|
61 |
t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
|
62 |
|
63 |
-
# Chamada corrigida com estrutura de
|
64 |
upscaled = model_edsr.apply(
|
65 |
-
|
66 |
source_jax,
|
67 |
t,
|
68 |
target_shape
|
|
|
25 |
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
26 |
with open(model_path, 'rb') as fh:
|
27 |
check = pickle.load(fh)
|
28 |
+
# Carregar estrutura completa de variáveis
|
29 |
+
variables = check['model'] # Deve conter {'params': ...}
|
30 |
backbone, size = check['backbone'], check['size']
|
31 |
model = build_thera(3, backbone, size)
|
32 |
+
return model, variables
|
33 |
|
34 |
|
35 |
print("Carregando Thera EDSR...")
|
36 |
+
model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
|
37 |
|
38 |
# 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
|
39 |
print("Carregando SDXL + LoRA...")
|
|
|
60 |
source_jax = jax.device_put(source, JAX_DEVICE)
|
61 |
t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
|
62 |
|
63 |
+
# Chamada corrigida com estrutura de variáveis correta
|
64 |
upscaled = model_edsr.apply(
|
65 |
+
variables_edsr, # Estrutura completa {'params': ...}
|
66 |
source_jax,
|
67 |
t,
|
68 |
target_shape
|
super_resolve.py
CHANGED
@@ -1,130 +1,99 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
3 |
import jax
|
|
|
4 |
import jax.numpy as jnp
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
7 |
-
|
8 |
-
import warnings
|
9 |
-
from huggingface_hub import hf_hub_download
|
10 |
-
from diffusers import StableDiffusionXLImg2ImgPipeline
|
11 |
-
from transformers import DPTImageProcessor, DPTForDepthEstimation
|
12 |
from model import build_thera
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
21 |
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
def
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
check = pickle.load(fh)
|
28 |
-
|
29 |
-
|
30 |
-
backbone, size = check['backbone'], check['size']
|
31 |
model = build_thera(3, backbone, size)
|
32 |
-
return model, variables
|
33 |
-
|
34 |
-
|
35 |
-
print("Carregando Thera EDSR...")
|
36 |
-
model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
|
37 |
-
|
38 |
-
# 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
|
39 |
-
print("Carregando SDXL + LoRA...")
|
40 |
-
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
41 |
-
"stabilityai/stable-diffusion-xl-base-1.0",
|
42 |
-
torch_dtype=torch.float32
|
43 |
-
).to(TORCH_DEVICE)
|
44 |
-
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
|
45 |
-
|
46 |
-
# 3. Carregar modelo de profundidade ----------------------------------------------------------
|
47 |
-
print("Carregando DPT Depth...")
|
48 |
-
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
|
49 |
-
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
|
50 |
-
|
51 |
-
|
52 |
-
# Pipeline principal --------------------------------------------------------------------------
|
53 |
-
def full_pipeline(image, prompt, scale_factor=2.0):
|
54 |
-
try:
|
55 |
-
# 1. Super Resolução com Thera
|
56 |
-
image = image.convert("RGB")
|
57 |
-
source = np.array(image) / 255.0
|
58 |
-
target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
|
59 |
-
|
60 |
-
source_jax = jax.device_put(source, JAX_DEVICE)
|
61 |
-
t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
|
62 |
-
|
63 |
-
# Chamada corrigida com estrutura de variáveis correta
|
64 |
-
upscaled = model_edsr.apply(
|
65 |
-
variables_edsr, # Estrutura completa {'params': ...}
|
66 |
-
source_jax,
|
67 |
-
t,
|
68 |
-
target_shape
|
69 |
-
)
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
size=bas_relief.size[::-1],
|
92 |
-
mode="bicubic"
|
93 |
-
).squeeze().cpu().numpy()
|
94 |
-
|
95 |
-
depth_normalized = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
|
96 |
-
depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))
|
97 |
-
|
98 |
-
return upscaled_pil, bas_relief, depth_pil
|
99 |
-
|
100 |
-
except Exception as e:
|
101 |
-
raise gr.Error(f"Erro no processamento: {str(e)}")
|
102 |
-
|
103 |
-
|
104 |
-
# Interface Gradio ----------------------------------------------------------------------------
|
105 |
-
with gr.Blocks(title="Super Res + Bas-Relief") as app:
|
106 |
-
gr.Markdown("## 🔍 Super Resolução + 🗿 Bas-Relief + 🗺️ Profundidade")
|
107 |
-
|
108 |
-
with gr.Row():
|
109 |
-
with gr.Column():
|
110 |
-
img_input = gr.Image(type="pil", label="Imagem de Entrada")
|
111 |
-
prompt = gr.Textbox(
|
112 |
-
label="Descrição do Relevo",
|
113 |
-
value="insanely detailed and complex engraving relief, ultra-high definition, rich in detail, and 16K resolution."
|
114 |
-
)
|
115 |
-
scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
|
116 |
-
btn = gr.Button("Processar")
|
117 |
-
|
118 |
-
with gr.Column():
|
119 |
-
img_upscaled = gr.Image(label="Imagem Super Resolvida")
|
120 |
-
img_basrelief = gr.Image(label="Resultado Bas-Relief")
|
121 |
-
img_depth = gr.Image(label="Mapa de Profundidade")
|
122 |
-
|
123 |
-
btn.click(
|
124 |
-
full_pipeline,
|
125 |
-
inputs=[img_input, prompt, scale],
|
126 |
-
outputs=[img_upscaled, img_basrelief, img_depth]
|
127 |
-
)
|
128 |
-
|
129 |
-
if __name__ == "__main__":
|
130 |
-
app.launch(share=false)
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from argparse import ArgumentParser, Namespace
|
4 |
+
import pickle
|
5 |
+
|
6 |
import jax
|
7 |
+
from jax import jit
|
8 |
import jax.numpy as jnp
|
9 |
import numpy as np
|
10 |
from PIL import Image
|
11 |
+
|
|
|
|
|
|
|
|
|
12 |
from model import build_thera
|
13 |
+
from utils import make_grid, interpolate_grid
|
14 |
+
|
15 |
+
MEAN = jnp.array([.4488, .4371, .4040])
|
16 |
+
VAR = jnp.array([.25, .25, .25])
|
17 |
+
PATCH_SIZE = 256
|
18 |
+
|
19 |
+
|
20 |
+
def process_single(source, apply_encoder, apply_decoder, params, target_shape):
|
21 |
+
t = jnp.float32((target_shape[0] / source.shape[1])**-2)[None]
|
22 |
+
coords_nearest = jnp.asarray(make_grid(target_shape)[None])
|
23 |
+
source_up = interpolate_grid(coords_nearest, source[None])
|
24 |
+
source = jax.nn.standardize(source, mean=MEAN, variance=VAR)[None]
|
25 |
+
|
26 |
+
encoding = apply_encoder(params, source)
|
27 |
+
coords = jnp.asarray(make_grid(source_up.shape[1:3])[None]) # global sampling coords
|
28 |
+
out = jnp.full_like(source_up, jnp.nan, dtype=jnp.float32)
|
29 |
+
|
30 |
+
for h_min in range(0, coords.shape[1], PATCH_SIZE):
|
31 |
+
h_max = min(h_min + PATCH_SIZE, coords.shape[1])
|
32 |
+
for w_min in range(0, coords.shape[2], PATCH_SIZE):
|
33 |
+
# apply decoder with one patch of coordinates
|
34 |
+
w_max = min(w_min + PATCH_SIZE, coords.shape[2])
|
35 |
+
coords_patch = coords[:, h_min:h_max, w_min:w_max]
|
36 |
+
out_patch = apply_decoder(params, encoding, coords_patch, t)
|
37 |
+
out = out.at[:, h_min:h_max, w_min:w_max].set(out_patch)
|
38 |
+
|
39 |
+
out = out * jnp.sqrt(VAR)[None, None, None] + MEAN[None, None, None]
|
40 |
+
out += source_up
|
41 |
+
return out
|
42 |
+
|
43 |
|
44 |
+
def process(source, model, params, target_shape, do_ensemble=True):
|
45 |
+
apply_encoder = jit(model.apply_encoder)
|
46 |
+
apply_decoder = jit(model.apply_decoder)
|
47 |
|
48 |
+
outs = []
|
49 |
+
for i_rot in range(4 if do_ensemble else 1):
|
50 |
+
source_ = jnp.rot90(source, k=i_rot, axes=(-3, -2))
|
51 |
+
target_shape_ = tuple(reversed(target_shape)) if i_rot % 2 else target_shape
|
52 |
+
out = process_single(source_, apply_encoder, apply_decoder, params, target_shape_)
|
53 |
+
outs.append(jnp.rot90(out, k=i_rot, axes=(-2, -3)))
|
54 |
|
55 |
+
out = jnp.stack(outs).mean(0).clip(0., 1.)
|
56 |
+
return jnp.rint(out[0] * 255).astype(jnp.uint8)
|
57 |
|
58 |
+
|
59 |
+
def main(args: Namespace):
|
60 |
+
source = np.asarray(Image.open(args.in_file)) / 255.
|
61 |
+
|
62 |
+
if args.scale is not None:
|
63 |
+
if args.size is not None:
|
64 |
+
raise ValueError('Cannot specify both size and scale')
|
65 |
+
target_shape = (
|
66 |
+
round(source.shape[0] * args.scale),
|
67 |
+
round(source.shape[1] * args.scale),
|
68 |
+
)
|
69 |
+
elif args.size is not None:
|
70 |
+
target_shape = args.size
|
71 |
+
else:
|
72 |
+
raise ValueError('Must specify either size or scale')
|
73 |
+
|
74 |
+
with open(args.checkpoint, 'rb') as fh:
|
75 |
check = pickle.load(fh)
|
76 |
+
params, backbone, size = check['model'], check['backbone'], check['size']
|
77 |
+
|
|
|
78 |
model = build_thera(3, backbone, size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
out = process(source, model, params, target_shape, not args.no_ensemble)
|
81 |
+
|
82 |
+
Image.fromarray(np.asarray(out)).save(args.out_file)
|
83 |
+
|
84 |
+
|
85 |
+
def parse_args() -> Namespace:
|
86 |
+
parser = ArgumentParser()
|
87 |
+
parser.add_argument('in_file')
|
88 |
+
parser.add_argument('out_file')
|
89 |
+
parser.add_argument('--scale', type=float, help='Scale factor for super-resolution')
|
90 |
+
parser.add_argument('--size', type=int, nargs=2,
|
91 |
+
help='Target size (h, w), mutually exclusive with --scale')
|
92 |
+
parser.add_argument('--checkpoint', help='Path to checkpoint file')
|
93 |
+
parser.add_argument('--no-ensemble', action='store_true', help='Disable geo-ensemble')
|
94 |
+
return parser.parse_args()
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == '__main__':
|
98 |
+
args = parse_args()
|
99 |
+
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|