satyanayak commited on
Commit
6df8da7
·
1 Parent(s): 18adfb4

fixing the path of concept model's bin file

Browse files
Files changed (2) hide show
  1. app.py +53 -29
  2. requirements.txt +2 -1
app.py CHANGED
@@ -3,6 +3,8 @@ import torch
3
  from torch import autocast
4
  from diffusers import StableDiffusionPipeline
5
  import random
 
 
6
 
7
  # Initialize the model
8
  model_id = "CompVis/stable-diffusion-v1-4"
@@ -17,6 +19,19 @@ concepts = [
17
  "sd-concepts-library/wu-shi"
18
  ]
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer):
21
  loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
22
 
@@ -43,36 +58,45 @@ def generate_images(prompt):
43
  ).to(device)
44
 
45
  for concept in concepts:
46
- # Load concept embedding
47
- token = load_learned_embed_in_clip(
48
- f"{concept}/blob/main/learned_embeds.bin",
49
- pipe.text_encoder,
50
- pipe.tokenizer
51
- )
52
-
53
- # Generate random seed
54
- seed = random.randint(1, 999999)
55
- generator = torch.Generator(device=device).manual_seed(seed)
56
-
57
- # Add concept token to prompt
58
- concept_prompt = f"{token} {prompt}"
59
-
60
- # Generate image
61
- with autocast(device):
62
- image = pipe(
63
- concept_prompt,
64
- num_inference_steps=50,
65
- generator=generator,
66
- guidance_scale=7.5
67
- ).images[0]
68
-
69
- images.append(image)
70
-
71
- # Clear concept from pipeline
72
- pipe.tokenizer.remove_tokens([token])
73
- pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))
 
 
 
 
 
 
 
 
 
74
 
75
- return images
76
 
77
  # Create Gradio interface
78
  iface = gr.Interface(
 
3
  from torch import autocast
4
  from diffusers import StableDiffusionPipeline
5
  import random
6
+ from huggingface_hub import hf_hub_download
7
+ import os
8
 
9
  # Initialize the model
10
  model_id = "CompVis/stable-diffusion-v1-4"
 
19
  "sd-concepts-library/wu-shi"
20
  ]
21
 
22
+ def download_concept_embedding(concept_name):
23
+ try:
24
+ # Download the learned_embeds.bin file from the Hub
25
+ embed_path = hf_hub_download(
26
+ repo_id=concept_name,
27
+ filename="learned_embeds.bin",
28
+ repo_type="model"
29
+ )
30
+ return embed_path
31
+ except Exception as e:
32
+ print(f"Error downloading {concept_name}: {str(e)}")
33
+ return None
34
+
35
  def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer):
36
  loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
37
 
 
58
  ).to(device)
59
 
60
  for concept in concepts:
61
+ # Download and load concept embedding
62
+ embed_path = download_concept_embedding(concept)
63
+ if embed_path is None:
64
+ continue
65
+
66
+ try:
67
+ token = load_learned_embed_in_clip(
68
+ embed_path,
69
+ pipe.text_encoder,
70
+ pipe.tokenizer
71
+ )
72
+
73
+ # Generate random seed
74
+ seed = random.randint(1, 999999)
75
+ generator = torch.Generator(device=device).manual_seed(seed)
76
+
77
+ # Add concept token to prompt
78
+ concept_prompt = f"{token} {prompt}"
79
+
80
+ # Generate image
81
+ with autocast(device):
82
+ image = pipe(
83
+ concept_prompt,
84
+ num_inference_steps=50,
85
+ generator=generator,
86
+ guidance_scale=7.5
87
+ ).images[0]
88
+
89
+ images.append(image)
90
+
91
+ # Clear concept from pipeline
92
+ pipe.tokenizer.remove_tokens([token])
93
+ pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))
94
+
95
+ except Exception as e:
96
+ print(f"Error processing concept {concept}: {str(e)}")
97
+ continue
98
 
99
+ return images if images else [None] * 5
100
 
101
  # Create Gradio interface
102
  iface = gr.Interface(
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  torch
2
  diffusers
3
  transformers
4
- gradio
 
 
1
  torch
2
  diffusers
3
  transformers
4
+ gradio
5
+ huggingface_hub