ds1david commited on
Commit
5c0e13b
·
1 Parent(s): 46bb495
Files changed (1) hide show
  1. utils.py +170 -32
utils.py CHANGED
@@ -1,36 +1,174 @@
1
- from functools import partial
2
-
3
  import jax
 
4
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
 
 
 
 
 
6
 
7
- def repeat_vmap(fun, in_axes=[0]):
8
- for axes in in_axes:
9
- fun = jax.vmap(fun, in_axes=axes)
10
- return fun
11
-
12
-
13
- def make_grid(patch_size: int | tuple[int, int]):
14
- if isinstance(patch_size, int):
15
- patch_size = (patch_size, patch_size)
16
- offset_h, offset_w = 1 / (2 * np.array(patch_size))
17
- space_h = np.linspace(-0.5 + offset_h, 0.5 - offset_h, patch_size[0])
18
- space_w = np.linspace(-0.5 + offset_w, 0.5 - offset_w, patch_size[1])
19
- return np.stack(np.meshgrid(space_h, space_w, indexing='ij'), axis=-1) # [h, w]
20
-
21
-
22
- def interpolate_grid(coords, grid, order=0):
23
- """
24
- args:
25
- coords: Tensor of shape (B, H, W, 2) with coordinates in [-0.5, 0.5]
26
- grid: Tensor of shape (B, H', W', C)
27
- returns:
28
- Tensor of shape (B, H, W, C) with interpolated values
29
- """
30
- # convert [-0.5, 0.5] -> [0, size], where pixel centers are expected at
31
- # [-0.5 + 1 / (2*size), ..., 0.5 - 1 / (2*size)]
32
- coords = coords.transpose((0, 3, 1, 2))
33
- coords = coords.at[:, 0].set(coords[:, 0] * grid.shape[-3] + (grid.shape[-3] - 1) / 2)
34
- coords = coords.at[:, 1].set(coords[:, 1] * grid.shape[-2] + (grid.shape[-2] - 1) / 2)
35
- map_coordinates = partial(jax.scipy.ndimage.map_coordinates, order=order, mode='nearest')
36
- return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)
 
1
+ import gradio as gr
2
+ import torch
3
  import jax
4
+ import jax.numpy as jnp
5
  import numpy as np
6
+ from PIL import Image
7
+ import pickle
8
+ import warnings
9
+ import logging
10
+ from huggingface_hub import hf_hub_download
11
+ from diffusers import StableDiffusionXLImg2ImgPipeline
12
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
13
+ from model import build_thera
14
+
15
+ # Configuração de logging
16
+ logging.basicConfig(
17
+ level=logging.INFO,
18
+ format='%(asctime)s - %(levelname)s - %(message)s',
19
+ handlers=[
20
+ logging.FileHandler("processing.log"),
21
+ logging.StreamHandler()
22
+ ]
23
+ )
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Configurações e supressão de avisos
27
+ warnings.filterwarnings("ignore", category=FutureWarning)
28
+ warnings.filterwarnings("ignore", category=UserWarning)
29
+
30
+ # Configurar dispositivos
31
+ JAX_DEVICE = jax.devices("cpu")[0]
32
+ TORCH_DEVICE = "cpu"
33
+
34
+
35
+ # 1. Carregar modelos do Thera ----------------------------------------------------------------
36
+ def load_thera_model(repo_id, filename):
37
+ try:
38
+ logger.info(f"Carregando modelo Thera de {repo_id}")
39
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
40
+ with open(model_path, 'rb') as fh:
41
+ check = pickle.load(fh)
42
+ variables = check['model']
43
+ backbone, size = check['backbone'], check['size']
44
+ model = build_thera(3, backbone, size)
45
+ return model, variables
46
+ except Exception as e:
47
+ logger.error(f"Erro ao carregar modelo: {str(e)}")
48
+ raise
49
+
50
+
51
+ logger.info("Carregando Thera EDSR...")
52
+ model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
53
+
54
+ # 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
55
+ try:
56
+ logger.info("Carregando SDXL + LoRA...")
57
+ pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
58
+ "stabilityai/stable-diffusion-xl-base-1.0",
59
+ torch_dtype=torch.float32
60
+ ).to(TORCH_DEVICE)
61
+ pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
62
+ except Exception as e:
63
+ logger.error(f"Erro ao carregar SDXL: {str(e)}")
64
+ raise
65
+
66
+ # 3. Carregar modelo de profundidade ----------------------------------------------------------
67
+ try:
68
+ logger.info("Carregando DPT Depth...")
69
+ feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
70
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
71
+ except Exception as e:
72
+ logger.error(f"Erro ao carregar DPT: {str(e)}")
73
+ raise
74
+
75
+
76
+ def adjust_size(size):
77
+ """Garante que o tamanho seja divisível por 8"""
78
+ return (size // 8) * 8
79
+
80
+
81
+ def full_pipeline(image, prompt, scale_factor=2.0, progress=gr.Progress()):
82
+ try:
83
+ progress(0.1, desc="Pré-processamento...")
84
+
85
+ # Converter e verificar imagem
86
+ image = image.convert("RGB")
87
+ source = np.array(image) / 255.0
88
+
89
+ # Adicionar dimensão de batch se necessário
90
+ if source.ndim == 3:
91
+ source = source[np.newaxis, ...]
92
+
93
+ # Ajustar tamanho alvo
94
+ target_shape = (
95
+ adjust_size(int(image.height * scale_factor)),
96
+ adjust_size(int(image.width * scale_factor))
97
+ )
98
+
99
+ progress(0.3, desc="Super-resolução...")
100
+ source_jax = jax.device_put(source, JAX_DEVICE)
101
+ t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
102
+
103
+ # Processar com Thera
104
+ upscaled = model_edsr.apply(
105
+ variables_edsr,
106
+ source_jax,
107
+ t,
108
+ target_shape
109
+ )
110
+
111
+ # Remover dimensão de batch se necessário
112
+ if upscaled.ndim == 4:
113
+ upscaled = upscaled[0]
114
+
115
+ upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
116
+
117
+ progress(0.6, desc="Gerando Bas-Relief...")
118
+ full_prompt = f"BAS-RELIEF {prompt}, ultra detailed engraving, 16K resolution"
119
+ bas_relief = pipe(
120
+ prompt=full_prompt,
121
+ image=upscaled_pil,
122
+ strength=0.7,
123
+ num_inference_steps=25
124
+ ).images[0]
125
+
126
+ progress(0.8, desc="Calculando profundidade...")
127
+ inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
128
+ with torch.no_grad():
129
+ outputs = depth_model(**inputs)
130
+ depth = outputs.predicted_depth
131
+
132
+ depth_map = torch.nn.functional.interpolate(
133
+ depth.unsqueeze(1),
134
+ size=bas_relief.size[::-1],
135
+ mode="bicubic"
136
+ ).squeeze().cpu().numpy()
137
+
138
+ depth_normalized = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
139
+ depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))
140
+
141
+ return upscaled_pil, bas_relief, depth_pil
142
+
143
+ except Exception as e:
144
+ logger.error(f"Erro: {str(e)}", exc_info=True)
145
+ raise gr.Error(f"Erro: {str(e)}")
146
+
147
+
148
+ # Interface Gradio ----------------------------------------------------------------------------
149
+ with gr.Blocks(title="SuperRes + BasRelief") as app:
150
+ gr.Markdown("## 🖼️ Super Resolução + Bas-Relief + Mapa de Profundidade")
151
+
152
+ with gr.Row():
153
+ with gr.Column():
154
+ img_input = gr.Image(type="pil", label="Imagem de Entrada")
155
+ prompt = gr.Textbox(
156
+ label="Descrição",
157
+ value="insanely detailed and complex engraving relief, ultra-high definition"
158
+ )
159
+ scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
160
+ btn = gr.Button("Processar")
161
+
162
+ with gr.Column():
163
+ img_upscaled = gr.Image(label="Super Resolvida")
164
+ img_basrelief = gr.Image(label="Bas-Relief")
165
+ img_depth = gr.Image(label="Profundidade")
166
 
167
+ btn.click(
168
+ full_pipeline,
169
+ inputs=[img_input, prompt, scale],
170
+ outputs=[img_upscaled, img_basrelief, img_depth]
171
+ )
172
 
173
+ if __name__ == "__main__":
174
+ app.launch() # Sem compartilhamento público