Spaces:
Sleeping
Sleeping
File size: 5,566 Bytes
8f3793b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# float16 +32
import os
import random
import numpy as np
import torch
import gradio as gr
from diffusers import StableDiffusionPipeline
import paramiko
from huggingface_hub import login
# Hugging Face Token
HF_TOKEN = os.getenv('HF_TOKEN', '').strip()
if not HF_TOKEN:
raise ValueError("HUGGING_TOKEN is not set. Please set the token as an environment variable.")
# Hugging Face Login
login(token=HF_TOKEN)
# Konfiguration
STORAGE_DOMAIN = os.getenv('STORAGE_DOMAIN', '').strip() # SFTP Server Domain
STORAGE_USER = os.getenv('STORAGE_USER', '').strip() # SFTP User
STORAGE_PSWD = os.getenv('STORAGE_PSWD', '').strip() # SFTP Passwort
STORAGE_PORT = int(os.getenv('STORAGE_PORT', '22').strip()) # SFTP Port
STORAGE_SECRET = os.getenv('STORAGE_SECRET', '').strip() # Secret Token
# Modell-Konfiguration
available_models = {
"sd3-medium": "stabilityai/stable-diffusion-3-medium-diffusers",
"sd2-base": "stabilityai/stable-diffusion-2-1-base"
}
# SFTP-Funktion
def upload_to_sftp(local_file, remote_path):
try:
transport = paramiko.Transport((STORAGE_DOMAIN, STORAGE_PORT))
transport.connect(username=STORAGE_USER, password=STORAGE_PSWD)
sftp = paramiko.SFTPClient.from_transport(transport)
sftp.put(local_file, remote_path)
sftp.close()
transport.close()
print(f"File {local_file} successfully uploaded to {remote_path}")
return True
except Exception as e:
print(f"Error during SFTP upload: {e}")
return False
# Modell laden Funktion
def load_model(model_name, precision):
device = "cuda" if torch.cuda.is_available() else "cpu"
repo = available_models.get(model_name, available_models["sd3-medium"])
try:
# Wähle Präzision basierend auf Auswahl
if precision == "float16":
torch_dtype = torch.float16
else: # float32
torch_dtype = torch.float32
pipe = StableDiffusionPipeline.from_pretrained(
repo,
torch_dtype=torch_dtype
).to(device)
# Wenn auf CPU und Speicheroptimierung gewünscht
if device == "cpu":
pipe.enable_sequential_cpu_offload()
return pipe
except Exception as e:
raise RuntimeError(f"Failed to load the model. Ensure the token has access to the repo. Error: {e}")
# Maximalwerte
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1344
# Globale Pipe-Variable
pipe = None
# Inferenz-Funktion
def infer(prompt, width, height, guidance_scale, num_inference_steps, seed, randomize_seed, model_name, precision):
global pipe
# Prüfe, ob Modell neu geladen werden muss
if pipe is None:
pipe = load_model(model_name, precision)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.manual_seed(seed)
image = pipe(
prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
).images[0]
# Speichere Bild lokal
local_file = f"/tmp/generated_image_{seed}.png"
image.save(local_file)
# Hochladen zu SFTP
remote_path = f"/uploads/generated_image_{seed}.png"
if upload_to_sftp(local_file, remote_path):
os.remove(local_file)
return f"Image uploaded to {remote_path}", seed
else:
return "Failed to upload image", seed
# Modell neu laden
def reload_model(model_name, precision):
global pipe
pipe = load_model(model_name, precision)
return f"Model loaded: {model_name} with {precision} precision"
# Gradio-App
with gr.Blocks() as demo:
gr.Markdown("### Stable Diffusion - Test App")
with gr.Row():
with gr.Column():
# Modell Auswahl
model_name = gr.Radio(
choices=list(available_models.keys()),
value="sd3-medium",
label="Model"
)
# Präzision Auswahl
precision = gr.Radio(
choices=["float16", "float32"],
value="float16",
label="Precision"
)
reload_button = gr.Button("Load/Reload Model")
model_status = gr.Textbox(label="Model Status")
# Modell laden Button
reload_button.click(
reload_model,
inputs=[model_name, precision],
outputs=[model_status]
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Width")
height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Height")
guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=7.5, label="Guidance Scale")
num_inference_steps = gr.Slider(1, 50, step=1, value=25, label="Inference Steps")
seed = gr.Number(value=42, label="Seed")
randomize_seed = gr.Checkbox(value=False, label="Randomize Seed")
generate_button = gr.Button("Generate Image")
output = gr.Text(label="Output")
generate_button.click(
infer,
inputs=[
prompt, width, height, guidance_scale,
num_inference_steps, seed, randomize_seed,
model_name, precision
],
outputs=[output, seed]
)
demo.launch() |