# backend.py import torch from diffusers import ( DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image, ) from flux_app.config import DTYPE, DEVICE, BASE_MODEL, TAEF1_MODEL, MAX_SEED # Absolute import from flux_app.utilities import calculate_shift, retrieve_timesteps, load_image_from_path, calculateDuration # Absolute import from flux_app.lora_handling import flux_pipe_call_that_returns_an_iterable_of_images # Absolute import import time class ModelManager: def __init__(self): self.pipe = None self.pipe_i2i = None self.good_vae = None self.taef1 = None self.initialize_models() def initialize_models(self): """Initializes the diffusion pipelines and autoencoders.""" self.taef1 = AutoencoderTiny.from_pretrained(TAEF1_MODEL, torch_dtype=DTYPE).to(DEVICE) self.good_vae = AutoencoderKL.from_pretrained(BASE_MODEL, subfolder="vae", torch_dtype=DTYPE).to(DEVICE) self.pipe = DiffusionPipeline.from_pretrained(BASE_MODEL, torch_dtype=DTYPE, vae=self.taef1).to(DEVICE) self.pipe_i2i = AutoPipelineForImage2Image.from_pretrained( BASE_MODEL, vae=self.good_vae, transformer=self.pipe.transformer, text_encoder=self.pipe.text_encoder, tokenizer=self.pipe.tokenizer, text_encoder_2=self.pipe.text_encoder_2, tokenizer_2=self.pipe.tokenizer_2, torch_dtype=DTYPE, ) self.pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(self.pipe) def generate_image(self, prompt_mash, steps, seed, cfg_scale, width, height, lora_scale): """Generates an image using the text-to-image pipeline.""" self.pipe.to(DEVICE) generator = torch.Generator(device=DEVICE).manual_seed(seed) with calculateDuration("Generating image"): for img in self.pipe.flux_pipe_call_that_returns_an_iterable_of_images( prompt=prompt_mash, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, joint_attention_kwargs={"scale": lora_scale}, output_type="pil", good_vae=self.good_vae, ): yield img def generate_image_to_image(self, prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed): """Generates an image using the image-to-image pipeline.""" generator = torch.Generator(device=DEVICE).manual_seed(seed) self.pipe_i2i.to(DEVICE) image_input = load_image_from_path(image_input_path) with calculateDuration("Generating image to image"): final_image = self.pipe_i2i( prompt=prompt_mash, image=image_input, strength=image_strength, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, joint_attention_kwargs={"scale": lora_scale}, output_type="pil", ).images[0] return final_image