ds1david commited on
Commit
a02c6d7
·
1 Parent(s): eb02bc3
Files changed (2) hide show
  1. app.py +6 -6
  2. super_resolve.py +87 -118
app.py CHANGED
@@ -25,15 +25,15 @@ def load_thera_model(repo_id, filename):
25
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
26
  with open(model_path, 'rb') as fh:
27
  check = pickle.load(fh)
28
- # Ajustar a estrutura dos parâmetros
29
- params = check['model']['params'] # Acessar os parâmetros corretamente
30
  backbone, size = check['backbone'], check['size']
31
  model = build_thera(3, backbone, size)
32
- return model, params
33
 
34
 
35
  print("Carregando Thera EDSR...")
36
- model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
37
 
38
  # 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
39
  print("Carregando SDXL + LoRA...")
@@ -60,9 +60,9 @@ def full_pipeline(image, prompt, scale_factor=2.0):
60
  source_jax = jax.device_put(source, JAX_DEVICE)
61
  t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
62
 
63
- # Chamada corrigida com estrutura de parâmetros adequada
64
  upscaled = model_edsr.apply(
65
- params_edsr, # Parâmetros estruturados corretamente
66
  source_jax,
67
  t,
68
  target_shape
 
25
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
26
  with open(model_path, 'rb') as fh:
27
  check = pickle.load(fh)
28
+ # Carregar estrutura completa de variáveis
29
+ variables = check['model'] # Deve conter {'params': ...}
30
  backbone, size = check['backbone'], check['size']
31
  model = build_thera(3, backbone, size)
32
+ return model, variables
33
 
34
 
35
  print("Carregando Thera EDSR...")
36
+ model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
37
 
38
  # 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
39
  print("Carregando SDXL + LoRA...")
 
60
  source_jax = jax.device_put(source, JAX_DEVICE)
61
  t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
62
 
63
+ # Chamada corrigida com estrutura de variáveis correta
64
  upscaled = model_edsr.apply(
65
+ variables_edsr, # Estrutura completa {'params': ...}
66
  source_jax,
67
  t,
68
  target_shape
super_resolve.py CHANGED
@@ -1,130 +1,99 @@
1
- import gradio as gr
2
- import torch
 
 
 
3
  import jax
 
4
  import jax.numpy as jnp
5
  import numpy as np
6
  from PIL import Image
7
- import pickle
8
- import warnings
9
- from huggingface_hub import hf_hub_download
10
- from diffusers import StableDiffusionXLImg2ImgPipeline
11
- from transformers import DPTImageProcessor, DPTForDepthEstimation
12
  from model import build_thera
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Configurações e supressão de avisos
15
- warnings.filterwarnings("ignore", category=FutureWarning)
16
- warnings.filterwarnings("ignore", category=UserWarning)
17
 
18
- # Configurar dispositivos
19
- JAX_DEVICE = jax.devices("cpu")[0]
20
- TORCH_DEVICE = "cpu"
 
 
 
21
 
 
 
22
 
23
- # 1. Carregar modelos do Thera ----------------------------------------------------------------
24
- def load_thera_model(repo_id, filename):
25
- model_path = hf_hub_download(repo_id=repo_id, filename=filename)
26
- with open(model_path, 'rb') as fh:
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  check = pickle.load(fh)
28
- # Carregar estrutura completa de variáveis
29
- variables = check['model'] # Deve conter {'params': ...}
30
- backbone, size = check['backbone'], check['size']
31
  model = build_thera(3, backbone, size)
32
- return model, variables
33
-
34
-
35
- print("Carregando Thera EDSR...")
36
- model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
37
-
38
- # 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
39
- print("Carregando SDXL + LoRA...")
40
- pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
41
- "stabilityai/stable-diffusion-xl-base-1.0",
42
- torch_dtype=torch.float32
43
- ).to(TORCH_DEVICE)
44
- pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
45
-
46
- # 3. Carregar modelo de profundidade ----------------------------------------------------------
47
- print("Carregando DPT Depth...")
48
- feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
49
- depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
50
-
51
-
52
- # Pipeline principal --------------------------------------------------------------------------
53
- def full_pipeline(image, prompt, scale_factor=2.0):
54
- try:
55
- # 1. Super Resolução com Thera
56
- image = image.convert("RGB")
57
- source = np.array(image) / 255.0
58
- target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
59
-
60
- source_jax = jax.device_put(source, JAX_DEVICE)
61
- t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
62
-
63
- # Chamada corrigida com estrutura de variáveis correta
64
- upscaled = model_edsr.apply(
65
- variables_edsr, # Estrutura completa {'params': ...}
66
- source_jax,
67
- t,
68
- target_shape
69
- )
70
 
71
- upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
72
-
73
- # 2. Gerar Bas-Relief
74
- full_prompt = f"BAS-RELIEF {prompt}, insanely detailed and complex engraving relief, ultra-high definition, rich in detail, 16K resolution"
75
- bas_relief = pipe(
76
- prompt=full_prompt,
77
- image=upscaled_pil,
78
- strength=0.7,
79
- num_inference_steps=25,
80
- guidance_scale=7.5
81
- ).images[0]
82
-
83
- # 3. Calcular Depth Map
84
- inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
85
- with torch.no_grad():
86
- outputs = depth_model(**inputs)
87
- depth = outputs.predicted_depth
88
-
89
- depth_map = torch.nn.functional.interpolate(
90
- depth.unsqueeze(1),
91
- size=bas_relief.size[::-1],
92
- mode="bicubic"
93
- ).squeeze().cpu().numpy()
94
-
95
- depth_normalized = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
96
- depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))
97
-
98
- return upscaled_pil, bas_relief, depth_pil
99
-
100
- except Exception as e:
101
- raise gr.Error(f"Erro no processamento: {str(e)}")
102
-
103
-
104
- # Interface Gradio ----------------------------------------------------------------------------
105
- with gr.Blocks(title="Super Res + Bas-Relief") as app:
106
- gr.Markdown("## 🔍 Super Resolução + 🗿 Bas-Relief + 🗺️ Profundidade")
107
-
108
- with gr.Row():
109
- with gr.Column():
110
- img_input = gr.Image(type="pil", label="Imagem de Entrada")
111
- prompt = gr.Textbox(
112
- label="Descrição do Relevo",
113
- value="insanely detailed and complex engraving relief, ultra-high definition, rich in detail, and 16K resolution."
114
- )
115
- scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
116
- btn = gr.Button("Processar")
117
-
118
- with gr.Column():
119
- img_upscaled = gr.Image(label="Imagem Super Resolvida")
120
- img_basrelief = gr.Image(label="Resultado Bas-Relief")
121
- img_depth = gr.Image(label="Mapa de Profundidade")
122
-
123
- btn.click(
124
- full_pipeline,
125
- inputs=[img_input, prompt, scale],
126
- outputs=[img_upscaled, img_basrelief, img_depth]
127
- )
128
-
129
- if __name__ == "__main__":
130
- app.launch(share=false)
 
1
+ #!/usr/bin/env python
2
+
3
+ from argparse import ArgumentParser, Namespace
4
+ import pickle
5
+
6
  import jax
7
+ from jax import jit
8
  import jax.numpy as jnp
9
  import numpy as np
10
  from PIL import Image
11
+
 
 
 
 
12
  from model import build_thera
13
+ from utils import make_grid, interpolate_grid
14
+
15
+ MEAN = jnp.array([.4488, .4371, .4040])
16
+ VAR = jnp.array([.25, .25, .25])
17
+ PATCH_SIZE = 256
18
+
19
+
20
+ def process_single(source, apply_encoder, apply_decoder, params, target_shape):
21
+ t = jnp.float32((target_shape[0] / source.shape[1])**-2)[None]
22
+ coords_nearest = jnp.asarray(make_grid(target_shape)[None])
23
+ source_up = interpolate_grid(coords_nearest, source[None])
24
+ source = jax.nn.standardize(source, mean=MEAN, variance=VAR)[None]
25
+
26
+ encoding = apply_encoder(params, source)
27
+ coords = jnp.asarray(make_grid(source_up.shape[1:3])[None]) # global sampling coords
28
+ out = jnp.full_like(source_up, jnp.nan, dtype=jnp.float32)
29
+
30
+ for h_min in range(0, coords.shape[1], PATCH_SIZE):
31
+ h_max = min(h_min + PATCH_SIZE, coords.shape[1])
32
+ for w_min in range(0, coords.shape[2], PATCH_SIZE):
33
+ # apply decoder with one patch of coordinates
34
+ w_max = min(w_min + PATCH_SIZE, coords.shape[2])
35
+ coords_patch = coords[:, h_min:h_max, w_min:w_max]
36
+ out_patch = apply_decoder(params, encoding, coords_patch, t)
37
+ out = out.at[:, h_min:h_max, w_min:w_max].set(out_patch)
38
+
39
+ out = out * jnp.sqrt(VAR)[None, None, None] + MEAN[None, None, None]
40
+ out += source_up
41
+ return out
42
+
43
 
44
+ def process(source, model, params, target_shape, do_ensemble=True):
45
+ apply_encoder = jit(model.apply_encoder)
46
+ apply_decoder = jit(model.apply_decoder)
47
 
48
+ outs = []
49
+ for i_rot in range(4 if do_ensemble else 1):
50
+ source_ = jnp.rot90(source, k=i_rot, axes=(-3, -2))
51
+ target_shape_ = tuple(reversed(target_shape)) if i_rot % 2 else target_shape
52
+ out = process_single(source_, apply_encoder, apply_decoder, params, target_shape_)
53
+ outs.append(jnp.rot90(out, k=i_rot, axes=(-2, -3)))
54
 
55
+ out = jnp.stack(outs).mean(0).clip(0., 1.)
56
+ return jnp.rint(out[0] * 255).astype(jnp.uint8)
57
 
58
+
59
+ def main(args: Namespace):
60
+ source = np.asarray(Image.open(args.in_file)) / 255.
61
+
62
+ if args.scale is not None:
63
+ if args.size is not None:
64
+ raise ValueError('Cannot specify both size and scale')
65
+ target_shape = (
66
+ round(source.shape[0] * args.scale),
67
+ round(source.shape[1] * args.scale),
68
+ )
69
+ elif args.size is not None:
70
+ target_shape = args.size
71
+ else:
72
+ raise ValueError('Must specify either size or scale')
73
+
74
+ with open(args.checkpoint, 'rb') as fh:
75
  check = pickle.load(fh)
76
+ params, backbone, size = check['model'], check['backbone'], check['size']
77
+
 
78
  model = build_thera(3, backbone, size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ out = process(source, model, params, target_shape, not args.no_ensemble)
81
+
82
+ Image.fromarray(np.asarray(out)).save(args.out_file)
83
+
84
+
85
+ def parse_args() -> Namespace:
86
+ parser = ArgumentParser()
87
+ parser.add_argument('in_file')
88
+ parser.add_argument('out_file')
89
+ parser.add_argument('--scale', type=float, help='Scale factor for super-resolution')
90
+ parser.add_argument('--size', type=int, nargs=2,
91
+ help='Target size (h, w), mutually exclusive with --scale')
92
+ parser.add_argument('--checkpoint', help='Path to checkpoint file')
93
+ parser.add_argument('--no-ensemble', action='store_true', help='Disable geo-ensemble')
94
+ return parser.parse_args()
95
+
96
+
97
+ if __name__ == '__main__':
98
+ args = parse_args()
99
+ main(args)