ds1david commited on
Commit
3920f5c
·
1 Parent(s): 1eb87a5
Files changed (2) hide show
  1. app.py +77 -55
  2. requirements.txt +2 -0
app.py CHANGED
@@ -3,83 +3,105 @@ import torch
3
  import jax
4
  import numpy as np
5
  from PIL import Image
 
 
6
  from diffusers import StableDiffusionXLImg2ImgPipeline
7
  from transformers import DPTFeatureExtractor, DPTForDepthEstimation
8
- from super_resolve import process as thera_process # Assume imports do Thera
9
 
10
- # Configurações
11
- DEVICE = "cpu" # ou "cuda" se disponível
12
- JAX_DEVICE = jax.devices("cpu")[0] # Usar CPU para JAX
13
 
14
- # 1. Carregar modelos do Thera (EDSR/RDN)
15
- # (Implementar conforme código original do Thera)
16
- model_edsr, params_edsr = None, None # Carregar usando pickle/HF Hub
17
 
18
- # 2. Carregar SDXL Img2Img + LoRA
19
- print("Carregando SDXL Img2Img com LoRA...")
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
21
  "stabilityai/stable-diffusion-xl-base-1.0",
22
  torch_dtype=torch.float32
23
- ).to(DEVICE)
24
  pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
25
 
26
- # 3. Carregar modelo de profundidade
27
- print("Carregando DPT...")
28
  feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
29
- depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(DEVICE)
30
-
31
-
32
- def enhance_depth_map(depth_arr):
33
- depth_normalized = (depth_arr - depth_arr.min()) / (depth_arr.max() - depth_arr.min() + 1e-8)
34
- return Image.fromarray((depth_normalized * 255).astype(np.uint8))
35
 
36
 
 
37
  def full_pipeline(image, prompt, scale_factor=2.0):
38
- # 1. Super Resolução com Thera
39
- source = np.array(image) / 255.0
40
- target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
41
- upscaled = thera_process(source, model_edsr, params_edsr, target_shape, do_ensemble=True)
42
- upscaled_pil = Image.fromarray((upscaled * 255).astype(np.uint8))
43
-
44
- # 2. Gerar Bas-Relief com SDXL Img2Img
45
- full_prompt = f"BAS-RELIEF {prompt}, intricate carving, marble relief"
46
- bas_relief = pipe(
47
- prompt=full_prompt,
48
- image=upscaled_pil,
49
- strength=0.7,
50
- num_inference_steps=25,
51
- guidance_scale=7.5
52
- ).images[0]
53
-
54
- # 3. Calcular Depth Map
55
- inputs = feature_extractor(bas_relief, return_tensors="pt").to(DEVICE)
56
- with torch.no_grad():
57
- outputs = depth_model(**inputs)
58
- depth = outputs.predicted_depth
59
-
60
- depth_map = torch.nn.functional.interpolate(
61
- depth.unsqueeze(1),
62
- size=bas_relief.size[::-1],
63
- mode="bicubic"
64
- ).squeeze().cpu().numpy()
65
-
66
- return upscaled_pil, bas_relief, enhance_depth_map(depth_map)
67
-
68
-
69
- # Interface Gradio
70
- with gr.Blocks(title="Super Resolução + Bas-Relief") as app:
71
- gr.Markdown("## 📈 Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  with gr.Row():
74
  with gr.Column():
75
  img_input = gr.Image(type="pil", label="Imagem de Entrada")
76
- prompt = gr.Textbox("ancient sculpture, marble", label="Descrição do Relevo")
77
  scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
78
  btn = gr.Button("Processar")
79
 
80
  with gr.Column():
81
- img_upscaled = gr.Image(label="Imagem Super Resolvida")
82
- img_basrelief = gr.Image(label="Relevo Escultural")
83
  img_depth = gr.Image(label="Mapa de Profundidade")
84
 
85
  btn.click(
 
3
  import jax
4
  import numpy as np
5
  from PIL import Image
6
+ import pickle
7
+ from huggingface_hub import hf_hub_download
8
  from diffusers import StableDiffusionXLImg2ImgPipeline
9
  from transformers import DPTFeatureExtractor, DPTForDepthEstimation
10
+ from model import build_thera # Importar do código original do Thera
11
 
12
+ # Configurar dispositivos
13
+ JAX_DEVICE = jax.devices("cpu")[0]
14
+ TORCH_DEVICE = "cpu"
15
 
 
 
 
16
 
17
+ # 1. Carregar modelos do Thera ------------------------------------------------------------------
18
+ def load_thera_model(repo_id, filename):
19
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
20
+ with open(model_path, 'rb') as fh:
21
+ check = pickle.load(fh)
22
+ params, backbone, size = check['model'], check['backbone'], check['size']
23
+ model = build_thera(3, backbone, size)
24
+ return model, params
25
+
26
+
27
+ print("Carregando Thera EDSR...")
28
+ model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
29
+
30
+ # 2. Carregar SDXL + LoRA ----------------------------------------------------------------------
31
+ print("Carregando SDXL + LoRA...")
32
  pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
33
  "stabilityai/stable-diffusion-xl-base-1.0",
34
  torch_dtype=torch.float32
35
+ ).to(TORCH_DEVICE)
36
  pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
37
 
38
+ # 3. Carregar modelo de profundidade -----------------------------------------------------------
39
+ print("Carregando DPT Depth...")
40
  feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
41
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
 
 
 
 
 
42
 
43
 
44
+ # Pipeline principal ---------------------------------------------------------------------------
45
  def full_pipeline(image, prompt, scale_factor=2.0):
46
+ try:
47
+ # 1. Super Resolução com Thera
48
+ source = np.array(image.convert("RGB")) / 255.0
49
+ target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
50
+
51
+ # Converter para JAX array
52
+ source_jax = jax.device_put(source, JAX_DEVICE)
53
+
54
+ # Processar com Thera
55
+ upscaled = model_edsr.apply(
56
+ params_edsr,
57
+ source_jax,
58
+ target_shape,
59
+ do_ensemble=True
60
+ )
61
+ upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
62
+
63
+ # 2. Gerar Bas-Relief
64
+ full_prompt = f"BAS-RELIEF {prompt}, intricate carving, marble relief"
65
+ bas_relief = pipe(
66
+ prompt=full_prompt,
67
+ image=upscaled_pil,
68
+ strength=0.7,
69
+ num_inference_steps=25,
70
+ guidance_scale=7.5
71
+ ).images[0]
72
+
73
+ # 3. Calcular Depth Map
74
+ inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
75
+ with torch.no_grad():
76
+ outputs = depth_model(**inputs)
77
+ depth = outputs.predicted_depth
78
+
79
+ depth_map = torch.nn.functional.interpolate(
80
+ depth.unsqueeze(1),
81
+ size=bas_relief.size[::-1],
82
+ mode="bicubic"
83
+ ).squeeze().cpu().numpy()
84
+
85
+ return upscaled_pil, bas_relief, (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
86
+
87
+ except Exception as e:
88
+ raise gr.Error(f"Erro no processamento: {str(e)}")
89
+
90
+
91
+ # Interface Gradio -----------------------------------------------------------------------------
92
+ with gr.Blocks(title="Super Res + Bas-Relief") as app:
93
+ gr.Markdown("## 🔍 Super Resolução + 🗿 Bas-Relief + 🗺️ Profundidade")
94
 
95
  with gr.Row():
96
  with gr.Column():
97
  img_input = gr.Image(type="pil", label="Imagem de Entrada")
98
+ prompt = gr.Textbox("insanely detailed and complex engraving relief, ultra-high definition, rich in detail, and 16K resolution.", label="Descrição")
99
  scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
100
  btn = gr.Button("Processar")
101
 
102
  with gr.Column():
103
+ img_upscaled = gr.Image(label="Super Resolvida")
104
+ img_basrelief = gr.Image(label="Bas-Relief")
105
  img_depth = gr.Image(label="Mapa de Profundidade")
106
 
107
  btn.click(
requirements.txt CHANGED
@@ -7,6 +7,7 @@ diffusers
7
  einops==0.6.1
8
  flax==0.6.10
9
  flaxmodels==0.1.3
 
10
  jax==0.4.11
11
  jaxlib==0.4.11+cuda11.cudnn86
12
  jaxtyping==0.2.20
@@ -24,6 +25,7 @@ opt-einsum==3.3.0
24
  optax==0.2.0
25
  orbax-checkpoint==0.2.4
26
  peft
 
27
  scipy==1.10.1
28
  timm==0.9.6
29
  torch
 
7
  einops==0.6.1
8
  flax==0.6.10
9
  flaxmodels==0.1.3
10
+ huggingface_hub
11
  jax==0.4.11
12
  jaxlib==0.4.11+cuda11.cudnn86
13
  jaxtyping==0.2.20
 
25
  optax==0.2.0
26
  orbax-checkpoint==0.2.4
27
  peft
28
+ pillow
29
  scipy==1.10.1
30
  timm==0.9.6
31
  torch