satyanayak commited on
Commit
c33b978
·
1 Parent(s): f67fd7b

remove_token method removed

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -58,6 +58,10 @@ def generate_images(prompt):
58
  torch_dtype=torch.float16 if device == "cuda" else torch.float32
59
  ).to(device)
60
 
 
 
 
 
61
  for concept in concepts:
62
  try:
63
  # Download and load concept embedding
@@ -90,9 +94,10 @@ def generate_images(prompt):
90
 
91
  images.append(image)
92
 
93
- # Clear concept from pipeline
94
- pipe.tokenizer.remove_tokens([token])
95
- pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))
 
96
 
97
  except Exception as e:
98
  print(f"Error processing concept {concept}: {str(e)}")
 
58
  torch_dtype=torch.float16 if device == "cuda" else torch.float32
59
  ).to(device)
60
 
61
+ # Store original tokenizer and text encoder state
62
+ original_vocab_size = len(pipe.tokenizer)
63
+ original_embedding_weights = pipe.text_encoder.get_input_embeddings().weight.data.clone()
64
+
65
  for concept in concepts:
66
  try:
67
  # Download and load concept embedding
 
94
 
95
  images.append(image)
96
 
97
+ # Reset tokenizer and text encoder to original state
98
+ pipe.tokenizer = pipe.tokenizer.__class__.from_pretrained(model_id)
99
+ pipe.text_encoder.resize_token_embeddings(original_vocab_size)
100
+ pipe.text_encoder.get_input_embeddings().weight.data.copy_(original_embedding_weights)
101
 
102
  except Exception as e:
103
  print(f"Error processing concept {concept}: {str(e)}")