Spaces:
Sleeping
Sleeping
File size: 4,572 Bytes
43a3c7f 663e209 43a3c7f 56a7978 43a3c7f 663e209 bf4853f 43a3c7f 663e209 43a3c7f 56a7978 ee4120c 56a7978 43a3c7f b6b5406 bf4853f 663e209 bf4853f 663e209 bf4853f 56a7978 ee4120c 663e209 43a3c7f 663e209 bf4853f 663e209 43a3c7f 663e209 43a3c7f bf4853f 56a7978 663e209 43a3c7f 663e209 43a3c7f bf4853f 43a3c7f bf4853f 43a3c7f 663e209 bf4853f 663e209 43a3c7f 663e209 43a3c7f bf4853f 43a3c7f 663e209 43a3c7f 663e209 43a3c7f bf4853f 56a7978 663e209 bf4853f 663e209 43a3c7f 56a7978 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import DiffusionPipeline
dtype = torch.bfloat16
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print("MPS not available because the current PyTorch install was not "
"built with MPS enabled.")
device = "mps"
else:
device = "cpu"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
# Initialize the pipeline globally
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)
lora_weights = {
"cajerky": {"path": "bryanbrunetti/cajerky"}
}
@spaces.GPU(duration=120)
def infer(prompt, lora_models, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0,
num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
global pipe
# Load LoRAs if specified
if lora_models:
try:
for lora_model in lora_models:
print(f"loading LoRA: {lora_model}")
pipe.load_lora_weights(lora_weights[lora_model]["path"])
except Exception as e:
return None, seed, f"Failed to load LoRA model: {str(e)}"
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
try:
image = pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=guidance_scale
).images[0]
# Unload LoRA weights after generation
if lora_models:
pipe.unload_lora_weights()
return image, seed, "Image generated successfully."
except Exception as e:
return None, seed, f"Error during image generation: {str(e)}"
css = """
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
# lora_model = gr.Text(
# label="LoRA Model ID (optional)",
# placeholder="Enter Hugging Face LoRA model ID",
# )
lora_models = gr.Dropdown(list(lora_weights.keys()), multiselect=True,
info="Load LoRA (optional) use the name in the prompt", label="Choose LoRAs")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
info="How close to follow prompt",
minimum=1,
maximum=15,
step=0.1,
value=3.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
info="higher = more details",
minimum=1,
maximum=50,
step=1,
value=28,
)
output_message = gr.Textbox(label="Output Message")
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, lora_models, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs=[result, seed, output_message]
)
demo.launch()
|