leemonz commited on
Commit
8fc6445
·
verified ·
1 Parent(s): 55d4811

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -4,20 +4,29 @@ import gradio as gr
4
 
5
  # Modelo base
6
  BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
7
- # LoRA de ejemplo (puedes cambiarlo por el tuyo)
8
  LORA_MODEL = "nerijs/pixel-art-xl"
9
 
 
 
 
 
 
 
 
 
10
  print("Cargando modelo base...")
11
  pipe = StableDiffusionXLPipeline.from_pretrained(
12
  BASE_MODEL,
13
- torch_dtype=torch.float16,
14
- variant="fp16",
15
- use_safetensors=True
16
- ).to("cuda")
 
17
 
18
  print("Cargando LoRA...")
19
  pipe.load_lora_weights(LORA_MODEL)
20
- pipe.fuse_lora(lora_scale=0.8) # Ajusta el peso del LoRA
21
 
22
  def generar(prompt):
23
  with torch.inference_mode():
@@ -28,7 +37,7 @@ demo = gr.Interface(
28
  fn=generar,
29
  inputs=gr.Textbox(label="Prompt", placeholder="Escribe tu prompt aquí..."),
30
  outputs=gr.Image(label="Imagen generada"),
31
- title="Generador con LoRA en T4 Gratis"
32
  )
33
 
34
  if __name__ == "__main__":
 
4
 
5
  # Modelo base
6
  BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
7
+ # LoRA de ejemplo
8
  LORA_MODEL = "nerijs/pixel-art-xl"
9
 
10
+ # Detectar dispositivo
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ print(f"Usando dispositivo: {device}")
13
+
14
+ # Ajuste de dtype y memoria
15
+ dtype = torch.float16 if device == "cuda" else torch.float32
16
+ low_mem = True if device == "cpu" else False
17
+
18
  print("Cargando modelo base...")
19
  pipe = StableDiffusionXLPipeline.from_pretrained(
20
  BASE_MODEL,
21
+ torch_dtype=dtype,
22
+ variant="fp16" if device == "cuda" else None,
23
+ use_safetensors=True,
24
+ low_cpu_mem_usage=low_mem
25
+ ).to(device)
26
 
27
  print("Cargando LoRA...")
28
  pipe.load_lora_weights(LORA_MODEL)
29
+ pipe.fuse_lora(lora_scale=0.8)
30
 
31
  def generar(prompt):
32
  with torch.inference_mode():
 
37
  fn=generar,
38
  inputs=gr.Textbox(label="Prompt", placeholder="Escribe tu prompt aquí..."),
39
  outputs=gr.Image(label="Imagen generada"),
40
+ title="Generador con LoRA (CPU/GPU)"
41
  )
42
 
43
  if __name__ == "__main__":