ds1david commited on
Commit
5779b8d
·
1 Parent(s): eb719b4

fixing bugs

Browse files
Files changed (1) hide show
  1. app.py +48 -52
app.py CHANGED
@@ -40,23 +40,19 @@ class CustomLogger:
40
  def error(self, text):
41
  self.logger.error(f"✗ {text}")
42
 
43
- def warning(self, text):
44
- self.logger.warning(f"⚠ {text}")
45
-
46
 
47
  logger = CustomLogger(__name__)
48
 
49
- # ================== CONFIGURAÇÃO DE HARDWARE ==================
50
  device = "cuda" if torch.cuda.is_available() else "cpu"
51
- torch_dtype = torch.float32 # Forçar precisão única para compatibilidade
52
- logger.divider("Inicialização do Sistema")
53
- logger.success(f"Dispositivo detectado: {device.upper()}")
54
- logger.success(f"Modo de precisão: float32")
55
 
56
 
57
  # ================== CARREGAMENTO DE MODELOS ==================
58
  def carregar_modelo_thera(repo_id):
59
- """Carrega modelos Thera com tratamento de erros robusto"""
60
  try:
61
  logger.divider(f"Carregando {repo_id}")
62
  model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
@@ -67,30 +63,33 @@ def carregar_modelo_thera(repo_id):
67
  logger.success(f"{repo_id} carregado")
68
  return model, params
69
  except Exception as e:
70
- logger.error(f"Falha ao carregar {repo_id}: {str(e)}")
71
  return None, None
72
 
73
 
74
- # Carregar modelos principais
75
- modelo_edsr, params_edsr = carregar_modelo_thera("prs-eth/thera-edsr-pro")
76
- modelo_rdn, params_rdn = carregar_modelo_thera("prs-eth/thera-rdn-pro")
 
 
 
77
 
78
- # ================== MODELOS DE ARTE (CARREGAMENTO CONDICIONAL) ==================
79
  pipe = None
80
  modelo_profundidade = None
81
- processador_profundidade = None
82
 
83
  try:
84
- logger.divider("Inicializando Componentes Artísticos")
85
 
86
- # Pipeline de estilo
87
  pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
88
  "stabilityai/stable-diffusion-xl-base-1.0",
89
  torch_dtype=torch_dtype,
90
- use_safetensors=True
 
91
  ).to(device)
92
 
93
- # Adapter LoRA
94
  pipe.load_lora_weights(
95
  "KappaNeuro/bas-relief",
96
  weight_name="BAS-RELIEF.safetensors"
@@ -98,61 +97,60 @@ try:
98
 
99
  # Modelo de profundidade
100
  processador_profundidade = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
101
- modelo_profundidade = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
102
 
103
- logger.success("Componentes artísticos prontos")
104
  except Exception as e:
105
- logger.warning(f"Recursos artísticos desativados: {str(e)}")
106
  pipe = None
107
- modelo_profundidade = None
108
 
109
 
110
- # ================== PIPELINE PRINCIPAL ==================
111
- def processar_imagem(imagem, escala, modelo, prompt):
112
- """Fluxo completo de processamento com fallbacks"""
113
  try:
114
- logger.divider("Novo Processamento")
115
 
116
- # Converter entrada para PIL
117
  if not isinstance(imagem, Image.Image):
118
  imagem = Image.fromarray(imagem)
119
 
120
  # ========= 1. SUPER-RESOLUÇÃO =========
121
- logger.etapa("Super-Resolução Thera")
122
  modelo_sr = modelo_edsr if modelo == "EDSR" else modelo_rdn
123
  params_sr = params_edsr if modelo == "EDSR" else params_rdn
124
 
125
  sr_jax = process(
126
- np.array(imagem) / 255.0,
127
  modelo_sr,
128
  params_sr,
129
- (int(imagem.height * escala),
130
- int(imagem.width * escala)),
131
  True
132
  )
133
 
134
  sr_pil = Image.fromarray(np.array(sr_jax)).convert("RGB")
135
- logger.success(f"Resolução: {sr_pil.size[0]}x{sr_pil.size[1]}")
136
 
137
  # ========= 2. ESTILO BAIXO-RELEVO =========
138
- arte_pil = sr_pil # Fallback padrão
139
  if pipe:
140
  try:
141
  logger.etapa("Aplicando Estilo")
142
  arte_pil = pipe(
143
- prompt=f"BAS-RELIEF {prompt}, marble texture, cinematic lighting",
144
  image=sr_pil,
145
  strength=0.6,
146
  num_inference_steps=25,
147
- guidance_scale=7.0
 
148
  ).images[0]
149
  logger.success("Estilo aplicado")
150
  except Exception as e:
151
  logger.error(f"Erro no estilo: {str(e)}")
152
 
153
  # ========= 3. MAPA DE PROFUNDIDADE =========
154
- mapa_pil = arte_pil # Fallback padrão
155
- if modelo_profundidade and arte_pil:
156
  try:
157
  logger.etapa("Calculando Profundidade")
158
  inputs = processador_profundidade(arte_pil, return_tensors="pt").to(device)
@@ -160,13 +158,13 @@ def processar_imagem(imagem, escala, modelo, prompt):
160
  depth = modelo_profundidade(**inputs).predicted_depth
161
 
162
  depth = torch.nn.functional.interpolate(
163
- depth.unsqueeze(1).float(),
164
  size=arte_pil.size[::-1],
165
  mode="bicubic"
166
  ).squeeze().cpu().numpy()
167
 
168
- depth_normalized = (depth - depth.min()) / (depth.max() - depth.min())
169
- mapa_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))
170
  logger.success("Profundidade calculada")
171
  except Exception as e:
172
  logger.error(f"Erro na profundidade: {str(e)}")
@@ -180,18 +178,16 @@ def processar_imagem(imagem, escala, modelo, prompt):
180
 
181
  # ================== INTERFACE GRADIO ==================
182
  with gr.Blocks(title="TheraSR Universal", theme=gr.themes.Soft()) as app:
183
- gr.Markdown("# 🏛 TheraSR - Super Resolução & Arte")
184
 
185
  with gr.Row():
186
  with gr.Column():
187
- input_image = gr.Image(label="Imagem de Entrada", type="pil")
188
- scale = gr.Slider(1.0, 4.0, value=2.0,
189
- label="Fator de Escala", step=0.1)
190
- model_select = gr.Radio(["EDSR", "RDN"],
191
- value="EDSR", label="Modelo")
192
- style_prompt = gr.Textbox(
193
- label="Descrição do Estilo",
194
- value="ancient greek marble浮雕, ultra detailed, 8k"
195
  )
196
  btn_process = gr.Button("Processar", variant="primary")
197
 
@@ -201,8 +197,8 @@ with gr.Blocks(title="TheraSR Universal", theme=gr.themes.Soft()) as app:
201
  output_depth = gr.Image(label="Mapa de Profundidade", interactive=False)
202
 
203
  btn_process.click(
204
- processar_imagem,
205
- inputs=[input_image, scale, model_select, style_prompt],
206
  outputs=[output_sr, output_art, output_depth]
207
  )
208
 
 
40
  def error(self, text):
41
  self.logger.error(f"✗ {text}")
42
 
 
 
 
43
 
44
  logger = CustomLogger(__name__)
45
 
46
+ # ================== CONFIGURAÇÃO FORÇADA ==================
47
  device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ torch_dtype = torch.float32 # Forçar float32 universalmente
49
+ logger.divider("Configuração Forçada")
50
+ logger.success(f"Dispositivo: {device.upper()}")
51
+ logger.success(f"Precisão: {str(torch_dtype).replace('torch.', '')}")
52
 
53
 
54
  # ================== CARREGAMENTO DE MODELOS ==================
55
  def carregar_modelo_thera(repo_id):
 
56
  try:
57
  logger.divider(f"Carregando {repo_id}")
58
  model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
 
63
  logger.success(f"{repo_id} carregado")
64
  return model, params
65
  except Exception as e:
66
+ logger.error(f"Falha no carregamento: {str(e)}")
67
  return None, None
68
 
69
 
70
+ try:
71
+ modelo_edsr, params_edsr = carregar_modelo_thera("prs-eth/thera-edsr-pro")
72
+ modelo_rdn, params_rdn = carregar_modelo_thera("prs-eth/thera-rdn-pro")
73
+ except Exception as e:
74
+ logger.error("Falha crítica nos modelos Thera")
75
+ raise
76
 
77
+ # ================== PIPELINE ARTÍSTICO ==================
78
  pipe = None
79
  modelo_profundidade = None
 
80
 
81
  try:
82
+ logger.divider("Configurando Componentes Artísticos")
83
 
84
+ # Pipeline principal
85
  pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
86
  "stabilityai/stable-diffusion-xl-base-1.0",
87
  torch_dtype=torch_dtype,
88
+ use_safetensors=True,
89
+ variant=None # Forçar carregamento sem fp16
90
  ).to(device)
91
 
92
+ # LoRA
93
  pipe.load_lora_weights(
94
  "KappaNeuro/bas-relief",
95
  weight_name="BAS-RELIEF.safetensors"
 
97
 
98
  # Modelo de profundidade
99
  processador_profundidade = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
100
+ modelo_profundidade = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).float()
101
 
102
+ logger.success("Componentes artísticos em float32")
103
  except Exception as e:
104
+ logger.warning(f"Recursos artísticos limitados: {str(e)}")
105
  pipe = None
 
106
 
107
 
108
+ # ================== PROCESSAMENTO PRINCIPAL ==================
109
+ def processar_imagem_completa(imagem, escala, modelo, prompt):
 
110
  try:
111
+ logger.divider("Iniciando Processamento")
112
 
113
+ # Converter entrada
114
  if not isinstance(imagem, Image.Image):
115
  imagem = Image.fromarray(imagem)
116
 
117
  # ========= 1. SUPER-RESOLUÇÃO =========
118
+ logger.etapa("Processando Super-Resolução")
119
  modelo_sr = modelo_edsr if modelo == "EDSR" else modelo_rdn
120
  params_sr = params_edsr if modelo == "EDSR" else params_rdn
121
 
122
  sr_jax = process(
123
+ np.array(imagem) / 255.,
124
  modelo_sr,
125
  params_sr,
126
+ (round(imagem.height * escala),
127
+ round(imagem.width * escala)),
128
  True
129
  )
130
 
131
  sr_pil = Image.fromarray(np.array(sr_jax)).convert("RGB")
132
+ logger.success(f"SR: {sr_pil.size[0]}x{sr_pil.size[1]}")
133
 
134
  # ========= 2. ESTILO BAIXO-RELEVO =========
135
+ arte_pil = sr_pil # Fallback
136
  if pipe:
137
  try:
138
  logger.etapa("Aplicando Estilo")
139
  arte_pil = pipe(
140
+ prompt=f"BAS-RELIEF {prompt}, marble texture, 8k",
141
  image=sr_pil,
142
  strength=0.6,
143
  num_inference_steps=25,
144
+ guidance_scale=7.0,
145
+ generator=torch.Generator(device).manual_seed(42)
146
  ).images[0]
147
  logger.success("Estilo aplicado")
148
  except Exception as e:
149
  logger.error(f"Erro no estilo: {str(e)}")
150
 
151
  # ========= 3. MAPA DE PROFUNDIDADE =========
152
+ mapa_pil = arte_pil # Fallback
153
+ if modelo_profundidade:
154
  try:
155
  logger.etapa("Calculando Profundidade")
156
  inputs = processador_profundidade(arte_pil, return_tensors="pt").to(device)
 
158
  depth = modelo_profundidade(**inputs).predicted_depth
159
 
160
  depth = torch.nn.functional.interpolate(
161
+ depth.unsqueeze(1),
162
  size=arte_pil.size[::-1],
163
  mode="bicubic"
164
  ).squeeze().cpu().numpy()
165
 
166
+ depth = (depth - depth.min()) / (depth.max() - depth.min())
167
+ mapa_pil = Image.fromarray((depth * 255).astype(np.uint8))
168
  logger.success("Profundidade calculada")
169
  except Exception as e:
170
  logger.error(f"Erro na profundidade: {str(e)}")
 
178
 
179
  # ================== INTERFACE GRADIO ==================
180
  with gr.Blocks(title="TheraSR Universal", theme=gr.themes.Soft()) as app:
181
+ gr.Markdown("# 🏛 TheraSR - Processamento Completo em Float32")
182
 
183
  with gr.Row():
184
  with gr.Column():
185
+ input_img = gr.Image(label="Imagem de Entrada", type="pil")
186
+ slider_scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
187
+ radio_model = gr.Radio(["EDSR", "RDN"], value="EDSR", label="Modelo")
188
+ text_prompt = gr.Textbox(
189
+ label="Prompt de Estilo",
190
+ value="ancient marble浮雕, ultra detailed, 8k cinematic"
 
 
191
  )
192
  btn_process = gr.Button("Processar", variant="primary")
193
 
 
197
  output_depth = gr.Image(label="Mapa de Profundidade", interactive=False)
198
 
199
  btn_process.click(
200
+ processar_imagem_completa,
201
+ inputs=[input_img, slider_scale, radio_model, text_prompt],
202
  outputs=[output_sr, output_art, output_depth]
203
  )
204