rezashkv commited on
Commit
809edc0
·
verified ·
1 Parent(s): 31b3bcc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +7 -6
README.md CHANGED
@@ -96,11 +96,13 @@ from transformers import AutoTokenizer, AutoModel
96
  import torch
97
 
98
  prompt_encoder_model_name_or_path = "sentence-transformers/all-mpnet-base-v2"
99
- prompt_encoder_tokenizer = AutoTokenizer.from_pretrained(prompt_encoder_model_name_or_path)
100
- prompt_encoder = AutoModel.from_pretrained(prompt_encoder_model_name_or_path)
101
-
102
  aptp_model_name_or_path = f"rezashkv/APTP"
103
  aptp_variant = "APTP-Base-CC3M"
 
 
 
 
 
104
  hyper_net = HyperStructure.from_pretrained(aptp_model_name_or_path, subfolder=f"{aptp_variant}/hypernet")
105
  quantizer = StructureVectorQuantizer.from_pretrained(aptp_model_name_or_path, subfolder=f"{aptp_variant}/quantizer")
106
 
@@ -110,15 +112,14 @@ prompt_embedding = get_mpnet_embeddings(prompts, prompt_encoder, prompt_encoder_
110
  arch_embedding = hyper_net(prompt_embedding)
111
  expert_id = quantizer.get_cosine_sim_min_encoding_indices(arch_embedding)[0].item()
112
 
113
- sd_model_name_or_path = "stabilityai/stable-diffusion-2-1"
114
-
115
  unet = UNet2DConditionModelPruned.from_pretrained(aptp_model_name_or_path,
116
  subfolder=f"{aptp_variant}/arch{expert_id}/checkpoint-30000/unet")
117
-
118
  noise_scheduler = PNDMScheduler.from_pretrained(sd_model_name_or_path, subfolder="scheduler")
 
119
  pipeline = StableDiffusionPipeline.from_pretrained(sd_model_name_or_path, unet=unet, scheduler=noise_scheduler)
120
 
121
  pipeline.to('cuda')
 
122
  generator = torch.Generator(device='cuda').manual_seed(43)
123
 
124
  image = pipeline(
 
96
  import torch
97
 
98
  prompt_encoder_model_name_or_path = "sentence-transformers/all-mpnet-base-v2"
 
 
 
99
  aptp_model_name_or_path = f"rezashkv/APTP"
100
  aptp_variant = "APTP-Base-CC3M"
101
+ sd_model_name_or_path = "stabilityai/stable-diffusion-2-1"
102
+
103
+ prompt_encoder = AutoModel.from_pretrained(prompt_encoder_model_name_or_path)
104
+ prompt_encoder_tokenizer = AutoTokenizer.from_pretrained(prompt_encoder_model_name_or_path)
105
+
106
  hyper_net = HyperStructure.from_pretrained(aptp_model_name_or_path, subfolder=f"{aptp_variant}/hypernet")
107
  quantizer = StructureVectorQuantizer.from_pretrained(aptp_model_name_or_path, subfolder=f"{aptp_variant}/quantizer")
108
 
 
112
  arch_embedding = hyper_net(prompt_embedding)
113
  expert_id = quantizer.get_cosine_sim_min_encoding_indices(arch_embedding)[0].item()
114
 
 
 
115
  unet = UNet2DConditionModelPruned.from_pretrained(aptp_model_name_or_path,
116
  subfolder=f"{aptp_variant}/arch{expert_id}/checkpoint-30000/unet")
 
117
  noise_scheduler = PNDMScheduler.from_pretrained(sd_model_name_or_path, subfolder="scheduler")
118
+
119
  pipeline = StableDiffusionPipeline.from_pretrained(sd_model_name_or_path, unet=unet, scheduler=noise_scheduler)
120
 
121
  pipeline.to('cuda')
122
+
123
  generator = torch.Generator(device='cuda').manual_seed(43)
124
 
125
  image = pipeline(