Nupur Kumari commited on
Commit
61ea052
1 Parent(s): 6d6f59f

custom-diffusion-space

Browse files
Files changed (1) hide show
  1. inference.py +22 -2
inference.py CHANGED
@@ -12,7 +12,27 @@ import torch
12
  from diffusers import StableDiffusionPipeline
13
  sys.path.insert(0, 'custom-diffusion')
14
 
15
- from src import diffuser_training
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  class InferencePipeline:
@@ -48,7 +68,7 @@ class InferencePipeline:
48
  model_id, torch_dtype=torch.float16)
49
  pipe = pipe.to(self.device)
50
 
51
- diffuser_training.load_model(pipe.text_encoder, pipe.tokenizer, pipe.unet, weight_path, '<new1>')
52
 
53
  self.pipe = pipe
54
 
 
12
  from diffusers import StableDiffusionPipeline
13
  sys.path.insert(0, 'custom-diffusion')
14
 
15
+
16
+ def load_model(text_encoder, tokenizer, unet, save_path, modifier_token, freeze_model='crossattn_kv'):
17
+ logger.info("loading embeddings")
18
+ st = torch.load(save_path)
19
+ if 'text_encoder' in st:
20
+ text_encoder.load_state_dict(st['text_encoder'])
21
+ if modifier_token in st:
22
+ _ = tokenizer.add_tokens(modifier_token)
23
+ modifier_token_id = tokenizer.convert_tokens_to_ids(modifier_token)
24
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
25
+ text_encoder.resize_token_embeddings(len(tokenizer))
26
+ token_embeds = text_encoder.get_input_embeddings().weight.data
27
+ token_embeds[modifier_token_id] = st[modifier_token]
28
+ print(st.keys())
29
+ for name, params in unet.named_parameters():
30
+ if freeze_model == 'crossattn':
31
+ if 'attn2' in name:
32
+ params.data.copy_(st['unet'][f'{name}'])
33
+ else:
34
+ if 'attn2.to_k' in name or 'attn2.to_v' in name:
35
+ params.data.copy_(st['unet'][f'{name}'])
36
 
37
 
38
  class InferencePipeline:
 
68
  model_id, torch_dtype=torch.float16)
69
  pipe = pipe.to(self.device)
70
 
71
+ load_model(pipe.text_encoder, pipe.tokenizer, pipe.unet, weight_path, '<new1>')
72
 
73
  self.pipe = pipe
74