lmzjms commited on
Commit
a0f9d55
·
1 Parent(s): c2c1dca

Update ldm/modules/encoders/modules.py

Browse files
Files changed (1) hide show
  1. ldm/modules/encoders/modules.py +37 -0
ldm/modules/encoders/modules.py CHANGED
@@ -310,5 +310,42 @@ class FrozenFLANEmbedder(AbstractEncoder):
310
  z = outputs.last_hidden_state
311
  return z
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  def encode(self, text):
314
  return self(text)
 
310
  z = outputs.last_hidden_state
311
  return z
312
 
313
+ def encode(self, text):
314
+ return self(text)
315
+
316
+ class FrozenGlobalNormOpenCLIPEmbedder(AbstractEncoder):
317
+ """
318
+ Uses the OpenCLIP transformer encoder for text
319
+ """
320
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", freeze=True, delvisual=True):
321
+ super().__init__()
322
+ model, _, preprocess = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
323
+ if delvisual:
324
+ del model.visual
325
+ del preprocess
326
+ else:
327
+ self.preprocess = preprocess
328
+ self.model = model
329
+
330
+ self.device = device
331
+ if freeze:
332
+ self.freeze()
333
+
334
+ def freeze(self):
335
+ self.model = self.model.eval()
336
+ for param in self.parameters():
337
+ param.requires_grad = False
338
+
339
+ def forward(self, text):
340
+ tokens = open_clip.tokenize(text)
341
+ z = self.model.encode_text(tokens.to(self.device))
342
+ z /= z.norm(dim=-1, keepdim=True)
343
+ return z.unsqueeze(1)
344
+
345
+ def forward_img(self, image):
346
+ z = self.model.encode_image(image.to(self.device))
347
+ z /= z.norm(dim=-1, keepdim=True)
348
+ return z.unsqueeze(1)
349
+
350
  def encode(self, text):
351
  return self(text)