FLUXllama / app.py
ginipick's picture
Update app.py
8398621 verified
raw
history blame
5.46 kB
import spaces
import gradio as gr
import random
import os
import time
import torch
from diffusers import FluxPipeline
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
DEFAULT_HEIGHT = 1024
DEFAULT_WIDTH = 1024
DEFAULT_GUIDANCE_SCALE = 3.5
DEFAULT_NUM_INFERENCE_STEPS = 15
DEFAULT_MAX_SEQUENCE_LENGTH = 512
HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
# Cache for the pipeline
CACHED_PIPE = None
def load_bnb_4bit_pipeline():
"""Load the 4-bit quantized pipeline"""
global CACHED_PIPE
if CACHED_PIPE is not None:
return CACHED_PIPE
print("Loading 4-bit BNB pipeline...")
MODEL_ID = "derekl35/FLUX.1-dev-nf4"
start_time = time.time()
try:
pipe = FluxPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
end_time = time.time()
mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
print(f"4-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
CACHED_PIPE = pipe
return pipe
except Exception as e:
print(f"Error loading 4-bit BNB pipeline: {e}")
raise
@spaces.GPU(duration=240)
def generate_image(prompt, progress=gr.Progress(track_tqdm=True)):
"""Generate image using 4-bit quantized model"""
if not prompt:
return None, "Please enter a prompt."
progress(0.2, desc="Loading 4-bit quantized model...")
try:
# Load the 4-bit pipeline
pipe = load_bnb_4bit_pipeline()
# Set up generation parameters
pipe_kwargs = {
"prompt": prompt,
"height": DEFAULT_HEIGHT,
"width": DEFAULT_WIDTH,
"guidance_scale": DEFAULT_GUIDANCE_SCALE,
"num_inference_steps": DEFAULT_NUM_INFERENCE_STEPS,
"max_sequence_length": DEFAULT_MAX_SEQUENCE_LENGTH,
}
# Generate seed
seed = random.getrandbits(64)
print(f"Using seed: {seed}")
progress(0.5, desc="Generating image...")
# Generate image
gen_start_time = time.time()
image = pipe(**pipe_kwargs, generator=torch.manual_seed(seed)).images[0]
gen_end_time = time.time()
print(f"Image generated in {gen_end_time - gen_start_time:.2f} seconds")
mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
print(f"Memory reserved: {mem_reserved:.2f} GB")
return image, f"Generation complete! (Seed: {seed})"
except Exception as e:
print(f"Error during generation: {e}")
return None, f"Error: {e}"
# Create Gradio interface
with gr.Blocks(title="FLUXllama", theme=gr.themes.Soft()) as demo:
gr.HTML(
"""
<div style='text-align: center; margin-bottom: 20px;'>
<h1>FLUXllama</h1>
<p>FLUX.1-dev 4-bit Quantized Version</p>
</div>
"""
)
gr.HTML(
"""
<div class='container' style='display:flex; justify-content:center; gap:12px; margin-bottom: 20px;'>
<a href="https://huggingface.co/spaces/openfree/Best-AI" target="_blank">
<img src="https://img.shields.io/static/v1?label=OpenFree&message=BEST%20AI%20Services&color=%230000ff&labelColor=%23000080&logo=huggingface&logoColor=%23ffa500&style=for-the-badge" alt="OpenFree badge">
</a>
<a href="https://discord.gg/openfreeai" target="_blank">
<img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="Discord badge">
</a>
</div>
"""
)
with gr.Row():
prompt_input = gr.Textbox(
label="Enter your prompt",
placeholder="e.g., A photorealistic portrait of an astronaut on Mars",
lines=2,
scale=4
)
generate_button = gr.Button("Generate", variant="primary", scale=1)
output_image = gr.Image(
label="Generated Image (4-bit Quantized)",
type="pil",
height=600
)
status_text = gr.Textbox(
label="Status",
interactive=False,
lines=1
)
# Connect components
generate_button.click(
fn=generate_image,
inputs=[prompt_input],
outputs=[output_image, status_text]
)
# Enter key to submit
prompt_input.submit(
fn=generate_image,
inputs=[prompt_input],
outputs=[output_image, status_text]
)
# Example prompts
gr.Examples(
examples=[
"A photorealistic portrait of an astronaut on Mars",
"Water-color painting of a cat wearing sunglasses",
"Neo-tokyo cyberpunk cityscape at night, rain-soaked streets, 8K",
"A majestic dragon flying over a medieval castle at sunset",
"Abstract art representing the concept of time and space",
"Detailed oil painting of a steampunk clockwork city",
"Underwater scene with bioluminescent creatures in deep ocean",
"Japanese garden in autumn with falling maple leaves"
],
inputs=prompt_input
)
if __name__ == "__main__":
demo.launch(share=True)