File size: 3,576 Bytes
e4797ed
c3b2960
e4797ed
6f49453
0e578b3
 
 
e4797ed
 
b6f0174
04ab47b
 
 
de8ccd6
b6f0174
6f49453
75bc4c7
 
 
 
0e578b3
 
 
2f01d97
04ab47b
 
 
 
 
 
11419c0
589a7f4
0e578b3
2f01d97
c2177be
0e578b3
 
 
 
 
 
 
 
 
 
 
c2177be
 
 
 
0e578b3
 
c2177be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e578b3
 
c2177be
 
 
 
0e578b3
 
c2177be
0e578b3
 
 
 
2f01d97
c2177be
 
79b5b6f
0e578b3
c2177be
 
 
 
 
 
 
 
 
 
0e578b3
 
 
 
393347c
2f01d97
 
0e578b3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import spaces
import os
import numpy as np
import gradio as gr
import json
import torch
from diffusers import DiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast


# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Use the 'waffles' environment variable as the access token
hf_token = os.getenv('waffles')

# Ensure the token is loaded correctly
if not hf_token:
    raise ValueError("Hugging Face API token not found. Please set the 'waffles' environment variable.")

# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
    loras = json.load(f)

# Initialize the base model with authentication and specify the device
pipe = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
    token=hf_token
).to(device)

@spaces.GPU(duration=190)
MAX_SEED = 2**32-1

def run_lora(prompt, cfg_scale, steps, selected_repo, randomize_seed, seed, width, height, lora_scale):
    if not selected_repo:
        raise gr.Error("You must select a LoRA before proceeding.")

    selected_lora = next((lora for lora in loras if lora["repo"] == selected_repo), None)
    if not selected_lora:
        raise gr.Error("Selected LoRA not found.")

    lora_path = selected_lora["repo"]
    trigger_word = selected_lora["trigger_word"]

    # Load LoRA weights
    if "weights" in selected_lora:
        pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
    else:
        pipe.load_lora_weights(lora_path)
        
    # Set random seed for reproducibility
    if randomize_seed:
        seed = torch.randint(0, MAX_SEED, (1,)).item()

    # Generate image
    generator = torch.Generator(device=device).manual_seed(seed)
    image = pipe(
        prompt=f"{prompt} {trigger_word}",
        num_inference_steps=steps,
        guidance_scale=cfg_scale,
        width=width,
        height=height,
        generator=generator,
        joint_attention_kwargs={"scale": lora_scale},
    ).images[0]

    # Reset the model to CPU and unload LoRA weights to free up memory
    pipe.to("cpu")
    pipe.unload_lora_weights()

    return image, seed

with gr.Blocks() as app:
    with gr.Row():
        with gr.Column(scale=3):
            prompt = gr.Textbox(label="Prompt", lines=5, placeholder="Enter your prompt")
            lora_dropdown = gr.Dropdown(
                label="Select LoRA",
                choices=[lora["repo"] for lora in loras],
                value="XLabs-AI/flux-RealismLora",
            )
        with gr.Column(scale=1):
            generate_button = gr.Button("Generate", variant="primary")

    with gr.Row():
        result = gr.Image(label="Generated Image")
        seed = gr.Number(label="Seed", value=0, interactive=False)

    with gr.Accordion("Advanced Settings", open=False):
        cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
        steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
        width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
        height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
        randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
        lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95)

    generate_button.click(
        run_lora,
        inputs=[prompt, cfg_scale, steps, lora_dropdown, randomize_seed, seed, width, height, lora_scale],
        outputs=[result, seed]
    )

app.launch()