ds1david commited on
Commit
eb02bc3
·
1 Parent(s): 42a2e7b
Files changed (2) hide show
  1. app.py +1 -1
  2. super_resolve.py +118 -87
app.py CHANGED
@@ -127,4 +127,4 @@ with gr.Blocks(title="Super Res + Bas-Relief") as app:
127
  )
128
 
129
  if __name__ == "__main__":
130
- app.launch(share=False) # Ativando link público
 
127
  )
128
 
129
  if __name__ == "__main__":
130
+ app.launch(share=False)
super_resolve.py CHANGED
@@ -1,99 +1,130 @@
1
- #!/usr/bin/env python
2
-
3
- from argparse import ArgumentParser, Namespace
4
- import pickle
5
-
6
  import jax
7
- from jax import jit
8
  import jax.numpy as jnp
9
  import numpy as np
10
  from PIL import Image
11
-
 
 
 
 
12
  from model import build_thera
13
- from utils import make_grid, interpolate_grid
14
-
15
- MEAN = jnp.array([.4488, .4371, .4040])
16
- VAR = jnp.array([.25, .25, .25])
17
- PATCH_SIZE = 256
18
-
19
-
20
- def process_single(source, apply_encoder, apply_decoder, params, target_shape):
21
- t = jnp.float32((target_shape[0] / source.shape[1])**-2)[None]
22
- coords_nearest = jnp.asarray(make_grid(target_shape)[None])
23
- source_up = interpolate_grid(coords_nearest, source[None])
24
- source = jax.nn.standardize(source, mean=MEAN, variance=VAR)[None]
25
-
26
- encoding = apply_encoder(params, source)
27
- coords = jnp.asarray(make_grid(source_up.shape[1:3])[None]) # global sampling coords
28
- out = jnp.full_like(source_up, jnp.nan, dtype=jnp.float32)
29
-
30
- for h_min in range(0, coords.shape[1], PATCH_SIZE):
31
- h_max = min(h_min + PATCH_SIZE, coords.shape[1])
32
- for w_min in range(0, coords.shape[2], PATCH_SIZE):
33
- # apply decoder with one patch of coordinates
34
- w_max = min(w_min + PATCH_SIZE, coords.shape[2])
35
- coords_patch = coords[:, h_min:h_max, w_min:w_max]
36
- out_patch = apply_decoder(params, encoding, coords_patch, t)
37
- out = out.at[:, h_min:h_max, w_min:w_max].set(out_patch)
38
-
39
- out = out * jnp.sqrt(VAR)[None, None, None] + MEAN[None, None, None]
40
- out += source_up
41
- return out
42
-
43
 
44
- def process(source, model, params, target_shape, do_ensemble=True):
45
- apply_encoder = jit(model.apply_encoder)
46
- apply_decoder = jit(model.apply_decoder)
47
 
48
- outs = []
49
- for i_rot in range(4 if do_ensemble else 1):
50
- source_ = jnp.rot90(source, k=i_rot, axes=(-3, -2))
51
- target_shape_ = tuple(reversed(target_shape)) if i_rot % 2 else target_shape
52
- out = process_single(source_, apply_encoder, apply_decoder, params, target_shape_)
53
- outs.append(jnp.rot90(out, k=i_rot, axes=(-2, -3)))
54
 
55
- out = jnp.stack(outs).mean(0).clip(0., 1.)
56
- return jnp.rint(out[0] * 255).astype(jnp.uint8)
57
 
58
-
59
- def main(args: Namespace):
60
- source = np.asarray(Image.open(args.in_file)) / 255.
61
-
62
- if args.scale is not None:
63
- if args.size is not None:
64
- raise ValueError('Cannot specify both size and scale')
65
- target_shape = (
66
- round(source.shape[0] * args.scale),
67
- round(source.shape[1] * args.scale),
68
- )
69
- elif args.size is not None:
70
- target_shape = args.size
71
- else:
72
- raise ValueError('Must specify either size or scale')
73
-
74
- with open(args.checkpoint, 'rb') as fh:
75
  check = pickle.load(fh)
76
- params, backbone, size = check['model'], check['backbone'], check['size']
77
-
 
78
  model = build_thera(3, backbone, size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- out = process(source, model, params, target_shape, not args.no_ensemble)
81
-
82
- Image.fromarray(np.asarray(out)).save(args.out_file)
83
-
84
-
85
- def parse_args() -> Namespace:
86
- parser = ArgumentParser()
87
- parser.add_argument('in_file')
88
- parser.add_argument('out_file')
89
- parser.add_argument('--scale', type=float, help='Scale factor for super-resolution')
90
- parser.add_argument('--size', type=int, nargs=2,
91
- help='Target size (h, w), mutually exclusive with --scale')
92
- parser.add_argument('--checkpoint', help='Path to checkpoint file')
93
- parser.add_argument('--no-ensemble', action='store_true', help='Disable geo-ensemble')
94
- return parser.parse_args()
95
-
96
-
97
- if __name__ == '__main__':
98
- args = parse_args()
99
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
 
 
 
3
  import jax
 
4
  import jax.numpy as jnp
5
  import numpy as np
6
  from PIL import Image
7
+ import pickle
8
+ import warnings
9
+ from huggingface_hub import hf_hub_download
10
+ from diffusers import StableDiffusionXLImg2ImgPipeline
11
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
12
  from model import build_thera
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Configurações e supressão de avisos
15
+ warnings.filterwarnings("ignore", category=FutureWarning)
16
+ warnings.filterwarnings("ignore", category=UserWarning)
17
 
18
+ # Configurar dispositivos
19
+ JAX_DEVICE = jax.devices("cpu")[0]
20
+ TORCH_DEVICE = "cpu"
 
 
 
21
 
 
 
22
 
23
+ # 1. Carregar modelos do Thera ----------------------------------------------------------------
24
+ def load_thera_model(repo_id, filename):
25
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
26
+ with open(model_path, 'rb') as fh:
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  check = pickle.load(fh)
28
+ # Carregar estrutura completa de variáveis
29
+ variables = check['model'] # Deve conter {'params': ...}
30
+ backbone, size = check['backbone'], check['size']
31
  model = build_thera(3, backbone, size)
32
+ return model, variables
33
+
34
+
35
+ print("Carregando Thera EDSR...")
36
+ model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
37
+
38
+ # 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
39
+ print("Carregando SDXL + LoRA...")
40
+ pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
41
+ "stabilityai/stable-diffusion-xl-base-1.0",
42
+ torch_dtype=torch.float32
43
+ ).to(TORCH_DEVICE)
44
+ pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
45
+
46
+ # 3. Carregar modelo de profundidade ----------------------------------------------------------
47
+ print("Carregando DPT Depth...")
48
+ feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
49
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
50
+
51
+
52
+ # Pipeline principal --------------------------------------------------------------------------
53
+ def full_pipeline(image, prompt, scale_factor=2.0):
54
+ try:
55
+ # 1. Super Resolução com Thera
56
+ image = image.convert("RGB")
57
+ source = np.array(image) / 255.0
58
+ target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
59
+
60
+ source_jax = jax.device_put(source, JAX_DEVICE)
61
+ t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
62
+
63
+ # Chamada corrigida com estrutura de variáveis correta
64
+ upscaled = model_edsr.apply(
65
+ variables_edsr, # Estrutura completa {'params': ...}
66
+ source_jax,
67
+ t,
68
+ target_shape
69
+ )
70
 
71
+ upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
72
+
73
+ # 2. Gerar Bas-Relief
74
+ full_prompt = f"BAS-RELIEF {prompt}, insanely detailed and complex engraving relief, ultra-high definition, rich in detail, 16K resolution"
75
+ bas_relief = pipe(
76
+ prompt=full_prompt,
77
+ image=upscaled_pil,
78
+ strength=0.7,
79
+ num_inference_steps=25,
80
+ guidance_scale=7.5
81
+ ).images[0]
82
+
83
+ # 3. Calcular Depth Map
84
+ inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
85
+ with torch.no_grad():
86
+ outputs = depth_model(**inputs)
87
+ depth = outputs.predicted_depth
88
+
89
+ depth_map = torch.nn.functional.interpolate(
90
+ depth.unsqueeze(1),
91
+ size=bas_relief.size[::-1],
92
+ mode="bicubic"
93
+ ).squeeze().cpu().numpy()
94
+
95
+ depth_normalized = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
96
+ depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))
97
+
98
+ return upscaled_pil, bas_relief, depth_pil
99
+
100
+ except Exception as e:
101
+ raise gr.Error(f"Erro no processamento: {str(e)}")
102
+
103
+
104
+ # Interface Gradio ----------------------------------------------------------------------------
105
+ with gr.Blocks(title="Super Res + Bas-Relief") as app:
106
+ gr.Markdown("## 🔍 Super Resolução + 🗿 Bas-Relief + 🗺️ Profundidade")
107
+
108
+ with gr.Row():
109
+ with gr.Column():
110
+ img_input = gr.Image(type="pil", label="Imagem de Entrada")
111
+ prompt = gr.Textbox(
112
+ label="Descrição do Relevo",
113
+ value="insanely detailed and complex engraving relief, ultra-high definition, rich in detail, and 16K resolution."
114
+ )
115
+ scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
116
+ btn = gr.Button("Processar")
117
+
118
+ with gr.Column():
119
+ img_upscaled = gr.Image(label="Imagem Super Resolvida")
120
+ img_basrelief = gr.Image(label="Resultado Bas-Relief")
121
+ img_depth = gr.Image(label="Mapa de Profundidade")
122
+
123
+ btn.click(
124
+ full_pipeline,
125
+ inputs=[img_input, prompt, scale],
126
+ outputs=[img_upscaled, img_basrelief, img_depth]
127
+ )
128
+
129
+ if __name__ == "__main__":
130
+ app.launch(share=false)