ds1david commited on
Commit
a7111d1
·
1 Parent(s): 3920f5c
Files changed (1) hide show
  1. app.py +30 -15
app.py CHANGED
@@ -1,20 +1,26 @@
1
  import gradio as gr
2
  import torch
3
  import jax
 
4
  import numpy as np
5
  from PIL import Image
6
  import pickle
 
7
  from huggingface_hub import hf_hub_download
8
  from diffusers import StableDiffusionXLImg2ImgPipeline
9
- from transformers import DPTFeatureExtractor, DPTForDepthEstimation
10
- from model import build_thera # Importar do código original do Thera
 
 
 
 
11
 
12
  # Configurar dispositivos
13
  JAX_DEVICE = jax.devices("cpu")[0]
14
  TORCH_DEVICE = "cpu"
15
 
16
 
17
- # 1. Carregar modelos do Thera ------------------------------------------------------------------
18
  def load_thera_model(repo_id, filename):
19
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
20
  with open(model_path, 'rb') as fh:
@@ -27,7 +33,7 @@ def load_thera_model(repo_id, filename):
27
  print("Carregando Thera EDSR...")
28
  model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
29
 
30
- # 2. Carregar SDXL + LoRA ----------------------------------------------------------------------
31
  print("Carregando SDXL + LoRA...")
32
  pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
33
  "stabilityai/stable-diffusion-xl-base-1.0",
@@ -35,33 +41,36 @@ pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
35
  ).to(TORCH_DEVICE)
36
  pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
37
 
38
- # 3. Carregar modelo de profundidade -----------------------------------------------------------
39
  print("Carregando DPT Depth...")
40
- feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
41
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
42
 
43
 
44
- # Pipeline principal ---------------------------------------------------------------------------
45
  def full_pipeline(image, prompt, scale_factor=2.0):
46
  try:
47
  # 1. Super Resolução com Thera
48
- source = np.array(image.convert("RGB")) / 255.0
 
49
  target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
50
 
51
- # Converter para JAX array
52
  source_jax = jax.device_put(source, JAX_DEVICE)
 
53
 
54
  # Processar com Thera
55
  upscaled = model_edsr.apply(
56
  params_edsr,
57
  source_jax,
 
58
  target_shape,
59
  do_ensemble=True
60
  )
61
  upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
62
 
63
  # 2. Gerar Bas-Relief
64
- full_prompt = f"BAS-RELIEF {prompt}, intricate carving, marble relief"
65
  bas_relief = pipe(
66
  prompt=full_prompt,
67
  image=upscaled_pil,
@@ -82,26 +91,32 @@ def full_pipeline(image, prompt, scale_factor=2.0):
82
  mode="bicubic"
83
  ).squeeze().cpu().numpy()
84
 
85
- return upscaled_pil, bas_relief, (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
 
 
 
86
 
87
  except Exception as e:
88
  raise gr.Error(f"Erro no processamento: {str(e)}")
89
 
90
 
91
- # Interface Gradio -----------------------------------------------------------------------------
92
  with gr.Blocks(title="Super Res + Bas-Relief") as app:
93
  gr.Markdown("## 🔍 Super Resolução + 🗿 Bas-Relief + 🗺️ Profundidade")
94
 
95
  with gr.Row():
96
  with gr.Column():
97
  img_input = gr.Image(type="pil", label="Imagem de Entrada")
98
- prompt = gr.Textbox("insanely detailed and complex engraving relief, ultra-high definition, rich in detail, and 16K resolution.", label="Descrição")
 
 
 
99
  scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
100
  btn = gr.Button("Processar")
101
 
102
  with gr.Column():
103
- img_upscaled = gr.Image(label="Super Resolvida")
104
- img_basrelief = gr.Image(label="Bas-Relief")
105
  img_depth = gr.Image(label="Mapa de Profundidade")
106
 
107
  btn.click(
 
1
  import gradio as gr
2
  import torch
3
  import jax
4
+ import jax.numpy as jnp
5
  import numpy as np
6
  from PIL import Image
7
  import pickle
8
+ import warnings
9
  from huggingface_hub import hf_hub_download
10
  from diffusers import StableDiffusionXLImg2ImgPipeline
11
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
12
+ from model import build_thera
13
+
14
+ # Configurações e supressão de avisos
15
+ warnings.filterwarnings("ignore", category=FutureWarning)
16
+ warnings.filterwarnings("ignore", category=UserWarning)
17
 
18
  # Configurar dispositivos
19
  JAX_DEVICE = jax.devices("cpu")[0]
20
  TORCH_DEVICE = "cpu"
21
 
22
 
23
+ # 1. Carregar modelos do Thera ----------------------------------------------------------------
24
  def load_thera_model(repo_id, filename):
25
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
26
  with open(model_path, 'rb') as fh:
 
33
  print("Carregando Thera EDSR...")
34
  model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
35
 
36
+ # 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
37
  print("Carregando SDXL + LoRA...")
38
  pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
39
  "stabilityai/stable-diffusion-xl-base-1.0",
 
41
  ).to(TORCH_DEVICE)
42
  pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
43
 
44
+ # 3. Carregar modelo de profundidade ----------------------------------------------------------
45
  print("Carregando DPT Depth...")
46
+ feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
47
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
48
 
49
 
50
+ # Pipeline principal --------------------------------------------------------------------------
51
  def full_pipeline(image, prompt, scale_factor=2.0):
52
  try:
53
  # 1. Super Resolução com Thera
54
+ image = image.convert("RGB")
55
+ source = np.array(image) / 255.0
56
  target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
57
 
58
+ # Preparar parâmetros para JAX
59
  source_jax = jax.device_put(source, JAX_DEVICE)
60
+ t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
61
 
62
  # Processar com Thera
63
  upscaled = model_edsr.apply(
64
  params_edsr,
65
  source_jax,
66
+ t,
67
  target_shape,
68
  do_ensemble=True
69
  )
70
  upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
71
 
72
  # 2. Gerar Bas-Relief
73
+ full_prompt = f"BAS-RELIEF {prompt}, insanely detailed and complex engraving relief, ultra-high definition, rich in detail, 16K resolution"
74
  bas_relief = pipe(
75
  prompt=full_prompt,
76
  image=upscaled_pil,
 
91
  mode="bicubic"
92
  ).squeeze().cpu().numpy()
93
 
94
+ depth_normalized = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
95
+ depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))
96
+
97
+ return upscaled_pil, bas_relief, depth_pil
98
 
99
  except Exception as e:
100
  raise gr.Error(f"Erro no processamento: {str(e)}")
101
 
102
 
103
+ # Interface Gradio ----------------------------------------------------------------------------
104
  with gr.Blocks(title="Super Res + Bas-Relief") as app:
105
  gr.Markdown("## 🔍 Super Resolução + 🗿 Bas-Relief + 🗺️ Profundidade")
106
 
107
  with gr.Row():
108
  with gr.Column():
109
  img_input = gr.Image(type="pil", label="Imagem de Entrada")
110
+ prompt = gr.Textbox(
111
+ label="Descrição do Relevo",
112
+ value="insanely detailed and complex engraving relief, ultra-high definition, rich in detail, and 16K resolution."
113
+ )
114
  scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
115
  btn = gr.Button("Processar")
116
 
117
  with gr.Column():
118
+ img_upscaled = gr.Image(label="Imagem Super Resolvida")
119
+ img_basrelief = gr.Image(label="Resultado Bas-Relief")
120
  img_depth = gr.Image(label="Mapa de Profundidade")
121
 
122
  btn.click(