|
import torch |
|
|
|
from optimum.quanto import freeze, qfloat8, quantize |
|
from transformers.modeling_utils import PreTrainedModel |
|
from diffusers import AutoencoderTiny |
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel |
|
from diffusers.pipelines.flux.pipeline_flux_img2img import FluxImg2ImgPipeline |
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast |
|
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL |
|
|
|
|
|
from pruna import smash, SmashConfig |
|
from pruna.telemetry import set_telemetry_metrics |
|
|
|
set_telemetry_metrics(False) |
|
set_telemetry_metrics(False, set_as_default=True) |
|
|
|
|
|
try: |
|
import intel_extension_for_pytorch as ipex |
|
except: |
|
pass |
|
|
|
import psutil |
|
from config import Args |
|
from pydantic import BaseModel, Field |
|
from PIL import Image |
|
from pathlib import Path |
|
from util import ParamsModel |
|
import math |
|
import gc |
|
|
|
|
|
|
|
model_path = "black-forest-labs/FLUX.1-schnell" |
|
base_model_path = "black-forest-labs/FLUX.1-schnell" |
|
taesd_path = "madebyollin/taef1" |
|
subfolder = "transformer" |
|
transformer_path = model_path |
|
models_path = Path("models") |
|
|
|
default_prompt = "close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm summilux" |
|
default_negative_prompt = "blurry, low quality, render, 3D, oversaturated" |
|
page_content = """ |
|
<h1 class="text-3xl font-bold">Real-Time FLUX</h1> |
|
|
|
""" |
|
|
|
|
|
def flush(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
class Pipeline: |
|
class Info(BaseModel): |
|
name: str = "img2img" |
|
title: str = "Image-to-Image SDXL" |
|
description: str = "Generates an image from a text prompt" |
|
input_mode: str = "image" |
|
page_content: str = page_content |
|
|
|
class InputParams(ParamsModel): |
|
prompt: str = Field( |
|
default_prompt, |
|
title="Prompt", |
|
field="textarea", |
|
id="prompt", |
|
) |
|
seed: int = Field( |
|
2159232, min=0, title="Seed", field="seed", hide=True, id="seed" |
|
) |
|
steps: int = Field( |
|
1, min=1, max=15, title="Steps", field="range", hide=True, id="steps" |
|
) |
|
width: int = Field( |
|
1024, min=2, max=15, title="Width", disabled=True, hide=True, id="width" |
|
) |
|
height: int = Field( |
|
1024, min=2, max=15, title="Height", disabled=True, hide=True, id="height" |
|
) |
|
strength: float = Field( |
|
0.5, |
|
min=0.25, |
|
max=1.0, |
|
step=0.001, |
|
title="Strength", |
|
field="range", |
|
hide=True, |
|
id="strength", |
|
) |
|
guidance: float = Field( |
|
3.5, |
|
min=0, |
|
max=20, |
|
step=0.001, |
|
title="Guidance", |
|
hide=True, |
|
field="range", |
|
id="guidance", |
|
) |
|
|
|
def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype): |
|
|
|
|
|
|
|
print("Loading model") |
|
|
|
model_id = "black-forest-labs/FLUX.1-schnell" |
|
model_revision = "refs/pr/1" |
|
text_model_id = "openai/clip-vit-large-patch14" |
|
model_data_type = torch.bfloat16 |
|
tokenizer = CLIPTokenizer.from_pretrained( |
|
text_model_id, torch_dtype=model_data_type |
|
) |
|
text_encoder = CLIPTextModel.from_pretrained( |
|
text_model_id, torch_dtype=model_data_type |
|
) |
|
|
|
|
|
tokenizer_2 = T5TokenizerFast.from_pretrained( |
|
model_id, |
|
subfolder="tokenizer_2", |
|
torch_dtype=model_data_type, |
|
revision=model_revision, |
|
) |
|
text_encoder_2 = T5EncoderModel.from_pretrained( |
|
model_id, |
|
subfolder="text_encoder_2", |
|
torch_dtype=model_data_type, |
|
revision=model_revision, |
|
) |
|
|
|
|
|
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( |
|
model_id, subfolder="scheduler", revision=model_revision |
|
) |
|
transformer = FluxTransformer2DModel.from_pretrained( |
|
model_id, |
|
subfolder="transformer", |
|
torch_dtype=model_data_type, |
|
revision=model_revision, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vae = AutoencoderTiny.from_pretrained( |
|
"madebyollin/taef1", torch_dtype=torch.bfloat16 |
|
) |
|
|
|
|
|
smash_config = SmashConfig() |
|
smash_config["quantizer"] = "quanto" |
|
smash_config["quanto_calibrate"] = False |
|
smash_config["quanto_weight_bits"] = "qint4" |
|
|
|
|
|
|
|
|
|
transformer = smash( |
|
model=transformer, |
|
smash_config=smash_config, |
|
) |
|
text_encoder_2 = smash( |
|
model=text_encoder_2, |
|
smash_config=smash_config, |
|
) |
|
|
|
pipe = FluxImg2ImgPipeline( |
|
scheduler=scheduler, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
text_encoder_2=text_encoder_2, |
|
tokenizer_2=tokenizer_2, |
|
vae=vae, |
|
transformer=transformer, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
pipe.text_encoder.to(device) |
|
pipe.vae.to(device) |
|
pipe.transformer.to(device) |
|
pipe.text_encoder_2.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.pipe = pipe |
|
self.pipe.set_progress_bar_config(disable=True) |
|
|
|
|
|
|
|
|
|
def predict(self, params: "Pipeline.InputParams") -> Image.Image: |
|
generator = torch.manual_seed(params.seed) |
|
steps = params.steps |
|
strength = params.strength |
|
prompt = params.prompt |
|
guidance = params.guidance |
|
|
|
results = self.pipe( |
|
image=params.image, |
|
prompt=prompt, |
|
generator=generator, |
|
strength=strength, |
|
num_inference_steps=steps, |
|
guidance_scale=guidance, |
|
width=params.width, |
|
height=params.height, |
|
) |
|
return results.images[0] |
|
|