Pusheen commited on
Commit
2fe6c8a
·
verified ·
1 Parent(s): e0e53b2

Update gligen/ldm/models/diffusion/plms.py

Browse files
Files changed (1) hide show
  1. gligen/ldm/models/diffusion/plms.py +68 -62
gligen/ldm/models/diffusion/plms.py CHANGED
@@ -3,7 +3,6 @@ import numpy as np
3
  from tqdm import tqdm
4
  from functools import partial
5
  from copy import deepcopy
6
-
7
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
8
  import math
9
  from ldm.models.diffusion.loss import caculate_loss_att_fixed_cnt, caculate_loss_self_att, caculate_loss_LoCo_V2
@@ -58,14 +57,14 @@ class PLMSSampler(object):
58
 
59
 
60
  # @torch.no_grad()
61
- def sample(self, S, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type=None):
62
  self.make_schedule(ddim_num_steps=S)
63
  # import pdb; pdb.set_trace()
64
  return self.plms_sampling(shape, input, uc, guidance_scale, mask=mask, x0=x0, loss_type=loss_type)
65
 
66
 
67
  # @torch.no_grad()
68
- def plms_sampling(self, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type=None):
69
 
70
  b = shape[0]
71
 
@@ -82,7 +81,6 @@ class PLMSSampler(object):
82
  if self.alpha_generator_func != None:
83
  alphas = self.alpha_generator_func(len(time_range))
84
 
85
-
86
  for i, step in enumerate(time_range):
87
 
88
  # set alpha and restore first conv layer
@@ -104,7 +102,14 @@ class PLMSSampler(object):
104
  # three loss types
105
  if loss_type !=None and loss_type!='standard':
106
  if input['object_position'] != []:
107
- x = self.update_loss_LoCo( input,i, index, ts, time_factor = time_factor)
 
 
 
 
 
 
 
108
  input["x"] = x
109
  img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
110
  input["x"] = img
@@ -113,60 +118,67 @@ class PLMSSampler(object):
113
  old_eps.pop(0)
114
 
115
  return img
116
-
117
- def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
 
 
 
 
118
  if index1 < 10:
119
- loss_scale = 3
120
  max_iter = 5
121
  elif index1 < 20:
122
- loss_scale = 2
123
- max_iter = 3
124
  else:
125
  loss_scale = 1
126
  max_iter = 1
127
-
128
  loss_threshold = 0.1
 
129
  max_index = 30
130
  x = deepcopy(input["x"])
131
  iteration = 0
132
  loss = torch.tensor(10000)
133
  input["timesteps"] = ts
134
 
135
- print("optimize", index1)
136
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
137
- print('iter', iteration)
138
  x = x.requires_grad_(True)
 
139
  input['x'] = x
140
  e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
 
141
  bboxes = input['boxes']
142
  object_positions = input['object_position']
143
- loss1 = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
144
- object_positions=object_positions, t = index1)*loss_scale
145
- loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
146
- object_positions=object_positions, t = index1)*loss_scale
147
- loss = loss1 + loss2
148
- print('AR loss:', loss, 'SAR:', loss1, 'CAR:', loss2)
149
- hh = torch.autograd.backward(loss)
150
- grad_cond = x.grad
151
- x = x - grad_cond
 
 
152
  x = x.detach()
153
  iteration += 1
154
  torch.cuda.empty_cache()
155
  return x
156
-
157
- def update_loss_only_cross(self, input,index1, index, ts,type_loss='self_accross'):
158
-
159
  if index1 < 10:
160
- loss_scale = 3
161
- max_iter = 5
162
  elif index1 < 20:
163
- loss_scale = 2
164
- max_iter = 5
165
  else:
166
  loss_scale = 1
167
  max_iter = 1
 
168
  loss_threshold = 0.1
169
-
170
  max_index = 30
171
  x = deepcopy(input["x"])
172
  iteration = 0
@@ -174,68 +186,64 @@ class PLMSSampler(object):
174
  input["timesteps"] = ts
175
 
176
  print("optimize", index1)
 
177
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
178
  print('iter', iteration)
 
179
  x = x.requires_grad_(True)
180
- print('x shape', x.shape)
181
  input['x'] = x
182
  e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
183
-
184
- bboxes = input['boxes']
185
  object_positions = input['object_position']
 
 
186
  loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
187
- object_positions=object_positions, t = index1)*loss_scale
188
- loss = loss2
189
- print('loss', loss)
190
- hh = torch.autograd.backward(loss, retain_graph=True)
191
- grad_cond = x.grad
192
- x = x - grad_cond
 
193
  x = x.detach()
194
  iteration += 1
195
- torch.cuda.empty_cache()
 
196
  return x
197
 
198
- def update_loss_LoCo(self, input,index1, index, ts, time_factor, type_loss='self_accross'):
199
-
200
- # loss_scale = 30
201
- # max_iter = 5
202
- #print('time_factor is: ', time_factor)
203
  if index1 < 10:
204
- loss_scale = 8
205
  max_iter = 5
206
  elif index1 < 20:
207
- loss_scale = 5
208
  max_iter = 5
209
  else:
210
  loss_scale = 1
211
  max_iter = 1
212
  loss_threshold = 0.1
213
-
214
  max_index = 30
215
  x = deepcopy(input["x"])
216
  iteration = 0
217
  loss = torch.tensor(10000)
218
  input["timesteps"] = ts
219
 
220
- # print("optimize", index1)
221
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
222
- # print('iter', iteration)
223
  x = x.requires_grad_(True)
224
- # print('x shape', x.shape)
225
  input['x'] = x
226
  e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
227
 
228
  bboxes = input['boxes']
229
  object_positions = input['object_position']
230
- loss2 = caculate_loss_LoCo_V2(att_second,att_first,att_third, bboxes=bboxes,
231
  object_positions=object_positions, t = index1)*loss_scale
232
- # loss = loss2
233
- # loss.requires_grad_(True)
234
- #print('LoCo loss', loss)
235
-
236
-
237
-
238
- hh = torch.autograd.backward(loss2, retain_graph=True)
239
  grad_cond = x.grad
240
  x = x - grad_cond
241
  x = x.detach()
@@ -286,7 +294,7 @@ class PLMSSampler(object):
286
  def p_sample_plms(self, input, t, index, guidance_scale=1., uc=None, old_eps=None, t_next=None):
287
  x = deepcopy(input["x"])
288
  b = x.shape[0]
289
-
290
  def get_model_output(input):
291
  e_t, first, second, third,_,_,_ = self.model(input)
292
  if uc is not None and guidance_scale != 1:
@@ -335,5 +343,3 @@ class PLMSSampler(object):
335
  x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
336
 
337
  return x_prev, pred_x0, e_t
338
-
339
-
 
3
  from tqdm import tqdm
4
  from functools import partial
5
  from copy import deepcopy
 
6
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
7
  import math
8
  from ldm.models.diffusion.loss import caculate_loss_att_fixed_cnt, caculate_loss_self_att, caculate_loss_LoCo_V2
 
57
 
58
 
59
  # @torch.no_grad()
60
+ def sample(self, S, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type='LoCo'):
61
  self.make_schedule(ddim_num_steps=S)
62
  # import pdb; pdb.set_trace()
63
  return self.plms_sampling(shape, input, uc, guidance_scale, mask=mask, x0=x0, loss_type=loss_type)
64
 
65
 
66
  # @torch.no_grad()
67
+ def plms_sampling(self, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type='LoCo'):
68
 
69
  b = shape[0]
70
 
 
81
  if self.alpha_generator_func != None:
82
  alphas = self.alpha_generator_func(len(time_range))
83
 
 
84
  for i, step in enumerate(time_range):
85
 
86
  # set alpha and restore first conv layer
 
102
  # three loss types
103
  if loss_type !=None and loss_type!='standard':
104
  if input['object_position'] != []:
105
+ # if loss_type=='SAR_CAR':
106
+ # x = self.update_loss_self_cross( input,i, index, ts )
107
+ # elif loss_type=='SAR':
108
+ # x = self.update_only_self( input,i, index, ts )
109
+ # elif loss_type=='CAR':
110
+ # x = self.update_loss_only_cross( input,i, index, ts )
111
+ # elif loss_type=='LoCo':
112
+ x = self.update_loss_LoCo( input,i, index, ts, )
113
  input["x"] = x
114
  img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
115
  input["x"] = img
 
118
  old_eps.pop(0)
119
 
120
  return img
121
+
122
+ def update_loss_LoCo(self, input,index1, index, ts, type_loss='self_accross'):
123
+
124
+ # loss_scale = 30
125
+ # max_iter = 5
126
+ #print('time_factor is: ', time_factor)
127
  if index1 < 10:
128
+ loss_scale = 8
129
  max_iter = 5
130
  elif index1 < 20:
131
+ loss_scale = 5
132
+ max_iter = 5
133
  else:
134
  loss_scale = 1
135
  max_iter = 1
 
136
  loss_threshold = 0.1
137
+
138
  max_index = 30
139
  x = deepcopy(input["x"])
140
  iteration = 0
141
  loss = torch.tensor(10000)
142
  input["timesteps"] = ts
143
 
144
+ # print("optimize", index1)
145
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
146
+ # print('iter', iteration)
147
  x = x.requires_grad_(True)
148
+ # print('x shape', x.shape)
149
  input['x'] = x
150
  e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
151
+
152
  bboxes = input['boxes']
153
  object_positions = input['object_position']
154
+ loss2 = caculate_loss_LoCo_V2(att_second,att_first,att_third, bboxes=bboxes,
155
+ object_positions=object_positions, t = index1)*loss_scale
156
+ # loss = loss2
157
+ # loss.requires_grad_(True)
158
+ #print('LoCo loss', loss)
159
+
160
+
161
+
162
+ hh = torch.autograd.backward(loss2, retain_graph=True)
163
+ grad_cond = x.grad
164
+ x = x - grad_cond
165
  x = x.detach()
166
  iteration += 1
167
  torch.cuda.empty_cache()
168
  return x
169
+
170
+ def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
 
171
  if index1 < 10:
172
+ loss_scale = 4
173
+ max_iter = 1
174
  elif index1 < 20:
175
+ loss_scale = 3
176
+ max_iter = 1
177
  else:
178
  loss_scale = 1
179
  max_iter = 1
180
+
181
  loss_threshold = 0.1
 
182
  max_index = 30
183
  x = deepcopy(input["x"])
184
  iteration = 0
 
186
  input["timesteps"] = ts
187
 
188
  print("optimize", index1)
189
+ self.model.train()
190
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
191
  print('iter', iteration)
192
+ # import pdb; pdb.set_trace()
193
  x = x.requires_grad_(True)
 
194
  input['x'] = x
195
  e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
196
+ bboxes = input['boxes_att']
 
197
  object_positions = input['object_position']
198
+ loss1 = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
199
+ object_positions=object_positions, t = index1)*loss_scale
200
  loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
201
+ object_positions=object_positions, t = index1)*loss_scale
202
+ loss = loss1 + loss2
203
+ print('loss', loss, loss1, loss2)
204
+ # hh = torch.autograd.backward(loss, retain_graph=True)
205
+ grad_cond = torch.autograd.grad(loss.requires_grad_(True), [x])[0]
206
+ # grad_cond = x.grad
207
+ x = x - grad_cond
208
  x = x.detach()
209
  iteration += 1
210
+
211
+
212
  return x
213
 
214
+ def update_loss_only_cross(self, input,index1, index, ts,type_loss='self_accross'):
215
+
 
 
 
216
  if index1 < 10:
217
+ loss_scale = 3
218
  max_iter = 5
219
  elif index1 < 20:
220
+ loss_scale = 2
221
  max_iter = 5
222
  else:
223
  loss_scale = 1
224
  max_iter = 1
225
  loss_threshold = 0.1
226
+
227
  max_index = 30
228
  x = deepcopy(input["x"])
229
  iteration = 0
230
  loss = torch.tensor(10000)
231
  input["timesteps"] = ts
232
 
233
+ print("optimize", index1)
234
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
235
+ print('iter', iteration)
236
  x = x.requires_grad_(True)
 
237
  input['x'] = x
238
  e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
239
 
240
  bboxes = input['boxes']
241
  object_positions = input['object_position']
242
+ loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
243
  object_positions=object_positions, t = index1)*loss_scale
244
+ loss = loss2
245
+ print('loss', loss)
246
+ hh = torch.autograd.backward(loss)
 
 
 
 
247
  grad_cond = x.grad
248
  x = x - grad_cond
249
  x = x.detach()
 
294
  def p_sample_plms(self, input, t, index, guidance_scale=1., uc=None, old_eps=None, t_next=None):
295
  x = deepcopy(input["x"])
296
  b = x.shape[0]
297
+ self.model.eval()
298
  def get_model_output(input):
299
  e_t, first, second, third,_,_,_ = self.model(input)
300
  if uc is not None and guidance_scale != 1:
 
343
  x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
344
 
345
  return x_prev, pred_x0, e_t