AlekseyCalvin commited on
Commit
40c98fe
·
verified ·
1 Parent(s): d552a7b

Create app7.py

Browse files
Files changed (1) hide show
  1. app7.py +213 -0
app7.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import json
4
+ import logging
5
+ import torch
6
+ from PIL import Image
7
+ import spaces
8
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForText2Image
9
+ import copy
10
+ import random
11
+ import time
12
+ from diffusers.models.transformers import FluxTransformer2DModel
13
+ import safetensors.torch
14
+ from transformers import CLIPModel, CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPConfig, T5EncoderModel, T5Tokenizer
15
+ from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
16
+ from huggingface_hub import HfFileSystem, ModelCard
17
+ from safetensors.torch import load_file
18
+ from huggingface_hub import login
19
+
20
+ hf_token = os.environ.get("HF_TOKEN")
21
+ login(token=hf_token)
22
+
23
+ torch.set_float32_matmul_precision("medium")
24
+
25
+ # Load LoRAs from JSON file
26
+ with open('loras.json', 'r') as f:
27
+ loras = json.load(f)
28
+
29
+ # Initialize the base model
30
+ dtype = torch.bfloat16
31
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
+ base_model = "John6666/hyper-flux1-dev-fp8-flux"
33
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
34
+ good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
35
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=good_vae).to(device)
36
+
37
+ model_id = ("zer0int/LongCLIP-GmP-ViT-L-14")
38
+ config = CLIPConfig.from_pretrained(model_id)
39
+ config.text_config.max_position_embeddings = 248
40
+ clip_model = CLIPModel.from_pretrained(model_id, torch_dtype=torch.bfloat16, config=config, ignore_mismatched_sizes=True)
41
+ clip_processor = CLIPProcessor.from_pretrained(model_id, padding="max_length", max_length=248)
42
+ pipe.tokenizer = clip_processor.tokenizer
43
+ pipe.text_encoder = clip_model.text_model
44
+ pipe.tokenizer_max_length = 248
45
+ pipe.text_encoder.dtype = torch.bfloat16
46
+
47
+ MAX_SEED = 2**32-1
48
+
49
+ class calculateDuration:
50
+ def __init__(self, activity_name=""):
51
+ self.activity_name = activity_name
52
+
53
+ def __enter__(self):
54
+ self.start_time = time.time()
55
+ return self
56
+
57
+ def __exit__(self, exc_type, exc_value, traceback):
58
+ self.end_time = time.time()
59
+ self.elapsed_time = self.end_time - self.start_time
60
+ if self.activity_name:
61
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
62
+ else:
63
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
64
+
65
+
66
+ def update_selection(evt: gr.SelectData, width, height):
67
+ selected_lora = loras[evt.index]
68
+ new_placeholder = f"Prompt with activator word(s): '{selected_lora['trigger_word']}'! "
69
+ lora_repo = selected_lora["repo"]
70
+ lora_trigger = selected_lora['trigger_word']
71
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}). Prompt using: '{lora_trigger}'!"
72
+ if "aspect" in selected_lora:
73
+ if selected_lora["aspect"] == "portrait":
74
+ width = 768
75
+ height = 1024
76
+ elif selected_lora["aspect"] == "landscape":
77
+ width = 1024
78
+ height = 768
79
+ return (
80
+ gr.update(placeholder=new_placeholder),
81
+ updated_text,
82
+ evt.index,
83
+ width,
84
+ height,
85
+ )
86
+
87
+ @spaces.GPU()
88
+ def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress):
89
+ pipe.to("cuda")
90
+ generator = torch.Generator(device="cuda").manual_seed(seed)
91
+
92
+ with calculateDuration("Generating image"):
93
+ # Generate image
94
+ image = pipe(
95
+ prompt=f"{prompt} {trigger_word}",
96
+ num_inference_steps=steps,
97
+ guidance_scale=cfg_scale,
98
+ width=width,
99
+ height=height,
100
+ generator=generator,
101
+ joint_attention_kwargs={"scale": lora_scale},
102
+ ).images[0]
103
+ return image
104
+
105
+ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
106
+ if selected_index is None:
107
+ raise gr.Error("You must select a LoRA before proceeding.")
108
+
109
+ selected_lora = loras[selected_index]
110
+ lora_path = selected_lora["repo"]
111
+ trigger_word = selected_lora['trigger_word']
112
+ if(trigger_word):
113
+ if "trigger_position" in selected_lora:
114
+ if selected_lora["trigger_position"] == "prepend":
115
+ prompt_mash = f"{trigger_word} {prompt}"
116
+ else:
117
+ prompt_mash = f"{prompt} {trigger_word}"
118
+ else:
119
+ prompt_mash = f"{trigger_word} {prompt}"
120
+ else:
121
+ prompt_mash = prompt
122
+
123
+ # Load LoRA weights
124
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
125
+ if "weights" in selected_lora:
126
+ pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
127
+ else:
128
+ pipe.load_lora_weights(lora_path)
129
+
130
+ # Set random seed for reproducibility
131
+ with calculateDuration("Randomizing seed"):
132
+ if randomize_seed:
133
+ seed = random.randint(0, MAX_SEED)
134
+
135
+ image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress)
136
+ pipe.to("cpu")
137
+ pipe.unload_lora_weights()
138
+ return image, seed
139
+
140
+ run_lora.zerogpu = True
141
+
142
+ css = '''
143
+ #gen_btn{height: 100%}
144
+ #title{text-align: center}
145
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
146
+ #title img{width: 100px; margin-right: 0.5em}
147
+ #gallery .grid-wrap{height: 10vh}
148
+ '''
149
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
150
+ title = gr.HTML(
151
+ """<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> SOONfactory </h1>""",
152
+ elem_id="title",
153
+ )
154
+ # Info blob stating what the app is running
155
+ info_blob = gr.HTML(
156
+ """<div id="info_blob"> Generative Models Celebrating the Unique Style & Sensibility of the Bay Area-based artist Jacqueline Trosclair (known to her friends as "Jax", "Starlic Jorca", & in countless ways)... </div>"""
157
+ )
158
+
159
+ # Info blob stating what the app is running
160
+ info_blob = gr.HTML(
161
+ """<div id="info_blob"> To create new arts via a generative model variant inspired by a Jax's artworks, or a model merging in the style of her favorite artist Unica Zürn, choose a version below. </div>"""
162
+ )
163
+ selected_index = gr.State(None)
164
+ with gr.Row():
165
+ with gr.Column(scale=3):
166
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Select LoRa/Style & type prompt!")
167
+ with gr.Column(scale=1, elem_id="gen_column"):
168
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
169
+ with gr.Row():
170
+ with gr.Column(scale=3):
171
+ selected_info = gr.Markdown("")
172
+ gallery = gr.Gallery(
173
+ [(item["image"], item["title"]) for item in loras],
174
+ label="LoRA Inventory",
175
+ allow_preview=False,
176
+ columns=3,
177
+ elem_id="gallery"
178
+ )
179
+
180
+ with gr.Column(scale=4):
181
+ result = gr.Image(label="Generated Image")
182
+
183
+ with gr.Row():
184
+ with gr.Accordion("Advanced Settings", open=True):
185
+ with gr.Column():
186
+ with gr.Row():
187
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=0, maximum=20, step=0.5, value=3.0)
188
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=12)
189
+
190
+ with gr.Row():
191
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
192
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1088)
193
+
194
+ with gr.Row():
195
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
196
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
197
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2.0, step=0.01, value=1.05)
198
+
199
+ gallery.select(
200
+ update_selection,
201
+ inputs=[width, height],
202
+ outputs=[prompt, selected_info, selected_index, width, height]
203
+ )
204
+
205
+ gr.on(
206
+ triggers=[generate_button.click, prompt.submit],
207
+ fn=run_lora,
208
+ inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
209
+ outputs=[result, seed]
210
+ )
211
+
212
+ app.queue(default_concurrency_limit=2).launch(show_error=True)
213
+ app.launch()