File size: 4,630 Bytes
c0490dd
 
 
da90319
c0490dd
7c9e811
c0490dd
 
 
 
 
e5efe2c
c0490dd
 
 
 
 
 
 
 
1d511db
c0490dd
 
 
 
 
 
 
1d511db
e5efe2c
 
1d511db
 
e5efe2c
 
a0c6111
e5efe2c
 
c0490dd
 
 
 
 
e5efe2c
c0490dd
 
 
 
cb3dd77
c0490dd
1d511db
c0490dd
 
 
bfd8827
 
c0490dd
 
bc47113
c0490dd
 
 
bc47113
 
c0490dd
 
bfd8827
c0490dd
1d511db
c0490dd
 
 
bfd8827
c0490dd
cdf02e9
c0490dd
1d511db
c0490dd
1d511db
 
 
 
c0490dd
 
1d511db
bc47113
1d511db
 
 
c0490dd
 
 
 
1d511db
c0490dd
 
 
bc47113
c0490dd
 
 
 
 
 
 
1d511db
bfd8827
1d511db
c0490dd
1d511db
 
c0490dd
1d511db
c0490dd
 
 
 
1d511db
c0490dd
 
 
1d511db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0490dd
 
1d511db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import logging
import random
import warnings
import os
import gradio as gr
import spaces
import numpy as np
import torch
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from PIL import Image
from huggingface_hub import snapshot_download

css = """
#col-container {
    margin: 0 auto;
    max-width: 512px;
}
"""

# Check for GPU availability
if torch.cuda.is_available():
    power_device = "GPU"
    device = "cuda"
else:
    power_device = "CPU"
    device = "cpu"

# Load HuggingFace model
huggingface_token = os.getenv("HUGGINFACE_TOKEN")
model_path = snapshot_download(
    repo_id="black-forest-labs/FLUX.1-dev",
    repo_type="model",
    ignore_patterns=["*.md", "*..gitattributes"],
    local_dir="FLUX.1-dev",
    token=huggingface_token,
)

# Load pipeline
controlnet = FluxControlNetModel.from_pretrained(
    "jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16
).to(device)
pipe = FluxControlNetPipeline.from_pretrained(
    model_path, controlnet=controlnet, torch_dtype=torch.bfloat16
)
pipe.to(device)

MAX_SEED = 1000000
MAX_PIXEL_BUDGET = 1024 * 1024

def process_input(input_image, upscale_factor):
    w, h = input_image.size
    w_original, h_original = w, h
    aspect_ratio = w / h
    was_resized = False

    if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
        warnings.warn(
            f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to ({int(aspect_ratio * MAX_PIXEL_BUDGET ** 0.5 // upscale_factor), int(MAX_PIXEL_BUDGET ** 0.5 // aspect_ratio // upscale_factor)}) pixels."
        )
        input_image = input_image.resize(
            (
                int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor),
                int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor),
            )
        )
        was_resized = True

    # Resize to multiple of 8
    w, h = input_image.size
    w = w - w % 8
    h = h - h % 8
    return input_image.resize((w, h)), w_original, h_original, was_resized

@spaces.GPU
def infer(
    seed, randomize_seed, input_image_path, num_inference_steps, upscale_factor, controlnet_conditioning_scale
):
    # Load image
    input_image = Image.open(input_image_path)
    
    # Handle random seed if specified
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    true_input_image = input_image
    input_image, w_original, h_original, was_resized = process_input(input_image, upscale_factor)
    
    # Rescale with upscale factor
    w, h = input_image.size
    control_image = input_image.resize((w * upscale_factor, h * upscale_factor))
    generator = torch.Generator().manual_seed(seed)

    # Upscale
    image = pipe(
        prompt="",
        control_image=control_image,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        num_inference_steps=num_inference_steps,
        guidance_scale=3.5,
        height=control_image.size[1],
        width=control_image.size[0],
        generator=generator,
    ).images[0]

    # Resize output if initially resized
    if was_resized:
        image = image.resize((w_original * upscale_factor, h_original * upscale_factor))
    image.save("output.jpg")
    
    return true_input_image, image, seed

# Gradio setup without ImageSlider
with gr.Blocks(css=css) as demo:
    gr.Markdown(
        f"""
    # ⚡ Flux.1-dev Upscaler ControlNet ⚡
    This is an interactive demo of [Flux.1-dev Upscaler ControlNet](https://huggingface.co/jasperai/Flux.1-dev-Controlnet-Upscaler).
    """
    )

    run_button = gr.Button(value="Run")
    input_im = gr.Image(label="Input Image", type="filepath")
    num_inference_steps = gr.Slider(label="Number of Inference Steps", minimum=8, maximum=50, step=1, value=28)
    upscale_factor = gr.Slider(label="Upscale Factor", minimum=1, maximum=4, step=1, value=4)
    controlnet_conditioning_scale = gr.Slider(label="Controlnet Conditioning Scale", minimum=0.1, maximum=1.5, step=0.1, value=0.6)
    seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
    randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
    
    input_image_display = gr.Image(label="Input Image Display")
    output_image_display = gr.Image(label="Upscaled Image Display")

    run_button.click(
        infer,
        inputs=[seed, randomize_seed, input_im, num_inference_steps, upscale_factor, controlnet_conditioning_scale],
        outputs=[input_image_display, output_image_display, gr.Textbox(label="Used Seed")]
    )

demo.queue().launch(share=False, show_api=True, show_error=True)