HexaGrid / utils /ai_generator_diffusers_flux.py
Surn's picture
Working Version with negative prompts and dynamic trigger words
650c805
raw
history blame
11 kB
# utils/ai_generator_diffusers_flux.py
import os
import torch
import accelerate
import transformers
import safetensors
import xformers
from diffusers import FluxPipeline
from diffusers.utils import load_image
# from huggingface_hub import hf_hub_download
from PIL import Image
from tempfile import NamedTemporaryFile
from src.condition import Condition
import utils.constants as constants
from utils.image_utils import (
crop_and_resize_image,
)
from utils.version_info import (
versions_html,
get_torch_info,
get_diffusers_version,
get_transformers_version,
get_xformers_version
)
from utils.lora_details import get_trigger_words
from utils.color_utils import detect_color_format
# import utils.misc as misc
from pathlib import Path
import warnings
warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")
#print(torch.__version__) # Ensure it's 2.0 or newer
#print(torch.cuda.is_available()) # Ensure CUDA is available
def generate_image_from_text(
text,
model_name="black-forest-labs/FLUX.1-dev",
lora_weights=None,
conditioned_image=None,
image_width=1344,
image_height=848,
guidance_scale=3.5,
num_inference_steps=50,
seed=0,
additional_parameters=None
):
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device:{device}\nmodel_name:{model_name}\n")
pipe = FluxPipeline.from_pretrained(
model_name,
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
).to(device)
pipe = pipe.to(device)
pipe.enable_model_cpu_offload()
# Load and apply LoRA weights
if lora_weights:
for lora_weight in lora_weights:
lora_configs = constants.LORA_DETAILS.get(lora_weight, [])
if lora_configs:
for config in lora_configs:
weight_name = config.get("weight_name")
adapter_name = config.get("adapter_name")
pipe.load_lora_weights(
lora_weight,
weight_name=weight_name,
adapter_name=adapter_name,
use_auth_token=constants.HF_API_TOKEN
)
else:
pipe.load_lora_weights(lora_weight, use_auth_token=constants.HF_API_TOKEN)
generator = torch.Generator(device=device).manual_seed(seed)
conditions = []
if conditioned_image is not None:
conditioned_image = crop_and_resize_image(conditioned_image, 1024, 1024)
condition = Condition("subject", conditioned_image)
conditions.append(condition)
generate_params = {
"prompt": text,
"height": image_height,
"width": image_width,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"generator": generator,
"conditions": conditions if conditions else None
}
if additional_parameters:
generate_params.update(additional_parameters)
generate_params = {k: v for k, v in generate_params.items() if v is not None}
result = pipe(**generate_params)
image = result.images[0]
return image
def generate_image_lowmem(
text,
neg_prompt=None,
model_name="black-forest-labs/FLUX.1-dev",
lora_weights=None,
conditioned_image=None,
image_width=1344,
image_height=848,
guidance_scale=3.5,
num_inference_steps=50,
seed=0,
true_cfg_scale=1.0,
additional_parameters=None
):
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device:{device}\nmodel_name:{model_name}\n")
print(f"\n {get_torch_info()}\n")
# Disable gradient calculations
with torch.no_grad():
# Initialize the pipeline inside the context manager
pipe = FluxPipeline.from_pretrained(
model_name,
torch_dtype=torch.bfloat16 if device == "cuda" else torch.bfloat32
).to(device)
# Optionally, don't use CPU offload if not necessary
pipe.enable_model_cpu_offload()
# alternative version that may be more efficient
# pipe.enable_sequential_cpu_offload()
flash_attention_enabled = torch.backends.cuda.flash_sdp_enabled()
if flash_attention_enabled == False:
#Enable xFormers memory-efficient attention (optional)
pipe.enable_xformers_memory_efficient_attention()
print("\nEnabled xFormers memory-efficient attention.\n")
else:
pipe.attn_implementation="flash_attention_2"
print("\nEnabled flash_attention_2.\n")
pipe.enable_vae_tiling()
# Load LoRA weights
if lora_weights:
for lora_weight in lora_weights:
lora_configs = constants.LORA_DETAILS.get(lora_weight, [])
if lora_configs:
for config in lora_configs:
# Load LoRA weights with optional weight_name and adapter_name
weight_name = config.get("weight_name")
adapter_name = config.get("adapter_name")
if weight_name and adapter_name:
pipe.load_lora_weights(
lora_weight,
weight_name=weight_name,
adapter_name=adapter_name,
use_auth_token=constants.HF_API_TOKEN
)
else:
pipe.load_lora_weights(
lora_weight,
use_auth_token=constants.HF_API_TOKEN
)
# Apply 'pipe' configurations if present
if 'pipe' in config:
pipe_config = config['pipe']
for method_name, params in pipe_config.items():
method = getattr(pipe, method_name, None)
if method:
print(f"Applying pipe method: {method_name} with params: {params}")
method(**params)
else:
print(f"Method {method_name} not found in pipe.")
else:
pipe.load_lora_weights(lora_weight, use_auth_token=constants.HF_API_TOKEN)
generator = torch.Generator(device=device).manual_seed(seed)
conditions = []
if conditioned_image is not None:
conditioned_image = crop_and_resize_image(conditioned_image, 1024, 1024)
condition = Condition("subject", conditioned_image)
conditions.append(condition)
if neg_prompt!=None:
true_cfg_scale=1.1
generate_params = {
"prompt": text,
"negative_prompt": neg_prompt,
"true_cfg_scale": true_cfg_scale,
"height": image_height,
"width": image_width,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"generator": generator,
"conditions": conditions if conditions else None
}
if additional_parameters:
generate_params.update(additional_parameters)
generate_params = {k: v for k, v in generate_params.items() if v is not None}
# Generate the image
result = pipe(**generate_params)
image = result.images[0]
# Clean up
del result
del conditions
del generator
# Delete the pipeline and clear cache
del pipe
torch.cuda.empty_cache()
print(torch.cuda.memory_summary(device=None, abbreviated=False))
return image
def generate_ai_image_local (
map_option,
prompt_textbox_value,
neg_prompt_textbox_value,
model="black-forest-labs/FLUX.1-dev",
lora_weights=None,
conditioned_image=None,
height=512,
width=896,
num_inference_steps=50,
guidance_scale=3.5,
seed=777
):
try:
if map_option != "Prompt":
prompt = constants.PROMPTS[map_option]
negative_prompt = constants.NEGATIVE_PROMPTS.get(map_option, "")
else:
prompt = prompt_textbox_value
negative_prompt = neg_prompt_textbox_value or ""
#full_prompt = f"{prompt} {negative_prompt}"
additional_parameters = {}
if lora_weights:
for lora_weight in lora_weights:
lora_configs = constants.LORA_DETAILS.get(lora_weight, [])
for config in lora_configs:
if 'parameters' in config:
additional_parameters.update(config['parameters'])
elif 'trigger_words' in config:
trigger_words = get_trigger_words(lora_weight)
prompt = f"{trigger_words} {prompt}"
for key, value in additional_parameters.items():
if key in ['height', 'width', 'num_inference_steps', 'max_sequence_length']:
additional_parameters[key] = int(value)
elif key in ['guidance_scale','true_cfg_scale']:
additional_parameters[key] = float(value)
height = additional_parameters.get('height', height)
width = additional_parameters.get('width', width)
num_inference_steps = additional_parameters.get('num_inference_steps', num_inference_steps)
guidance_scale = additional_parameters.get('guidance_scale', guidance_scale)
print("Generating image with the following parameters:")
print(f"Model: {model}")
print(f"LoRA Weights: {lora_weights}")
print(f"Prompt: {prompt}")
print(f"Neg Prompt: {negative_prompt}")
print(f"Height: {height}")
print(f"Width: {width}")
print(f"Number of Inference Steps: {num_inference_steps}")
print(f"Guidance Scale: {guidance_scale}")
print(f"Seed: {seed}")
print(f"Additional Parameters: {additional_parameters}")
image = generate_image_lowmem(
text=prompt,
model_name=model,
neg_prompt=negative_prompt,
lora_weights=lora_weights,
conditioned_image=conditioned_image,
image_width=width,
image_height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
seed=seed,
additional_parameters=additional_parameters
)
with NamedTemporaryFile(delete=False, suffix=".png") as tmp:
image.save(tmp.name, format="PNG")
constants.temp_files.append(tmp.name)
print(f"Image saved to {tmp.name}")
return tmp.name
except Exception as e:
print(f"Error generating AI image: {e}")
return None