Alibrown commited on
Commit
e57b5e3
·
verified ·
1 Parent(s): 0f35108

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -42
app.py CHANGED
@@ -16,35 +16,25 @@ if not HF_TOKEN:
16
  login(token=HF_TOKEN)
17
 
18
  # Konfiguration
19
- STORAGE_DOMAIN = os.getenv('STORAGE_DOMAIN', '').strip() # SFTP Server Domain
20
- STORAGE_USER = os.getenv('STORAGE_USER', '').strip() # SFTP User
21
- STORAGE_PSWD = os.getenv('STORAGE_PSWD', '').strip() # SFTP Passwort
22
- STORAGE_PORT = int(os.getenv('STORAGE_PORT', '22').strip()) # SFTP Port
23
- STORAGE_SECRET = os.getenv('STORAGE_SECRET', '').strip() # Secret Token
24
 
25
- # Modell-Konfiguration und Device-Setup
26
- device = "cuda" if torch.cuda.is_available() else "cpu"
27
- print(f"Using device: {device}")
28
-
29
- # Stelle fest, ob auf CPU oder GPU-System
30
- is_gpu_available = torch.cuda.is_available()
31
 
32
- # Modell laden - passend zur Hardware
33
- repo = "stabilityai/stable-diffusion-3-medium-diffusers"
34
-
35
- # Die Standard-Präzision basiert auf verfügbarer Hardware
36
- DEFAULT_PRECISION = "float16" if is_gpu_available else "float32"
37
- print(f"Default precision: {DEFAULT_PRECISION}")
38
 
39
- # Modell beim Start laden
40
  try:
41
- # Wähle Präzision basierend auf Hardware
42
- if DEFAULT_PRECISION == "float16":
43
- pipe = StableDiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16).to(device)
44
- else: # float32 für CPU
45
- pipe = StableDiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float32).to(device)
46
-
47
- print("Model loaded successfully")
48
  except Exception as e:
49
  raise RuntimeError(f"Failed to load the model. Ensure the token has access to the repo. Error: {e}")
50
 
@@ -94,28 +84,25 @@ def infer(prompt, width, height, guidance_scale, num_inference_steps, seed, rand
94
  else:
95
  return "Failed to upload image", seed
96
 
 
 
 
97
  # Gradio-App
98
  with gr.Blocks() as demo:
99
- gr.Markdown(f"### Stable Diffusion 3 - Test App (Running on {device.upper()} with {DEFAULT_PRECISION})")
100
-
101
- with gr.Row():
102
- with gr.Column():
103
- prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
104
- width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Width")
105
- height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Height")
106
- guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=7.5, label="Guidance Scale")
107
- num_inference_steps = gr.Slider(1, 50, step=1, value=25, label="Inference Steps")
108
- seed = gr.Number(value=42, label="Seed")
109
- randomize_seed = gr.Checkbox(value=False, label="Randomize Seed")
110
- generate_button = gr.Button("Generate Image")
111
- output = gr.Text(label="Output")
112
 
113
  generate_button.click(
114
  infer,
115
- inputs=[
116
- prompt, width, height, guidance_scale,
117
- num_inference_steps, seed, randomize_seed
118
- ],
119
  outputs=[output, seed]
120
  )
121
 
 
16
  login(token=HF_TOKEN)
17
 
18
  # Konfiguration
19
+ STORAGE_DOMAIN = os.getenv('STORAGE_DOMAIN', '').strip() # SFTP Server Domain
20
+ STORAGE_USER = os.getenv('STORAGE_USER', '').strip() # SFTP User
21
+ STORAGE_PSWD = os.getenv('STORAGE_PSWD', '').strip() # SFTP Passwort
22
+ STORAGE_PORT = int(os.getenv('STORAGE_PORT', '22').strip()) # SFTP Port
23
+ STORAGE_SECRET = os.getenv('STORAGE_SECRET', '').strip() # Secret Token
24
 
25
+ # Modell-Optionen - können angepasst werden
26
+ MODEL_REPO = os.getenv('MODEL_REPO', 'stabilityai/stable-diffusion-2-1-base') # Standard-Modell
27
+ TORCH_DTYPE = os.getenv('TORCH_DTYPE', 'float16') # Standard-Präzision
 
 
 
28
 
29
+ # Modell laden
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ torch_dtype = torch.float16 if TORCH_DTYPE == 'float16' else torch.float32
 
 
 
32
 
 
33
  try:
34
+ pipe = StableDiffusionPipeline.from_pretrained(
35
+ MODEL_REPO,
36
+ torch_dtype=torch_dtype
37
+ ).to(device)
 
 
 
38
  except Exception as e:
39
  raise RuntimeError(f"Failed to load the model. Ensure the token has access to the repo. Error: {e}")
40
 
 
84
  else:
85
  return "Failed to upload image", seed
86
 
87
+ # App-Titel mit Modell- und Präzisionsinformationen
88
+ APP_TITLE = f"### Stable Diffusion - {os.path.basename(MODEL_REPO)} ({TORCH_DTYPE} auf {device})"
89
+
90
  # Gradio-App
91
  with gr.Blocks() as demo:
92
+ gr.Markdown(APP_TITLE)
93
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
94
+ width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Width")
95
+ height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Height")
96
+ guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=7.5, label="Guidance Scale")
97
+ num_inference_steps = gr.Slider(1, 50, step=1, value=25, label="Inference Steps")
98
+ seed = gr.Number(value=42, label="Seed")
99
+ randomize_seed = gr.Checkbox(value=False, label="Randomize Seed")
100
+ generate_button = gr.Button("Generate Image")
101
+ output = gr.Text(label="Output")
 
 
 
102
 
103
  generate_button.click(
104
  infer,
105
+ inputs=[prompt, width, height, guidance_scale, num_inference_steps, seed, randomize_seed],
 
 
 
106
  outputs=[output, seed]
107
  )
108