Alibrown commited on
Commit
8f3793b
·
verified ·
1 Parent(s): b192fe5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # float16 +32
2
+ import os
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import gradio as gr
7
+ from diffusers import StableDiffusionPipeline
8
+ import paramiko
9
+ from huggingface_hub import login
10
+
11
+ # Hugging Face Token
12
+ HF_TOKEN = os.getenv('HF_TOKEN', '').strip()
13
+ if not HF_TOKEN:
14
+ raise ValueError("HUGGING_TOKEN is not set. Please set the token as an environment variable.")
15
+
16
+ # Hugging Face Login
17
+ login(token=HF_TOKEN)
18
+
19
+ # Konfiguration
20
+ STORAGE_DOMAIN = os.getenv('STORAGE_DOMAIN', '').strip() # SFTP Server Domain
21
+ STORAGE_USER = os.getenv('STORAGE_USER', '').strip() # SFTP User
22
+ STORAGE_PSWD = os.getenv('STORAGE_PSWD', '').strip() # SFTP Passwort
23
+ STORAGE_PORT = int(os.getenv('STORAGE_PORT', '22').strip()) # SFTP Port
24
+ STORAGE_SECRET = os.getenv('STORAGE_SECRET', '').strip() # Secret Token
25
+
26
+ # Modell-Konfiguration
27
+ available_models = {
28
+ "sd3-medium": "stabilityai/stable-diffusion-3-medium-diffusers",
29
+ "sd2-base": "stabilityai/stable-diffusion-2-1-base"
30
+ }
31
+
32
+ # SFTP-Funktion
33
+ def upload_to_sftp(local_file, remote_path):
34
+ try:
35
+ transport = paramiko.Transport((STORAGE_DOMAIN, STORAGE_PORT))
36
+ transport.connect(username=STORAGE_USER, password=STORAGE_PSWD)
37
+ sftp = paramiko.SFTPClient.from_transport(transport)
38
+ sftp.put(local_file, remote_path)
39
+ sftp.close()
40
+ transport.close()
41
+ print(f"File {local_file} successfully uploaded to {remote_path}")
42
+ return True
43
+ except Exception as e:
44
+ print(f"Error during SFTP upload: {e}")
45
+ return False
46
+
47
+ # Modell laden Funktion
48
+ def load_model(model_name, precision):
49
+ device = "cuda" if torch.cuda.is_available() else "cpu"
50
+ repo = available_models.get(model_name, available_models["sd3-medium"])
51
+
52
+ try:
53
+ # Wähle Präzision basierend auf Auswahl
54
+ if precision == "float16":
55
+ torch_dtype = torch.float16
56
+ else: # float32
57
+ torch_dtype = torch.float32
58
+
59
+ pipe = StableDiffusionPipeline.from_pretrained(
60
+ repo,
61
+ torch_dtype=torch_dtype
62
+ ).to(device)
63
+
64
+ # Wenn auf CPU und Speicheroptimierung gewünscht
65
+ if device == "cpu":
66
+ pipe.enable_sequential_cpu_offload()
67
+
68
+ return pipe
69
+ except Exception as e:
70
+ raise RuntimeError(f"Failed to load the model. Ensure the token has access to the repo. Error: {e}")
71
+
72
+ # Maximalwerte
73
+ MAX_SEED = np.iinfo(np.int32).max
74
+ MAX_IMAGE_SIZE = 1344
75
+
76
+ # Globale Pipe-Variable
77
+ pipe = None
78
+
79
+ # Inferenz-Funktion
80
+ def infer(prompt, width, height, guidance_scale, num_inference_steps, seed, randomize_seed, model_name, precision):
81
+ global pipe
82
+
83
+ # Prüfe, ob Modell neu geladen werden muss
84
+ if pipe is None:
85
+ pipe = load_model(model_name, precision)
86
+
87
+ if randomize_seed:
88
+ seed = random.randint(0, MAX_SEED)
89
+
90
+ generator = torch.manual_seed(seed)
91
+ image = pipe(
92
+ prompt,
93
+ guidance_scale=guidance_scale,
94
+ num_inference_steps=num_inference_steps,
95
+ width=width,
96
+ height=height,
97
+ generator=generator
98
+ ).images[0]
99
+
100
+ # Speichere Bild lokal
101
+ local_file = f"/tmp/generated_image_{seed}.png"
102
+ image.save(local_file)
103
+
104
+ # Hochladen zu SFTP
105
+ remote_path = f"/uploads/generated_image_{seed}.png"
106
+ if upload_to_sftp(local_file, remote_path):
107
+ os.remove(local_file)
108
+ return f"Image uploaded to {remote_path}", seed
109
+ else:
110
+ return "Failed to upload image", seed
111
+
112
+ # Modell neu laden
113
+ def reload_model(model_name, precision):
114
+ global pipe
115
+ pipe = load_model(model_name, precision)
116
+ return f"Model loaded: {model_name} with {precision} precision"
117
+
118
+ # Gradio-App
119
+ with gr.Blocks() as demo:
120
+ gr.Markdown("### Stable Diffusion - Test App")
121
+
122
+ with gr.Row():
123
+ with gr.Column():
124
+ # Modell Auswahl
125
+ model_name = gr.Radio(
126
+ choices=list(available_models.keys()),
127
+ value="sd3-medium",
128
+ label="Model"
129
+ )
130
+
131
+ # Präzision Auswahl
132
+ precision = gr.Radio(
133
+ choices=["float16", "float32"],
134
+ value="float16",
135
+ label="Precision"
136
+ )
137
+
138
+ reload_button = gr.Button("Load/Reload Model")
139
+ model_status = gr.Textbox(label="Model Status")
140
+
141
+ # Modell laden Button
142
+ reload_button.click(
143
+ reload_model,
144
+ inputs=[model_name, precision],
145
+ outputs=[model_status]
146
+ )
147
+
148
+ with gr.Row():
149
+ with gr.Column():
150
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
151
+ width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Width")
152
+ height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Height")
153
+ guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=7.5, label="Guidance Scale")
154
+ num_inference_steps = gr.Slider(1, 50, step=1, value=25, label="Inference Steps")
155
+ seed = gr.Number(value=42, label="Seed")
156
+ randomize_seed = gr.Checkbox(value=False, label="Randomize Seed")
157
+ generate_button = gr.Button("Generate Image")
158
+ output = gr.Text(label="Output")
159
+
160
+ generate_button.click(
161
+ infer,
162
+ inputs=[
163
+ prompt, width, height, guidance_scale,
164
+ num_inference_steps, seed, randomize_seed,
165
+ model_name, precision
166
+ ],
167
+ outputs=[output, seed]
168
+ )
169
+
170
+ demo.launch()