rynmurdock commited on
Commit
447c576
Β·
verified Β·
1 Parent(s): e73e95e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -43
app.py CHANGED
@@ -8,10 +8,15 @@ from sklearn.svm import LinearSVC
8
  from sklearn import preprocessing
9
  import pandas as pd
10
 
 
 
 
 
 
 
11
  import random
12
  import time
13
 
14
- import replicate
15
  import torch
16
  from urllib.request import urlopen
17
 
@@ -24,11 +29,49 @@ prompt_list = [p for p in list(set(
24
 
25
  start_time = time.time()
26
 
27
- # TODO add to state instead of shared across all
28
- glob_idx = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
- deployment = replicate.deployments.get("rynmurdock/zahir-deployment")
 
 
 
32
 
33
  def next_image(embs, ys, calibrate_prompts):
34
  global glob_idx
@@ -46,26 +89,11 @@ def next_image(embs, ys, calibrate_prompts):
46
  print('######### Calibrating with sample prompts #########')
47
  prompt = calibrate_prompts.pop(0)
48
  print(prompt)
49
-
50
- prediction = deployment.predictions.create(
51
- input={"prompt": prompt,}
52
- )
53
- prediction.wait()
54
- output = prediction.output
55
-
56
- # output = replicate.run(
57
- # "rynmurdock/zahir:42c58addd49ab57f1e309f0b9a0f271f483bbef0470758757c623648fe989e42",
58
- # input={"prompt": prompt,}
59
- # )
60
-
61
- response = requests.get(output['file1'])
62
- image = Image.open(BytesIO(response.content))
63
-
64
- embs.append(torch.tensor([float(i) for i in urlopen(output['file2']).read().decode('utf-8').split(', ')]).unsqueeze(0))
65
  return image, embs, ys, calibrate_prompts
66
  else:
67
  print('######### Roaming #########')
68
-
69
  # sample only as many negatives as there are positives
70
  indices = range(len(ys))
71
  pos_indices = [i for i in indices if ys[i] == 1]
@@ -93,28 +121,8 @@ def next_image(embs, ys, calibrate_prompts):
93
  im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16)
94
  prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt
95
  print(prompt)
96
-
97
- im_emb_st = str(im_emb[0].cpu().detach().tolist())[1:-1]
98
-
99
- prediction = deployment.predictions.create(
100
- input={"prompt": prompt, 'im_emb': im_emb_st}
101
- )
102
- prediction.wait()
103
- output = prediction.output
104
-
105
- # output = replicate.run(
106
- # "rynmurdock/zahir:42c58addd49ab57f1e309f0b9a0f271f483bbef0470758757c623648fe989e42",
107
- # input={"prompt": prompt, 'im_emb': im_emb_st}
108
- # )
109
-
110
- response = requests.get(output['file1'])
111
- image = Image.open(BytesIO(response.content))
112
-
113
-
114
- im_emb = torch.tensor([float(i) for i in urlopen(output['file2']).read().decode('utf-8').split(', ')]).unsqueeze(0)
115
  embs.append(im_emb)
116
-
117
- torch.save(lin_class.coef_, f'./{start_time}.pt')
118
  return image, embs, ys, calibrate_prompts
119
 
120
 
@@ -195,6 +203,6 @@ with gr.Blocks(css=css) as demo:
195
  [b4, embs, ys, calibrate_prompts],
196
  [b1, b2, b3, b4, img, embs, ys, calibrate_prompts])
197
  with gr.Row():
198
- html = gr.HTML('''<div style='text-align:center; font-size:32'>You will callibrate for several prompts and then roam.</ div>''')
199
 
200
  demo.launch() # Share your demo with just 1 extra parameter πŸš€
 
8
  from sklearn import preprocessing
9
  import pandas as pd
10
 
11
+ from diffusers import LCMScheduler
12
+ from diffusers.models import ImageProjection
13
+ from patch_sdxl import SDEmb
14
+ import torch
15
+ import spaces
16
+
17
  import random
18
  import time
19
 
 
20
  import torch
21
  from urllib.request import urlopen
22
 
 
29
 
30
  start_time = time.time()
31
 
32
+ ####################### Setup Model
33
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
34
+ lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
35
+ pipe = SDEmb.from_pretrained(model_id, variant="fp16")
36
+ pipe.load_lora_weights(lcm_lora_id)
37
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
38
+ pipe.to(device='cuda', dtype=torch.float16)
39
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
40
+ output_hidden_state = False
41
+ #######################
42
+
43
+ @spaces.GPU
44
+ def predict(
45
+ prompt,
46
+ im_emb=None,
47
+ ):
48
+ """Run a single prediction on the model"""
49
+ with torch.no_grad():
50
+ if im_emb == None:
51
+ im_emb = torch.zeros(1, 1280, dtype=torch.float16, device='cuda')
52
+ else:
53
+ im_emb = torch.tensor([float(i) for i in im_emb.split(', ')]).unsqueeze(0).to(dtype=torch.float16).to('cuda')
54
+ image = pipe(
55
+ prompt=prompt,
56
+ ip_adapter_emb=im_emb,
57
+ height=1024,
58
+ width=1024,
59
+ num_inference_steps=8,
60
+ guidance_scale=0,
61
+ ).images[0]
62
+ im_emb, _ = pipe.encode_image(
63
+ image, 'cuda', 1, output_hidden_state
64
+ )
65
+ return image, im_emb.to(DEVICE)
66
+
67
+
68
+
69
 
70
 
71
+
72
+
73
+ # TODO add to state instead of shared across all
74
+ glob_idx = 0
75
 
76
  def next_image(embs, ys, calibrate_prompts):
77
  global glob_idx
 
89
  print('######### Calibrating with sample prompts #########')
90
  prompt = calibrate_prompts.pop(0)
91
  print(prompt)
92
+ image, img_emb = predict(prompt)
93
+ embs.append(img_emb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  return image, embs, ys, calibrate_prompts
95
  else:
96
  print('######### Roaming #########')
 
97
  # sample only as many negatives as there are positives
98
  indices = range(len(ys))
99
  pos_indices = [i for i in indices if ys[i] == 1]
 
121
  im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16)
122
  prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt
123
  print(prompt)
124
+ image, im_emb = predict(prompt, img_emb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  embs.append(im_emb)
 
 
126
  return image, embs, ys, calibrate_prompts
127
 
128
 
 
203
  [b4, embs, ys, calibrate_prompts],
204
  [b1, b2, b3, b4, img, embs, ys, calibrate_prompts])
205
  with gr.Row():
206
+ html = gr.HTML('''<div style='text-align:center; font-size:32'>You will calibrate for several prompts and then roam.</ div>''')
207
 
208
  demo.launch() # Share your demo with just 1 extra parameter πŸš€