Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
af5ea8a
1
Parent(s):
f01b73c
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ from diffusers import DiffusionPipeline
|
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
from safetensors.torch import load_file
|
6 |
from share_btn import community_icon_html, loading_icon_html, share_js
|
|
|
7 |
|
8 |
import torch
|
9 |
import json
|
@@ -13,6 +14,20 @@ import gc
|
|
13 |
|
14 |
lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
with open(lora_list, "r") as file:
|
17 |
data = json.load(file)
|
18 |
sdxl_loras = [
|
@@ -66,7 +81,7 @@ div#share-btn-container > div {flex-direction: row;background: black;align-items
|
|
66 |
|
67 |
original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
68 |
|
69 |
-
def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, progress=gr.Progress(track_tqdm=True)):
|
70 |
state_dict_1 = copy.deepcopy(shuffled_items[0]['state_dict'])
|
71 |
state_dict_2 = copy.deepcopy(shuffled_items[1]['state_dict'])
|
72 |
pipe = copy.deepcopy(original_pipe)
|
@@ -79,12 +94,15 @@ def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lor
|
|
79 |
|
80 |
if negative_prompt == "":
|
81 |
negative_prompt = None
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
84 |
del pipe
|
85 |
gc.collect()
|
86 |
torch.cuda.empty_cache()
|
87 |
-
return image, gr.update(visible=True)
|
88 |
|
89 |
def get_description(item):
|
90 |
trigger_word = item["trigger_word"]
|
@@ -108,6 +126,15 @@ def shuffle_images():
|
|
108 |
|
109 |
return title_1, prompt_description_1, repo_id_1, title_2, prompt_description_2, repo_id_2, prompt, two_shuffled_items, scale, scale
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
with gr.Blocks(css=css) as demo:
|
112 |
shuffled_items = gr.State()
|
113 |
title = gr.HTML(
|
@@ -147,9 +174,11 @@ with gr.Blocks(css=css) as demo:
|
|
147 |
community_icon = gr.HTML(community_icon_html)
|
148 |
loading_icon = gr.HTML(loading_icon_html)
|
149 |
share_button = gr.Button("Share to community", elem_id="share-btn")
|
150 |
-
|
151 |
with gr.Accordion("Advanced settings", open=False):
|
152 |
negative_prompt = gr.Textbox(label="Negative prompt")
|
|
|
|
|
153 |
with gr.Row():
|
154 |
lora_1_scale = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
|
155 |
lora_2_scale = gr.Slider(label="LoRa 2 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
|
@@ -158,8 +187,10 @@ with gr.Blocks(css=css) as demo:
|
|
158 |
demo.load(shuffle_images, inputs=[], outputs=[lora_1, lora_1_prompt, lora_1_id, lora_2, lora_2_prompt, lora_2_id, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")
|
159 |
shuffle_button.click(shuffle_images, outputs=[lora_1, lora_1_prompt, lora_1_id, lora_2, lora_2_prompt, lora_2_id, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")
|
160 |
|
161 |
-
run_btn.click(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image, post_gen_info])
|
162 |
-
prompt.submit(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image, post_gen_info])
|
|
|
|
|
163 |
share_button.click(None, [], [], _js=share_js)
|
164 |
demo.queue()
|
165 |
demo.launch()
|
|
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
from safetensors.torch import load_file
|
6 |
from share_btn import community_icon_html, loading_icon_html, share_js
|
7 |
+
from uuid import uuid4
|
8 |
|
9 |
import torch
|
10 |
import json
|
|
|
14 |
|
15 |
lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")
|
16 |
|
17 |
+
IMAGE_DATASET_DIR = Path("image_dataset") / f"train-{uuid4()}"
|
18 |
+
IMAGE_DATASET_DIR.mkdir(parents=True, exist_ok=True)
|
19 |
+
IMAGE_JSONL_PATH = IMAGE_DATASET_DIR / "metadata.jsonl"
|
20 |
+
|
21 |
+
scheduler = CommitScheduler(
|
22 |
+
repo_id="multimodalart/lora-fusing-preferences",
|
23 |
+
repo_type="dataset",
|
24 |
+
folder_path=IMAGE_DATASET_DIR,
|
25 |
+
path_in_repo=IMAGE_DATASET_DIR.name,
|
26 |
+
every=10
|
27 |
+
)
|
28 |
+
|
29 |
+
client = InferenceClient()
|
30 |
+
|
31 |
with open(lora_list, "r") as file:
|
32 |
data = json.load(file)
|
33 |
sdxl_loras = [
|
|
|
81 |
|
82 |
original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
83 |
|
84 |
+
def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, seed, progress=gr.Progress(track_tqdm=True)):
|
85 |
state_dict_1 = copy.deepcopy(shuffled_items[0]['state_dict'])
|
86 |
state_dict_2 = copy.deepcopy(shuffled_items[1]['state_dict'])
|
87 |
pipe = copy.deepcopy(original_pipe)
|
|
|
94 |
|
95 |
if negative_prompt == "":
|
96 |
negative_prompt = None
|
97 |
+
|
98 |
+
if(seed < 0):
|
99 |
+
seed = random.randint(0, 2147483647)
|
100 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
101 |
+
image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=20, width=768, height=768, generator=generator).images[0]
|
102 |
del pipe
|
103 |
gc.collect()
|
104 |
torch.cuda.empty_cache()
|
105 |
+
return image, gr.update(visible=True), seed
|
106 |
|
107 |
def get_description(item):
|
108 |
trigger_word = item["trigger_word"]
|
|
|
126 |
|
127 |
return title_1, prompt_description_1, repo_id_1, title_2, prompt_description_2, repo_id_2, prompt, two_shuffled_items, scale, scale
|
128 |
|
129 |
+
def save_preferences(lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, generated_image, thumbs_direction, seed):
|
130 |
+
image_path = IMAGE_DATASET_DIR / f"{uuid4()}.png"
|
131 |
+
with scheduler.lock:
|
132 |
+
generated_image.save(image_path)
|
133 |
+
with IMAGE_JSONL_PATH.open("a") as f:
|
134 |
+
json.dump({"prompt": prompt, "file_name":image_path.name, "lora_1_id": lora_2_id, "lora_1_scale": lora_1_scale, "lora_2_id": lora_2_id, "lora_2_scale": lora_2_scale, "thumbs_direction": thumbs_direction, "seed": seed}, f)
|
135 |
+
f.write("\n")
|
136 |
+
return gr.update(visible=True)
|
137 |
+
|
138 |
with gr.Blocks(css=css) as demo:
|
139 |
shuffled_items = gr.State()
|
140 |
title = gr.HTML(
|
|
|
174 |
community_icon = gr.HTML(community_icon_html)
|
175 |
loading_icon = gr.HTML(loading_icon_html)
|
176 |
share_button = gr.Button("Share to community", elem_id="share-btn")
|
177 |
+
post_eval = gr.Markdown("Thanks for evaluating. The dataset with evaluations is [here](#)", visible=False)
|
178 |
with gr.Accordion("Advanced settings", open=False):
|
179 |
negative_prompt = gr.Textbox(label="Negative prompt")
|
180 |
+
seed = gr.Slider(label="Seed", info="-1 denotes a random seed", minimum=-1, maximum=2147483647, value=-1)
|
181 |
+
last_used_seed = gr.Slider(label="Last used seed", info="The seed used in the last generation", minimum=0, maximum=2147483647, value=-1, interactive=False)
|
182 |
with gr.Row():
|
183 |
lora_1_scale = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
|
184 |
lora_2_scale = gr.Slider(label="LoRa 2 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
|
|
|
187 |
demo.load(shuffle_images, inputs=[], outputs=[lora_1, lora_1_prompt, lora_1_id, lora_2, lora_2_prompt, lora_2_id, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")
|
188 |
shuffle_button.click(shuffle_images, outputs=[lora_1, lora_1_prompt, lora_1_id, lora_2, lora_2_prompt, lora_2_id, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")
|
189 |
|
190 |
+
run_btn.click(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image, post_gen_info, last_used_seed])
|
191 |
+
prompt.submit(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image, post_gen_info, last_used_seed])
|
192 |
+
thumbs_up.click(save_preferences, inputs=[lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, output_image, gr.State("up"), seed], outputs=[post_eval])
|
193 |
+
thumbs_down.click(save_preferences, inputs=[lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, output_image, gr.State("down"), seed], outputs=[post_eval])
|
194 |
share_button.click(None, [], [], _js=share_js)
|
195 |
demo.queue()
|
196 |
demo.launch()
|