Spaces:
Running
on
A10G
Running
on
A10G
rynmurdock
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -7,34 +7,17 @@ import numpy as np
|
|
7 |
from sklearn.svm import LinearSVC
|
8 |
from sklearn import preprocessing
|
9 |
import pandas as pd
|
10 |
-
import kornia
|
11 |
-
import torchvision
|
12 |
|
13 |
import random
|
14 |
import time
|
15 |
|
16 |
-
|
17 |
-
from diffusers.models import ImageProjection
|
18 |
-
from patch_sdxl import SDEmb
|
19 |
import torch
|
20 |
-
|
21 |
|
22 |
prompt_list = [p for p in list(set(
|
23 |
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
|
24 |
|
25 |
-
|
26 |
-
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
27 |
-
lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
|
28 |
-
|
29 |
-
pipe = SDEmb.from_pretrained(model_id, variant="fp16")
|
30 |
-
pipe.load_lora_weights(lcm_lora_id)
|
31 |
-
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
32 |
-
pipe.to(device=DEVICE, dtype=torch.float16)
|
33 |
-
|
34 |
-
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
calibrate_prompts = [
|
39 |
"4k photo",
|
40 |
'surrealist art',
|
@@ -57,20 +40,6 @@ ys = []
|
|
57 |
|
58 |
start_time = time.time()
|
59 |
|
60 |
-
output_hidden_state = False if isinstance(pipe.unet.encoder_hid_proj, ImageProjection) else True
|
61 |
-
|
62 |
-
|
63 |
-
transform = kornia.augmentation.RandomResizedCrop(size=(224, 224), scale=(.3, .5))
|
64 |
-
nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
65 |
-
def patch_encode_image(image):
|
66 |
-
image = torch.tensor(torchvision.transforms.functional.pil_to_tensor(image).to(torch.float16)).repeat(16, 1, 1, 1).to(DEVICE)
|
67 |
-
image = image / 255
|
68 |
-
patches = nom(transform(image))
|
69 |
-
output, _ = pipe.encode_image(
|
70 |
-
patches, DEVICE, 1, output_hidden_state
|
71 |
-
)
|
72 |
-
return output.mean(0, keepdim=True)
|
73 |
-
|
74 |
|
75 |
glob_idx = 0
|
76 |
|
@@ -96,7 +65,6 @@ def next_image():
|
|
96 |
pooled_embeds, _ = pipe.encode_image(
|
97 |
image[0], DEVICE, 1, output_hidden_state
|
98 |
)
|
99 |
-
#pooled_embeds = patch_encode_image(image[0])
|
100 |
|
101 |
embs.append(pooled_embeds)
|
102 |
return image[0]
|
@@ -131,19 +99,10 @@ def next_image():
|
|
131 |
prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt
|
132 |
print(prompt)
|
133 |
|
134 |
-
image =
|
135 |
-
|
136 |
-
|
137 |
-
height=1024,
|
138 |
-
width=1024,
|
139 |
-
num_inference_steps=8,
|
140 |
-
guidance_scale=0,
|
141 |
-
).images
|
142 |
-
|
143 |
-
im_emb, _ = pipe.encode_image(
|
144 |
-
image[0], DEVICE, 1, output_hidden_state
|
145 |
)
|
146 |
-
#im_emb = patch_encode_image(image[0])
|
147 |
|
148 |
embs.append(im_emb)
|
149 |
|
|
|
7 |
from sklearn.svm import LinearSVC
|
8 |
from sklearn import preprocessing
|
9 |
import pandas as pd
|
|
|
|
|
10 |
|
11 |
import random
|
12 |
import time
|
13 |
|
14 |
+
import replicate
|
|
|
|
|
15 |
import torch
|
16 |
+
import pickle
|
17 |
|
18 |
prompt_list = [p for p in list(set(
|
19 |
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
calibrate_prompts = [
|
22 |
"4k photo",
|
23 |
'surrealist art',
|
|
|
40 |
|
41 |
start_time = time.time()
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
glob_idx = 0
|
45 |
|
|
|
65 |
pooled_embeds, _ = pipe.encode_image(
|
66 |
image[0], DEVICE, 1, output_hidden_state
|
67 |
)
|
|
|
68 |
|
69 |
embs.append(pooled_embeds)
|
70 |
return image[0]
|
|
|
99 |
prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt
|
100 |
print(prompt)
|
101 |
|
102 |
+
image, im_emb = replicate.run(
|
103 |
+
"rynmurdock/zahir:43177e0594f3bc2e3560170ff0ffb6d1cacdddda1be25fbcd4348ef02b0b7d0f",
|
104 |
+
input={"prompt": prompt, 'im_emg': pickle.dumps(im_emb)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
)
|
|
|
106 |
|
107 |
embs.append(im_emb)
|
108 |
|