Alibrown commited on
Commit
e27ae0d
·
verified ·
1 Parent(s): 7ee11ab

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ import gradio as gr
6
+ from diffusers import StableDiffusionPipeline
7
+ import paramiko
8
+ from huggingface_hub import login
9
+
10
+ # Hugging Face Token
11
+ HF_TOKEN = os.getenv('HF_TOKEN', '').strip()
12
+ if not HF_TOKEN:
13
+ raise ValueError("HUGGING_TOKEN is not set. Please set the token as an environment variable.")
14
+
15
+ # Hugging Face Login
16
+ login(token=HF_TOKEN)
17
+
18
+ # Konfiguration
19
+ STORAGE_DOMAIN = os.getenv('STORAGE_DOMAIN', '').strip() # SFTP Server Domain
20
+ STORAGE_USER = os.getenv('STORAGE_USER', '').strip() # SFTP User
21
+ STORAGE_PSWD = os.getenv('STORAGE_PSWD', '').strip() # SFTP Passwort
22
+ STORAGE_PORT = int(os.getenv('STORAGE_PORT', '22').strip()) # SFTP Port
23
+ STORAGE_SECRET = os.getenv('STORAGE_SECRET', '').strip() # Secret Token
24
+
25
+ # Modell-Optionen - können angepasst werden
26
+ MODEL_REPO = os.getenv('MODEL_REPO', 'stabilityai/stable-diffusion-2-1-base') # Standard-Modell
27
+ TORCH_DTYPE = os.getenv('TORCH_DTYPE', 'float16') # Standard-Präzision
28
+
29
+ # Modell laden
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ torch_dtype = torch.float16 if TORCH_DTYPE == 'float16' else torch.float32
32
+
33
+ try:
34
+ pipe = StableDiffusionPipeline.from_pretrained(
35
+ MODEL_REPO,
36
+ torch_dtype=torch_dtype
37
+ ).to(device)
38
+ except Exception as e:
39
+ raise RuntimeError(f"Failed to load the model. Ensure the token has access to the repo. Error: {e}")
40
+
41
+ # Maximalwerte
42
+ MAX_SEED = np.iinfo(np.int32).max
43
+ MAX_IMAGE_SIZE = 1344
44
+
45
+ # SFTP-Funktion
46
+ def upload_to_sftp(local_file, remote_path):
47
+ """Versucht, eine Datei auf einen SFTP-Server hochzuladen."""
48
+ try:
49
+ transport = paramiko.Transport((STORAGE_DOMAIN, STORAGE_PORT))
50
+ transport.connect(username=STORAGE_USER, password=STORAGE_PSWD)
51
+ sftp = paramiko.SFTPClient.from_transport(transport)
52
+ sftp.put(local_file, remote_path)
53
+ sftp.close()
54
+ transport.close()
55
+ return True
56
+ except Exception as e:
57
+ return f"Error during SFTP upload: {e}"
58
+
59
+ # Inferenz-Funktion
60
+ def infer(prompt, width, height, guidance_scale, num_inference_steps, seed, randomize_seed):
61
+ """Generiert ein Bild basierend auf dem Eingabe-Prompt und lädt es auf einen SFTP-Server hoch."""
62
+ if randomize_seed:
63
+ seed = random.randint(0, MAX_SEED)
64
+
65
+ generator = torch.manual_seed(seed)
66
+ image = pipe(
67
+ prompt,
68
+ guidance_scale=guidance_scale,
69
+ num_inference_steps=num_inference_steps,
70
+ width=width,
71
+ height=height,
72
+ generator=generator
73
+ ).images[0]
74
+
75
+ # Speichere Bild lokal
76
+ local_file = f"/tmp/generated_image_{seed}.png"
77
+ image.save(local_file)
78
+
79
+ # Hochladen zu SFTP
80
+ remote_path = f"/uploads/generated_image_{seed}.png"
81
+ upload_result = upload_to_sftp(local_file, remote_path)
82
+
83
+ # Entferne das lokale Bild, wenn der Upload erfolgreich war
84
+ if upload_result == True:
85
+ os.remove(local_file)
86
+ return f"Image successfully uploaded to {remote_path}", seed
87
+ else:
88
+ return upload_result, seed
89
+
90
+ # App-Titel mit Modell- und Präzisionsinformationen
91
+ APP_TITLE = f"### Stable Diffusion - {os.path.basename(MODEL_REPO)} ({TORCH_DTYPE} auf {device})"
92
+
93
+ # Gradio-App
94
+ with gr.Blocks() as demo:
95
+ gr.Markdown(APP_TITLE)
96
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
97
+ width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Width")
98
+ height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Height")
99
+ guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=7.5, label="Guidance Scale")
100
+ num_inference_steps = gr.Slider(1, 50, step=1, value=25, label="Inference Steps")
101
+ seed = gr.Number(value=42, label="Seed")
102
+ randomize_seed = gr.Checkbox(value=False, label="Randomize Seed")
103
+ generate_button = gr.Button("Generate Image")
104
+ output = gr.Text(label="Output")
105
+
106
+ # Klick-Event für die Generierung
107
+ generate_button.click(
108
+ infer,
109
+ inputs=[prompt, width, height, guidance_scale, num_inference_steps, seed, randomize_seed],
110
+ outputs=[output, seed]
111
+ )
112
+
113
+ demo.launch()