SD3-dev1 / app.py
Alibrown's picture
Create app.py
8f3793b verified
raw
history blame
5.57 kB
# float16 +32
import os
import random
import numpy as np
import torch
import gradio as gr
from diffusers import StableDiffusionPipeline
import paramiko
from huggingface_hub import login
# Hugging Face Token
HF_TOKEN = os.getenv('HF_TOKEN', '').strip()
if not HF_TOKEN:
raise ValueError("HUGGING_TOKEN is not set. Please set the token as an environment variable.")
# Hugging Face Login
login(token=HF_TOKEN)
# Konfiguration
STORAGE_DOMAIN = os.getenv('STORAGE_DOMAIN', '').strip() # SFTP Server Domain
STORAGE_USER = os.getenv('STORAGE_USER', '').strip() # SFTP User
STORAGE_PSWD = os.getenv('STORAGE_PSWD', '').strip() # SFTP Passwort
STORAGE_PORT = int(os.getenv('STORAGE_PORT', '22').strip()) # SFTP Port
STORAGE_SECRET = os.getenv('STORAGE_SECRET', '').strip() # Secret Token
# Modell-Konfiguration
available_models = {
"sd3-medium": "stabilityai/stable-diffusion-3-medium-diffusers",
"sd2-base": "stabilityai/stable-diffusion-2-1-base"
}
# SFTP-Funktion
def upload_to_sftp(local_file, remote_path):
try:
transport = paramiko.Transport((STORAGE_DOMAIN, STORAGE_PORT))
transport.connect(username=STORAGE_USER, password=STORAGE_PSWD)
sftp = paramiko.SFTPClient.from_transport(transport)
sftp.put(local_file, remote_path)
sftp.close()
transport.close()
print(f"File {local_file} successfully uploaded to {remote_path}")
return True
except Exception as e:
print(f"Error during SFTP upload: {e}")
return False
# Modell laden Funktion
def load_model(model_name, precision):
device = "cuda" if torch.cuda.is_available() else "cpu"
repo = available_models.get(model_name, available_models["sd3-medium"])
try:
# Wähle Präzision basierend auf Auswahl
if precision == "float16":
torch_dtype = torch.float16
else: # float32
torch_dtype = torch.float32
pipe = StableDiffusionPipeline.from_pretrained(
repo,
torch_dtype=torch_dtype
).to(device)
# Wenn auf CPU und Speicheroptimierung gewünscht
if device == "cpu":
pipe.enable_sequential_cpu_offload()
return pipe
except Exception as e:
raise RuntimeError(f"Failed to load the model. Ensure the token has access to the repo. Error: {e}")
# Maximalwerte
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1344
# Globale Pipe-Variable
pipe = None
# Inferenz-Funktion
def infer(prompt, width, height, guidance_scale, num_inference_steps, seed, randomize_seed, model_name, precision):
global pipe
# Prüfe, ob Modell neu geladen werden muss
if pipe is None:
pipe = load_model(model_name, precision)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.manual_seed(seed)
image = pipe(
prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
).images[0]
# Speichere Bild lokal
local_file = f"/tmp/generated_image_{seed}.png"
image.save(local_file)
# Hochladen zu SFTP
remote_path = f"/uploads/generated_image_{seed}.png"
if upload_to_sftp(local_file, remote_path):
os.remove(local_file)
return f"Image uploaded to {remote_path}", seed
else:
return "Failed to upload image", seed
# Modell neu laden
def reload_model(model_name, precision):
global pipe
pipe = load_model(model_name, precision)
return f"Model loaded: {model_name} with {precision} precision"
# Gradio-App
with gr.Blocks() as demo:
gr.Markdown("### Stable Diffusion - Test App")
with gr.Row():
with gr.Column():
# Modell Auswahl
model_name = gr.Radio(
choices=list(available_models.keys()),
value="sd3-medium",
label="Model"
)
# Präzision Auswahl
precision = gr.Radio(
choices=["float16", "float32"],
value="float16",
label="Precision"
)
reload_button = gr.Button("Load/Reload Model")
model_status = gr.Textbox(label="Model Status")
# Modell laden Button
reload_button.click(
reload_model,
inputs=[model_name, precision],
outputs=[model_status]
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Width")
height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Height")
guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=7.5, label="Guidance Scale")
num_inference_steps = gr.Slider(1, 50, step=1, value=25, label="Inference Steps")
seed = gr.Number(value=42, label="Seed")
randomize_seed = gr.Checkbox(value=False, label="Randomize Seed")
generate_button = gr.Button("Generate Image")
output = gr.Text(label="Output")
generate_button.click(
infer,
inputs=[
prompt, width, height, guidance_scale,
num_inference_steps, seed, randomize_seed,
model_name, precision
],
outputs=[output, seed]
)
demo.launch()