ds1david commited on
Commit
b82dc7d
·
1 Parent(s): 1557411

fixing bugs

Browse files
Files changed (3) hide show
  1. app.py +119 -185
  2. requirements.txt +10 -38
  3. utils.py +24 -57
app.py CHANGED
@@ -1,195 +1,129 @@
1
- # app.py
2
  import gradio as gr
3
  import torch
4
- import jax
5
- import jax.numpy as jnp
6
  import numpy as np
7
- from PIL import Image
8
  import pickle
9
- import logging
10
  from huggingface_hub import hf_hub_download
11
- from diffusers import StableDiffusionXLImg2ImgPipeline
12
- from transformers import DPTImageProcessor, DPTForDepthEstimation
13
  from model import build_thera
14
- from utils import make_grid, interpolate_grid
15
-
16
- # Configuração de logging
17
- logging.basicConfig(
18
- level=logging.INFO,
19
- format='%(asctime)s - %(levelname)s - %(message)s',
20
- handlers=[logging.FileHandler("processing.log"), logging.StreamHandler()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  )
22
- logger = logging.getLogger(__name__)
23
 
24
- # Configurações
25
- JAX_DEVICE = jax.devices("cpu")[0]
26
- TORCH_DEVICE = "cpu"
27
-
28
-
29
- def load_thera_model(repo_id: str, filename: str):
30
- """Carrega modelo com múltiplas verificações"""
31
- try:
32
- model_path = hf_hub_download(repo_id=repo_id, filename=filename)
33
- with open(model_path, 'rb') as fh:
34
- checkpoint = pickle.load(fh)
35
-
36
- # Verificar estrutura do checkpoint
37
- required_keys = {'model', 'backbone', 'size'}
38
- if not required_keys.issubset(checkpoint.keys()):
39
- missing = required_keys - checkpoint.keys()
40
- raise ValueError(f"Checkpoint corrompido. Chaves faltando: {missing}")
41
-
42
- return build_thera(3, checkpoint['backbone'], checkpoint['size']), checkpoint['model']
43
- except Exception as e:
44
- logger.error(f"Erro ao carregar modelo: {str(e)}")
45
- raise
46
-
47
-
48
- # Inicialização segura
49
- try:
50
- logger.info("Inicializando modelos...")
51
- model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
52
-
53
- # Pipeline SDXL
54
- pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
55
- "stabilityai/stable-diffusion-xl-base-1.0",
56
- torch_dtype=torch.float32
57
- ).to(TORCH_DEVICE)
58
- pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
59
-
60
- # Modelo de profundidade
61
- feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
62
- depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
63
-
64
- except Exception as e:
65
- logger.error(f"Falha crítica na inicialização: {str(e)}")
66
- raise
67
-
68
-
69
- def safe_resize(original: tuple[int, int], scale: float) -> tuple[int, int]:
70
- """Calcula tamanho garantindo estabilidade numérica"""
71
- h, w = original
72
- new_h = int(h * scale)
73
- new_w = int(w * scale)
74
-
75
- # Ajustar para múltiplo de 8
76
- new_h = max(32, new_h - new_h % 8)
77
- new_w = max(32, new_w - new_w % 8)
78
-
79
- return (new_h, new_w)
80
-
81
-
82
- def full_pipeline(image: Image.Image, prompt: str, scale_factor: float = 2.0):
83
- """Pipeline completo com tratamento de erros robusto"""
84
- try:
85
- # Verificação inicial
86
- if not image:
87
- raise ValueError("Nenhuma imagem fornecida")
88
-
89
- # Conversão segura para RGB
90
- image = image.convert("RGB")
91
- orig_w, orig_h = image.size
92
- logger.info(f"Processando imagem: {orig_w}x{orig_h}")
93
-
94
- # Cálculo do novo tamanho
95
- new_h, new_w = safe_resize((orig_h, orig_w), scale_factor)
96
- logger.info(f"Novo tamanho calculado: {new_h}x{new_w}")
97
-
98
- # Gerar grid de coordenadas
99
- grid = make_grid((new_h, new_w))
100
- logger.debug(f"Grid gerado: {grid.shape}")
101
-
102
- # Verificação crítica do grid
103
- if grid.shape[1:3] != (new_h, new_w):
104
- raise RuntimeError(
105
- f"Incompatibilidade de dimensões: "
106
- f"Grid {grid.shape[1:3]} vs Alvo {new_h}x{new_w}"
107
- )
108
-
109
- # Pré-processamento da imagem
110
- source = jnp.array(image).astype(jnp.float32) / 255.0
111
- source = source[jnp.newaxis, ...] # Adicionar dimensão de batch
112
-
113
- # Parâmetro de escala
114
- t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
115
-
116
- # Processamento Thera
117
- upscaled = model_edsr.apply(params_edsr, source, t, (new_h, new_w))
118
-
119
- # Conversão para PIL
120
- upscaled_img = Image.fromarray((np.array(upscaled[0]) * 255).astype(np.uint8))
121
- logger.info(f"Imagem super-resolvida: {upscaled_img.size}")
122
-
123
- # Geração do Bas-Relief
124
- result = pipe(
125
- prompt=f"BAS-RELIEF {prompt}, ultra detailed, 8K resolution",
126
- image=upscaled_img,
127
- strength=0.7,
128
- num_inference_steps=30,
129
- guidance_scale=7.5
130
- )
131
- bas_relief = result.images[0]
132
- logger.info(f"Bas-Relief gerado: {bas_relief.size}")
133
-
134
- # Cálculo da profundidade
135
- inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
136
- with torch.no_grad():
137
- depth = depth_model(**inputs).predicted_depth
138
-
139
- # Redimensionamento
140
- depth_map = torch.nn.functional.interpolate(
141
- depth.unsqueeze(1),
142
- size=bas_relief.size[::-1],
143
- mode="bicubic"
144
- ).squeeze().cpu().numpy()
145
-
146
- # Normalização e conversão
147
- depth_min = depth_map.min()
148
- depth_max = depth_map.max()
149
- depth_normalized = (depth_map - depth_min) / (depth_max - depth_min + 1e-8)
150
- depth_img = Image.fromarray((depth_normalized * 255).astype(np.uint8))
151
- logger.info("Mapa de profundidade calculado")
152
-
153
- return upscaled_img, bas_relief, depth_img
154
-
155
- except Exception as e:
156
- logger.error(f"ERRO NO PIPELINE: {str(e)}", exc_info=True)
157
- raise gr.Error(f"Falha no processamento: {str(e)}")
158
-
159
-
160
- # Interface Gradio
161
- with gr.Blocks(title="SuperRes+BasRelief Pro", theme=gr.themes.Soft()) as app:
162
- gr.Markdown("# 🖼️ Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")
163
-
164
- with gr.Row():
165
- input_col = gr.Column()
166
- output_col = gr.Column()
167
-
168
- with input_col:
169
- img_input = gr.Image(label="Carregar Imagem", type="pil", height=300)
170
- prompt = gr.Textbox(
171
- label="Descrição do Relevo",
172
- value="A insanely detailed and complex engraving relief, ultra-high definition",
173
- placeholder="Descreva o estilo desejado..."
174
- )
175
- scale = gr.Slider(1.0, 4.0, value=2.0, step=0.1, label="Fator de Escala")
176
- process_btn = gr.Button("Iniciar Processamento", variant="primary")
177
-
178
- with output_col:
179
- with gr.Tabs():
180
- with gr.TabItem("Super Resolução"):
181
- upscaled_output = gr.Image(label="Resultado", show_label=False)
182
- with gr.TabItem("Bas-Relief"):
183
- basrelief_output = gr.Image(label="Relevo", show_label=False)
184
- with gr.TabItem("Profundidade"):
185
- depth_output = gr.Image(label="Mapa 3D", show_label=False)
186
-
187
- process_btn.click(
188
- full_pipeline,
189
- inputs=[img_input, prompt, scale],
190
- outputs=[upscaled_output, basrelief_output, depth_output],
191
- api_name="processar"
192
  )
193
 
194
- if __name__ == "__main__":
195
- app.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
  import gradio as gr
2
  import torch
 
 
3
  import numpy as np
4
+ import jax
5
  import pickle
6
+ from PIL import Image
7
  from huggingface_hub import hf_hub_download
 
 
8
  from model import build_thera
9
+ from super_resolve import process
10
+ from diffusers import StableDiffusionXLPipeline
11
+ from transformers import DPTFeatureExtractor, DPTForDepthEstimation
12
+
13
+ # ========== Configuração do Thera ==========
14
+ REPO_ID_EDSR = "prs-eth/thera-edsr-pro"
15
+ REPO_ID_RDN = "prs-eth/thera-rdn-pro"
16
+
17
+
18
+ # Carregar modelos Thera
19
+ def load_thera_model(repo_id):
20
+ model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
21
+ with open(model_path, 'rb') as fh:
22
+ check = pickle.load(fh)
23
+ params, backbone, size = check['model'], check['backbone'], check['size']
24
+ model = build_thera(3, backbone, size)
25
+ return model, params
26
+
27
+
28
+ model_edsr, params_edsr = load_thera_model(REPO_ID_EDSR)
29
+ model_rdn, params_rdn = load_thera_model(REPO_ID_RDN)
30
+
31
+ # ========== Configuração do SDXL + Depth ==========
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
34
+
35
+ # Carregar modelos de geração
36
+ pipe = StableDiffusionXLPipeline.from_pretrained(
37
+ "stabilityai/stable-diffusion-xl-base-1.0",
38
+ torch_dtype=torch_dtype
39
+ ).to(device)
40
+
41
+ pipe.load_lora_weights(
42
+ "KappaNeuro/bas-relief",
43
+ weight_name="BAS-RELIEF.safetensors",
44
+ peft_backend="peft"
45
  )
 
46
 
47
+ feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
48
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
49
+
50
+
51
+ # ========== Funções Principais ==========
52
+ def super_resolution(image, scale_factor, model_type):
53
+ model = model_edsr if model_type == "EDSR" else model_rdn
54
+ params = params_edsr if model_type == "EDSR" else params_rdn
55
+
56
+ source = np.asarray(image) / 255.
57
+ target_shape = (
58
+ round(source.shape[0] * scale_factor),
59
+ round(source.shape[1] * scale_factor),
60
+ )
61
+
62
+ output = process(source, model, params, target_shape, do_ensemble=True)
63
+ return Image.fromarray(np.asarray(output))
64
+
65
+
66
+ def generate_bas_relief(prompt):
67
+ full_prompt = f"BAS-RELIEF {prompt}"
68
+ image = pipe(
69
+ prompt=full_prompt,
70
+ num_inference_steps=25,
71
+ guidance_scale=7.5,
72
+ height=512,
73
+ width=512
74
+ ).images[0]
75
+
76
+ inputs = feature_extractor(image, return_tensors="pt").to(device)
77
+ with torch.no_grad():
78
+ outputs = depth_model(**inputs)
79
+ depth_map = outputs.predicted_depth
80
+
81
+ depth_map = torch.nn.functional.interpolate(
82
+ depth_map.unsqueeze(1),
83
+ size=image.size[::-1],
84
+ mode="bicubic"
85
+ ).squeeze().cpu().numpy()
86
+
87
+ depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
88
+ depth_map = (depth_map * 255).astype(np.uint8)
89
+
90
+ return image, Image.fromarray(depth_map)
91
+
92
+
93
+ # ========== Interface Gradio ==========
94
+ with gr.Blocks(title="TheraSR + Bas-Relief Generator") as app:
95
+ gr.Markdown("# 🔥 TheraSR + Bas-Relief Generator")
96
+ gr.Markdown("Combine aliasing-free super-resolution with artistic bas-relief generation")
97
+
98
+ with gr.Tabs():
99
+ with gr.TabItem("🖼 Super-Resolution"):
100
+ with gr.Row():
101
+ sr_input = gr.Image(label="Input Image", type="pil")
102
+ sr_output = gr.Image(label="Super-Resolution Result")
103
+ sr_scale = gr.Slider(1.0, 6.0, value=2.0, label="Scale Factor")
104
+ sr_model = gr.Radio(["EDSR", "RDN"], value="EDSR", label="Model Type")
105
+ sr_btn = gr.Button("Enhance Resolution")
106
+
107
+ with gr.TabItem("🎨 Generate Bas-Relief"):
108
+ with gr.Row():
109
+ text_input = gr.Textbox(label="Art Prompt", placeholder="Roman soldier marble relief...")
110
+ with gr.Row():
111
+ gen_output = gr.Image(label="Generated Art")
112
+ depth_output = gr.Image(label="Depth Map")
113
+ gen_btn = gr.Button("Generate Artwork")
114
+
115
+ # Event Handlers
116
+ sr_btn.click(
117
+ super_resolution,
118
+ inputs=[sr_input, sr_scale, sr_model],
119
+ outputs=sr_output
120
+ )
121
+
122
+ gen_btn.click(
123
+ generate_bas_relief,
124
+ inputs=text_input,
125
+ outputs=[gen_output, depth_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  )
127
 
128
+ # Configuração do Hugging Face
129
+ app.launch(debug=False, share=True)
requirements.txt CHANGED
@@ -1,39 +1,11 @@
1
- -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
2
-
3
- ConfigArgParse==1.7
4
- Pillow==10.0.0
5
- chex==0.1.7
6
- diffusers
7
- einops==0.6.1
8
- flax==0.6.10
9
- flaxmodels==0.1.3
10
  huggingface_hub
11
- jax==0.4.11
12
- jaxlib==0.4.11+cuda11.cudnn86
13
- jaxtyping==0.2.20
14
- ml-dtypes==0.1.0
15
- numpy==1.24.1
16
- nvidia-cublas-cu11==11.11.3.6
17
- nvidia-cuda-cupti-cu11==11.8.87
18
- nvidia-cuda-nvcc-cu11==11.8.89
19
- nvidia-cuda-runtime-cu11==11.8.89
20
- nvidia-cudnn-cu11==8.9.2.26
21
- nvidia-cufft-cu11==10.9.0.58
22
- nvidia-cusolver-cu11==11.4.1.48
23
- nvidia-cusparse-cu11==11.7.5.86
24
- opt-einsum==3.3.0
25
- optax==0.2.0
26
- orbax-checkpoint==0.2.4
27
- peft
28
- Pillow==10.0.0
29
- scipy==1.10.1
30
- timm==0.9.6
31
- torch
32
- torchvision
33
- tqdm==4.65.0
34
- transformers==4.46.3
35
- wandb
36
-
37
- gradio==4.44.1
38
- gradio_imageslider==0.0.20
39
- spaces
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu118
2
+ torch==2.0.1
3
+ torchvision==0.15.2
4
+ jax[cuda11_pip]==0.4.13
5
+ flax==0.7.4
6
+ diffusers==0.24.0
7
+ transformers==4.35.2
8
+ peft==0.6.2
9
+ gradio==4.12.0
10
  huggingface_hub
11
+ pillow==10.0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -1,69 +1,36 @@
1
- # utils.py
 
2
  import jax
3
- import jax.numpy as jnp
4
  import numpy as np
5
- from functools import partial
6
 
7
 
8
- def repeat_vmap(fun, in_axes=None):
9
- if in_axes is None:
10
- in_axes = [0]
11
  for axes in in_axes:
12
  fun = jax.vmap(fun, in_axes=axes)
13
  return fun
14
 
15
 
16
- def make_grid(target_shape: tuple[int, int]):
17
- """Gera grid de coordenadas com validação rigorosa"""
18
- h, w = target_shape
19
-
20
- # Garantir tamanho mínimo e divisibilidade
21
- h = max(32, h)
22
- w = max(32, w)
23
- h = h if h % 8 == 0 else h + (8 - h % 8)
24
- w = w if w % 8 == 0 else w + (8 - w % 8)
25
-
26
- # Espaçamento preciso
27
- y_coords = np.linspace(-0.5 + 1 / (2 * h), 0.5 - 1 / (2 * h), h)
28
- x_coords = np.linspace(-0.5 + 1 / (2 * w), 0.5 - 1 / (2 * w), w)
29
-
30
- # Criar grid 4D (1, H, W, 2)
31
- grid = np.stack(np.meshgrid(y_coords, x_coords, indexing='ij'), axis=-1)
32
- return grid[np.newaxis, ...]
33
 
34
 
35
  def interpolate_grid(coords, grid, order=0):
36
- """Interpolação segura com verificações em tempo real"""
37
- try:
38
- # Converter e garantir formato 4D
39
- coords = jnp.asarray(coords)
40
- original_shape = coords.shape
41
-
42
- # Adicionar dimensões faltantes
43
- while coords.ndim < 4:
44
- coords = coords[jnp.newaxis, ...]
45
-
46
- # Validação final
47
- if coords.shape[-1] != 2 or coords.ndim != 4:
48
- raise ValueError(
49
- f"Formato inválido: {original_shape} {coords.shape}. "
50
- f"Esperado (B, H, W, 2)"
51
- )
52
-
53
- # Transformação de coordenadas
54
- coords = coords.transpose((0, 3, 1, 2))
55
- coords = coords.at[:, 0].set(
56
- coords[:, 0] * (grid.shape[-3] - 1) + (grid.shape[-3] - 1) / 2
57
- )
58
- coords = coords.at[:, 1].set(
59
- coords[:, 1] * (grid.shape[-2] - 1) + (grid.shape[-2] - 1) / 2
60
- )
61
-
62
- # Interpolação vetorizada
63
- map_coordinates = partial(jax.scipy.ndimage.map_coordinates,
64
- order=order,
65
- mode='nearest')
66
- return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)
67
-
68
- except Exception as e:
69
- raise RuntimeError(f"Erro na interpolação: {str(e)}") from e
 
1
+ from functools import partial
2
+
3
  import jax
 
4
  import numpy as np
 
5
 
6
 
7
+ def repeat_vmap(fun, in_axes=[0]):
 
 
8
  for axes in in_axes:
9
  fun = jax.vmap(fun, in_axes=axes)
10
  return fun
11
 
12
 
13
+ def make_grid(patch_size: int | tuple[int, int]):
14
+ if isinstance(patch_size, int):
15
+ patch_size = (patch_size, patch_size)
16
+ offset_h, offset_w = 1 / (2 * np.array(patch_size))
17
+ space_h = np.linspace(-0.5 + offset_h, 0.5 - offset_h, patch_size[0])
18
+ space_w = np.linspace(-0.5 + offset_w, 0.5 - offset_w, patch_size[1])
19
+ return np.stack(np.meshgrid(space_h, space_w, indexing='ij'), axis=-1) # [h, w]
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  def interpolate_grid(coords, grid, order=0):
23
+ """
24
+ args:
25
+ coords: Tensor of shape (B, H, W, 2) with coordinates in [-0.5, 0.5]
26
+ grid: Tensor of shape (B, H', W', C)
27
+ returns:
28
+ Tensor of shape (B, H, W, C) with interpolated values
29
+ """
30
+ # convert [-0.5, 0.5] -> [0, size], where pixel centers are expected at
31
+ # [-0.5 + 1 / (2*size), ..., 0.5 - 1 / (2*size)]
32
+ coords = coords.transpose((0, 3, 1, 2))
33
+ coords = coords.at[:, 0].set(coords[:, 0] * grid.shape[-3] + (grid.shape[-3] - 1) / 2)
34
+ coords = coords.at[:, 1].set(coords[:, 1] * grid.shape[-2] + (grid.shape[-2] - 1) / 2)
35
+ map_coordinates = partial(jax.scipy.ndimage.map_coordinates, order=order, mode='nearest')
36
+ return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)