satyanayak commited on
Commit
d6f9ffa
·
1 Parent(s): 339f63e

fresh pipeline for each concept

Browse files
Files changed (1) hide show
  1. app.py +10 -14
app.py CHANGED
@@ -5,6 +5,7 @@ from diffusers import StableDiffusionPipeline
5
  import random
6
  from huggingface_hub import hf_hub_download
7
  import os
 
8
 
9
  # Initialize the model
10
  model_id = "CompVis/stable-diffusion-v1-4"
@@ -52,18 +53,14 @@ def generate_images(prompt):
52
  images = []
53
  failed_concepts = []
54
 
55
- # Load base pipeline
56
- pipe = StableDiffusionPipeline.from_pretrained(
57
- model_id,
58
- torch_dtype=torch.float16 if device == "cuda" else torch.float32
59
- ).to(device)
60
-
61
- # Store original tokenizer and text encoder state
62
- original_vocab_size = len(pipe.tokenizer)
63
- original_embedding_weights = pipe.text_encoder.get_input_embeddings().weight.data.clone()
64
-
65
  for concept in concepts:
66
  try:
 
 
 
 
 
 
67
  # Download and load concept embedding
68
  embed_path = download_concept_embedding(concept)
69
  if embed_path is None:
@@ -94,10 +91,9 @@ def generate_images(prompt):
94
 
95
  images.append(image)
96
 
97
- # Reset tokenizer and text encoder to original state
98
- pipe.tokenizer = pipe.tokenizer.__class__.from_pretrained(model_id)
99
- pipe.text_encoder.resize_token_embeddings(original_vocab_size)
100
- pipe.text_encoder.get_input_embeddings().weight.data.copy_(original_embedding_weights)
101
 
102
  except Exception as e:
103
  print(f"Error processing concept {concept}: {str(e)}")
 
5
  import random
6
  from huggingface_hub import hf_hub_download
7
  import os
8
+ from transformers import CLIPTokenizer
9
 
10
  # Initialize the model
11
  model_id = "CompVis/stable-diffusion-v1-4"
 
53
  images = []
54
  failed_concepts = []
55
 
 
 
 
 
 
 
 
 
 
 
56
  for concept in concepts:
57
  try:
58
+ # Create a fresh pipeline for each concept
59
+ pipe = StableDiffusionPipeline.from_pretrained(
60
+ model_id,
61
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
62
+ ).to(device)
63
+
64
  # Download and load concept embedding
65
  embed_path = download_concept_embedding(concept)
66
  if embed_path is None:
 
91
 
92
  images.append(image)
93
 
94
+ # Clean up to free memory
95
+ del pipe
96
+ torch.cuda.empty_cache()
 
97
 
98
  except Exception as e:
99
  print(f"Error processing concept {concept}: {str(e)}")