Alibrown commited on
Commit
0f35108
·
verified ·
1 Parent(s): 8f3793b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -80
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # float16 +32
2
  import os
3
  import random
4
  import numpy as np
@@ -23,11 +22,35 @@ STORAGE_PSWD = os.getenv('STORAGE_PSWD', '').strip() # SFTP Passwort
23
  STORAGE_PORT = int(os.getenv('STORAGE_PORT', '22').strip()) # SFTP Port
24
  STORAGE_SECRET = os.getenv('STORAGE_SECRET', '').strip() # Secret Token
25
 
26
- # Modell-Konfiguration
27
- available_models = {
28
- "sd3-medium": "stabilityai/stable-diffusion-3-medium-diffusers",
29
- "sd2-base": "stabilityai/stable-diffusion-2-1-base"
30
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # SFTP-Funktion
33
  def upload_to_sftp(local_file, remote_path):
@@ -44,46 +67,8 @@ def upload_to_sftp(local_file, remote_path):
44
  print(f"Error during SFTP upload: {e}")
45
  return False
46
 
47
- # Modell laden Funktion
48
- def load_model(model_name, precision):
49
- device = "cuda" if torch.cuda.is_available() else "cpu"
50
- repo = available_models.get(model_name, available_models["sd3-medium"])
51
-
52
- try:
53
- # Wähle Präzision basierend auf Auswahl
54
- if precision == "float16":
55
- torch_dtype = torch.float16
56
- else: # float32
57
- torch_dtype = torch.float32
58
-
59
- pipe = StableDiffusionPipeline.from_pretrained(
60
- repo,
61
- torch_dtype=torch_dtype
62
- ).to(device)
63
-
64
- # Wenn auf CPU und Speicheroptimierung gewünscht
65
- if device == "cpu":
66
- pipe.enable_sequential_cpu_offload()
67
-
68
- return pipe
69
- except Exception as e:
70
- raise RuntimeError(f"Failed to load the model. Ensure the token has access to the repo. Error: {e}")
71
-
72
- # Maximalwerte
73
- MAX_SEED = np.iinfo(np.int32).max
74
- MAX_IMAGE_SIZE = 1344
75
-
76
- # Globale Pipe-Variable
77
- pipe = None
78
-
79
  # Inferenz-Funktion
80
- def infer(prompt, width, height, guidance_scale, num_inference_steps, seed, randomize_seed, model_name, precision):
81
- global pipe
82
-
83
- # Prüfe, ob Modell neu geladen werden muss
84
- if pipe is None:
85
- pipe = load_model(model_name, precision)
86
-
87
  if randomize_seed:
88
  seed = random.randint(0, MAX_SEED)
89
 
@@ -109,41 +94,9 @@ def infer(prompt, width, height, guidance_scale, num_inference_steps, seed, rand
109
  else:
110
  return "Failed to upload image", seed
111
 
112
- # Modell neu laden
113
- def reload_model(model_name, precision):
114
- global pipe
115
- pipe = load_model(model_name, precision)
116
- return f"Model loaded: {model_name} with {precision} precision"
117
-
118
  # Gradio-App
119
  with gr.Blocks() as demo:
120
- gr.Markdown("### Stable Diffusion - Test App")
121
-
122
- with gr.Row():
123
- with gr.Column():
124
- # Modell Auswahl
125
- model_name = gr.Radio(
126
- choices=list(available_models.keys()),
127
- value="sd3-medium",
128
- label="Model"
129
- )
130
-
131
- # Präzision Auswahl
132
- precision = gr.Radio(
133
- choices=["float16", "float32"],
134
- value="float16",
135
- label="Precision"
136
- )
137
-
138
- reload_button = gr.Button("Load/Reload Model")
139
- model_status = gr.Textbox(label="Model Status")
140
-
141
- # Modell laden Button
142
- reload_button.click(
143
- reload_model,
144
- inputs=[model_name, precision],
145
- outputs=[model_status]
146
- )
147
 
148
  with gr.Row():
149
  with gr.Column():
@@ -161,8 +114,7 @@ with gr.Blocks() as demo:
161
  infer,
162
  inputs=[
163
  prompt, width, height, guidance_scale,
164
- num_inference_steps, seed, randomize_seed,
165
- model_name, precision
166
  ],
167
  outputs=[output, seed]
168
  )
 
 
1
  import os
2
  import random
3
  import numpy as np
 
22
  STORAGE_PORT = int(os.getenv('STORAGE_PORT', '22').strip()) # SFTP Port
23
  STORAGE_SECRET = os.getenv('STORAGE_SECRET', '').strip() # Secret Token
24
 
25
+ # Modell-Konfiguration und Device-Setup
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ print(f"Using device: {device}")
28
+
29
+ # Stelle fest, ob auf CPU oder GPU-System
30
+ is_gpu_available = torch.cuda.is_available()
31
+
32
+ # Modell laden - passend zur Hardware
33
+ repo = "stabilityai/stable-diffusion-3-medium-diffusers"
34
+
35
+ # Die Standard-Präzision basiert auf verfügbarer Hardware
36
+ DEFAULT_PRECISION = "float16" if is_gpu_available else "float32"
37
+ print(f"Default precision: {DEFAULT_PRECISION}")
38
+
39
+ # Modell beim Start laden
40
+ try:
41
+ # Wähle Präzision basierend auf Hardware
42
+ if DEFAULT_PRECISION == "float16":
43
+ pipe = StableDiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16).to(device)
44
+ else: # float32 für CPU
45
+ pipe = StableDiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float32).to(device)
46
+
47
+ print("Model loaded successfully")
48
+ except Exception as e:
49
+ raise RuntimeError(f"Failed to load the model. Ensure the token has access to the repo. Error: {e}")
50
+
51
+ # Maximalwerte
52
+ MAX_SEED = np.iinfo(np.int32).max
53
+ MAX_IMAGE_SIZE = 1344
54
 
55
  # SFTP-Funktion
56
  def upload_to_sftp(local_file, remote_path):
 
67
  print(f"Error during SFTP upload: {e}")
68
  return False
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  # Inferenz-Funktion
71
+ def infer(prompt, width, height, guidance_scale, num_inference_steps, seed, randomize_seed):
 
 
 
 
 
 
72
  if randomize_seed:
73
  seed = random.randint(0, MAX_SEED)
74
 
 
94
  else:
95
  return "Failed to upload image", seed
96
 
 
 
 
 
 
 
97
  # Gradio-App
98
  with gr.Blocks() as demo:
99
+ gr.Markdown(f"### Stable Diffusion 3 - Test App (Running on {device.upper()} with {DEFAULT_PRECISION})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  with gr.Row():
102
  with gr.Column():
 
114
  infer,
115
  inputs=[
116
  prompt, width, height, guidance_scale,
117
+ num_inference_steps, seed, randomize_seed
 
118
  ],
119
  outputs=[output, seed]
120
  )