ds1david commited on
Commit
46bb495
·
1 Parent(s): a02c6d7
Files changed (1) hide show
  1. app.py +76 -22
app.py CHANGED
@@ -6,11 +6,24 @@ import numpy as np
6
  from PIL import Image
7
  import pickle
8
  import warnings
 
 
9
  from huggingface_hub import hf_hub_download
10
  from diffusers import StableDiffusionXLImg2ImgPipeline
11
  from transformers import DPTImageProcessor, DPTForDepthEstimation
12
  from model import build_thera
13
 
 
 
 
 
 
 
 
 
 
 
 
14
  # Configurações e supressão de avisos
15
  warnings.filterwarnings("ignore", category=FutureWarning)
16
  warnings.filterwarnings("ignore", category=UserWarning)
@@ -22,56 +35,86 @@ TORCH_DEVICE = "cpu"
22
 
23
  # 1. Carregar modelos do Thera ----------------------------------------------------------------
24
  def load_thera_model(repo_id, filename):
25
- model_path = hf_hub_download(repo_id=repo_id, filename=filename)
26
- with open(model_path, 'rb') as fh:
27
- check = pickle.load(fh)
28
- # Carregar estrutura completa de variáveis
29
- variables = check['model'] # Deve conter {'params': ...}
30
- backbone, size = check['backbone'], check['size']
31
- model = build_thera(3, backbone, size)
32
- return model, variables
 
 
 
 
 
33
 
34
 
35
- print("Carregando Thera EDSR...")
36
  model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
37
 
38
  # 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
39
- print("Carregando SDXL + LoRA...")
40
- pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
41
- "stabilityai/stable-diffusion-xl-base-1.0",
42
- torch_dtype=torch.float32
43
- ).to(TORCH_DEVICE)
44
- pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
 
 
 
 
 
45
 
46
  # 3. Carregar modelo de profundidade ----------------------------------------------------------
47
- print("Carregando DPT Depth...")
48
- feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
49
- depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
 
 
 
 
 
50
 
51
 
52
  # Pipeline principal --------------------------------------------------------------------------
53
- def full_pipeline(image, prompt, scale_factor=2.0):
54
  try:
 
 
55
  # 1. Super Resolução com Thera
 
56
  image = image.convert("RGB")
 
 
57
  source = np.array(image) / 255.0
 
58
  target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
 
59
 
 
60
  source_jax = jax.device_put(source, JAX_DEVICE)
61
  t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
62
 
63
- # Chamada corrigida com estrutura de variáveis correta
64
  upscaled = model_edsr.apply(
65
- variables_edsr, # Estrutura completa {'params': ...}
66
  source_jax,
67
  t,
68
  target_shape
69
  )
 
70
 
 
71
  upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
 
72
 
73
  # 2. Gerar Bas-Relief
 
74
  full_prompt = f"BAS-RELIEF {prompt}, insanely detailed and complex engraving relief, ultra-high definition, rich in detail, 16K resolution"
 
 
 
75
  bas_relief = pipe(
76
  prompt=full_prompt,
77
  image=upscaled_pil,
@@ -79,9 +122,13 @@ def full_pipeline(image, prompt, scale_factor=2.0):
79
  num_inference_steps=25,
80
  guidance_scale=7.5
81
  ).images[0]
 
82
 
83
  # 3. Calcular Depth Map
 
 
84
  inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
 
85
  with torch.no_grad():
86
  outputs = depth_model(**inputs)
87
  depth = outputs.predicted_depth
@@ -92,12 +139,18 @@ def full_pipeline(image, prompt, scale_factor=2.0):
92
  mode="bicubic"
93
  ).squeeze().cpu().numpy()
94
 
95
- depth_normalized = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
 
 
 
96
  depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))
 
97
 
 
98
  return upscaled_pil, bas_relief, depth_pil
99
 
100
  except Exception as e:
 
101
  raise gr.Error(f"Erro no processamento: {str(e)}")
102
 
103
 
@@ -127,4 +180,5 @@ with gr.Blocks(title="Super Res + Bas-Relief") as app:
127
  )
128
 
129
  if __name__ == "__main__":
 
130
  app.launch(share=False)
 
6
  from PIL import Image
7
  import pickle
8
  import warnings
9
+ import logging
10
+ from datetime import datetime
11
  from huggingface_hub import hf_hub_download
12
  from diffusers import StableDiffusionXLImg2ImgPipeline
13
  from transformers import DPTImageProcessor, DPTForDepthEstimation
14
  from model import build_thera
15
 
16
+ # Configuração de logging
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format='%(asctime)s - %(levelname)s - %(message)s',
20
+ handlers=[
21
+ logging.FileHandler("processing.log"),
22
+ logging.StreamHandler()
23
+ ]
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
  # Configurações e supressão de avisos
28
  warnings.filterwarnings("ignore", category=FutureWarning)
29
  warnings.filterwarnings("ignore", category=UserWarning)
 
35
 
36
  # 1. Carregar modelos do Thera ----------------------------------------------------------------
37
  def load_thera_model(repo_id, filename):
38
+ try:
39
+ logger.info(f"Iniciando carregamento do modelo Thera de {repo_id}")
40
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
41
+ with open(model_path, 'rb') as fh:
42
+ check = pickle.load(fh)
43
+ variables = check['model']
44
+ backbone, size = check['backbone'], check['size']
45
+ model = build_thera(3, backbone, size)
46
+ logger.info("Modelo Thera carregado com sucesso")
47
+ return model, variables
48
+ except Exception as e:
49
+ logger.error(f"Falha ao carregar modelo Thera: {str(e)}")
50
+ raise
51
 
52
 
53
+ logger.info("Carregando Thera EDSR...")
54
  model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
55
 
56
  # 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
57
+ try:
58
+ logger.info("Iniciando carregamento do SDXL + LoRA...")
59
+ pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
60
+ "stabilityai/stable-diffusion-xl-base-1.0",
61
+ torch_dtype=torch.float32
62
+ ).to(TORCH_DEVICE)
63
+ pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
64
+ logger.info("SDXL + LoRA carregado com sucesso")
65
+ except Exception as e:
66
+ logger.error(f"Falha ao carregar SDXL: {str(e)}")
67
+ raise
68
 
69
  # 3. Carregar modelo de profundidade ----------------------------------------------------------
70
+ try:
71
+ logger.info("Iniciando carregamento do DPT Depth...")
72
+ feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
73
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
74
+ logger.info("Modelo DPT carregado com sucesso")
75
+ except Exception as e:
76
+ logger.error(f"Falha ao carregar DPT: {str(e)}")
77
+ raise
78
 
79
 
80
  # Pipeline principal --------------------------------------------------------------------------
81
+ def full_pipeline(image, prompt, scale_factor=2.0, progress=gr.Progress()):
82
  try:
83
+ progress(0, desc="Iniciando processamento...")
84
+
85
  # 1. Super Resolução com Thera
86
+ progress(0.1, desc="Convertendo imagem para RGB...")
87
  image = image.convert("RGB")
88
+
89
+ progress(0.2, desc="Preparando entrada para super-resolução...")
90
  source = np.array(image) / 255.0
91
+ original_size = image.size
92
  target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
93
+ logger.info(f"Super-resolução: {original_size} → {target_shape} (scale: {scale_factor}x)")
94
 
95
+ progress(0.3, desc="Processando com Thera...")
96
  source_jax = jax.device_put(source, JAX_DEVICE)
97
  t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
98
 
99
+ start_time = datetime.now()
100
  upscaled = model_edsr.apply(
101
+ variables_edsr,
102
  source_jax,
103
  t,
104
  target_shape
105
  )
106
+ logger.info(f"Super-resolução concluída em {datetime.now() - start_time}")
107
 
108
+ progress(0.5, desc="Convertendo resultado...")
109
  upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
110
+ logger.info(f"Tamanho após super-resolução: {upscaled_pil.size}")
111
 
112
  # 2. Gerar Bas-Relief
113
+ progress(0.6, desc="Gerando Bas-Relief...")
114
  full_prompt = f"BAS-RELIEF {prompt}, insanely detailed and complex engraving relief, ultra-high definition, rich in detail, 16K resolution"
115
+ logger.info(f"Prompt final: {full_prompt}")
116
+
117
+ start_time = datetime.now()
118
  bas_relief = pipe(
119
  prompt=full_prompt,
120
  image=upscaled_pil,
 
122
  num_inference_steps=25,
123
  guidance_scale=7.5
124
  ).images[0]
125
+ logger.info(f"Bas-Relief gerado em {datetime.now() - start_time}")
126
 
127
  # 3. Calcular Depth Map
128
+ progress(0.8, desc="Calculando mapa de profundidade...")
129
+ start_time = datetime.now()
130
  inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
131
+
132
  with torch.no_grad():
133
  outputs = depth_model(**inputs)
134
  depth = outputs.predicted_depth
 
139
  mode="bicubic"
140
  ).squeeze().cpu().numpy()
141
 
142
+ progress(0.9, desc="Processando mapa de profundidade...")
143
+ depth_min = depth_map.min()
144
+ depth_max = depth_map.max()
145
+ depth_normalized = (depth_map - depth_min) / (depth_max - depth_min + 1e-8)
146
  depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))
147
+ logger.info(f"Profundidade calculada em {datetime.now() - start_time} | Range: {depth_min:.2f}-{depth_max:.2f}")
148
 
149
+ progress(1.0, desc="Finalizado!")
150
  return upscaled_pil, bas_relief, depth_pil
151
 
152
  except Exception as e:
153
+ logger.error(f"Erro no processamento: {str(e)}", exc_info=True)
154
  raise gr.Error(f"Erro no processamento: {str(e)}")
155
 
156
 
 
180
  )
181
 
182
  if __name__ == "__main__":
183
+ logger.info("Iniciando aplicação Gradio")
184
  app.launch(share=False)