ds1david commited on
Commit
d85fde4
Β·
1 Parent(s): 5c0e13b

fixing bugs

Browse files
Files changed (2) hide show
  1. app.py +43 -53
  2. utils.py +32 -170
app.py CHANGED
@@ -7,7 +7,6 @@ from PIL import Image
7
  import pickle
8
  import warnings
9
  import logging
10
- from datetime import datetime
11
  from huggingface_hub import hf_hub_download
12
  from diffusers import StableDiffusionXLImg2ImgPipeline
13
  from transformers import DPTImageProcessor, DPTForDepthEstimation
@@ -36,17 +35,16 @@ TORCH_DEVICE = "cpu"
36
  # 1. Carregar modelos do Thera ----------------------------------------------------------------
37
  def load_thera_model(repo_id, filename):
38
  try:
39
- logger.info(f"Iniciando carregamento do modelo Thera de {repo_id}")
40
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
41
  with open(model_path, 'rb') as fh:
42
  check = pickle.load(fh)
43
  variables = check['model']
44
  backbone, size = check['backbone'], check['size']
45
  model = build_thera(3, backbone, size)
46
- logger.info("Modelo Thera carregado com sucesso")
47
  return model, variables
48
  except Exception as e:
49
- logger.error(f"Falha ao carregar modelo Thera: {str(e)}")
50
  raise
51
 
52
 
@@ -55,80 +53,78 @@ model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.p
55
 
56
  # 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
57
  try:
58
- logger.info("Iniciando carregamento do SDXL + LoRA...")
59
  pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
60
  "stabilityai/stable-diffusion-xl-base-1.0",
61
  torch_dtype=torch.float32
62
  ).to(TORCH_DEVICE)
63
  pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
64
- logger.info("SDXL + LoRA carregado com sucesso")
65
  except Exception as e:
66
- logger.error(f"Falha ao carregar SDXL: {str(e)}")
67
  raise
68
 
69
  # 3. Carregar modelo de profundidade ----------------------------------------------------------
70
  try:
71
- logger.info("Iniciando carregamento do DPT Depth...")
72
  feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
73
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
74
- logger.info("Modelo DPT carregado com sucesso")
75
  except Exception as e:
76
- logger.error(f"Falha ao carregar DPT: {str(e)}")
77
  raise
78
 
79
 
80
- # Pipeline principal --------------------------------------------------------------------------
 
 
 
 
81
  def full_pipeline(image, prompt, scale_factor=2.0, progress=gr.Progress()):
82
  try:
83
- progress(0, desc="Iniciando processamento...")
84
 
85
- # 1. Super Resolução com Thera
86
- progress(0.1, desc="Convertendo imagem para RGB...")
87
  image = image.convert("RGB")
88
-
89
- progress(0.2, desc="Preparando entrada para super-resolução...")
90
  source = np.array(image) / 255.0
91
- original_size = image.size
92
- target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
93
- logger.info(f"Super-resolução: {original_size} β†’ {target_shape} (scale: {scale_factor}x)")
94
 
95
- progress(0.3, desc="Processando com Thera...")
 
 
 
 
 
 
 
 
 
 
96
  source_jax = jax.device_put(source, JAX_DEVICE)
97
  t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
98
 
99
- start_time = datetime.now()
100
  upscaled = model_edsr.apply(
101
  variables_edsr,
102
  source_jax,
103
  t,
104
  target_shape
105
  )
106
- logger.info(f"Super-resolução concluída em {datetime.now() - start_time}")
107
 
108
- progress(0.5, desc="Convertendo resultado...")
 
 
 
109
  upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
110
- logger.info(f"Tamanho após super-resolução: {upscaled_pil.size}")
111
 
112
- # 2. Gerar Bas-Relief
113
  progress(0.6, desc="Gerando Bas-Relief...")
114
- full_prompt = f"BAS-RELIEF {prompt}, insanely detailed and complex engraving relief, ultra-high definition, rich in detail, 16K resolution"
115
- logger.info(f"Prompt final: {full_prompt}")
116
-
117
- start_time = datetime.now()
118
  bas_relief = pipe(
119
  prompt=full_prompt,
120
  image=upscaled_pil,
121
  strength=0.7,
122
- num_inference_steps=25,
123
- guidance_scale=7.5
124
  ).images[0]
125
- logger.info(f"Bas-Relief gerado em {datetime.now() - start_time}")
126
 
127
- # 3. Calcular Depth Map
128
- progress(0.8, desc="Calculando mapa de profundidade...")
129
- start_time = datetime.now()
130
  inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
131
-
132
  with torch.no_grad():
133
  outputs = depth_model(**inputs)
134
  depth = outputs.predicted_depth
@@ -139,39 +135,34 @@ def full_pipeline(image, prompt, scale_factor=2.0, progress=gr.Progress()):
139
  mode="bicubic"
140
  ).squeeze().cpu().numpy()
141
 
142
- progress(0.9, desc="Processando mapa de profundidade...")
143
- depth_min = depth_map.min()
144
- depth_max = depth_map.max()
145
- depth_normalized = (depth_map - depth_min) / (depth_max - depth_min + 1e-8)
146
  depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))
147
- logger.info(f"Profundidade calculada em {datetime.now() - start_time} | Range: {depth_min:.2f}-{depth_max:.2f}")
148
 
149
- progress(1.0, desc="Finalizado!")
150
  return upscaled_pil, bas_relief, depth_pil
151
 
152
  except Exception as e:
153
- logger.error(f"Erro no processamento: {str(e)}", exc_info=True)
154
- raise gr.Error(f"Erro no processamento: {str(e)}")
155
 
156
 
157
  # Interface Gradio ----------------------------------------------------------------------------
158
- with gr.Blocks(title="Super Res + Bas-Relief") as app:
159
- gr.Markdown("## πŸ” Super Resolução + πŸ—Ώ Bas-Relief + πŸ—ΊοΈ Profundidade")
160
 
161
  with gr.Row():
162
  with gr.Column():
163
  img_input = gr.Image(type="pil", label="Imagem de Entrada")
164
  prompt = gr.Textbox(
165
- label="Descrição do Relevo",
166
- value="insanely detailed and complex engraving relief, ultra-high definition, rich in detail, and 16K resolution."
167
  )
168
  scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
169
  btn = gr.Button("Processar")
170
 
171
  with gr.Column():
172
- img_upscaled = gr.Image(label="Imagem Super Resolvida")
173
- img_basrelief = gr.Image(label="Resultado Bas-Relief")
174
- img_depth = gr.Image(label="Mapa de Profundidade")
175
 
176
  btn.click(
177
  full_pipeline,
@@ -180,5 +171,4 @@ with gr.Blocks(title="Super Res + Bas-Relief") as app:
180
  )
181
 
182
  if __name__ == "__main__":
183
- logger.info("Iniciando aplicação Gradio")
184
- app.launch(share=False)
 
7
  import pickle
8
  import warnings
9
  import logging
 
10
  from huggingface_hub import hf_hub_download
11
  from diffusers import StableDiffusionXLImg2ImgPipeline
12
  from transformers import DPTImageProcessor, DPTForDepthEstimation
 
35
  # 1. Carregar modelos do Thera ----------------------------------------------------------------
36
  def load_thera_model(repo_id, filename):
37
  try:
38
+ logger.info(f"Carregando modelo Thera de {repo_id}")
39
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
40
  with open(model_path, 'rb') as fh:
41
  check = pickle.load(fh)
42
  variables = check['model']
43
  backbone, size = check['backbone'], check['size']
44
  model = build_thera(3, backbone, size)
 
45
  return model, variables
46
  except Exception as e:
47
+ logger.error(f"Erro ao carregar modelo: {str(e)}")
48
  raise
49
 
50
 
 
53
 
54
  # 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
55
  try:
56
+ logger.info("Carregando SDXL + LoRA...")
57
  pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
58
  "stabilityai/stable-diffusion-xl-base-1.0",
59
  torch_dtype=torch.float32
60
  ).to(TORCH_DEVICE)
61
  pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
 
62
  except Exception as e:
63
+ logger.error(f"Erro ao carregar SDXL: {str(e)}")
64
  raise
65
 
66
  # 3. Carregar modelo de profundidade ----------------------------------------------------------
67
  try:
68
+ logger.info("Carregando DPT Depth...")
69
  feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
70
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
 
71
  except Exception as e:
72
+ logger.error(f"Erro ao carregar DPT: {str(e)}")
73
  raise
74
 
75
 
76
+ def adjust_size(size):
77
+ """Garante que o tamanho seja divisΓ­vel por 8"""
78
+ return (size // 8) * 8
79
+
80
+
81
  def full_pipeline(image, prompt, scale_factor=2.0, progress=gr.Progress()):
82
  try:
83
+ progress(0.1, desc="PrΓ©-processamento...")
84
 
85
+ # Converter e verificar imagem
 
86
  image = image.convert("RGB")
 
 
87
  source = np.array(image) / 255.0
 
 
 
88
 
89
+ # Adicionar dimensΓ£o de batch se necessΓ‘rio
90
+ if source.ndim == 3:
91
+ source = source[np.newaxis, ...]
92
+
93
+ # Ajustar tamanho alvo
94
+ target_shape = (
95
+ adjust_size(int(image.height * scale_factor)),
96
+ adjust_size(int(image.width * scale_factor))
97
+ )
98
+
99
+ progress(0.3, desc="Super-resolução...")
100
  source_jax = jax.device_put(source, JAX_DEVICE)
101
  t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
102
 
103
+ # Processar com Thera
104
  upscaled = model_edsr.apply(
105
  variables_edsr,
106
  source_jax,
107
  t,
108
  target_shape
109
  )
 
110
 
111
+ # Remover dimensΓ£o de batch se necessΓ‘rio
112
+ if upscaled.ndim == 4:
113
+ upscaled = upscaled[0]
114
+
115
  upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
 
116
 
 
117
  progress(0.6, desc="Gerando Bas-Relief...")
118
+ full_prompt = f"BAS-RELIEF {prompt}, ultra detailed engraving, 16K resolution"
 
 
 
119
  bas_relief = pipe(
120
  prompt=full_prompt,
121
  image=upscaled_pil,
122
  strength=0.7,
123
+ num_inference_steps=25
 
124
  ).images[0]
 
125
 
126
+ progress(0.8, desc="Calculando profundidade...")
 
 
127
  inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
 
128
  with torch.no_grad():
129
  outputs = depth_model(**inputs)
130
  depth = outputs.predicted_depth
 
135
  mode="bicubic"
136
  ).squeeze().cpu().numpy()
137
 
138
+ depth_normalized = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
 
 
 
139
  depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))
 
140
 
 
141
  return upscaled_pil, bas_relief, depth_pil
142
 
143
  except Exception as e:
144
+ logger.error(f"Erro: {str(e)}", exc_info=True)
145
+ raise gr.Error(f"Erro: {str(e)}")
146
 
147
 
148
  # Interface Gradio ----------------------------------------------------------------------------
149
+ with gr.Blocks(title="SuperRes + BasRelief") as app:
150
+ gr.Markdown("## πŸ–ΌοΈ Super Resolução + Bas-Relief + Mapa de Profundidade")
151
 
152
  with gr.Row():
153
  with gr.Column():
154
  img_input = gr.Image(type="pil", label="Imagem de Entrada")
155
  prompt = gr.Textbox(
156
+ label="Descrição",
157
+ value="insanely detailed and complex engraving relief, ultra-high definition"
158
  )
159
  scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
160
  btn = gr.Button("Processar")
161
 
162
  with gr.Column():
163
+ img_upscaled = gr.Image(label="Super Resolvida")
164
+ img_basrelief = gr.Image(label="Bas-Relief")
165
+ img_depth = gr.Image(label="Profundidade")
166
 
167
  btn.click(
168
  full_pipeline,
 
171
  )
172
 
173
  if __name__ == "__main__":
174
+ app.launch() # Sem compartilhamento pΓΊblico
 
utils.py CHANGED
@@ -1,174 +1,36 @@
1
- import gradio as gr
2
- import torch
3
  import jax
4
- import jax.numpy as jnp
5
  import numpy as np
6
- from PIL import Image
7
- import pickle
8
- import warnings
9
- import logging
10
- from huggingface_hub import hf_hub_download
11
- from diffusers import StableDiffusionXLImg2ImgPipeline
12
- from transformers import DPTImageProcessor, DPTForDepthEstimation
13
- from model import build_thera
14
-
15
- # Configuração de logging
16
- logging.basicConfig(
17
- level=logging.INFO,
18
- format='%(asctime)s - %(levelname)s - %(message)s',
19
- handlers=[
20
- logging.FileHandler("processing.log"),
21
- logging.StreamHandler()
22
- ]
23
- )
24
- logger = logging.getLogger(__name__)
25
-
26
- # Configuraçáes e supressão de avisos
27
- warnings.filterwarnings("ignore", category=FutureWarning)
28
- warnings.filterwarnings("ignore", category=UserWarning)
29
-
30
- # Configurar dispositivos
31
- JAX_DEVICE = jax.devices("cpu")[0]
32
- TORCH_DEVICE = "cpu"
33
-
34
-
35
- # 1. Carregar modelos do Thera ----------------------------------------------------------------
36
- def load_thera_model(repo_id, filename):
37
- try:
38
- logger.info(f"Carregando modelo Thera de {repo_id}")
39
- model_path = hf_hub_download(repo_id=repo_id, filename=filename)
40
- with open(model_path, 'rb') as fh:
41
- check = pickle.load(fh)
42
- variables = check['model']
43
- backbone, size = check['backbone'], check['size']
44
- model = build_thera(3, backbone, size)
45
- return model, variables
46
- except Exception as e:
47
- logger.error(f"Erro ao carregar modelo: {str(e)}")
48
- raise
49
-
50
-
51
- logger.info("Carregando Thera EDSR...")
52
- model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
53
-
54
- # 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
55
- try:
56
- logger.info("Carregando SDXL + LoRA...")
57
- pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
58
- "stabilityai/stable-diffusion-xl-base-1.0",
59
- torch_dtype=torch.float32
60
- ).to(TORCH_DEVICE)
61
- pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
62
- except Exception as e:
63
- logger.error(f"Erro ao carregar SDXL: {str(e)}")
64
- raise
65
-
66
- # 3. Carregar modelo de profundidade ----------------------------------------------------------
67
- try:
68
- logger.info("Carregando DPT Depth...")
69
- feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
70
- depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
71
- except Exception as e:
72
- logger.error(f"Erro ao carregar DPT: {str(e)}")
73
- raise
74
-
75
-
76
- def adjust_size(size):
77
- """Garante que o tamanho seja divisΓ­vel por 8"""
78
- return (size // 8) * 8
79
-
80
-
81
- def full_pipeline(image, prompt, scale_factor=2.0, progress=gr.Progress()):
82
- try:
83
- progress(0.1, desc="PrΓ©-processamento...")
84
-
85
- # Converter e verificar imagem
86
- image = image.convert("RGB")
87
- source = np.array(image) / 255.0
88
-
89
- # Adicionar dimensΓ£o de batch se necessΓ‘rio
90
- if source.ndim == 3:
91
- source = source[np.newaxis, ...]
92
-
93
- # Ajustar tamanho alvo
94
- target_shape = (
95
- adjust_size(int(image.height * scale_factor)),
96
- adjust_size(int(image.width * scale_factor))
97
- )
98
-
99
- progress(0.3, desc="Super-resolução...")
100
- source_jax = jax.device_put(source, JAX_DEVICE)
101
- t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
102
-
103
- # Processar com Thera
104
- upscaled = model_edsr.apply(
105
- variables_edsr,
106
- source_jax,
107
- t,
108
- target_shape
109
- )
110
-
111
- # Remover dimensΓ£o de batch se necessΓ‘rio
112
- if upscaled.ndim == 4:
113
- upscaled = upscaled[0]
114
-
115
- upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
116
-
117
- progress(0.6, desc="Gerando Bas-Relief...")
118
- full_prompt = f"BAS-RELIEF {prompt}, ultra detailed engraving, 16K resolution"
119
- bas_relief = pipe(
120
- prompt=full_prompt,
121
- image=upscaled_pil,
122
- strength=0.7,
123
- num_inference_steps=25
124
- ).images[0]
125
-
126
- progress(0.8, desc="Calculando profundidade...")
127
- inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
128
- with torch.no_grad():
129
- outputs = depth_model(**inputs)
130
- depth = outputs.predicted_depth
131
-
132
- depth_map = torch.nn.functional.interpolate(
133
- depth.unsqueeze(1),
134
- size=bas_relief.size[::-1],
135
- mode="bicubic"
136
- ).squeeze().cpu().numpy()
137
-
138
- depth_normalized = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
139
- depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))
140
-
141
- return upscaled_pil, bas_relief, depth_pil
142
-
143
- except Exception as e:
144
- logger.error(f"Erro: {str(e)}", exc_info=True)
145
- raise gr.Error(f"Erro: {str(e)}")
146
-
147
-
148
- # Interface Gradio ----------------------------------------------------------------------------
149
- with gr.Blocks(title="SuperRes + BasRelief") as app:
150
- gr.Markdown("## πŸ–ΌοΈ Super Resolução + Bas-Relief + Mapa de Profundidade")
151
-
152
- with gr.Row():
153
- with gr.Column():
154
- img_input = gr.Image(type="pil", label="Imagem de Entrada")
155
- prompt = gr.Textbox(
156
- label="Descrição",
157
- value="insanely detailed and complex engraving relief, ultra-high definition"
158
- )
159
- scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
160
- btn = gr.Button("Processar")
161
-
162
- with gr.Column():
163
- img_upscaled = gr.Image(label="Super Resolvida")
164
- img_basrelief = gr.Image(label="Bas-Relief")
165
- img_depth = gr.Image(label="Profundidade")
166
 
167
- btn.click(
168
- full_pipeline,
169
- inputs=[img_input, prompt, scale],
170
- outputs=[img_upscaled, img_basrelief, img_depth]
171
- )
172
 
173
- if __name__ == "__main__":
174
- app.launch() # Sem compartilhamento pΓΊblico
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
  import jax
 
4
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
 
 
 
 
 
6
 
7
+ def repeat_vmap(fun, in_axes=[0]):
8
+ for axes in in_axes:
9
+ fun = jax.vmap(fun, in_axes=axes)
10
+ return fun
11
+
12
+
13
+ def make_grid(patch_size: int | tuple[int, int]):
14
+ if isinstance(patch_size, int):
15
+ patch_size = (patch_size, patch_size)
16
+ offset_h, offset_w = 1 / (2 * np.array(patch_size))
17
+ space_h = np.linspace(-0.5 + offset_h, 0.5 - offset_h, patch_size[0])
18
+ space_w = np.linspace(-0.5 + offset_w, 0.5 - offset_w, patch_size[1])
19
+ return np.stack(np.meshgrid(space_h, space_w, indexing='ij'), axis=-1) # [h, w]
20
+
21
+
22
+ def interpolate_grid(coords, grid, order=0):
23
+ """
24
+ args:
25
+ coords: Tensor of shape (B, H, W, 2) with coordinates in [-0.5, 0.5]
26
+ grid: Tensor of shape (B, H', W', C)
27
+ returns:
28
+ Tensor of shape (B, H, W, C) with interpolated values
29
+ """
30
+ # convert [-0.5, 0.5] -> [0, size], where pixel centers are expected at
31
+ # [-0.5 + 1 / (2*size), ..., 0.5 - 1 / (2*size)]
32
+ coords = coords.transpose((0, 3, 1, 2))
33
+ coords = coords.at[:, 0].set(coords[:, 0] * grid.shape[-3] + (grid.shape[-3] - 1) / 2)
34
+ coords = coords.at[:, 1].set(coords[:, 1] * grid.shape[-2] + (grid.shape[-2] - 1) / 2)
35
+ map_coordinates = partial(jax.scipy.ndimage.map_coordinates, order=order, mode='nearest')
36
+ return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)