ds1david commited on
Commit
e87f39a
·
1 Parent(s): 98889c8

Trying fix variants

Browse files
Files changed (1) hide show
  1. app.py +42 -36
app.py CHANGED
@@ -5,28 +5,32 @@ from diffusers import StableDiffusionXLImg2ImgPipeline
5
  from transformers import DPTFeatureExtractor, DPTForDepthEstimation
6
  from PIL import Image, ImageEnhance, ImageOps
7
 
8
- device = "cpu" # or "cuda" if you have a GPU
9
- torch_dtype = torch.float32
 
10
 
11
- print("Loading SDXL Img2Img model...")
12
  pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
13
  "stabilityai/stable-diffusion-xl-base-1.0",
14
- torch_dtype=torch_dtype
 
 
15
  ).to(device)
16
 
17
- print("Loading bas-relief LoRA weights with PEFT...")
18
  pipe.load_lora_weights(
19
  "KappaNeuro/bas-relief",
20
  weight_name="BAS-RELIEF.safetensors",
 
21
  peft_backend="peft"
22
  )
23
 
24
- print("Loading DPT Depth Model...")
25
  feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
26
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
27
 
28
 
29
- def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image:
30
  d_min, d_max = depth_arr.min(), depth_arr.max()
31
  depth_stretched = (depth_arr - d_min) / (d_max - d_min + 1e-8)
32
  depth_stretched = (depth_stretched * 255).astype(np.uint8)
@@ -40,57 +44,59 @@ def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image:
40
  return depth_pil
41
 
42
 
43
- def generate_bas_relief_and_depth(input_image: Image.Image):
44
- # Redimensionar a imagem para o tamanho esperado
45
- input_image = input_image.resize((512, 512))
46
 
47
- # Prompt fixo para ativar o LoRA
48
- prompt = "BAS-RELIEF"
 
 
 
 
 
 
 
 
49
 
50
- print("Gerando imagem no estilo baixo-relevo...")
51
- result = pipe(
52
- prompt=prompt,
53
- image=input_image,
54
- strength=0.7, # Controla a intensidade da transformação
55
- num_inference_steps=15,
56
- guidance_scale=7.5
57
- )
58
- generated_image = result.images[0]
59
 
60
- print("Calculando mapa de profundidade...")
61
- inputs = feature_extractor(generated_image, return_tensors="pt").to(device)
62
  with torch.no_grad():
63
  outputs = depth_model(**inputs)
64
  predicted_depth = outputs.predicted_depth
65
 
66
  prediction = torch.nn.functional.interpolate(
67
  predicted_depth.unsqueeze(1),
68
- size=generated_image.size[::-1],
69
  mode="bicubic",
70
  align_corners=False
71
  ).squeeze()
72
 
73
- depth_map_pil = enhance_depth_map(prediction.cpu().numpy())
74
 
75
- return generated_image, depth_map_pil
76
 
77
 
78
- title = "Conversor para Baixo-relevo (SDXL + LoRA) com Mapa de Profundidade"
79
- description = (
80
- "Carrega stable-diffusion-xl-base-1.0 no CPU, aplica LoRA de 'KappaNeuro/bas-relief' "
81
- "para transformar imagens em baixo-relevo e calcula o mapa de profundidade correspondente."
 
82
  )
83
 
84
- iface = gr.Interface(
85
- fn=generate_bas_relief_and_depth,
86
  inputs=gr.Image(label="Imagem de Entrada", type="pil"),
87
  outputs=[
88
- gr.Image(label="Imagem em Baixo-relevo"),
89
  gr.Image(label="Mapa de Profundidade")
90
  ],
91
- title=title,
92
- description=description
 
93
  )
94
 
95
  if __name__ == "__main__":
96
- iface.launch()
 
5
  from transformers import DPTFeatureExtractor, DPTForDepthEstimation
6
  from PIL import Image, ImageEnhance, ImageOps
7
 
8
+ # Configuração de dispositivo e tipos de dados
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
11
 
12
+ print("Carregando modelo SDXL Img2Img...")
13
  pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
14
  "stabilityai/stable-diffusion-xl-base-1.0",
15
+ torch_dtype=torch_dtype,
16
+ variant="fp32",
17
+ use_safetensors=True
18
  ).to(device)
19
 
20
+ print("Carregando pesos LoRA para baixo-relevo...")
21
  pipe.load_lora_weights(
22
  "KappaNeuro/bas-relief",
23
  weight_name="BAS-RELIEF.safetensors",
24
+ adapter_name="bas_relief",
25
  peft_backend="peft"
26
  )
27
 
28
+ print("Carregando modelo de profundidade DPT...")
29
  feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
30
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
31
 
32
 
33
+ def melhorar_mapa_profundidade(depth_arr: np.ndarray) -> Image.Image:
34
  d_min, d_max = depth_arr.min(), depth_arr.max()
35
  depth_stretched = (depth_arr - d_min) / (d_max - d_min + 1e-8)
36
  depth_stretched = (depth_stretched * 255).astype(np.uint8)
 
44
  return depth_pil
45
 
46
 
47
+ def gerar_baixo_relevo_e_profundidade(imagem: Image.Image):
48
+ # Pré-processamento da imagem
49
+ imagem = imagem.convert("RGB").resize((512, 512))
50
 
51
+ # Geração da imagem em baixo-relevo
52
+ with torch.autocast(device, dtype=torch_dtype):
53
+ resultado = pipe(
54
+ prompt="BAS-RELIEF",
55
+ image=imagem,
56
+ strength=0.7,
57
+ num_inference_steps=15,
58
+ guidance_scale=7.5,
59
+ generator=torch.Generator(device=device).manual_seed(0)
60
+ )
61
 
62
+ imagem_gerada = resultado.images[0]
 
 
 
 
 
 
 
 
63
 
64
+ # Cálculo do mapa de profundidade
65
+ inputs = feature_extractor(imagem_gerada, return_tensors="pt").to(device)
66
  with torch.no_grad():
67
  outputs = depth_model(**inputs)
68
  predicted_depth = outputs.predicted_depth
69
 
70
  prediction = torch.nn.functional.interpolate(
71
  predicted_depth.unsqueeze(1),
72
+ size=imagem_gerada.size[::-1],
73
  mode="bicubic",
74
  align_corners=False
75
  ).squeeze()
76
 
77
+ mapa_profundidade = melhorar_mapa_profundidade(prediction.cpu().numpy())
78
 
79
+ return imagem_gerada, mapa_profundidade
80
 
81
 
82
+ # Interface Gradio
83
+ titulo = "Conversor para Baixo-relevo com Mapa de Profundidade"
84
+ descricao = (
85
+ "Carrega uma imagem para transformar em estilo baixo-relevo usando SDXL + LoRA "
86
+ "e gera o mapa de profundidade correspondente."
87
  )
88
 
89
+ interface = gr.Interface(
90
+ fn=gerar_baixo_relevo_e_profundidade,
91
  inputs=gr.Image(label="Imagem de Entrada", type="pil"),
92
  outputs=[
93
+ gr.Image(label="Baixo-relevo Gerado"),
94
  gr.Image(label="Mapa de Profundidade")
95
  ],
96
+ title=titulo,
97
+ description=descricao,
98
+ allow_flagging="never"
99
  )
100
 
101
  if __name__ == "__main__":
102
+ interface.launch(server_name="0.0.0.0" if torch.cuda.is_available() else None)