File size: 4,128 Bytes
f7c1363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f7b856
 
f7c1363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import random
import numpy as np
import torch
import gradio as gr
from diffusers import StableDiffusionPipeline
import paramiko
from huggingface_hub import login

# Hugging Face Token
HF_TOKEN = os.getenv('HF_TOKEN', '').strip()
if not HF_TOKEN:
    raise ValueError("HUGGING_TOKEN is not set. Please set the token as an environment variable.")

# Hugging Face Login
login(token=HF_TOKEN)

# Konfiguration
STORAGE_DOMAIN = os.getenv('STORAGE_DOMAIN', '').strip()  # SFTP Server Domain
STORAGE_USER = os.getenv('STORAGE_USER', '').strip()  # SFTP User
STORAGE_PSWD = os.getenv('STORAGE_PSWD', '').strip()  # SFTP Passwort
STORAGE_PORT = int(os.getenv('STORAGE_PORT', '22').strip())  # SFTP Port
STORAGE_SECRET = os.getenv('STORAGE_SECRET', '').strip()  # Secret Token

# Modell-Optionen - können angepasst werden
MODEL_REPO = os.getenv('MODEL_REPO', 'stabilityai/stable-diffusion-2-1-base')  # Standard-Modell
TORCH_DTYPE = os.getenv('TORCH_DTYPE', 'float16')  # Standard-Präzision

# Modell laden
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() and TORCH_DTYPE == 'float16' else torch.float32


try:
    pipe = StableDiffusionPipeline.from_pretrained(
        MODEL_REPO,
        torch_dtype=torch_dtype
    ).to(device)
except Exception as e:
    raise RuntimeError(f"Failed to load the model. Ensure the token has access to the repo. Error: {e}")

# Maximalwerte
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1344

# SFTP-Funktion
def upload_to_sftp(local_file, remote_path):
    """Versucht, eine Datei auf einen SFTP-Server hochzuladen."""
    try:
        transport = paramiko.Transport((STORAGE_DOMAIN, STORAGE_PORT))
        transport.connect(username=STORAGE_USER, password=STORAGE_PSWD)
        sftp = paramiko.SFTPClient.from_transport(transport)
        sftp.put(local_file, remote_path)
        sftp.close()
        transport.close()
        return True
    except Exception as e:
        return f"Error during SFTP upload: {e}"

# Inferenz-Funktion
def infer(prompt, width, height, guidance_scale, num_inference_steps, seed, randomize_seed):
    """Generiert ein Bild basierend auf dem Eingabe-Prompt und lädt es auf einen SFTP-Server hoch."""
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    generator = torch.manual_seed(seed)
    image = pipe(
        prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        width=width,
        height=height,
        generator=generator
    ).images[0]
    
    # Speichere Bild lokal
    local_file = f"/tmp/generated_image_{seed}.png"
    image.save(local_file)
    
    # Hochladen zu SFTP
    remote_path = f"/uploads/generated_image_{seed}.png"
    upload_result = upload_to_sftp(local_file, remote_path)
    
    # Entferne das lokale Bild, wenn der Upload erfolgreich war
    if upload_result == True:
        os.remove(local_file)
        return f"Image successfully uploaded to {remote_path}", seed
    else:
        return upload_result, seed

# App-Titel mit Modell- und Präzisionsinformationen
APP_TITLE = f"### Stable Diffusion - {os.path.basename(MODEL_REPO)} ({TORCH_DTYPE} auf {device})"

# Gradio-App
with gr.Blocks() as demo:
    gr.Markdown(APP_TITLE)
    prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
    width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Width")
    height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Height")
    guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=7.5, label="Guidance Scale")
    num_inference_steps = gr.Slider(1, 50, step=1, value=25, label="Inference Steps")
    seed = gr.Number(value=42, label="Seed")
    randomize_seed = gr.Checkbox(value=False, label="Randomize Seed")
    generate_button = gr.Button("Generate Image")
    output = gr.Text(label="Output")
    
    # Klick-Event für die Generierung
    generate_button.click(
        infer,
        inputs=[prompt, width, height, guidance_scale, num_inference_steps, seed, randomize_seed],
        outputs=[output, seed]
    )

demo.launch()