New logic
Browse files- app.py +1 -1
- super_resolve.py +118 -87
app.py
CHANGED
@@ -127,4 +127,4 @@ with gr.Blocks(title="Super Res + Bas-Relief") as app:
|
|
127 |
)
|
128 |
|
129 |
if __name__ == "__main__":
|
130 |
-
app.launch(share=False)
|
|
|
127 |
)
|
128 |
|
129 |
if __name__ == "__main__":
|
130 |
+
app.launch(share=False)
|
super_resolve.py
CHANGED
@@ -1,99 +1,130 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
from argparse import ArgumentParser, Namespace
|
4 |
-
import pickle
|
5 |
-
|
6 |
import jax
|
7 |
-
from jax import jit
|
8 |
import jax.numpy as jnp
|
9 |
import numpy as np
|
10 |
from PIL import Image
|
11 |
-
|
|
|
|
|
|
|
|
|
12 |
from model import build_thera
|
13 |
-
from utils import make_grid, interpolate_grid
|
14 |
-
|
15 |
-
MEAN = jnp.array([.4488, .4371, .4040])
|
16 |
-
VAR = jnp.array([.25, .25, .25])
|
17 |
-
PATCH_SIZE = 256
|
18 |
-
|
19 |
-
|
20 |
-
def process_single(source, apply_encoder, apply_decoder, params, target_shape):
|
21 |
-
t = jnp.float32((target_shape[0] / source.shape[1])**-2)[None]
|
22 |
-
coords_nearest = jnp.asarray(make_grid(target_shape)[None])
|
23 |
-
source_up = interpolate_grid(coords_nearest, source[None])
|
24 |
-
source = jax.nn.standardize(source, mean=MEAN, variance=VAR)[None]
|
25 |
-
|
26 |
-
encoding = apply_encoder(params, source)
|
27 |
-
coords = jnp.asarray(make_grid(source_up.shape[1:3])[None]) # global sampling coords
|
28 |
-
out = jnp.full_like(source_up, jnp.nan, dtype=jnp.float32)
|
29 |
-
|
30 |
-
for h_min in range(0, coords.shape[1], PATCH_SIZE):
|
31 |
-
h_max = min(h_min + PATCH_SIZE, coords.shape[1])
|
32 |
-
for w_min in range(0, coords.shape[2], PATCH_SIZE):
|
33 |
-
# apply decoder with one patch of coordinates
|
34 |
-
w_max = min(w_min + PATCH_SIZE, coords.shape[2])
|
35 |
-
coords_patch = coords[:, h_min:h_max, w_min:w_max]
|
36 |
-
out_patch = apply_decoder(params, encoding, coords_patch, t)
|
37 |
-
out = out.at[:, h_min:h_max, w_min:w_max].set(out_patch)
|
38 |
-
|
39 |
-
out = out * jnp.sqrt(VAR)[None, None, None] + MEAN[None, None, None]
|
40 |
-
out += source_up
|
41 |
-
return out
|
42 |
-
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
target_shape_ = tuple(reversed(target_shape)) if i_rot % 2 else target_shape
|
52 |
-
out = process_single(source_, apply_encoder, apply_decoder, params, target_shape_)
|
53 |
-
outs.append(jnp.rot90(out, k=i_rot, axes=(-2, -3)))
|
54 |
|
55 |
-
out = jnp.stack(outs).mean(0).clip(0., 1.)
|
56 |
-
return jnp.rint(out[0] * 255).astype(jnp.uint8)
|
57 |
|
58 |
-
|
59 |
-
def
|
60 |
-
|
61 |
-
|
62 |
-
if args.scale is not None:
|
63 |
-
if args.size is not None:
|
64 |
-
raise ValueError('Cannot specify both size and scale')
|
65 |
-
target_shape = (
|
66 |
-
round(source.shape[0] * args.scale),
|
67 |
-
round(source.shape[1] * args.scale),
|
68 |
-
)
|
69 |
-
elif args.size is not None:
|
70 |
-
target_shape = args.size
|
71 |
-
else:
|
72 |
-
raise ValueError('Must specify either size or scale')
|
73 |
-
|
74 |
-
with open(args.checkpoint, 'rb') as fh:
|
75 |
check = pickle.load(fh)
|
76 |
-
|
77 |
-
|
|
|
78 |
model = build_thera(3, backbone, size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
|
|
|
|
|
|
3 |
import jax
|
|
|
4 |
import jax.numpy as jnp
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
7 |
+
import pickle
|
8 |
+
import warnings
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
from diffusers import StableDiffusionXLImg2ImgPipeline
|
11 |
+
from transformers import DPTImageProcessor, DPTForDepthEstimation
|
12 |
from model import build_thera
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
# Configurações e supressão de avisos
|
15 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
16 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
17 |
|
18 |
+
# Configurar dispositivos
|
19 |
+
JAX_DEVICE = jax.devices("cpu")[0]
|
20 |
+
TORCH_DEVICE = "cpu"
|
|
|
|
|
|
|
21 |
|
|
|
|
|
22 |
|
23 |
+
# 1. Carregar modelos do Thera ----------------------------------------------------------------
|
24 |
+
def load_thera_model(repo_id, filename):
|
25 |
+
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
26 |
+
with open(model_path, 'rb') as fh:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
check = pickle.load(fh)
|
28 |
+
# Carregar estrutura completa de variáveis
|
29 |
+
variables = check['model'] # Deve conter {'params': ...}
|
30 |
+
backbone, size = check['backbone'], check['size']
|
31 |
model = build_thera(3, backbone, size)
|
32 |
+
return model, variables
|
33 |
+
|
34 |
+
|
35 |
+
print("Carregando Thera EDSR...")
|
36 |
+
model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
|
37 |
+
|
38 |
+
# 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
|
39 |
+
print("Carregando SDXL + LoRA...")
|
40 |
+
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
41 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
42 |
+
torch_dtype=torch.float32
|
43 |
+
).to(TORCH_DEVICE)
|
44 |
+
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
|
45 |
+
|
46 |
+
# 3. Carregar modelo de profundidade ----------------------------------------------------------
|
47 |
+
print("Carregando DPT Depth...")
|
48 |
+
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
|
49 |
+
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
|
50 |
+
|
51 |
+
|
52 |
+
# Pipeline principal --------------------------------------------------------------------------
|
53 |
+
def full_pipeline(image, prompt, scale_factor=2.0):
|
54 |
+
try:
|
55 |
+
# 1. Super Resolução com Thera
|
56 |
+
image = image.convert("RGB")
|
57 |
+
source = np.array(image) / 255.0
|
58 |
+
target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
|
59 |
+
|
60 |
+
source_jax = jax.device_put(source, JAX_DEVICE)
|
61 |
+
t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
|
62 |
+
|
63 |
+
# Chamada corrigida com estrutura de variáveis correta
|
64 |
+
upscaled = model_edsr.apply(
|
65 |
+
variables_edsr, # Estrutura completa {'params': ...}
|
66 |
+
source_jax,
|
67 |
+
t,
|
68 |
+
target_shape
|
69 |
+
)
|
70 |
|
71 |
+
upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
|
72 |
+
|
73 |
+
# 2. Gerar Bas-Relief
|
74 |
+
full_prompt = f"BAS-RELIEF {prompt}, insanely detailed and complex engraving relief, ultra-high definition, rich in detail, 16K resolution"
|
75 |
+
bas_relief = pipe(
|
76 |
+
prompt=full_prompt,
|
77 |
+
image=upscaled_pil,
|
78 |
+
strength=0.7,
|
79 |
+
num_inference_steps=25,
|
80 |
+
guidance_scale=7.5
|
81 |
+
).images[0]
|
82 |
+
|
83 |
+
# 3. Calcular Depth Map
|
84 |
+
inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
|
85 |
+
with torch.no_grad():
|
86 |
+
outputs = depth_model(**inputs)
|
87 |
+
depth = outputs.predicted_depth
|
88 |
+
|
89 |
+
depth_map = torch.nn.functional.interpolate(
|
90 |
+
depth.unsqueeze(1),
|
91 |
+
size=bas_relief.size[::-1],
|
92 |
+
mode="bicubic"
|
93 |
+
).squeeze().cpu().numpy()
|
94 |
+
|
95 |
+
depth_normalized = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
|
96 |
+
depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))
|
97 |
+
|
98 |
+
return upscaled_pil, bas_relief, depth_pil
|
99 |
+
|
100 |
+
except Exception as e:
|
101 |
+
raise gr.Error(f"Erro no processamento: {str(e)}")
|
102 |
+
|
103 |
+
|
104 |
+
# Interface Gradio ----------------------------------------------------------------------------
|
105 |
+
with gr.Blocks(title="Super Res + Bas-Relief") as app:
|
106 |
+
gr.Markdown("## 🔍 Super Resolução + 🗿 Bas-Relief + 🗺️ Profundidade")
|
107 |
+
|
108 |
+
with gr.Row():
|
109 |
+
with gr.Column():
|
110 |
+
img_input = gr.Image(type="pil", label="Imagem de Entrada")
|
111 |
+
prompt = gr.Textbox(
|
112 |
+
label="Descrição do Relevo",
|
113 |
+
value="insanely detailed and complex engraving relief, ultra-high definition, rich in detail, and 16K resolution."
|
114 |
+
)
|
115 |
+
scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
|
116 |
+
btn = gr.Button("Processar")
|
117 |
+
|
118 |
+
with gr.Column():
|
119 |
+
img_upscaled = gr.Image(label="Imagem Super Resolvida")
|
120 |
+
img_basrelief = gr.Image(label="Resultado Bas-Relief")
|
121 |
+
img_depth = gr.Image(label="Mapa de Profundidade")
|
122 |
+
|
123 |
+
btn.click(
|
124 |
+
full_pipeline,
|
125 |
+
inputs=[img_input, prompt, scale],
|
126 |
+
outputs=[img_upscaled, img_basrelief, img_depth]
|
127 |
+
)
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
app.launch(share=false)
|