wondervictor commited on
Commit
24f7cff
·
verified ·
1 Parent(s): ca4dba4

Update autoregressive/models/generate.py

Browse files
Files changed (1) hide show
  1. autoregressive/models/generate.py +0 -3
autoregressive/models/generate.py CHANGED
@@ -137,15 +137,12 @@ def decode_n_tokens(
137
 
138
  @torch.no_grad()
139
  def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
140
- condition = condition.to(torch.float32)
141
- print(condition)
142
  if condition is not None:
143
  with torch.no_grad():
144
  print(model.adapter.model.embeddings.patch_embeddings.projection.weight)
145
  condition = model.adapter(condition)
146
  print(condition)
147
  condition = model.adapter_mlp(condition)
148
- print(condition)
149
  if model.model_type == 'c2i':
150
  if cfg_scale > 1.0:
151
  cond_null = torch.ones_like(cond) * model.num_classes
 
137
 
138
  @torch.no_grad()
139
  def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
 
 
140
  if condition is not None:
141
  with torch.no_grad():
142
  print(model.adapter.model.embeddings.patch_embeddings.projection.weight)
143
  condition = model.adapter(condition)
144
  print(condition)
145
  condition = model.adapter_mlp(condition)
 
146
  if model.model_type == 'c2i':
147
  if cfg_scale > 1.0:
148
  cond_null = torch.ones_like(cond) * model.num_classes