rynmurdock commited on
Commit
bf71575
Β·
1 Parent(s): de9a113
app.py CHANGED
@@ -106,14 +106,15 @@ def get_user_emb(embs, ys):
106
  if len(positives) == 0:
107
  positives = torch.zeros_like(im_emb)[None]
108
  else:
109
- embs = random.sample(positives, min(4, len(positives))) + positives[-4:]
 
110
  positives = torch.stack(embs, 1)
111
 
112
  negs = [e for e, ys in zip(embs, ys) if ys == 0]
113
  if len(negs) == 0:
114
  negatives = torch.zeros_like(im_emb)[None]
115
  else:
116
- negative_embs = random.sample(negs, min(4, len(negs))) + negs[-4:]
117
  negatives = torch.stack(negative_embs, 1)
118
  # if random.random() < .5:
119
  # negatives = torch.zeros_like(negatives)
 
106
  if len(positives) == 0:
107
  positives = torch.zeros_like(im_emb)[None]
108
  else:
109
+ # take last 8 TODO verify this is chronolgical; should be and also k-4 random ones.
110
+ embs = random.sample(positives, k=min(k-8, len(positives))) + positives[-8:]
111
  positives = torch.stack(embs, 1)
112
 
113
  negs = [e for e, ys in zip(embs, ys) if ys == 0]
114
  if len(negs) == 0:
115
  negatives = torch.zeros_like(im_emb)[None]
116
  else:
117
+ negative_embs = random.sample(negs, min(k-4, len(negs))) + negs[-4:]
118
  negatives = torch.stack(negative_embs, 1)
119
  # if random.random() < .5:
120
  # negatives = torch.zeros_like(negatives)
config.py CHANGED
@@ -12,5 +12,5 @@ batch_size = 16
12
  number_k_clip_embed = 16 # divide by this to determine bundling together of sequences -> CLIP
13
  num_workers = 32
14
  seed = 107
15
- k = 8
16
  # TODO config option to swap to diffusion?
 
12
  number_k_clip_embed = 16 # divide by this to determine bundling together of sequences -> CLIP
13
  num_workers = 32
14
  seed = 107
15
+ k = 16
16
  # TODO config option to swap to diffusion?
last_epoch_ckpt/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ae34b5c319b9c804e1e82c93f78821b880553d2ac60ff628003175334ee9066d
3
  size 136790920
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33d7ca8a1d0f179ade0aa00cf9d622b0ac60ea2b58c79933a9212c54b5d6f719
3
  size 136790920
lightning_app_deprecated.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import random
4
+ import time
5
+ import torch
6
+ import glob
7
+
8
+ import config
9
+ from huggingface_hub import hf_hub_download
10
+ from diffusers import EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL, AutoPipelineForText2Image
11
+ from transformers import CLIPVisionModelWithProjection
12
+ from safetensors.torch import load_file
13
+
14
+ from model import get_model_and_tokenizer
15
+
16
+ model, tokenizer = get_model_and_tokenizer(config.model_path, 'cuda', torch.bfloat16)
17
+
18
+ del model.kandinsky_pipe
19
+ del tokenizer
20
+
21
+ torch.set_float32_matmul_precision('high')
22
+
23
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
24
+ sdxl_lightening = "ByteDance/SDXL-Lightning"
25
+ ckpt = "sdxl_lightning_8step_unet.safetensors"
26
+ unet = UNet2DConditionModel.from_config(model_id, subfolder="unet", low_cpu_mem_usage=True, device_map='cuda').to(torch.float16)
27
+ unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt)))
28
+
29
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="sdxl_models/image_encoder", torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map='cuda')
30
+ pipe = AutoPipelineForText2Image.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16", image_encoder=image_encoder, low_cpu_mem_usage=True)
31
+ pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl.bin')))
32
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
33
+ pipe.register_modules(image_encoder = image_encoder)
34
+ pipe.set_ip_adapter_scale(0.8)
35
+
36
+ #pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16, low_cpu_mem_usage=True)
37
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
38
+
39
+ pipe.to(device='cuda').to(dtype=config.dtype)
40
+ output_hidden_state = False
41
+
42
+
43
+ # TODO unify/merge origin and this
44
+ # TODO save & restart from (if it exists) dataframe parquet
45
+
46
+ device = "cuda"
47
+
48
+ k = config.k
49
+
50
+ import spaces
51
+ import matplotlib.pyplot as plt
52
+
53
+ import os
54
+ import gradio as gr
55
+ import pandas as pd
56
+ from apscheduler.schedulers.background import BackgroundScheduler
57
+
58
+ import random
59
+ import time
60
+ from PIL import Image
61
+ # from safety_checker_improved import maybe_nsfw
62
+
63
+
64
+ torch.set_grad_enabled(False)
65
+ torch.backends.cuda.matmul.allow_tf32 = True
66
+ torch.backends.cudnn.allow_tf32 = True
67
+
68
+ prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'from_user_id', 'text', 'gemb'])
69
+
70
+ import spaces
71
+ start_time = time.time()
72
+
73
+ ####################### Setup Model
74
+ from diffusers import EulerDiscreteScheduler
75
+ from PIL import Image
76
+ import uuid
77
+
78
+
79
+ @spaces.GPU()
80
+ def generate_gpu(in_im_embs, prompt='the scene'):
81
+ with torch.no_grad():
82
+ in_im_embs = in_im_embs.to('cuda')
83
+
84
+ negative_image_embeds = in_im_embs[0]# if random.random() < .3 else model.prior_pipe.get_zero_embed()
85
+ positive_image_embeds = in_im_embs[1]
86
+
87
+ in_im_embs = in_im_embs.to('cuda').view(2, 1, -1)
88
+ images = pipe(prompt=prompt, guidance_scale=4, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=8).images[0]
89
+ im_emb, _ = pipe.encode_image(
90
+ images, 'cuda', 1, output_hidden_state
91
+ )
92
+ im_emb = im_emb.detach().to('cpu').to(torch.float32)
93
+ return images, im_emb
94
+
95
+
96
+ def generate(in_im_embs, ):
97
+ output, im_emb = generate_gpu(in_im_embs)
98
+ nsfw = False#maybe_nsfw(output.images[0])
99
+
100
+ name = str(uuid.uuid4()).replace("-", "")
101
+ path = f"/tmp/{name}.png"
102
+
103
+ if nsfw:
104
+ gr.Warning("NSFW content detected.")
105
+ # TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring.
106
+ return None, im_emb
107
+
108
+ output.save(path)
109
+ return path, im_emb
110
+
111
+
112
+ #######################
113
+
114
+ @spaces.GPU()
115
+ def sample_embs(prompt_embeds):
116
+ latent = torch.randn(prompt_embeds.shape[0], 1, prompt_embeds.shape[-1])
117
+ if prompt_embeds.shape[1] < k:
118
+ prompt_embeds = torch.nn.functional.pad(prompt_embeds, [0, 0, 0, k-prompt_embeds.shape[1]])
119
+ assert prompt_embeds.shape[1] == k, f"The model is set to take `k`` cond image embeds but is shape {prompt_embeds.shape}"
120
+ image_embeds = model(latent.to('cuda'), prompt_embeds.to('cuda')).predicted_image_embedding
121
+ return image_embeds
122
+
123
+ @spaces.GPU()
124
+ def get_user_emb(embs, ys):
125
+ positives = [e for e, ys in zip(embs, ys) if ys == 1]
126
+ if len(positives) == 0:
127
+ positives = torch.zeros_like(im_emb)[None]
128
+ else:
129
+ embs = random.sample(positives, min(k-4, len(positives))) + positives[-4:]
130
+ positives = torch.stack(embs, 1)
131
+
132
+ negs = [e for e, ys in zip(embs, ys) if ys == 0]
133
+ if len(negs) == 0:
134
+ negatives = torch.zeros_like(im_emb)[None]
135
+ else:
136
+ negative_embs = random.sample(negs, min(k-4, len(negs))) + negs[-4:]
137
+ negatives = torch.stack(negative_embs, 1)
138
+ # if random.random() < .5:
139
+ # negatives = torch.zeros_like(negatives)
140
+
141
+ image_embeds = torch.stack([sample_embs(negatives), sample_embs(positives)])
142
+
143
+ return image_embeds
144
+
145
+
146
+ def background_next_image():
147
+ global prevs_df
148
+ # only let it get N (maybe 3) ahead of the user
149
+ #not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
150
+ rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
151
+ if len(rated_rows) < 4:
152
+ time.sleep(.1)
153
+ # not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
154
+ return
155
+
156
+ user_id_list = set(rated_rows['latest_user_to_rate'].to_list())
157
+ for uid in user_id_list:
158
+ rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is not None for i in prevs_df.iterrows()]]
159
+ not_rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is None for i in prevs_df.iterrows()]]
160
+
161
+ # we need to intersect not_rated_rows from this user's embed > 7. Just add a new column on which user_id spawned the
162
+ # media.
163
+
164
+ unrated_from_user = not_rated_rows[[i[1]['from_user_id'] == uid for i in not_rated_rows.iterrows()]]
165
+
166
+ # we don't compute more after n are in the queue for them
167
+ if len(unrated_from_user) >= 10:
168
+ continue
169
+
170
+ if len(rated_rows) < 4:
171
+ continue
172
+
173
+ global glob_idx
174
+ glob_idx += 1
175
+
176
+ ems = rated_rows['embeddings'].to_list()
177
+ ys = [i[uid][0] for i in rated_rows['user:rating'].to_list()]
178
+
179
+ emz = get_user_emb(ems, ys)
180
+ img, embs = generate(emz)
181
+
182
+ if img:
183
+ tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'text', 'gemb'])
184
+ tmp_df['paths'] = [img]
185
+ tmp_df['embeddings'] = [embs.to(torch.float32).to('cpu')]
186
+ tmp_df['user:rating'] = [{' ': ' '}]
187
+ tmp_df['from_user_id'] = [uid]
188
+ tmp_df['text'] = ['']
189
+ prevs_df = pd.concat((prevs_df, tmp_df))
190
+ # we can free up storage by deleting the image
191
+ if len(prevs_df) > 500:
192
+ oldest_path = prevs_df.iloc[6]['paths']
193
+ if os.path.isfile(oldest_path):
194
+ os.remove(oldest_path)
195
+ else:
196
+ # If it fails, inform the user.
197
+ print("Error: %s file not found" % oldest_path)
198
+ # only keep 50 images & embeddings & ips, then remove oldest besides calibrating
199
+ prevs_df = pd.concat((prevs_df.iloc[:6], prevs_df.iloc[7:]))
200
+
201
+ def pluck_img(user_id):
202
+ # TODO pluck images based on similarity but also based on diversity by cluster every few times.
203
+ rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) is not None for i in prevs_df.iterrows()]]
204
+ ems = rated_rows['embeddings'].to_list()
205
+ ys = [i[user_id][0] for i in rated_rows['user:rating'].to_list()]
206
+ user_emb = get_user_emb(ems, ys)
207
+
208
+ not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
209
+ while len(not_rated_rows) == 0:
210
+ not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
211
+ time.sleep(.1)
212
+ # TODO optimize this lol
213
+
214
+ # NOTE could opt for only showing their own or prioritizing their own media.
215
+ unrated_from_user = not_rated_rows[[i[1]['from_user_id'] == user_id for i in not_rated_rows.iterrows()]]
216
+
217
+ best_sim = -10000000
218
+ for i in not_rated_rows.iterrows():
219
+ # TODO sloppy .to but it is 3am.
220
+ sim = torch.cosine_similarity(i[1]['embeddings'].detach().to('cpu'), user_emb.detach().to('cpu'), -1)
221
+ if len(sim) > 1: sim = sim[1]
222
+ if sim.squeeze() > best_sim:
223
+ best_sim = sim
224
+ best_row = i[1]
225
+ img = best_row['paths']
226
+ return img
227
+
228
+ def next_image(calibrate_prompts, user_id):
229
+ with torch.no_grad():
230
+ # once we've done so many random calibration prompts out of the full media
231
+ if len(m_calibrate) - len(calibrate_prompts) < 5:
232
+ cal_video = calibrate_prompts.pop(random.randint(0, len(calibrate_prompts)-1))
233
+ image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
234
+ # we switch to just getting media by similarity.
235
+ else:
236
+ image = pluck_img(user_id)
237
+ return image, calibrate_prompts
238
+
239
+
240
+
241
+
242
+
243
+
244
+ def start(_, calibrate_prompts, user_id, request: gr.Request):
245
+ user_id = int(str(time.time())[-7:].replace('.', ''))
246
+ image, calibrate_prompts = next_image(calibrate_prompts, user_id)
247
+ return [
248
+ gr.Button(value='πŸ‘', interactive=True),
249
+ gr.Button(value='Neither (Space)', interactive=True, visible=False),
250
+ gr.Button(value='πŸ‘Ž', interactive=True),
251
+ gr.Button(value='Start', interactive=False),
252
+ gr.Button(value='πŸ‘ Content', interactive=True, visible=False),
253
+ gr.Button(value='πŸ‘ Style', interactive=True, visible=False),
254
+ image,
255
+ calibrate_prompts,
256
+ user_id,
257
+ ]
258
+
259
+
260
+ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
261
+ global prevs_df
262
+
263
+ if choice == 'πŸ‘':
264
+ choice = [1, 1]
265
+ elif choice == 'Neither (Space)':
266
+ img, calibrate_prompts = next_image(calibrate_prompts, user_id)
267
+ return img, calibrate_prompts
268
+ elif choice == 'πŸ‘Ž':
269
+ choice = [0, 0]
270
+ elif choice == 'πŸ‘ Style':
271
+ choice = [0, 1]
272
+ elif choice == 'πŸ‘ Content':
273
+ choice = [1, 0]
274
+ else:
275
+ assert False, f'choice is {choice}'
276
+
277
+ # if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
278
+ # TODO skip allowing rating & just continue
279
+ if img is None:
280
+ print('NSFW -- choice is disliked')
281
+ choice = [0, 0]
282
+
283
+ row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
284
+ # if it's still in the dataframe, add the choice
285
+ if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
286
+ prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
287
+ prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
288
+ else:
289
+ print('Image apparently removed', img)
290
+ img, calibrate_prompts = next_image(calibrate_prompts, user_id)
291
+ return img, calibrate_prompts
292
+
293
+ css = '''.gradio-container{max-width: 700px !important}
294
+ #description{text-align: center}
295
+ #description h1, #description h3{display: block}
296
+ #description p{margin-top: 0}
297
+ .fade-in-out {animation: fadeInOut 3s forwards}
298
+ @keyframes fadeInOut {
299
+ 0% {
300
+ background: var(--bg-color);
301
+ }
302
+ 100% {
303
+ background: var(--button-secondary-background-fill);
304
+ }
305
+ }
306
+ '''
307
+ js_head = '''
308
+ <script>
309
+ document.addEventListener('keydown', function(event) {
310
+ if (event.key === 'a' || event.key === 'A') {
311
+ // Trigger click on 'dislike' if 'A' is pressed
312
+ document.getElementById('dislike').click();
313
+ } else if (event.key === ' ' || event.keyCode === 32) {
314
+ // Trigger click on 'neither' if Spacebar is pressed
315
+ document.getElementById('neither').click();
316
+ } else if (event.key === 'l' || event.key === 'L') {
317
+ // Trigger click on 'like' if 'L' is pressed
318
+ document.getElementById('like').click();
319
+ }
320
+ });
321
+ function fadeInOut(button, color) {
322
+ button.style.setProperty('--bg-color', color);
323
+ button.classList.remove('fade-in-out');
324
+ void button.offsetWidth; // This line forces a repaint by accessing a DOM property
325
+
326
+ button.classList.add('fade-in-out');
327
+ button.addEventListener('animationend', () => {
328
+ button.classList.remove('fade-in-out'); // Reset the animation state
329
+ }, {once: true});
330
+ }
331
+ document.body.addEventListener('click', function(event) {
332
+ const target = event.target;
333
+ if (target.id === 'dislike') {
334
+ fadeInOut(target, '#ff1717');
335
+ } else if (target.id === 'like') {
336
+ fadeInOut(target, '#006500');
337
+ } else if (target.id === 'neither') {
338
+ fadeInOut(target, '#cccccc');
339
+ }
340
+ });
341
+
342
+ </script>
343
+ '''
344
+
345
+ with gr.Blocks(head=js_head, css=css) as demo:
346
+ gr.Markdown('''# The Other Tiger
347
+ ### Generative Recommenders for Exporation of Possible Images
348
+
349
+ Explore the latent space using binary feedback.
350
+
351
+ [rynmurdock.github.io](https://rynmurdock.github.io/)
352
+ ''', elem_id="description")
353
+ user_id = gr.State()
354
+ # calibration videos -- this is a misnomer now :D
355
+ calibrate_prompts = gr.State( glob.glob('image_init/*') )
356
+ def l():
357
+ return None
358
+
359
+ with gr.Row(elem_id='output-image'):
360
+ img = gr.Image(
361
+ label='Lightning',
362
+ interactive=False,
363
+ elem_id="output_im",
364
+ type='filepath',
365
+ height=700,
366
+ width=700,
367
+ )
368
+
369
+
370
+
371
+ with gr.Row(equal_height=True):
372
+ b3 = gr.Button(value='πŸ‘Ž', interactive=False, elem_id="dislike")
373
+
374
+ b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False)
375
+
376
+ b1 = gr.Button(value='πŸ‘', interactive=False, elem_id="like")
377
+ with gr.Row(equal_height=True):
378
+ b6 = gr.Button(value='πŸ‘ Style', interactive=False, elem_id="dislike like", visible=False)
379
+
380
+ b5 = gr.Button(value='πŸ‘ Content', interactive=False, elem_id="like dislike", visible=False)
381
+
382
+ b1.click(
383
+ choose,
384
+ [img, b1, calibrate_prompts, user_id],
385
+ [img, calibrate_prompts, ],
386
+ )
387
+ b2.click(
388
+ choose,
389
+ [img, b2, calibrate_prompts, user_id],
390
+ [img, calibrate_prompts, ],
391
+ )
392
+ b3.click(
393
+ choose,
394
+ [img, b3, calibrate_prompts, user_id],
395
+ [img, calibrate_prompts, ],
396
+ )
397
+ b5.click(
398
+ choose,
399
+ [img, b5, calibrate_prompts, user_id],
400
+ [img, calibrate_prompts, ],
401
+ )
402
+ b6.click(
403
+ choose,
404
+ [img, b6, calibrate_prompts, user_id],
405
+ [img, calibrate_prompts, ],
406
+ )
407
+ with gr.Row():
408
+ b4 = gr.Button(value='Start')
409
+ b4.click(start,
410
+ [b4, calibrate_prompts, user_id],
411
+ [b1, b2, b3, b4, b5, b6, img, calibrate_prompts, user_id, ]
412
+ )
413
+ with gr.Row():
414
+ html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several images and then roam. When your media is generating, you may encounter others'.</ div><br><br><br>
415
+
416
+ <br><br>
417
+ <div style='text-align:center; font-size:14px'>Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback.
418
+ </ div>''')
419
+
420
+ # TODO quiet logging
421
+ scheduler = BackgroundScheduler()
422
+ scheduler.add_job(func=background_next_image, trigger="interval", seconds=.2)
423
+ scheduler.start()
424
+
425
+ # TODO shouldn't call this before gradio launch, yeah?
426
+ @spaces.GPU()
427
+ def encode_space(x):
428
+ im = (
429
+ model.prior_pipe.image_processor(x, return_tensors="pt")
430
+ .pixel_values[0]
431
+ .unsqueeze(0)
432
+ .to(dtype=model.prior_pipe.image_encoder.dtype, device=device)
433
+ )
434
+ im_emb = model.prior_pipe.image_encoder(im)["image_embeds"]
435
+ return im_emb.detach().to('cpu').to(torch.float32)
436
+
437
+ # NOTE:
438
+ # media is moved into a random tmp folder so we need to parse filenames carefully.
439
+ # do not have any cases where a file name is the same or could be `in` another filename
440
+ # you also maybe can't use jpegs lmao
441
+
442
+ # prep our calibration videos
443
+ m_calibrate = glob.glob('image_init/*')
444
+ for im in m_calibrate:
445
+ tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb', 'from_user_id'])
446
+ tmp_df['paths'] = [im]
447
+ image = Image.open(im).convert('RGB')
448
+ im_emb = encode_space(image)
449
+
450
+ tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
451
+ tmp_df['user:rating'] = [{' ': ' '}]
452
+ tmp_df['text'] = ['']
453
+
454
+ # seems to break things...
455
+ tmp_df['from_user_id'] = [0]
456
+ tmp_df['latest_user_to_rate'] = [0]
457
+ prevs_df = pd.concat((prevs_df, tmp_df))
458
+
459
+ glob_idx = 0
460
+ demo.launch(share=True,)