SeemG commited on
Commit
6f85631
·
verified ·
1 Parent(s): 2a51e24

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +3 -1
utils.py CHANGED
@@ -275,7 +275,9 @@ def gen_image_as_per_prompt(prompt, style, seed, custom_loss=None):
275
  # replacement_token_embedding = text_encoder.get_input_embeddings()(torch.tensor(2368, device=torch_device))
276
 
277
  # Insert this into the token embeddings (
278
- token_embeddings[0, torch.where(input_ids[0] == 6829)] = replacement_token_embedding.to(torch_device)
 
 
279
 
280
  # get pos embed
281
  pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
 
275
  # replacement_token_embedding = text_encoder.get_input_embeddings()(torch.tensor(2368, device=torch_device))
276
 
277
  # Insert this into the token embeddings (
278
+ indices = torch.where(input_ids[0] == 6829)[0] # Extract indices where the condition is true
279
+ if indices.numel() > 0: # Check if any indices are found
280
+ token_embeddings[0, indices] = replacement_token_embedding.to(torch_device)
281
 
282
  # get pos embed
283
  pos_emb_layer = text_encoder.text_model.embeddings.position_embedding