File size: 5,566 Bytes
8f3793b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# float16 +32 
import os
import random
import numpy as np
import torch
import gradio as gr
from diffusers import StableDiffusionPipeline
import paramiko
from huggingface_hub import login

# Hugging Face Token
HF_TOKEN = os.getenv('HF_TOKEN', '').strip()
if not HF_TOKEN:
    raise ValueError("HUGGING_TOKEN is not set. Please set the token as an environment variable.")

# Hugging Face Login
login(token=HF_TOKEN)

# Konfiguration
STORAGE_DOMAIN = os.getenv('STORAGE_DOMAIN', '').strip()  # SFTP Server Domain
STORAGE_USER = os.getenv('STORAGE_USER', '').strip()  # SFTP User
STORAGE_PSWD = os.getenv('STORAGE_PSWD', '').strip()  # SFTP Passwort
STORAGE_PORT = int(os.getenv('STORAGE_PORT', '22').strip())  # SFTP Port
STORAGE_SECRET = os.getenv('STORAGE_SECRET', '').strip()  # Secret Token

# Modell-Konfiguration
available_models = {
    "sd3-medium": "stabilityai/stable-diffusion-3-medium-diffusers",
    "sd2-base": "stabilityai/stable-diffusion-2-1-base"
}

# SFTP-Funktion
def upload_to_sftp(local_file, remote_path):
    try:
        transport = paramiko.Transport((STORAGE_DOMAIN, STORAGE_PORT))
        transport.connect(username=STORAGE_USER, password=STORAGE_PSWD)
        sftp = paramiko.SFTPClient.from_transport(transport)
        sftp.put(local_file, remote_path)
        sftp.close()
        transport.close()
        print(f"File {local_file} successfully uploaded to {remote_path}")
        return True
    except Exception as e:
        print(f"Error during SFTP upload: {e}")
        return False

# Modell laden Funktion
def load_model(model_name, precision):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    repo = available_models.get(model_name, available_models["sd3-medium"])
    
    try:
        # Wähle Präzision basierend auf Auswahl
        if precision == "float16":
            torch_dtype = torch.float16
        else:  # float32
            torch_dtype = torch.float32
            
        pipe = StableDiffusionPipeline.from_pretrained(
            repo, 
            torch_dtype=torch_dtype
        ).to(device)
        
        # Wenn auf CPU und Speicheroptimierung gewünscht
        if device == "cpu":
            pipe.enable_sequential_cpu_offload()
            
        return pipe
    except Exception as e:
        raise RuntimeError(f"Failed to load the model. Ensure the token has access to the repo. Error: {e}")

# Maximalwerte
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1344

# Globale Pipe-Variable
pipe = None

# Inferenz-Funktion
def infer(prompt, width, height, guidance_scale, num_inference_steps, seed, randomize_seed, model_name, precision):
    global pipe
    
    # Prüfe, ob Modell neu geladen werden muss
    if pipe is None:
        pipe = load_model(model_name, precision)
    
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    generator = torch.manual_seed(seed)
    image = pipe(
        prompt, 
        guidance_scale=guidance_scale, 
        num_inference_steps=num_inference_steps, 
        width=width, 
        height=height, 
        generator=generator
    ).images[0]
    
    # Speichere Bild lokal
    local_file = f"/tmp/generated_image_{seed}.png"
    image.save(local_file)
    
    # Hochladen zu SFTP
    remote_path = f"/uploads/generated_image_{seed}.png"
    if upload_to_sftp(local_file, remote_path):
        os.remove(local_file)
        return f"Image uploaded to {remote_path}", seed
    else:
        return "Failed to upload image", seed

# Modell neu laden
def reload_model(model_name, precision):
    global pipe
    pipe = load_model(model_name, precision)
    return f"Model loaded: {model_name} with {precision} precision"

# Gradio-App
with gr.Blocks() as demo:
    gr.Markdown("### Stable Diffusion - Test App")
    
    with gr.Row():
        with gr.Column():
            # Modell Auswahl
            model_name = gr.Radio(
                choices=list(available_models.keys()),
                value="sd3-medium",
                label="Model"
            )
            
            # Präzision Auswahl
            precision = gr.Radio(
                choices=["float16", "float32"],
                value="float16",
                label="Precision"
            )
            
            reload_button = gr.Button("Load/Reload Model")
            model_status = gr.Textbox(label="Model Status")
            
            # Modell laden Button
            reload_button.click(
                reload_model,
                inputs=[model_name, precision],
                outputs=[model_status]
            )
    
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
            width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Width")
            height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Height")
            guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=7.5, label="Guidance Scale")
            num_inference_steps = gr.Slider(1, 50, step=1, value=25, label="Inference Steps")
            seed = gr.Number(value=42, label="Seed")
            randomize_seed = gr.Checkbox(value=False, label="Randomize Seed")
            generate_button = gr.Button("Generate Image")
            output = gr.Text(label="Output")
    
    generate_button.click(
        infer,
        inputs=[
            prompt, width, height, guidance_scale, 
            num_inference_steps, seed, randomize_seed,
            model_name, precision
        ],
        outputs=[output, seed]
    )

demo.launch()