Spaces:
Running
on
Zero
Running
on
Zero
Update autoregressive/models/generate.py
Browse files
autoregressive/models/generate.py
CHANGED
@@ -137,15 +137,12 @@ def decode_n_tokens(
|
|
137 |
|
138 |
@torch.no_grad()
|
139 |
def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
|
140 |
-
condition = condition.to(torch.float32)
|
141 |
-
print(condition)
|
142 |
if condition is not None:
|
143 |
with torch.no_grad():
|
144 |
print(model.adapter.model.embeddings.patch_embeddings.projection.weight)
|
145 |
condition = model.adapter(condition)
|
146 |
print(condition)
|
147 |
condition = model.adapter_mlp(condition)
|
148 |
-
print(condition)
|
149 |
if model.model_type == 'c2i':
|
150 |
if cfg_scale > 1.0:
|
151 |
cond_null = torch.ones_like(cond) * model.num_classes
|
|
|
137 |
|
138 |
@torch.no_grad()
|
139 |
def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
|
|
|
|
|
140 |
if condition is not None:
|
141 |
with torch.no_grad():
|
142 |
print(model.adapter.model.embeddings.patch_embeddings.projection.weight)
|
143 |
condition = model.adapter(condition)
|
144 |
print(condition)
|
145 |
condition = model.adapter_mlp(condition)
|
|
|
146 |
if model.model_type == 'c2i':
|
147 |
if cfg_scale > 1.0:
|
148 |
cond_null = torch.ones_like(cond) * model.num_classes
|