File size: 4,572 Bytes
43a3c7f
 
 
663e209
43a3c7f
56a7978
43a3c7f
663e209
bf4853f
 
 
 
 
 
 
 
 
 
43a3c7f
 
663e209
 
 
 
43a3c7f
56a7978
ee4120c
56a7978
 
43a3c7f
b6b5406
bf4853f
663e209
 
 
bf4853f
 
663e209
bf4853f
56a7978
ee4120c
663e209
 
 
43a3c7f
 
 
 
663e209
 
 
 
 
 
 
 
 
 
 
bf4853f
663e209
 
 
 
 
 
43a3c7f
663e209
43a3c7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf4853f
 
 
 
56a7978
 
663e209
43a3c7f
663e209
43a3c7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf4853f
43a3c7f
 
 
 
 
 
bf4853f
43a3c7f
 
 
663e209
bf4853f
663e209
 
43a3c7f
663e209
43a3c7f
 
 
bf4853f
43a3c7f
663e209
43a3c7f
663e209
43a3c7f
 
bf4853f
56a7978
663e209
 
 
bf4853f
663e209
43a3c7f
 
56a7978
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import DiffusionPipeline

dtype = torch.bfloat16

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    device = "mps"
else:
    device = "cpu"

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

# Initialize the pipeline globally
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)

lora_weights = {
    "cajerky": {"path": "bryanbrunetti/cajerky"}
}


@spaces.GPU(duration=120)
def infer(prompt, lora_models, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0,
          num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
    global pipe
    
    # Load LoRAs if specified
    if lora_models:
        try:
            for lora_model in lora_models:
                print(f"loading LoRA: {lora_model}")
                pipe.load_lora_weights(lora_weights[lora_model]["path"])
        except Exception as e:
            return None, seed, f"Failed to load LoRA model: {str(e)}"
    
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)
    
    try:
        image = pipe(
            prompt=prompt,
            width=width,
            height=height,
            num_inference_steps=num_inference_steps,
            generator=generator,
            guidance_scale=guidance_scale
        ).images[0]
        
        # Unload LoRA weights after generation
        if lora_models:
            pipe.unload_lora_weights()
        
        return image, seed, "Image generated successfully."
    except Exception as e:
        return None, seed, f"Error during image generation: {str(e)}"


css = """
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            run_button = gr.Button("Run", scale=0)
        
        # lora_model = gr.Text(
        #     label="LoRA Model ID (optional)",
        #     placeholder="Enter Hugging Face LoRA model ID",
        # )
        lora_models = gr.Dropdown(list(lora_weights.keys()), multiselect=True,
                                  info="Load LoRA (optional) use the name in the prompt", label="Choose LoRAs")
        
        result = gr.Image(label="Result", show_label=False)
        
        with gr.Accordion("Advanced Settings", open=False):
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=512,
                )
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=512,
                )
            with gr.Row():
                guidance_scale = gr.Slider(
                    label="Guidance Scale",
                    info="How close to follow prompt",
                    minimum=1,
                    maximum=15,
                    step=0.1,
                    value=3.5,
                )
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    info="higher = more details",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=28,
                )
        
        output_message = gr.Textbox(label="Output Message")
    
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=infer,
        inputs=[prompt, lora_models, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
        outputs=[result, seed, output_message]
    )

demo.launch()