File size: 5,514 Bytes
6612d96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef8d3f1
b141402
2a93323
b141402
6612d96
 
cd78290
6612d96
 
 
 
 
 
ef8d3f1
 
6612d96
 
 
cd78290
6612d96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfbaa8a
 
 
7afee21
6612d96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd78290
6612d96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef8d3f1
 
6612d96
 
2dbbced
6612d96
ef8d3f1
6612d96
 
2dbbced
7afee21
6612d96
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
import numpy as np

import spaces
import torch
import spaces
import random 

from diffusers import AutoPipelineForText2Image
from PIL import Image


MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

pipe = AutoPipelineForText2Image.from_pretrained(
    "ostris/Flex.2-preview",
    custom_pipeline="pipeline.py",
    torch_dtype=torch.bfloat16,
).to("cuda")

# def calculate_optimal_dimensions(image: Image.Image):
#     # Extract the original dimensions
#     original_width, original_height = image.size
    
#     # Set constants
#     MIN_ASPECT_RATIO = 9 / 16
#     MAX_ASPECT_RATIO = 16 / 9
#     FIXED_DIMENSION = 1024

#     # Calculate the aspect ratio of the original image
#     original_aspect_ratio = original_width / original_height

#     # Determine which dimension to fix
#     if original_aspect_ratio > 1:  # Wider than tall
#         width = FIXED_DIMENSION
#         height = round(FIXED_DIMENSION / original_aspect_ratio)
#     else:  # Taller than wide
#         height = FIXED_DIMENSION
#         width = round(FIXED_DIMENSION * original_aspect_ratio)

#     # Ensure dimensions are multiples of 8
#     width = (width // 8) * 8
#     height = (height // 8) * 8

#     # Enforce aspect ratio limits
#     calculated_aspect_ratio = width / height
#     if calculated_aspect_ratio > MAX_ASPECT_RATIO:
#         width = (height * MAX_ASPECT_RATIO // 8) * 8
#     elif calculated_aspect_ratio < MIN_ASPECT_RATIO:
#         height = (width / MIN_ASPECT_RATIO // 8) * 8

#     # Ensure width and height remain above the minimum dimensions
#     width = max(width, 576) if width == FIXED_DIMENSION else width
#     height = max(height, 576) if height == FIXED_DIMENSION else height

#     return width, height

@spaces.GPU
def infer(edit_images, prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5,control_strength=0.5, control_stop=0.33, num_inference_steps=50, progress=gr.Progress(track_tqdm=True)):
    image = edit_images["background"].convert("RGB")
    # width, height = calculate_optimal_dimensions(image)
    mask = edit_images["layers"][0].convert("RGB")
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    out_image = pipe(
        prompt=prompt,
        inpaint_image=image,
        inpaint_mask=mask,
        height=height,
        width=width,
        guidance_scale=guidance_scale,
        control_strength=control_strength, 
        control_stop=control_stop,
        num_inference_steps=num_inference_steps,
        generator=torch.Generator("cpu").manual_seed(seed)
    ).images[0]
    return (image, out_image), seed
    
examples = [
    "a tiny astronaut hatching from an egg on the moon",
    "a cat holding a sign that says hello world",
    "an anime illustration of a wiener schnitzel",
]

css="""
#col-container {
    margin: 0 auto;
    max-width: 1000px;
}
"""

with gr.Blocks(css=css) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# Flex.2 Preview - Inpaint
Inpainting demo for Flex.2 Preview - Open Source 8B parameter Text to Image Diffusion Model with universal control and built-in inpainting support
trained and devloped by [ostris](https://huggingface.co/ostris)
[[apache-2.0 license](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md)] [[model](https://huggingface.co/ostris/Flex.2-preview)]
        """)
        with gr.Row():
            with gr.Column():
                edit_image = gr.ImageEditor(
                    label='Upload and draw mask for inpainting',
                    type='pil',
                    sources=["upload", "webcam"],
                    image_mode='RGB',
                    layers=False,
                    brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
                    height=600
                )
                prompt = gr.Text(
                    label="Prompt",
                    show_label=False,
                    max_lines=1,
                    placeholder="Enter your prompt",
                    container=False,
                )
                run_button = gr.Button("Run")
                
            result = gr.ImageSlider(label="Generated Image", type="pil", image_mode='RGB')
        
        with gr.Accordion("Advanced Settings", open=False):
            
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            
            with gr.Row():
                
                height = gr.Slider(64, 2048, value=512, step=64, label="Height")
                width = gr.Slider(64, 2048, value=512, step=64, label="Width")
            
            with gr.Row():

                guidance_scale = gr.Slider(0.0, 20.0, value=3.5, step=0.1, label="Guidance Scale")
                control_strength = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Control Strength")
                control_stop = gr.Slider(0.0, 1.0, value=0.33, step=0.05, label="Control Stop")
                num_inference_steps = gr.Slider(1, 100, value=50, step=1, label="Inference Steps")

    run_button.click(
        fn = infer,
        inputs = [edit_image, prompt, seed, randomize_seed, width, height, guidance_scale, control_strength, control_stop,  num_inference_steps],
        outputs = [result, seed]
    )
    


demo.launch()