Spaces:
Sleeping
Sleeping
# float16 +32 | |
import os | |
import random | |
import numpy as np | |
import torch | |
import gradio as gr | |
from diffusers import StableDiffusionPipeline | |
import paramiko | |
from huggingface_hub import login | |
# Hugging Face Token | |
HF_TOKEN = os.getenv('HF_TOKEN', '').strip() | |
if not HF_TOKEN: | |
raise ValueError("HUGGING_TOKEN is not set. Please set the token as an environment variable.") | |
# Hugging Face Login | |
login(token=HF_TOKEN) | |
# Konfiguration | |
STORAGE_DOMAIN = os.getenv('STORAGE_DOMAIN', '').strip() # SFTP Server Domain | |
STORAGE_USER = os.getenv('STORAGE_USER', '').strip() # SFTP User | |
STORAGE_PSWD = os.getenv('STORAGE_PSWD', '').strip() # SFTP Passwort | |
STORAGE_PORT = int(os.getenv('STORAGE_PORT', '22').strip()) # SFTP Port | |
STORAGE_SECRET = os.getenv('STORAGE_SECRET', '').strip() # Secret Token | |
# Modell-Konfiguration | |
available_models = { | |
"sd3-medium": "stabilityai/stable-diffusion-3-medium-diffusers", | |
"sd2-base": "stabilityai/stable-diffusion-2-1-base" | |
} | |
# SFTP-Funktion | |
def upload_to_sftp(local_file, remote_path): | |
try: | |
transport = paramiko.Transport((STORAGE_DOMAIN, STORAGE_PORT)) | |
transport.connect(username=STORAGE_USER, password=STORAGE_PSWD) | |
sftp = paramiko.SFTPClient.from_transport(transport) | |
sftp.put(local_file, remote_path) | |
sftp.close() | |
transport.close() | |
print(f"File {local_file} successfully uploaded to {remote_path}") | |
return True | |
except Exception as e: | |
print(f"Error during SFTP upload: {e}") | |
return False | |
# Modell laden Funktion | |
def load_model(model_name, precision): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
repo = available_models.get(model_name, available_models["sd3-medium"]) | |
try: | |
# Wähle Präzision basierend auf Auswahl | |
if precision == "float16": | |
torch_dtype = torch.float16 | |
else: # float32 | |
torch_dtype = torch.float32 | |
pipe = StableDiffusionPipeline.from_pretrained( | |
repo, | |
torch_dtype=torch_dtype | |
).to(device) | |
# Wenn auf CPU und Speicheroptimierung gewünscht | |
if device == "cpu": | |
pipe.enable_sequential_cpu_offload() | |
return pipe | |
except Exception as e: | |
raise RuntimeError(f"Failed to load the model. Ensure the token has access to the repo. Error: {e}") | |
# Maximalwerte | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1344 | |
# Globale Pipe-Variable | |
pipe = None | |
# Inferenz-Funktion | |
def infer(prompt, width, height, guidance_scale, num_inference_steps, seed, randomize_seed, model_name, precision): | |
global pipe | |
# Prüfe, ob Modell neu geladen werden muss | |
if pipe is None: | |
pipe = load_model(model_name, precision) | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.manual_seed(seed) | |
image = pipe( | |
prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator | |
).images[0] | |
# Speichere Bild lokal | |
local_file = f"/tmp/generated_image_{seed}.png" | |
image.save(local_file) | |
# Hochladen zu SFTP | |
remote_path = f"/uploads/generated_image_{seed}.png" | |
if upload_to_sftp(local_file, remote_path): | |
os.remove(local_file) | |
return f"Image uploaded to {remote_path}", seed | |
else: | |
return "Failed to upload image", seed | |
# Modell neu laden | |
def reload_model(model_name, precision): | |
global pipe | |
pipe = load_model(model_name, precision) | |
return f"Model loaded: {model_name} with {precision} precision" | |
# Gradio-App | |
with gr.Blocks() as demo: | |
gr.Markdown("### Stable Diffusion - Test App") | |
with gr.Row(): | |
with gr.Column(): | |
# Modell Auswahl | |
model_name = gr.Radio( | |
choices=list(available_models.keys()), | |
value="sd3-medium", | |
label="Model" | |
) | |
# Präzision Auswahl | |
precision = gr.Radio( | |
choices=["float16", "float32"], | |
value="float16", | |
label="Precision" | |
) | |
reload_button = gr.Button("Load/Reload Model") | |
model_status = gr.Textbox(label="Model Status") | |
# Modell laden Button | |
reload_button.click( | |
reload_model, | |
inputs=[model_name, precision], | |
outputs=[model_status] | |
) | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here") | |
width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Width") | |
height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Height") | |
guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=7.5, label="Guidance Scale") | |
num_inference_steps = gr.Slider(1, 50, step=1, value=25, label="Inference Steps") | |
seed = gr.Number(value=42, label="Seed") | |
randomize_seed = gr.Checkbox(value=False, label="Randomize Seed") | |
generate_button = gr.Button("Generate Image") | |
output = gr.Text(label="Output") | |
generate_button.click( | |
infer, | |
inputs=[ | |
prompt, width, height, guidance_scale, | |
num_inference_steps, seed, randomize_seed, | |
model_name, precision | |
], | |
outputs=[output, seed] | |
) | |
demo.launch() |