Spaces:
Running
on
A10G
Running
on
A10G
rynmurdock
commited on
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
DEVICE = 'cpu'
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
from sklearn.svm import LinearSVC
|
8 |
+
from sklearn import preprocessing
|
9 |
+
import pandas as pd
|
10 |
+
import kornia
|
11 |
+
import torchvision
|
12 |
+
|
13 |
+
import random
|
14 |
+
import time
|
15 |
+
|
16 |
+
from diffusers import LCMScheduler
|
17 |
+
from diffusers.models import ImageProjection
|
18 |
+
from patch_sdxl import SDEmb
|
19 |
+
import torch
|
20 |
+
|
21 |
+
|
22 |
+
prompt_list = [p for p in list(set(
|
23 |
+
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
|
24 |
+
|
25 |
+
|
26 |
+
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
27 |
+
lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
|
28 |
+
|
29 |
+
pipe = SDEmb.from_pretrained(model_id, variant="fp16")
|
30 |
+
pipe.load_lora_weights(lcm_lora_id)
|
31 |
+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
32 |
+
pipe.to(device=DEVICE, dtype=torch.float16)
|
33 |
+
|
34 |
+
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
calibrate_prompts = [
|
39 |
+
"4k photo",
|
40 |
+
'surrealist art',
|
41 |
+
'a psychedelic, fractal view',
|
42 |
+
'a beautiful collage',
|
43 |
+
'an intricate portrait',
|
44 |
+
'an impressionist painting',
|
45 |
+
'abstract art',
|
46 |
+
'an eldritch image',
|
47 |
+
'a sketch',
|
48 |
+
'a city full of darkness and graffiti',
|
49 |
+
'a black & white photo',
|
50 |
+
'a brilliant, timeless tarot card of the world',
|
51 |
+
'a photo of a woman',
|
52 |
+
'',
|
53 |
+
]
|
54 |
+
|
55 |
+
embs = []
|
56 |
+
ys = []
|
57 |
+
|
58 |
+
start_time = time.time()
|
59 |
+
|
60 |
+
output_hidden_state = False if isinstance(pipe.unet.encoder_hid_proj, ImageProjection) else True
|
61 |
+
|
62 |
+
|
63 |
+
transform = kornia.augmentation.RandomResizedCrop(size=(224, 224), scale=(.3, .5))
|
64 |
+
nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
65 |
+
def patch_encode_image(image):
|
66 |
+
image = torch.tensor(torchvision.transforms.functional.pil_to_tensor(image).to(torch.float16)).repeat(16, 1, 1, 1).to(DEVICE)
|
67 |
+
image = image / 255
|
68 |
+
patches = nom(transform(image))
|
69 |
+
output, _ = pipe.encode_image(
|
70 |
+
patches, DEVICE, 1, output_hidden_state
|
71 |
+
)
|
72 |
+
return output.mean(0, keepdim=True)
|
73 |
+
|
74 |
+
|
75 |
+
glob_idx = 0
|
76 |
+
|
77 |
+
def next_image():
|
78 |
+
global glob_idx
|
79 |
+
glob_idx = glob_idx + 1
|
80 |
+
with torch.no_grad():
|
81 |
+
if len(calibrate_prompts) > 0:
|
82 |
+
print('######### Calibrating with sample prompts #########')
|
83 |
+
prompt = calibrate_prompts.pop(0)
|
84 |
+
print(prompt)
|
85 |
+
|
86 |
+
image = pipe(
|
87 |
+
prompt=prompt,
|
88 |
+
height=1024,
|
89 |
+
width=1024,
|
90 |
+
num_inference_steps=8,
|
91 |
+
guidance_scale=0,
|
92 |
+
ip_adapter_emb=torch.zeros(1, 1, 1280, device=DEVICE, dtype=torch.float16),
|
93 |
+
).images
|
94 |
+
|
95 |
+
|
96 |
+
pooled_embeds, _ = pipe.encode_image(
|
97 |
+
image[0], DEVICE, 1, output_hidden_state
|
98 |
+
)
|
99 |
+
#pooled_embeds = patch_encode_image(image[0])
|
100 |
+
|
101 |
+
embs.append(pooled_embeds)
|
102 |
+
return image[0]
|
103 |
+
else:
|
104 |
+
print('######### Roaming #########')
|
105 |
+
|
106 |
+
# sample only as many negatives as there are positives
|
107 |
+
indices = range(len(ys))
|
108 |
+
pos_indices = [i for i in indices if ys[i] == 1]
|
109 |
+
neg_indices = [i for i in indices if ys[i] == 0]
|
110 |
+
lower = min(len(pos_indices), len(neg_indices))
|
111 |
+
neg_indices = random.sample(neg_indices, lower)
|
112 |
+
pos_indices = random.sample(pos_indices, lower)
|
113 |
+
|
114 |
+
cut_embs = [embs[i] for i in neg_indices] + [embs[i] for i in pos_indices]
|
115 |
+
cut_ys = [ys[i] for i in neg_indices] + [ys[i] for i in pos_indices]
|
116 |
+
|
117 |
+
feature_embs = torch.stack([e[0].detach().cpu() for e in cut_embs])
|
118 |
+
scaler = preprocessing.StandardScaler().fit(feature_embs)
|
119 |
+
feature_embs = scaler.transform(feature_embs)
|
120 |
+
print(np.array(feature_embs).shape, np.array(ys).shape)
|
121 |
+
|
122 |
+
lin_class = LinearSVC(max_iter=50000, dual='auto', class_weight='balanced').fit(np.array(feature_embs), np.array(cut_ys))
|
123 |
+
lin_class.coef_ = torch.tensor(lin_class.coef_, dtype=torch.double)
|
124 |
+
lin_class.coef_ = (lin_class.coef_.flatten() / (lin_class.coef_.flatten().norm())).unsqueeze(0)
|
125 |
+
|
126 |
+
|
127 |
+
rng_prompt = random.choice(prompt_list)
|
128 |
+
|
129 |
+
w = 1# if len(embs) % 2 == 0 else 0
|
130 |
+
im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16)
|
131 |
+
prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt
|
132 |
+
print(prompt)
|
133 |
+
|
134 |
+
image = pipe(
|
135 |
+
prompt=prompt,
|
136 |
+
ip_adapter_emb=im_emb,
|
137 |
+
height=1024,
|
138 |
+
width=1024,
|
139 |
+
num_inference_steps=8,
|
140 |
+
guidance_scale=0,
|
141 |
+
).images
|
142 |
+
|
143 |
+
im_emb, _ = pipe.encode_image(
|
144 |
+
image[0], DEVICE, 1, output_hidden_state
|
145 |
+
)
|
146 |
+
#im_emb = patch_encode_image(image[0])
|
147 |
+
|
148 |
+
embs.append(im_emb)
|
149 |
+
|
150 |
+
torch.save(lin_class.coef_, f'./{start_time}.pt')
|
151 |
+
return image[0]
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
|
161 |
+
def start(_):
|
162 |
+
return [
|
163 |
+
gr.Button(value='Like', interactive=True),
|
164 |
+
gr.Button(value='Neither', interactive=True),
|
165 |
+
gr.Button(value='Dislike', interactive=True),
|
166 |
+
gr.Button(value='Start', interactive=False),
|
167 |
+
next_image()
|
168 |
+
]
|
169 |
+
|
170 |
+
|
171 |
+
def choose(choice):
|
172 |
+
if choice == 'Like':
|
173 |
+
choice = 1
|
174 |
+
elif choice == 'Neither':
|
175 |
+
_ = embs.pop(-1)
|
176 |
+
return next_image()
|
177 |
+
else:
|
178 |
+
choice = 0
|
179 |
+
ys.append(choice)
|
180 |
+
return next_image()
|
181 |
+
|
182 |
+
css = "div#output-image {height: 768px !important; width: 768px !important; margin:auto;}"
|
183 |
+
with gr.Blocks(css=css) as demo:
|
184 |
+
with gr.Row():
|
185 |
+
html = gr.HTML('''<div style='text-align:center; font-size:32'>You will callibrate for several prompts and then roam.</ div>''')
|
186 |
+
with gr.Row(elem_id='output-image'):
|
187 |
+
img = gr.Image(interactive=False, elem_id='output-image',)
|
188 |
+
with gr.Row(equal_height=True):
|
189 |
+
b3 = gr.Button(value='Dislike', interactive=False,)
|
190 |
+
b2 = gr.Button(value='Neither', interactive=False,)
|
191 |
+
b1 = gr.Button(value='Like', interactive=False,)
|
192 |
+
b1.click(
|
193 |
+
choose,
|
194 |
+
[b1],
|
195 |
+
[img]
|
196 |
+
)
|
197 |
+
b2.click(
|
198 |
+
choose,
|
199 |
+
[b2],
|
200 |
+
[img]
|
201 |
+
)
|
202 |
+
b3.click(
|
203 |
+
choose,
|
204 |
+
[b3],
|
205 |
+
[img]
|
206 |
+
)
|
207 |
+
with gr.Row():
|
208 |
+
b4 = gr.Button(value='Start')
|
209 |
+
b4.click(start,
|
210 |
+
[b4],
|
211 |
+
[b1, b2, b3, b4, img,])
|
212 |
+
|
213 |
+
demo.launch() # Share your demo with just 1 extra parameter 🚀
|