ds1david commited on
Commit
d160dc6
·
1 Parent(s): 4a3fe77

fixing bugs

Browse files
Files changed (2) hide show
  1. app.py +48 -82
  2. utils.py +29 -31
app.py CHANGED
@@ -1,16 +1,19 @@
 
 
 
 
1
  import gradio as gr
2
- import torch
3
  import jax
4
  import jax.numpy as jnp
5
  import numpy as np
 
6
  from PIL import Image
7
- import pickle
8
- import warnings
9
- import logging
10
- from huggingface_hub import hf_hub_download
11
  from diffusers import StableDiffusionXLImg2ImgPipeline
 
12
  from transformers import DPTImageProcessor, DPTForDepthEstimation
 
13
  from model import build_thera
 
14
 
15
  # Configuração de logging
16
  logging.basicConfig(
@@ -23,111 +26,84 @@ logging.basicConfig(
23
  )
24
  logger = logging.getLogger(__name__)
25
 
26
- # Configurações e supressão de avisos
27
- warnings.filterwarnings("ignore", category=FutureWarning)
28
- warnings.filterwarnings("ignore", category=UserWarning)
29
-
30
- # Configurar dispositivos
31
  JAX_DEVICE = jax.devices("cpu")[0]
32
  TORCH_DEVICE = "cpu"
33
 
34
 
35
- # 1. Carregar modelos do Thera ----------------------------------------------------------------
36
  def load_thera_model(repo_id, filename):
37
  try:
38
- logger.info(f"Carregando modelo Thera de {repo_id}")
39
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
40
  with open(model_path, 'rb') as fh:
41
  check = pickle.load(fh)
42
  variables = check['model']
43
  backbone, size = check['backbone'], check['size']
44
- model = build_thera(3, backbone, size)
45
- return model, variables
46
  except Exception as e:
47
- logger.error(f"Erro ao carregar modelo: {str(e)}")
48
  raise
49
 
50
 
51
- logger.info("Carregando Thera EDSR...")
52
  model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
53
-
54
- # 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
55
- try:
56
- logger.info("Carregando SDXL + LoRA...")
57
- pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
58
- "stabilityai/stable-diffusion-xl-base-1.0",
59
- torch_dtype=torch.float32
60
- ).to(TORCH_DEVICE)
61
- pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
62
- except Exception as e:
63
- logger.error(f"Erro ao carregar SDXL: {str(e)}")
64
- raise
65
-
66
- # 3. Carregar modelo de profundidade ----------------------------------------------------------
67
- try:
68
- logger.info("Carregando DPT Depth...")
69
- feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
70
- depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
71
- except Exception as e:
72
- logger.error(f"Erro ao carregar DPT: {str(e)}")
73
- raise
74
 
75
 
76
  def adjust_size(size):
77
- """Garante que o tamanho seja divisível por 8"""
78
- return (size // 8) * 8
79
 
80
 
81
  def full_pipeline(image, prompt, scale_factor=2.0, progress=gr.Progress()):
82
  try:
83
- progress(0.1, desc="Pré-processamento...")
84
-
85
- # Converter e verificar imagem
86
  image = image.convert("RGB")
87
  source = np.array(image) / 255.0
88
 
89
- # Adicionar dimensão de batch se necessário
90
- if source.ndim == 3:
91
- source = source[np.newaxis, ...]
92
-
93
- # Ajustar tamanho alvo
94
  target_shape = (
95
  adjust_size(int(image.height * scale_factor)),
96
  adjust_size(int(image.width * scale_factor))
97
  )
 
 
 
 
 
98
 
99
- progress(0.3, desc="Super-resolução...")
100
- source_jax = jax.device_put(source, JAX_DEVICE)
 
101
  t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
102
 
103
- # Processar com Thera
104
  upscaled = model_edsr.apply(
105
  variables_edsr,
106
  source_jax,
107
  t,
108
  target_shape
109
  )
 
110
 
111
- # Remover dimensão de batch se necessário
112
- if upscaled.ndim == 4:
113
- upscaled = upscaled[0]
114
-
115
- upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
116
-
117
- progress(0.6, desc="Gerando Bas-Relief...")
118
- full_prompt = f"BAS-RELIEF {prompt}, ultra detailed engraving, 16K resolution"
119
  bas_relief = pipe(
120
- prompt=full_prompt,
121
  image=upscaled_pil,
122
  strength=0.7,
123
  num_inference_steps=25
124
  ).images[0]
125
 
 
126
  progress(0.8, desc="Calculando profundidade...")
127
  inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
128
  with torch.no_grad():
129
- outputs = depth_model(**inputs)
130
- depth = outputs.predicted_depth
131
 
132
  depth_map = torch.nn.functional.interpolate(
133
  depth.unsqueeze(1),
@@ -141,34 +117,24 @@ def full_pipeline(image, prompt, scale_factor=2.0, progress=gr.Progress()):
141
  return upscaled_pil, bas_relief, depth_pil
142
 
143
  except Exception as e:
144
- logger.error(f"Erro: {str(e)}", exc_info=True)
145
- raise gr.Error(f"Erro: {str(e)}")
146
 
147
 
148
- # Interface Gradio ----------------------------------------------------------------------------
149
  with gr.Blocks(title="SuperRes + BasRelief") as app:
150
- gr.Markdown("## 🖼️ Super Resolução + Bas-Relief + Mapa de Profundidade")
151
-
152
  with gr.Row():
153
  with gr.Column():
154
- img_input = gr.Image(type="pil", label="Imagem de Entrada")
155
- prompt = gr.Textbox(
156
- label="Descrição",
157
- value="insanely detailed and complex engraving relief, ultra-high definition"
158
- )
159
- scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
160
- btn = gr.Button("Processar")
161
-
162
  with gr.Column():
163
- img_upscaled = gr.Image(label="Super Resolvida")
164
  img_basrelief = gr.Image(label="Bas-Relief")
165
  img_depth = gr.Image(label="Profundidade")
166
-
167
- btn.click(
168
- full_pipeline,
169
- inputs=[img_input, prompt, scale],
170
- outputs=[img_upscaled, img_basrelief, img_depth]
171
- )
172
 
173
  if __name__ == "__main__":
174
- app.launch() # Sem compartilhamento público
 
1
+ import logging
2
+ import pickle
3
+ import warnings
4
+
5
  import gradio as gr
 
6
  import jax
7
  import jax.numpy as jnp
8
  import numpy as np
9
+ import torch
10
  from PIL import Image
 
 
 
 
11
  from diffusers import StableDiffusionXLImg2ImgPipeline
12
+ from huggingface_hub import hf_hub_download
13
  from transformers import DPTImageProcessor, DPTForDepthEstimation
14
+
15
  from model import build_thera
16
+ from utils import make_grid
17
 
18
  # Configuração de logging
19
  logging.basicConfig(
 
26
  )
27
  logger = logging.getLogger(__name__)
28
 
29
+ # Configurações
30
+ warnings.filterwarnings("ignore")
 
 
 
31
  JAX_DEVICE = jax.devices("cpu")[0]
32
  TORCH_DEVICE = "cpu"
33
 
34
 
 
35
  def load_thera_model(repo_id, filename):
36
  try:
 
37
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
38
  with open(model_path, 'rb') as fh:
39
  check = pickle.load(fh)
40
  variables = check['model']
41
  backbone, size = check['backbone'], check['size']
42
+ return build_thera(3, backbone, size), variables
 
43
  except Exception as e:
44
+ logger.error(f"Erro ao carregar Thera: {str(e)}")
45
  raise
46
 
47
 
48
+ logger.info("Carregando modelos...")
49
  model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
50
+ pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
51
+ "stabilityai/stable-diffusion-xl-base-1.0",
52
+ torch_dtype=torch.float32
53
+ ).to(TORCH_DEVICE)
54
+ pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
55
+ feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
56
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  def adjust_size(size):
60
+ return max(8, (size // 8) * 8)
 
61
 
62
 
63
  def full_pipeline(image, prompt, scale_factor=2.0, progress=gr.Progress()):
64
  try:
65
+ progress(0.1, desc="Iniciando...")
 
 
66
  image = image.convert("RGB")
67
  source = np.array(image) / 255.0
68
 
69
+ # Ajuste de dimensões
 
 
 
 
70
  target_shape = (
71
  adjust_size(int(image.height * scale_factor)),
72
  adjust_size(int(image.width * scale_factor))
73
  )
74
+ logger.info(f"Transformação: {image.size} → {target_shape}")
75
+
76
+ # Gerar grid
77
+ coords = make_grid(target_shape)
78
+ logger.debug(f"Coords shape: {coords.shape}")
79
 
80
+ # Super-resolução
81
+ progress(0.3, desc="Processando super-resolução...")
82
+ source_jax = jax.device_put(source[np.newaxis, ...], JAX_DEVICE)
83
  t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
84
 
 
85
  upscaled = model_edsr.apply(
86
  variables_edsr,
87
  source_jax,
88
  t,
89
  target_shape
90
  )
91
+ upscaled_pil = Image.fromarray((np.array(upscaled[0]) * 255).astype(np.uint8))
92
 
93
+ # Bas-Relief
94
+ progress(0.6, desc="Gerando relevo...")
 
 
 
 
 
 
95
  bas_relief = pipe(
96
+ prompt=f"BAS-RELIEF {prompt}, ultra detailed engraving, 16K resolution",
97
  image=upscaled_pil,
98
  strength=0.7,
99
  num_inference_steps=25
100
  ).images[0]
101
 
102
+ # Depth Map
103
  progress(0.8, desc="Calculando profundidade...")
104
  inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
105
  with torch.no_grad():
106
+ depth = depth_model(**inputs).predicted_depth
 
107
 
108
  depth_map = torch.nn.functional.interpolate(
109
  depth.unsqueeze(1),
 
117
  return upscaled_pil, bas_relief, depth_pil
118
 
119
  except Exception as e:
120
+ logger.error(f"ERRO: {str(e)}", exc_info=True)
121
+ raise gr.Error(f"Erro no processamento: {str(e)}")
122
 
123
 
124
+ # Interface
125
  with gr.Blocks(title="SuperRes + BasRelief") as app:
126
+ gr.Markdown("## 🖼️ Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")
 
127
  with gr.Row():
128
  with gr.Column():
129
+ img_input = gr.Image(type="pil", label="Entrada")
130
+ prompt = gr.Textbox("Escultura detalhada em mármore, alto relevo", label="Descrição")
131
+ scale = gr.Slider(1.0, 4.0, value=2.0, label="Escala")
132
+ btn = gr.Button("Processar ▶️")
 
 
 
 
133
  with gr.Column():
134
+ img_upscaled = gr.Image(label="Super Resolução")
135
  img_basrelief = gr.Image(label="Bas-Relief")
136
  img_depth = gr.Image(label="Profundidade")
137
+ btn.click(full_pipeline, [img_input, prompt, scale], [img_upscaled, img_basrelief, img_depth])
 
 
 
 
 
138
 
139
  if __name__ == "__main__":
140
+ app.launch()
utils.py CHANGED
@@ -1,6 +1,6 @@
1
  from functools import partial
2
-
3
  import jax
 
4
  import numpy as np
5
 
6
 
@@ -14,46 +14,44 @@ def repeat_vmap(fun, in_axes=None):
14
 
15
  def make_grid(patch_size: int | tuple[int, int]):
16
  if isinstance(patch_size, int):
17
- patch_size = (patch_size, patch_size)
 
18
  offset_h, offset_w = 1 / (2 * np.array(patch_size))
19
  space_h = np.linspace(-0.5 + offset_h, 0.5 - offset_h, patch_size[0])
20
  space_w = np.linspace(-0.5 + offset_w, 0.5 - offset_w, patch_size[1])
21
- return np.stack(np.meshgrid(space_h, space_w, indexing='ij'), axis=-1) # [h, w]
 
 
22
 
23
 
24
  def interpolate_grid(coords, grid, order=0):
25
- """
26
- Args:
27
  coords: Tensor de shape (B, H, W, 2) ou (H, W, 2)
28
  grid: Tensor de shape (B, H', W', C)
 
29
  """
30
- # Adicionar dimensão de batch se necessário
31
- if coords.ndim == 3:
32
- coords = coords[np.newaxis, ...]
 
 
33
 
34
- # Verificar dimensões
35
- assert coords.ndim == 4, f"Dimensões inválidas para coords: {coords.shape}"
36
- assert grid.ndim == 4, f"Dimensões inválidas para grid: {grid.shape}"
37
 
38
- # Ajustar transposição de forma segura
39
- try:
40
  coords = coords.transpose((0, 3, 1, 2))
41
- except ValueError as e:
42
- raise ValueError(f"Falha na transposição: {coords.shape} (0,3,1,2)") from e
43
-
44
- # Conversão de coordenadas
45
- coords = coords.at[:, 0].set(coords[:, 0] * grid.shape[-3] + (grid.shape[-3] - 1) / 2)
46
- coords = coords.at[:, 1].set(coords[:, 1] * grid.shape[-2] + (grid.shape[-2] - 1) / 2)
47
-
48
- # Interpolação com JAX
49
- map_coordinates = partial(jax.scipy.ndimage.map_coordinates,
50
- order=order,
51
- mode='nearest')
52
-
53
- return jax.vmap( # Sobre batches
54
- jax.vmap( # Sobre canais
55
- map_coordinates,
56
- in_axes=(2, None), # (C, H', W'), (B, 2, H, W)
57
  out_axes=2
58
- )
59
- )(grid, coords)
 
 
 
 
1
  from functools import partial
 
2
  import jax
3
+ import jax.numpy as jnp
4
  import numpy as np
5
 
6
 
 
14
 
15
  def make_grid(patch_size: int | tuple[int, int]):
16
  if isinstance(patch_size, int):
17
+ patch_size = (max(1, patch_size), max(1, patch_size))
18
+
19
  offset_h, offset_w = 1 / (2 * np.array(patch_size))
20
  space_h = np.linspace(-0.5 + offset_h, 0.5 - offset_h, patch_size[0])
21
  space_w = np.linspace(-0.5 + offset_w, 0.5 - offset_w, patch_size[1])
22
+
23
+ grid = np.stack(np.meshgrid(space_h, space_w, indexing='ij'), axis=-1)
24
+ return grid[np.newaxis, ...] # Adiciona dimensão de batch
25
 
26
 
27
  def interpolate_grid(coords, grid, order=0):
28
+ """Args:
 
29
  coords: Tensor de shape (B, H, W, 2) ou (H, W, 2)
30
  grid: Tensor de shape (B, H', W', C)
31
+ order: default 0
32
  """
33
+ try:
34
+ # Converter para array JAX e ajustar dimensões
35
+ coords = jnp.asarray(coords)
36
+ while coords.ndim < 4:
37
+ coords = coords[jnp.newaxis, ...]
38
 
39
+ # Verificação final de dimensões
40
+ if coords.shape[-1] != 2 or coords.ndim != 4:
41
+ raise ValueError(f"Formato inválido: {coords.shape}. Esperado (B, H, W, 2)")
42
 
43
+ # Transformação de coordenadas
 
44
  coords = coords.transpose((0, 3, 1, 2))
45
+ coords = coords.at[:, 0].set(coords[:, 0] * grid.shape[-3] + (grid.shape[-3] - 1) / 2)
46
+ coords = coords.at[:, 1].set(coords[:, 1] * grid.shape[-2] + (grid.shape[-2] - 1) / 2)
47
+
48
+ # Função de interpolação vetorizada
49
+ map_fn = jax.vmap(jax.vmap(
50
+ partial(jax.scipy.ndimage.map_coordinates, order=order, mode='nearest'),
51
+ in_axes=(2, None),
 
 
 
 
 
 
 
 
 
52
  out_axes=2
53
+ ))
54
+ return map_fn(grid, coords)
55
+
56
+ except Exception as e:
57
+ raise RuntimeError(f"Falha na interpolação: {str(e)}") from e