t-montes's picture
Update app.py
7c9e811 verified
raw
history blame
4.63 kB
import logging
import random
import warnings
import os
import gradio as gr
import spaces
import numpy as np
import torch
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from PIL import Image
from huggingface_hub import snapshot_download
css = """
#col-container {
margin: 0 auto;
max-width: 512px;
}
"""
# Check for GPU availability
if torch.cuda.is_available():
power_device = "GPU"
device = "cuda"
else:
power_device = "CPU"
device = "cpu"
# Load HuggingFace model
huggingface_token = os.getenv("HUGGINFACE_TOKEN")
model_path = snapshot_download(
repo_id="black-forest-labs/FLUX.1-dev",
repo_type="model",
ignore_patterns=["*.md", "*..gitattributes"],
local_dir="FLUX.1-dev",
token=huggingface_token,
)
# Load pipeline
controlnet = FluxControlNetModel.from_pretrained(
"jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16
).to(device)
pipe = FluxControlNetPipeline.from_pretrained(
model_path, controlnet=controlnet, torch_dtype=torch.bfloat16
)
pipe.to(device)
MAX_SEED = 1000000
MAX_PIXEL_BUDGET = 1024 * 1024
def process_input(input_image, upscale_factor):
w, h = input_image.size
w_original, h_original = w, h
aspect_ratio = w / h
was_resized = False
if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
warnings.warn(
f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to ({int(aspect_ratio * MAX_PIXEL_BUDGET ** 0.5 // upscale_factor), int(MAX_PIXEL_BUDGET ** 0.5 // aspect_ratio // upscale_factor)}) pixels."
)
input_image = input_image.resize(
(
int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor),
int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor),
)
)
was_resized = True
# Resize to multiple of 8
w, h = input_image.size
w = w - w % 8
h = h - h % 8
return input_image.resize((w, h)), w_original, h_original, was_resized
@spaces.GPU
def infer(
seed, randomize_seed, input_image_path, num_inference_steps, upscale_factor, controlnet_conditioning_scale
):
# Load image
input_image = Image.open(input_image_path)
# Handle random seed if specified
if randomize_seed:
seed = random.randint(0, MAX_SEED)
true_input_image = input_image
input_image, w_original, h_original, was_resized = process_input(input_image, upscale_factor)
# Rescale with upscale factor
w, h = input_image.size
control_image = input_image.resize((w * upscale_factor, h * upscale_factor))
generator = torch.Generator().manual_seed(seed)
# Upscale
image = pipe(
prompt="",
control_image=control_image,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=num_inference_steps,
guidance_scale=3.5,
height=control_image.size[1],
width=control_image.size[0],
generator=generator,
).images[0]
# Resize output if initially resized
if was_resized:
image = image.resize((w_original * upscale_factor, h_original * upscale_factor))
image.save("output.jpg")
return true_input_image, image, seed
# Gradio setup without ImageSlider
with gr.Blocks(css=css) as demo:
gr.Markdown(
f"""
# ⚡ Flux.1-dev Upscaler ControlNet ⚡
This is an interactive demo of [Flux.1-dev Upscaler ControlNet](https://huggingface.co/jasperai/Flux.1-dev-Controlnet-Upscaler).
"""
)
run_button = gr.Button(value="Run")
input_im = gr.Image(label="Input Image", type="filepath")
num_inference_steps = gr.Slider(label="Number of Inference Steps", minimum=8, maximum=50, step=1, value=28)
upscale_factor = gr.Slider(label="Upscale Factor", minimum=1, maximum=4, step=1, value=4)
controlnet_conditioning_scale = gr.Slider(label="Controlnet Conditioning Scale", minimum=0.1, maximum=1.5, step=0.1, value=0.6)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
input_image_display = gr.Image(label="Input Image Display")
output_image_display = gr.Image(label="Upscaled Image Display")
run_button.click(
infer,
inputs=[seed, randomize_seed, input_im, num_inference_steps, upscale_factor, controlnet_conditioning_scale],
outputs=[input_image_display, output_image_display, gr.Textbox(label="Used Seed")]
)
demo.queue().launch(share=False, show_api=True, show_error=True)