File size: 3,793 Bytes
5b8270b
 
 
 
 
 
 
 
d4cae39
5b8270b
 
 
 
 
 
d4cae39
5b8270b
 
 
 
 
 
 
 
 
 
 
 
 
d4cae39
5b8270b
d4cae39
 
 
 
 
 
 
 
5b8270b
d4cae39
5b8270b
d4cae39
5b8270b
 
d4cae39
 
5b8270b
 
d4cae39
5b8270b
 
 
 
d4cae39
5b8270b
 
 
 
 
 
 
 
 
 
d4cae39
 
 
 
 
 
 
 
 
 
 
 
5b8270b
d4cae39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b8270b
 
 
 
 
 
 
d4cae39
 
 
 
 
 
5b8270b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import numpy as np

import spaces
import torch
import random
from PIL import Image

from pipeline_flux_kontext import FluxKontextPipeline
from diffusers import FluxTransformer2DModel
from diffusers.utils import load_image

from huggingface_hub import hf_hub_download


kontext_path = hf_hub_download(repo_id="diffusers/kontext-v2", filename="dev-opt-2-a-3.safetensors")

MAX_SEED = np.iinfo(np.int32).max

transformer = FluxTransformer2DModel.from_single_file(kontext_path, torch_dtype=torch.bfloat16)
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16).to("cuda")

@spaces.GPU
def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, progress=gr.Progress(track_tqdm=True)):
    
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
        
    input_image = input_image.convert("RGB")
    # original_width, original_height = input_image.size
    
    # if original_width >= original_height:
    #     new_width = 1024
    #     new_height = int(original_height * (new_width / original_width))
    #     new_height = round(new_height / 64) * 64
    # else:
    #     new_height = 1024
    #     new_width = int(original_width * (new_height / original_height))
    #     new_width = round(new_width / 64) * 64
    
    #input_image_resized = input_image.resize((new_width, new_height), Image.LANCZOS)
    image = pipe(
        image=input_image, 
        prompt=prompt,
        guidance_scale=guidance_scale,
        # width=new_width,
        # height=new_height,
        generator=torch.Generator().manual_seed(seed),
    ).images[0]
    return image, seed, gr.update(visible=True)

css="""
#col-container {
    margin: 0 auto;
    max-width: 960px;
}
"""

with gr.Blocks(css=css) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# FLUX.1 Kontext [dev]
        """)

        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Upload the image for editing", type="pil")
                with gr.Row():
                    prompt = gr.Text(
                        label="Prompt",
                        show_label=False,
                        max_lines=1,
                        placeholder="Enter your prompt for editing (e.g., 'Remove glasses', 'Add a hat')",
                        container=False,
                    )
                    run_button = gr.Button("Run", scale=0)
                with gr.Accordion("Advanced Settings", open=False):
            
                    seed = gr.Slider(
                        label="Seed",
                        minimum=0,
                        maximum=MAX_SEED,
                        step=1,
                        value=0,
                    )
                    
                    randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                    
                    guidance_scale = gr.Slider(
                        label="Guidance Scale",
                        minimum=1,
                        maximum=10,
                        step=0.1,
                        value=2.5,
                    )       
                    
            with gr.Column():
                result = gr.Image(label="Result", show_label=False, interactive=False)
                reuse_button = gr.Button("Reuse this image", visible=False)
        
        

    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn = infer,
        inputs = [input_image, prompt, seed, randomize_seed, guidance_scale],
        outputs = [result, seed, reuse_button]
    )
    reuse_button.click(
        fn = lambda image: image,
        inputs = [result],
        outputs = [input_image]
    )

demo.launch()