Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import random | |
import os | |
import time | |
import torch | |
from diffusers import FluxPipeline | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {DEVICE}") | |
DEFAULT_HEIGHT = 1024 | |
DEFAULT_WIDTH = 1024 | |
DEFAULT_GUIDANCE_SCALE = 3.5 | |
DEFAULT_NUM_INFERENCE_STEPS = 15 | |
DEFAULT_MAX_SEQUENCE_LENGTH = 512 | |
HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN") | |
# Cache for the pipeline | |
CACHED_PIPE = None | |
def load_bnb_4bit_pipeline(): | |
"""Load the 4-bit quantized pipeline""" | |
global CACHED_PIPE | |
if CACHED_PIPE is not None: | |
return CACHED_PIPE | |
print("Loading 4-bit BNB pipeline...") | |
MODEL_ID = "derekl35/FLUX.1-dev-nf4" | |
start_time = time.time() | |
try: | |
pipe = FluxPipeline.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.bfloat16 | |
) | |
pipe.enable_model_cpu_offload() | |
end_time = time.time() | |
mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 | |
print(f"4-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB") | |
CACHED_PIPE = pipe | |
return pipe | |
except Exception as e: | |
print(f"Error loading 4-bit BNB pipeline: {e}") | |
raise | |
def generate_image(prompt, progress=gr.Progress(track_tqdm=True)): | |
"""Generate image using 4-bit quantized model""" | |
if not prompt: | |
return None, "Please enter a prompt." | |
progress(0.2, desc="Loading 4-bit quantized model...") | |
try: | |
# Load the 4-bit pipeline | |
pipe = load_bnb_4bit_pipeline() | |
# Set up generation parameters | |
pipe_kwargs = { | |
"prompt": prompt, | |
"height": DEFAULT_HEIGHT, | |
"width": DEFAULT_WIDTH, | |
"guidance_scale": DEFAULT_GUIDANCE_SCALE, | |
"num_inference_steps": DEFAULT_NUM_INFERENCE_STEPS, | |
"max_sequence_length": DEFAULT_MAX_SEQUENCE_LENGTH, | |
} | |
# Generate seed | |
seed = random.getrandbits(64) | |
print(f"Using seed: {seed}") | |
progress(0.5, desc="Generating image...") | |
# Generate image | |
gen_start_time = time.time() | |
image = pipe(**pipe_kwargs, generator=torch.manual_seed(seed)).images[0] | |
gen_end_time = time.time() | |
print(f"Image generated in {gen_end_time - gen_start_time:.2f} seconds") | |
mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 | |
print(f"Memory reserved: {mem_reserved:.2f} GB") | |
return image, f"Generation complete! (Seed: {seed})" | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
return None, f"Error: {e}" | |
# Create Gradio interface | |
with gr.Blocks(title="FLUXllama", theme=gr.themes.Soft()) as demo: | |
gr.HTML( | |
""" | |
<div style='text-align: center; margin-bottom: 20px;'> | |
<h1>FLUXllama</h1> | |
<p>FLUX.1-dev 4-bit Quantized Version</p> | |
</div> | |
""" | |
) | |
gr.HTML( | |
""" | |
<div class='container' style='display:flex; justify-content:center; gap:12px; margin-bottom: 20px;'> | |
<a href="https://huggingface.co/spaces/openfree/Best-AI" target="_blank"> | |
<img src="https://img.shields.io/static/v1?label=OpenFree&message=BEST%20AI%20Services&color=%230000ff&labelColor=%23000080&logo=huggingface&logoColor=%23ffa500&style=for-the-badge" alt="OpenFree badge"> | |
</a> | |
<a href="https://discord.gg/openfreeai" target="_blank"> | |
<img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="Discord badge"> | |
</a> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
prompt_input = gr.Textbox( | |
label="Enter your prompt", | |
placeholder="e.g., A photorealistic portrait of an astronaut on Mars", | |
lines=2, | |
scale=4 | |
) | |
generate_button = gr.Button("Generate", variant="primary", scale=1) | |
output_image = gr.Image( | |
label="Generated Image (4-bit Quantized)", | |
type="pil", | |
height=600 | |
) | |
status_text = gr.Textbox( | |
label="Status", | |
interactive=False, | |
lines=1 | |
) | |
# Connect components | |
generate_button.click( | |
fn=generate_image, | |
inputs=[prompt_input], | |
outputs=[output_image, status_text] | |
) | |
# Enter key to submit | |
prompt_input.submit( | |
fn=generate_image, | |
inputs=[prompt_input], | |
outputs=[output_image, status_text] | |
) | |
# Example prompts | |
gr.Examples( | |
examples=[ | |
"A photorealistic portrait of an astronaut on Mars", | |
"Water-color painting of a cat wearing sunglasses", | |
"Neo-tokyo cyberpunk cityscape at night, rain-soaked streets, 8K", | |
"A majestic dragon flying over a medieval castle at sunset", | |
"Abstract art representing the concept of time and space", | |
"Detailed oil painting of a steampunk clockwork city", | |
"Underwater scene with bioluminescent creatures in deep ocean", | |
"Japanese garden in autumn with falling maple leaves" | |
], | |
inputs=prompt_input | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |